Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Look into using {butcher} to reduce fitted object size #302

Open
dshemetov opened this issue Mar 28, 2024 · 0 comments
Open

Look into using {butcher} to reduce fitted object size #302

dshemetov opened this issue Mar 28, 2024 · 0 comments

Comments

@dshemetov
Copy link
Contributor

dshemetov commented Mar 28, 2024

Aside: if we are super concerned with space (not clear we are at the moment, seems a "nice to have" rather than a "mandatory for adding new features"), we may want to investigate ways to use {butcher} for existing workflows.

Originally posted by @dajmcdon in #293 (comment)

A bit strange that the lm fit object here is way larger than the training dataset. I'm guessing there's per-quantile duplication going on here.

``` r
library(epipredict)
#> Loading required package: epiprocess
#> 
#> Attaching package: 'epiprocess'
#> The following object is masked from 'package:stats':
#> 
#>     filter
#> Loading required package: parsnip


# Basic fitting example
jhu <- case_death_rate_subset
r <- epi_recipe(jhu) %>%
  step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
  step_epi_ahead(death_rate, ahead = 7) %>%
  step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
  step_epi_naomit()
wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
latest <- jhu %>% dplyr::filter(time_value >= max(time_value) - 14)
preds <- predict(wf, latest)

# Recursively apply a function and flatten the result
apply_nested_flatten <- function(nested_list, func, depth = 1) {
  lapply(nested_list, function(item) {
    if (is.list(item) && depth > 1) {
      apply_nested_flatten(item, func, depth - 1)
    } else {
      func(item)
    }
  }) %>% `names<-`(names(nested_list)) %>% purrr::list_flatten()
}

# Get a sense for the object sizes
lobstr::obj_sizes(!!!list(jhu = jhu, r = r, wf = wf, latest = latest, preds = preds))
#> jhu   : 661.51 kB
#> r     :  16.74 kB
#> wf    :   4.55 MB
#> latest:  27.70 kB
#> preds :   2.26 kB
# Inspect the wf object, note that the fit object is the largest
lobstr::obj_sizes(!!!apply_nested_flatten(wf, function(x) x, 3))
#> pre_actions_recipe     :  20.75 kB
#> pre_mold_predictors    : 928.59 kB
#> pre_mold_outcomes      : 155.10 kB
#> pre_mold_blueprint     :  14.94 kB
#> pre_mold_extras_roles  : 622.83 kB
#> pre_case_weights       :       0 B
#> fit_actions_model      :   2.58 kB
#> fit_fit_lvl            :       0 B
#> fit_fit_spec           :  19.69 kB
#> fit_fit_fit            :   2.80 MB
#> fit_fit_preproc_x_var  :       0 B
#> fit_fit_preproc_y_var  :       0 B
#> fit_fit_elapsed_elapsed:      56 B
#> fit_meta_max_time_value:     112 B
#> fit_meta_as_of         :       0 B
#> trained                :       0 B
# Go deeper
lobstr::obj_sizes(!!!apply_nested_flatten(wf, function(x) x, 4))
#> pre_actions_recipe_recipe                      :  18.54 kB
#> pre_actions_recipe_blueprint                   :   1.65 kB
#> pre_mold_predictors_lag_0_death_rate           : 154.61 kB
#> pre_mold_predictors_lag_7_death_rate           : 154.61 kB
#> pre_mold_predictors_lag_14_death_rate          : 154.61 kB
#> pre_mold_predictors_lag_0_case_rate            : 154.61 kB
#> pre_mold_predictors_lag_7_case_rate            : 154.61 kB
#> pre_mold_predictors_lag_14_case_rate           : 154.61 kB
#> pre_mold_outcomes_ahead_7_death_rate           : 154.61 kB
#> pre_mold_blueprint_intercept                   :       0 B
#> pre_mold_blueprint_allow_novel_levels          :       0 B
#> pre_mold_blueprint_composition                 :       0 B
#> pre_mold_blueprint_ptypes_predictors           :     392 B
#> pre_mold_blueprint_ptypes_outcomes             :     296 B
#> pre_mold_blueprint_fresh                       :       0 B
#> pre_mold_blueprint_strings_as_factors          :       0 B
#> pre_mold_blueprint_recipe                      :  12.58 kB
#> pre_mold_blueprint_extra_role_ptypes_time_value:     464 B
#> pre_mold_blueprint_extra_role_ptypes_geo_value :     408 B
#> pre_mold_blueprint_extra_role_ptypes_raw       :     472 B
#> pre_mold_extras_roles_time_value               : 155.08 kB
#> pre_mold_extras_roles_geo_value                : 157.97 kB
#> pre_mold_extras_roles_raw                      : 309.57 kB
#> pre_case_weights                               :       0 B
#> fit_actions_model_spec                         :   2.08 kB
#> fit_actions_model_formula                      :       0 B
#> fit_fit_lvl                                    :       0 B
#> fit_fit_spec_args_penalty                      :       0 B
#> fit_fit_spec_args_mixture                      :       0 B
#> fit_fit_spec_mode                              :       0 B
#> fit_fit_spec_user_specified_mode               :       0 B
#> fit_fit_spec_method_libs                       :     112 B
#> fit_fit_spec_method_fit                        :   1.52 kB
#> fit_fit_spec_method_pred                       :  17.37 kB
#> fit_fit_spec_engine                            :       0 B
#> fit_fit_spec_user_specified_engine             :       0 B
#> fit_fit_fit_coefficients                       :     344 B
#> fit_fit_fit_residuals                          : 155.98 kB
#> fit_fit_fit_effects                            : 309.27 kB
#> fit_fit_fit_rank                               :      56 B
#> fit_fit_fit_fitted.values                      : 154.66 kB
#> fit_fit_fit_assign                             :      80 B
#> fit_fit_fit_qr                                 :   1.08 MB
#> fit_fit_fit_df.residual                        :      56 B
#> fit_fit_fit_call                               :   7.92 kB
#> fit_fit_fit_terms                              :   4.55 kB
#> fit_fit_fit_model                              :   1.08 MB
#> fit_fit_preproc_x_var                          :       0 B
#> fit_fit_preproc_y_var                          :      56 B
#> fit_fit_elapsed_elapsed                        :      56 B
#> fit_meta_max_time_value                        :     112 B
#> fit_meta_as_of                                 :       0 B
#> trained                                        :       0 B

# Use butcher to reduce the memory
small_lm <- butcher::butcher(wf$fit$fit$fit, verbose = TRUE)
#> ✔ Memory released: 1.24 MB
#> ✖ Disabled: `print()`, `summary()`, and `fitted()`
butcher::weigh(wf$fit$fit$fit) %>% print(n=20)
#> # A tibble: 21 × 2
#>    object                      size
#>    <chr>                      <dbl>
#>  1 terms                   1.10    
#>  2 call                    1.09    
#>  3 qr.qr                   1.08    
#>  4 effects                 0.310   
#>  5 residuals               0.156   
#>  6 fitted.values           0.156   
#>  7 model...y               0.155   
#>  8 model.lag_0_death_rate  0.155   
#>  9 model.lag_7_death_rate  0.155   
#> 10 model.lag_14_death_rate 0.155   
#> 11 model.lag_0_case_rate   0.155   
#> 12 model.lag_7_case_rate   0.155   
#> 13 model.lag_14_case_rate  0.155   
#> 14 coefficients            0.000848
#> 15 qr.qraux                0.000112
#> 16 assign                  0.00008 
#> 17 qr.pivot                0.00008 
#> 18 rank                    0.000056
#> 19 qr.tol                  0.000056
#> 20 qr.rank                 0.000056
#> # ℹ 1 more row
# Still a lot of memory used, even after butcher's cleanup
butcher::weigh(small_lm) %>% print(n=20)
#> # A tibble: 21 × 2
#>    object                      size
#>    <chr>                      <dbl>
#>  1 qr.qr                   1.08    
#>  2 effects                 0.310   
#>  3 residuals               0.156   
#>  4 model...y               0.155   
#>  5 model.lag_0_death_rate  0.155   
#>  6 model.lag_7_death_rate  0.155   
#>  7 model.lag_14_death_rate 0.155   
#>  8 model.lag_0_case_rate   0.155   
#>  9 model.lag_7_case_rate   0.155   
#> 10 model.lag_14_case_rate  0.155   
#> 11 terms                   0.00570 
#> 12 coefficients            0.000848
#> 13 qr.qraux                0.000112
#> 14 call                    0.000112
#> 15 assign                  0.00008 
#> 16 qr.pivot                0.00008 
#> 17 rank                    0.000056
#> 18 qr.tol                  0.000056
#> 19 qr.rank                 0.000056
#> 20 df.residual             0.000056
#> # ℹ 1 more row

Created on 2024-03-28 with reprex v2.0.2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant