Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature request: add finalization method for workflow sets #164

Open
jkylearmstrong opened this issue Oct 7, 2024 · 4 comments
Open

feature request: add finalization method for workflow sets #164

jkylearmstrong opened this issue Oct 7, 2024 · 4 comments

Comments

@jkylearmstrong
Copy link

I think it would be helpful to add a function to finalize all the workflows, so far something like this seems to work pretty well

#' Finalize Parameter Grid for a Single Workflow
#'
#' @param workflow_id workflow id
#' @param workflow_sets workflow sets
#' @param data data frame to use for \code{finalize}
#' @param ... not used
#'
#' @return updated workflow set with parameter grid added to the workflow

.finalize_workflow_set <- function(workflow_id, workflow_sets, data, ...) {
  param_set <- workflow_sets %>%
    workflowsets::extract_parameter_set_dials(id = workflow_id)

  if (nrow(param_set) > 0) {
    # Finalize the parameter set
    finalized_param_set <- param_set %>%
      dials::finalize(data)

    workflow_sets_with_finalized_params <- workflow_sets %>%
      workflowsets::option_add(param_info = finalized_param_set, id = workflow_id) %>%
      dplyr::filter(wflow_id == workflow_id)

    return(workflow_sets_with_finalized_params)
  } else {
    tibble::tibble()  # Return an empty tibble if no parameters
  }
}
#' Finalize the workflowset
#'
#' @param x workflow set
#' @param data data frame to use for finalization
#' @param ... additional arguments passed to \code{.finalize_workflow_set}
#'
#' @return updated workflow set with parameter grid added to each workflow

finalize_workflow_set <- function(x, data, ...){
  purrr::map(
    purrr::set_names(x$wflow_id),
    \(z) {
      .finalize_workflow_set(workflow_id = z, workflow_sets = x, data = data, ...)
    }
      ) %>%
    purrr::list_rbind()
}
@simonpcouch
Copy link
Contributor

Related to #45.

What would you like to do with the list of finalized models? The use cases I imagine here all involve some sort of model selection, which should ideally be carried out using resampled models.

i.e. the flow we recommend is:

  1. Resample models
    2a) Choose one workflow configuration to finalize
    3a) Fit it on the entire training set

Rather than:

  1. Resample models
    2b) Choose a configuration of each workflow to finalize
    3b) Presumably fit all of them on the entire training set
    4b) ...

As there's no way to evaluate the list of models resulting from 3b) as the test set can only be used once.

@jkylearmstrong
Copy link
Author

jkylearmstrong commented Oct 8, 2024

library('tidymodels')
library('workflowsets')
tidymodels_prefer()
data(parabolic)
parabolic <- parabolic 
str(parabolic)
#> tibble [500 × 3] (S3: tbl_df/tbl/data.frame)
#>  $ X1   : num [1:500] 3.29 1.47 1.66 1.6 2.17 ...
#>  $ X2   : num [1:500] 1.661 0.414 0.791 0.276 3.166 ...
#>  $ class: Factor w/ 2 levels "Class1","Class2": 1 2 2 2 1 1 2 1 2 1 ...
set.seed(1)
split <- initial_split(parabolic)

#train_set <- training(split)
#test_set <- testing(split)
rec <- recipe(class ~ ., data = training(split)) %>%
  step_interact(terms = ~ X1:X2) 
bake(prep(rec), new_data = training(split))
#> # A tibble: 375 × 4
#>         X1     X2 class  X1_x_X2
#>      <dbl>  <dbl> <fct>    <dbl>
#>  1  1.17    0.627 Class2  0.733 
#>  2 -0.769  -1.29  Class2  0.993 
#>  3  1.17    1.02  Class1  1.20  
#>  4  0.510  -2.10  Class2 -1.07  
#>  5  1.38   -0.974 Class2 -1.34  
#>  6 -0.0549 -1.77  Class2  0.0972
#>  7  0.703   1.24  Class1  0.868 
#>  8  1.50    0.418 Class2  0.625 
#>  9 -0.219  -3.08  Class2  0.675 
#> 10  0.606  -0.960 Class2 -0.582 
#> # ℹ 365 more rows
rec_norm <- rec %>%
  step_normalize(all_numeric_predictors())
rec_pca <- rec_norm %>%
  step_pca(all_numeric_predictors(), 
           num_comp = tune()
           )

library('embed')

rec_umap <- rec_norm %>%
  step_umap(all_numeric_predictors(), 
            outcome = "class",
            num_comp = tune(),
            neighbors = tune(),
            min_dist = tune()
            )
library('discrim')

mars_disc_spec <- 
  discrim_flexible(prod_degree = tune()) %>% 
  set_engine("earth")

reg_disc_sepc <- 
  discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) %>% 
  set_engine("klaR")

cart_spec <- 
  decision_tree(cost_complexity = tune(), min_n = tune()) %>% 
  set_engine("rpart") %>% 
  set_mode("classification")

xgboost_spec <-
  boost_tree(
    mtry = tune(),
    trees = tune(),
    min_n = tune(),
    tree_depth = tune(),
    learn_rate = tune(),
    loss_reduction = tune(),
    sample_size = tune(),
    stop_iter = tune()
  ) %>%
  set_engine("xgboost") %>%
  set_mode("classification")
set.seed(2)
folds <- vfold_cv(training(split), v = 5)
all_workflows <- 
  workflow_set(
    preproc = list(rec = rec,
                   rec_norm = rec_norm,
                   rec_pca = rec_pca,
                   rec_umap = rec_umap),
    models = list(regularized = reg_disc_sepc, 
                  mars = mars_disc_spec, 
                  cart = cart_spec,
                  xgboost_spec = xgboost_spec)
  )

@simonpcouch see error below:

all_workflows_res <- 
  all_workflows %>% 
  workflow_map(resamples = folds, 
               verbose = TRUE,
               control = control_grid(
                 save_pred = TRUE,
                 parallel_over = "everything",
                 save_workflow = TRUE)
               )
#> i  1 of 16 tuning:     rec_regularized
#> ✔  1 of 16 tuning:     rec_regularized (9.8s)
#> i  2 of 16 tuning:     rec_mars
#> ✔  2 of 16 tuning:     rec_mars (1s)
#> i  3 of 16 tuning:     rec_cart
#> ✔  3 of 16 tuning:     rec_cart (3.1s)
#> i  4 of 16 tuning:     rec_xgboost_spec
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ✔  4 of 16 tuning:     rec_xgboost_spec (26.2s)
#> i  5 of 16 tuning:     rec_norm_regularized
#> ✔  5 of 16 tuning:     rec_norm_regularized (10.1s)
#> i  6 of 16 tuning:     rec_norm_mars
#> ✔  6 of 16 tuning:     rec_norm_mars (821ms)
#> i  7 of 16 tuning:     rec_norm_cart
#> ✔  7 of 16 tuning:     rec_norm_cart (3.5s)
#> i  8 of 16 tuning:     rec_norm_xgboost_spec
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ✔  8 of 16 tuning:     rec_norm_xgboost_spec (26.6s)
#> i  9 of 16 tuning:     rec_pca_regularized
#> ✔  9 of 16 tuning:     rec_pca_regularized (10.3s)
#> i 10 of 16 tuning:     rec_pca_mars
#> ✔ 10 of 16 tuning:     rec_pca_mars (3.5s)
#> i 11 of 16 tuning:     rec_pca_cart
#> ✔ 11 of 16 tuning:     rec_pca_cart (5s)
#> i 12 of 16 tuning:     rec_pca_xgboost_spec
#> ✖ 12 of 16 tuning:     rec_pca_xgboost_spec failed with: Error in check_parameters(workflow, pset = pset, data = resamples$splits[[1]]$data,  :   Some model parameters require finalization but there are recipe parameters that require tuning. Please use  `extract_parameter_set_dials()` to set parameter ranges  manually and supply the output to the `param_info` argument.
#> i 13 of 16 tuning:     rec_umap_regularized
#> ✔ 13 of 16 tuning:     rec_umap_regularized (1m 32.6s)
#> i 14 of 16 tuning:     rec_umap_mars
#> ✔ 14 of 16 tuning:     rec_umap_mars (1m 28.6s)
#> i 15 of 16 tuning:     rec_umap_cart
#> ✔ 15 of 16 tuning:     rec_umap_cart (1m 27s)
#> i 16 of 16 tuning:     rec_umap_xgboost_spec
#> ✖ 16 of 16 tuning:     rec_umap_xgboost_spec failed with: Error in check_parameters(workflow, pset = pset, data = resamples$splits[[1]]$data,  :   Some model parameters require finalization but there are recipe parameters that require tuning. Please use  `extract_parameter_set_dials()` to set parameter ranges  manually and supply the output to the `param_info` argument.
all_workflows_res %>%
  autoplot()
all_workflows_res
#> # A workflow set/tibble: 16 × 4
#>    wflow_id              info             option    result        
#>    <chr>                 <list>           <list>    <list>        
#>  1 rec_regularized       <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  2 rec_mars              <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  3 rec_cart              <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  4 rec_xgboost_spec      <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  5 rec_norm_regularized  <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  6 rec_norm_mars         <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  7 rec_norm_cart         <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  8 rec_norm_xgboost_spec <tibble [1 × 4]> <opts[2]> <tune[+]>     
#>  9 rec_pca_regularized   <tibble [1 × 4]> <opts[2]> <tune[+]>     
#> 10 rec_pca_mars          <tibble [1 × 4]> <opts[2]> <tune[+]>     
#> 11 rec_pca_cart          <tibble [1 × 4]> <opts[2]> <tune[+]>     
#> 12 rec_pca_xgboost_spec  <tibble [1 × 4]> <opts[2]> <try-errr [1]>
#> 13 rec_umap_regularized  <tibble [1 × 4]> <opts[2]> <tune[+]>     
#> 14 rec_umap_mars         <tibble [1 × 4]> <opts[2]> <tune[+]>     
#> 15 rec_umap_cart         <tibble [1 × 4]> <opts[2]> <tune[+]>     
#> 16 rec_umap_xgboost_spec <tibble [1 × 4]> <opts[2]> <try-errr [1]>

@simonpcouch - The workflows with the errors are the models with contain tuning parameters for both PCA / UMAP as well as xgboost / random forest - basically when the preprocessor needs to determine num_comp the number of components which is the new number of columns which goes into the mtry tuning parameter.

# filter to only workflows that have results 

all_workflows_res <- all_workflows_res %>%
  filter(! map_lgl(result, ~ inherits(.x, "try-error"))) %>%
  filter(! map_lgl(result, ~ identical(.x, list())))

@simonpcouch - on a slightly different topic, one issue with auto-plot is that it defaults to using the data stored in 'info' column:

all_workflows_res %>%
  autoplot(metric = "accuracy") 

@simonpcouch - so we have to do a remapping into the info column to see the different preprocessor/model combinations

all_workflows_res <- all_workflows_res %>%
  mutate(info = map2(info, wflow_id, function(info, wflow_id) {
    info %>%
      mutate(preproc = 
               case_when(
                 stringr::str_detect(wflow_id, "rec_norm_") ~ "norm",
                 stringr::str_detect(wflow_id, "rec_pca_") ~ "PCA",
                 stringr::str_detect(wflow_id, "rec_umap_") ~ "UMAP",
                 TRUE ~ "recipe"
                 )
           )
  }))
all_workflows_res %>%
  autoplot(metric = "accuracy") +
  facet_grid(~preprocessor) +
  theme(legend.position = "bottom") + 
  guides(
    color = guide_legend(ncol = 2),
    shape = guide_legend(ncol = 2)
         )

here are the functions to address the unfinalized workflows and create a new workflow set:

#' Finalize Parameter Grid for a Single Workflow

#' @param workflow_id workflow id
#' @param workflow_sets workflow sets
#' @param data data frame to use for
#' @param … not used

#'  @return updated workflow set with parameter grid added to the workflow


.finalize_workflow_set <- function(workflow_id, workflow_sets, data, ...) {
  param_set <- workflow_sets %>%
    workflowsets::extract_parameter_set_dials(id = workflow_id)

  if (nrow(param_set) > 0) {
    # Finalize the parameter set
    finalized_param_set <- param_set %>%
      dials::finalize(data)

    workflow_sets_with_finalized_params <- workflow_sets %>%
      workflowsets::option_add(param_info = finalized_param_set, id = workflow_id) %>%
      dplyr::filter(wflow_id == workflow_id)

    return(workflow_sets_with_finalized_params)
  } else {
    tibble::tibble()  # Return an empty tibble if no parameters
  }
}
#' Finalize the workflowset

#' @param x workflow set
#' @param data data frame to use for finalization
#' @param … additional arguments passed to

#' @return updated workflow set with parameter grid added to each workflow


finalize_workflow_set <- function(x, data, ...){
  purrr::map(
    purrr::set_names(x$wflow_id),
    \(z) {
      .finalize_workflow_set(workflow_id = z, workflow_sets = x, data = data, ...)
    }
      ) %>%
    purrr::list_rbind()
}
all_workflows <- all_workflows %>%
  finalize_workflow_set(training(split))

Now we can run over all the model combinations:

all_workflows_res <- 
  all_workflows %>% 
  workflow_map(resamples = folds, 
               verbose = TRUE,
               control = control_grid(
                 save_pred = TRUE,
                 parallel_over = "everything",
                 save_workflow = TRUE)
               )
#> i  1 of 16 tuning:     rec_regularized
#> ✔  1 of 16 tuning:     rec_regularized (9.8s)
#> i  2 of 16 tuning:     rec_mars
#> ✔  2 of 16 tuning:     rec_mars (710ms)
#> i  3 of 16 tuning:     rec_cart
#> ✔  3 of 16 tuning:     rec_cart (3.1s)
#> i  4 of 16 tuning:     rec_xgboost_spec
#> ✔  4 of 16 tuning:     rec_xgboost_spec (34.7s)
#> i  5 of 16 tuning:     rec_norm_regularized
#> ✔  5 of 16 tuning:     rec_norm_regularized (10.7s)
#> i  6 of 16 tuning:     rec_norm_mars
#> ✔  6 of 16 tuning:     rec_norm_mars (938ms)
#> i  7 of 16 tuning:     rec_norm_cart
#> ✔  7 of 16 tuning:     rec_norm_cart (3.7s)
#> i  8 of 16 tuning:     rec_norm_xgboost_spec
#> ✔  8 of 16 tuning:     rec_norm_xgboost_spec (30.5s)
#> i  9 of 16 tuning:     rec_pca_regularized
#> ✔  9 of 16 tuning:     rec_pca_regularized (10.9s)
#> i 10 of 16 tuning:     rec_pca_mars
#> ✔ 10 of 16 tuning:     rec_pca_mars (4s)
#> i 11 of 16 tuning:     rec_pca_cart
#> ✔ 11 of 16 tuning:     rec_pca_cart (5s)
#> i 12 of 16 tuning:     rec_pca_xgboost_spec
#> ✔ 12 of 16 tuning:     rec_pca_xgboost_spec (36.4s)
#> i 13 of 16 tuning:     rec_umap_regularized
#> ✔ 13 of 16 tuning:     rec_umap_regularized (1m 35.8s)
#> i 14 of 16 tuning:     rec_umap_mars
#> ✔ 14 of 16 tuning:     rec_umap_mars (1m 31.9s)
#> i 15 of 16 tuning:     rec_umap_cart
#> ✔ 15 of 16 tuning:     rec_umap_cart (1m 31.6s)
#> i 16 of 16 tuning:     rec_umap_xgboost_spec
#> ✔ 16 of 16 tuning:     rec_umap_xgboost_spec (1m 53.4s)
# we have to do a remapping into the `info` column to address:

all_workflows_res <- all_workflows_res %>%
  mutate(info = map2(info, wflow_id, function(info, wflow_id) {
    info %>%
      mutate(preproc = 
               case_when(
                 stringr::str_detect(wflow_id, "rec_norm_") ~ "norm",
                 stringr::str_detect(wflow_id, "rec_pca_") ~ "PCA",
                 stringr::str_detect(wflow_id, "rec_umap_") ~ "UMAP",
                 TRUE ~ "recipe"
                 )
           )
  }))
all_workflows_res %>%
  autoplot(metric = "accuracy") +
  facet_grid(~preprocessor) +
  theme(legend.position = "bottom") + 
  guides(
    color = guide_legend(ncol = 2),
    shape = guide_legend(ncol = 2)
         )

sessionInfo()
#> R version 4.4.1 (2024-06-14)
#> Platform: x86_64-pc-linux-gnu
#> Running under: Ubuntu 20.04.6 LTS
#> 
#> Matrix products: default
#> BLAS/LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.8.so;  LAPACK version 3.9.0
#> 
#> locale:
#>  [1] LC_CTYPE=C.UTF-8       LC_NUMERIC=C           LC_TIME=C.UTF-8       
#>  [4] LC_COLLATE=C.UTF-8     LC_MONETARY=C.UTF-8    LC_MESSAGES=C.UTF-8   
#>  [7] LC_PAPER=C.UTF-8       LC_NAME=C              LC_ADDRESS=C          
#> [10] LC_TELEPHONE=C         LC_MEASUREMENT=C.UTF-8 LC_IDENTIFICATION=C   
#> 
#> time zone: UTC
#> tzcode source: system (glibc)
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] uwot_0.2.2         Matrix_1.7-0       xgboost_1.7.8.1    rpart_4.1.23      
#>  [5] earth_5.3.4        plotmo_3.6.4       plotrix_3.8-4      Formula_1.2-5     
#>  [9] mda_0.5-4          class_7.3-22       klaR_1.7-3         MASS_7.3-60.2     
#> [13] discrim_1.0.1      embed_1.1.4        yardstick_1.3.1    workflowsets_1.1.0
#> [17] workflows_1.1.4    tune_1.2.1         tidyr_1.3.1        tibble_3.2.1      
#> [21] rsample_1.2.1      recipes_1.1.0      purrr_1.0.2        parsnip_1.2.1     
#> [25] modeldata_1.4.0    infer_1.0.7        ggplot2_3.5.1      dplyr_1.1.4       
#> [29] dials_1.3.0        scales_1.3.0       broom_1.0.7        tidymodels_1.2.0  
#> 
#> loaded via a namespace (and not attached):
#>  [1] conflicted_1.2.0    rlang_1.1.4         magrittr_2.0.3     
#>  [4] furrr_0.3.1         RcppAnnoy_0.0.22    compiler_4.4.1     
#>  [7] vctrs_0.6.5         combinat_0.0-8      stringr_1.5.1      
#> [10] lhs_1.2.0           pkgconfig_2.0.3     fastmap_1.2.0      
#> [13] backports_1.5.0     labeling_0.4.3      utf8_1.2.4         
#> [16] promises_1.3.0      rmarkdown_2.28      prodlim_2024.06.25 
#> [19] haven_2.5.4         xfun_0.48           reprex_2.1.1       
#> [22] cachem_1.1.0        labelled_2.13.0     jsonlite_1.8.9     
#> [25] highr_0.11          later_1.3.2         irlba_2.3.5.1      
#> [28] parallel_4.4.1      prettyunits_1.2.0   R6_2.5.1           
#> [31] stringi_1.8.4       parallelly_1.38.0   lubridate_1.9.3    
#> [34] Rcpp_1.0.13         iterators_1.0.14    knitr_1.48         
#> [37] future.apply_1.11.2 httpuv_1.6.15       splines_4.4.1      
#> [40] nnet_7.3-19         timechange_0.3.0    tidyselect_1.2.1   
#> [43] rstudioapi_0.16.0   yaml_2.3.10         timeDate_4041.110  
#> [46] codetools_0.2-20    miniUI_0.1.1.1      curl_5.2.3         
#> [49] listenv_0.9.1       lattice_0.22-6      shiny_1.9.1        
#> [52] withr_3.0.1         evaluate_1.0.0      future_1.34.0      
#> [55] survival_3.6-4      xml2_1.3.6          pillar_1.9.0       
#> [58] foreach_1.5.2       generics_0.1.3      hms_1.1.3          
#> [61] munsell_0.5.1       globals_0.16.3      xtable_1.8-4       
#> [64] glue_1.8.0          tools_4.4.1         data.table_1.16.0  
#> [67] gower_1.0.1         forcats_1.0.0       fs_1.6.4           
#> [70] grid_4.4.1          ipred_0.9-15        colorspace_2.1-1   
#> [73] cli_3.6.3           DiceDesign_1.10     fansi_1.0.6        
#> [76] lava_1.8.0          gtable_0.3.5        GPfit_1.0-8        
#> [79] digest_0.6.37       farver_2.1.2        memoise_2.0.1      
#> [82] htmltools_0.5.8.1   questionr_0.7.8     lifecycle_1.0.4    
#> [85] hardhat_1.4.0       mime_0.12

Created on 2024-10-08 with reprex v2.1.1

@simonpcouch
Copy link
Contributor

Ah, I see. A slightly more minimal reprex:

library(tidymodels)

mtcars <- mtcars[rep(1:32, 10),]

# create a workflow ------------------------
rec <-
  recipe(mpg ~ ., mtcars) %>%
  step_pca(all_numeric_predictors(), num_comp = tune())

spec <- boost_tree(mode = "regression", mtry = tune())

wflow <- workflow(rec, spec)

folds <- vfold_cv(mtcars)

# attempt to tune ---------------------------
wflow_res <- tune_grid(wflow, folds)
#> Error in `check_parameters()`:
#> ! Some model parameters require finalization but there are recipe parameters that require tuning. Please use  `extract_parameter_set_dials()` to set parameter ranges  manually and supply the output to the `param_info` argument.

# finalize ----------------------------------
wflow_pset <- extract_parameter_set_dials(wflow)

wflow_pset_finalized <- finalize(wflow_pset, mtcars)

wflow_res <- tune_grid(wflow, folds, param_info = wflow_pset_finalized)

# analogous issue with workflow sets -------
wflow_set <- workflow_set(list(rec), list(spec))

wflow_set_res <- workflow_map(wflow_set, resamples = folds)

wflow_set_res$result[[1]][1]
#> [1] "Error in check_parameters(workflow, pset = pset, data = resamples$splits[[1]]$data,  : \n  Some model parameters require finalization but there are recipe parameters that require tuning. Please use  `extract_parameter_set_dials()` to set parameter ranges  manually and supply the output to the `param_info` argument.\n"

Created on 2024-10-15 with reprex v2.1.1

This feels like it ought to be a finalize.workflow_set() method to me, except that finalize() takes parameter information as its first argument rather than dispatching on the data.

@simonpcouch
Copy link
Contributor

Related to #157.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants