-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature request: add finalization method for workflow sets #164
Comments
Related to #45. What would you like to do with the list of finalized models? The use cases I imagine here all involve some sort of model selection, which should ideally be carried out using resampled models. i.e. the flow we recommend is:
Rather than:
As there's no way to evaluate the list of models resulting from 3b) as the test set can only be used once. |
library('tidymodels')
library('workflowsets')
tidymodels_prefer() data(parabolic)
parabolic <- parabolic
str(parabolic)
#> tibble [500 × 3] (S3: tbl_df/tbl/data.frame)
#> $ X1 : num [1:500] 3.29 1.47 1.66 1.6 2.17 ...
#> $ X2 : num [1:500] 1.661 0.414 0.791 0.276 3.166 ...
#> $ class: Factor w/ 2 levels "Class1","Class2": 1 2 2 2 1 1 2 1 2 1 ... set.seed(1)
split <- initial_split(parabolic)
#train_set <- training(split)
#test_set <- testing(split) rec <- recipe(class ~ ., data = training(split)) %>%
step_interact(terms = ~ X1:X2) bake(prep(rec), new_data = training(split))
#> # A tibble: 375 × 4
#> X1 X2 class X1_x_X2
#> <dbl> <dbl> <fct> <dbl>
#> 1 1.17 0.627 Class2 0.733
#> 2 -0.769 -1.29 Class2 0.993
#> 3 1.17 1.02 Class1 1.20
#> 4 0.510 -2.10 Class2 -1.07
#> 5 1.38 -0.974 Class2 -1.34
#> 6 -0.0549 -1.77 Class2 0.0972
#> 7 0.703 1.24 Class1 0.868
#> 8 1.50 0.418 Class2 0.625
#> 9 -0.219 -3.08 Class2 0.675
#> 10 0.606 -0.960 Class2 -0.582
#> # ℹ 365 more rows rec_norm <- rec %>%
step_normalize(all_numeric_predictors()) rec_pca <- rec_norm %>%
step_pca(all_numeric_predictors(),
num_comp = tune()
)
library('embed')
rec_umap <- rec_norm %>%
step_umap(all_numeric_predictors(),
outcome = "class",
num_comp = tune(),
neighbors = tune(),
min_dist = tune()
) library('discrim')
mars_disc_spec <-
discrim_flexible(prod_degree = tune()) %>%
set_engine("earth")
reg_disc_sepc <-
discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) %>%
set_engine("klaR")
cart_spec <-
decision_tree(cost_complexity = tune(), min_n = tune()) %>%
set_engine("rpart") %>%
set_mode("classification")
xgboost_spec <-
boost_tree(
mtry = tune(),
trees = tune(),
min_n = tune(),
tree_depth = tune(),
learn_rate = tune(),
loss_reduction = tune(),
sample_size = tune(),
stop_iter = tune()
) %>%
set_engine("xgboost") %>%
set_mode("classification") set.seed(2)
folds <- vfold_cv(training(split), v = 5) all_workflows <-
workflow_set(
preproc = list(rec = rec,
rec_norm = rec_norm,
rec_pca = rec_pca,
rec_umap = rec_umap),
models = list(regularized = reg_disc_sepc,
mars = mars_disc_spec,
cart = cart_spec,
xgboost_spec = xgboost_spec)
) @simonpcouch see error below: all_workflows_res <-
all_workflows %>%
workflow_map(resamples = folds,
verbose = TRUE,
control = control_grid(
save_pred = TRUE,
parallel_over = "everything",
save_workflow = TRUE)
)
#> i 1 of 16 tuning: rec_regularized
#> ✔ 1 of 16 tuning: rec_regularized (9.8s)
#> i 2 of 16 tuning: rec_mars
#> ✔ 2 of 16 tuning: rec_mars (1s)
#> i 3 of 16 tuning: rec_cart
#> ✔ 3 of 16 tuning: rec_cart (3.1s)
#> i 4 of 16 tuning: rec_xgboost_spec
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ✔ 4 of 16 tuning: rec_xgboost_spec (26.2s)
#> i 5 of 16 tuning: rec_norm_regularized
#> ✔ 5 of 16 tuning: rec_norm_regularized (10.1s)
#> i 6 of 16 tuning: rec_norm_mars
#> ✔ 6 of 16 tuning: rec_norm_mars (821ms)
#> i 7 of 16 tuning: rec_norm_cart
#> ✔ 7 of 16 tuning: rec_norm_cart (3.5s)
#> i 8 of 16 tuning: rec_norm_xgboost_spec
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ✔ 8 of 16 tuning: rec_norm_xgboost_spec (26.6s)
#> i 9 of 16 tuning: rec_pca_regularized
#> ✔ 9 of 16 tuning: rec_pca_regularized (10.3s)
#> i 10 of 16 tuning: rec_pca_mars
#> ✔ 10 of 16 tuning: rec_pca_mars (3.5s)
#> i 11 of 16 tuning: rec_pca_cart
#> ✔ 11 of 16 tuning: rec_pca_cart (5s)
#> i 12 of 16 tuning: rec_pca_xgboost_spec
#> ✖ 12 of 16 tuning: rec_pca_xgboost_spec failed with: Error in check_parameters(workflow, pset = pset, data = resamples$splits[[1]]$data, : Some model parameters require finalization but there are recipe parameters that require tuning. Please use `extract_parameter_set_dials()` to set parameter ranges manually and supply the output to the `param_info` argument.
#> i 13 of 16 tuning: rec_umap_regularized
#> ✔ 13 of 16 tuning: rec_umap_regularized (1m 32.6s)
#> i 14 of 16 tuning: rec_umap_mars
#> ✔ 14 of 16 tuning: rec_umap_mars (1m 28.6s)
#> i 15 of 16 tuning: rec_umap_cart
#> ✔ 15 of 16 tuning: rec_umap_cart (1m 27s)
#> i 16 of 16 tuning: rec_umap_xgboost_spec
#> ✖ 16 of 16 tuning: rec_umap_xgboost_spec failed with: Error in check_parameters(workflow, pset = pset, data = resamples$splits[[1]]$data, : Some model parameters require finalization but there are recipe parameters that require tuning. Please use `extract_parameter_set_dials()` to set parameter ranges manually and supply the output to the `param_info` argument.
all_workflows_res
#> # A workflow set/tibble: 16 × 4
#> wflow_id info option result
#> <chr> <list> <list> <list>
#> 1 rec_regularized <tibble [1 × 4]> <opts[2]> <tune[+]>
#> 2 rec_mars <tibble [1 × 4]> <opts[2]> <tune[+]>
#> 3 rec_cart <tibble [1 × 4]> <opts[2]> <tune[+]>
#> 4 rec_xgboost_spec <tibble [1 × 4]> <opts[2]> <tune[+]>
#> 5 rec_norm_regularized <tibble [1 × 4]> <opts[2]> <tune[+]>
#> 6 rec_norm_mars <tibble [1 × 4]> <opts[2]> <tune[+]>
#> 7 rec_norm_cart <tibble [1 × 4]> <opts[2]> <tune[+]>
#> 8 rec_norm_xgboost_spec <tibble [1 × 4]> <opts[2]> <tune[+]>
#> 9 rec_pca_regularized <tibble [1 × 4]> <opts[2]> <tune[+]>
#> 10 rec_pca_mars <tibble [1 × 4]> <opts[2]> <tune[+]>
#> 11 rec_pca_cart <tibble [1 × 4]> <opts[2]> <tune[+]>
#> 12 rec_pca_xgboost_spec <tibble [1 × 4]> <opts[2]> <try-errr [1]>
#> 13 rec_umap_regularized <tibble [1 × 4]> <opts[2]> <tune[+]>
#> 14 rec_umap_mars <tibble [1 × 4]> <opts[2]> <tune[+]>
#> 15 rec_umap_cart <tibble [1 × 4]> <opts[2]> <tune[+]>
#> 16 rec_umap_xgboost_spec <tibble [1 × 4]> <opts[2]> <try-errr [1]> @simonpcouch - The workflows with the errors are the models with contain tuning parameters for both PCA / UMAP as well as xgboost / random forest - basically when the preprocessor needs to determine # filter to only workflows that have results
all_workflows_res <- all_workflows_res %>%
filter(! map_lgl(result, ~ inherits(.x, "try-error"))) %>%
filter(! map_lgl(result, ~ identical(.x, list()))) @simonpcouch - on a slightly different topic, one issue with auto-plot is that it defaults to using the data stored in 'info' column: all_workflows_res %>%
autoplot(metric = "accuracy") @simonpcouch - so we have to do a remapping into the all_workflows_res <- all_workflows_res %>%
mutate(info = map2(info, wflow_id, function(info, wflow_id) {
info %>%
mutate(preproc =
case_when(
stringr::str_detect(wflow_id, "rec_norm_") ~ "norm",
stringr::str_detect(wflow_id, "rec_pca_") ~ "PCA",
stringr::str_detect(wflow_id, "rec_umap_") ~ "UMAP",
TRUE ~ "recipe"
)
)
})) all_workflows_res %>%
autoplot(metric = "accuracy") +
facet_grid(~preprocessor) +
theme(legend.position = "bottom") +
guides(
color = guide_legend(ncol = 2),
shape = guide_legend(ncol = 2)
) here are the functions to address the unfinalized workflows and create a new workflow set: #' Finalize Parameter Grid for a Single Workflow
#' @param workflow_id workflow id
#' @param workflow_sets workflow sets
#' @param data data frame to use for
#' @param … not used
#' @return updated workflow set with parameter grid added to the workflow
.finalize_workflow_set <- function(workflow_id, workflow_sets, data, ...) {
param_set <- workflow_sets %>%
workflowsets::extract_parameter_set_dials(id = workflow_id)
if (nrow(param_set) > 0) {
# Finalize the parameter set
finalized_param_set <- param_set %>%
dials::finalize(data)
workflow_sets_with_finalized_params <- workflow_sets %>%
workflowsets::option_add(param_info = finalized_param_set, id = workflow_id) %>%
dplyr::filter(wflow_id == workflow_id)
return(workflow_sets_with_finalized_params)
} else {
tibble::tibble() # Return an empty tibble if no parameters
}
} #' Finalize the workflowset
#' @param x workflow set
#' @param data data frame to use for finalization
#' @param … additional arguments passed to
#' @return updated workflow set with parameter grid added to each workflow
finalize_workflow_set <- function(x, data, ...){
purrr::map(
purrr::set_names(x$wflow_id),
\(z) {
.finalize_workflow_set(workflow_id = z, workflow_sets = x, data = data, ...)
}
) %>%
purrr::list_rbind()
} all_workflows <- all_workflows %>%
finalize_workflow_set(training(split)) Now we can run over all the model combinations: all_workflows_res <-
all_workflows %>%
workflow_map(resamples = folds,
verbose = TRUE,
control = control_grid(
save_pred = TRUE,
parallel_over = "everything",
save_workflow = TRUE)
)
#> i 1 of 16 tuning: rec_regularized
#> ✔ 1 of 16 tuning: rec_regularized (9.8s)
#> i 2 of 16 tuning: rec_mars
#> ✔ 2 of 16 tuning: rec_mars (710ms)
#> i 3 of 16 tuning: rec_cart
#> ✔ 3 of 16 tuning: rec_cart (3.1s)
#> i 4 of 16 tuning: rec_xgboost_spec
#> ✔ 4 of 16 tuning: rec_xgboost_spec (34.7s)
#> i 5 of 16 tuning: rec_norm_regularized
#> ✔ 5 of 16 tuning: rec_norm_regularized (10.7s)
#> i 6 of 16 tuning: rec_norm_mars
#> ✔ 6 of 16 tuning: rec_norm_mars (938ms)
#> i 7 of 16 tuning: rec_norm_cart
#> ✔ 7 of 16 tuning: rec_norm_cart (3.7s)
#> i 8 of 16 tuning: rec_norm_xgboost_spec
#> ✔ 8 of 16 tuning: rec_norm_xgboost_spec (30.5s)
#> i 9 of 16 tuning: rec_pca_regularized
#> ✔ 9 of 16 tuning: rec_pca_regularized (10.9s)
#> i 10 of 16 tuning: rec_pca_mars
#> ✔ 10 of 16 tuning: rec_pca_mars (4s)
#> i 11 of 16 tuning: rec_pca_cart
#> ✔ 11 of 16 tuning: rec_pca_cart (5s)
#> i 12 of 16 tuning: rec_pca_xgboost_spec
#> ✔ 12 of 16 tuning: rec_pca_xgboost_spec (36.4s)
#> i 13 of 16 tuning: rec_umap_regularized
#> ✔ 13 of 16 tuning: rec_umap_regularized (1m 35.8s)
#> i 14 of 16 tuning: rec_umap_mars
#> ✔ 14 of 16 tuning: rec_umap_mars (1m 31.9s)
#> i 15 of 16 tuning: rec_umap_cart
#> ✔ 15 of 16 tuning: rec_umap_cart (1m 31.6s)
#> i 16 of 16 tuning: rec_umap_xgboost_spec
#> ✔ 16 of 16 tuning: rec_umap_xgboost_spec (1m 53.4s) # we have to do a remapping into the `info` column to address:
all_workflows_res <- all_workflows_res %>%
mutate(info = map2(info, wflow_id, function(info, wflow_id) {
info %>%
mutate(preproc =
case_when(
stringr::str_detect(wflow_id, "rec_norm_") ~ "norm",
stringr::str_detect(wflow_id, "rec_pca_") ~ "PCA",
stringr::str_detect(wflow_id, "rec_umap_") ~ "UMAP",
TRUE ~ "recipe"
)
)
})) all_workflows_res %>%
autoplot(metric = "accuracy") +
facet_grid(~preprocessor) +
theme(legend.position = "bottom") +
guides(
color = guide_legend(ncol = 2),
shape = guide_legend(ncol = 2)
) sessionInfo()
#> R version 4.4.1 (2024-06-14)
#> Platform: x86_64-pc-linux-gnu
#> Running under: Ubuntu 20.04.6 LTS
#>
#> Matrix products: default
#> BLAS/LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.8.so; LAPACK version 3.9.0
#>
#> locale:
#> [1] LC_CTYPE=C.UTF-8 LC_NUMERIC=C LC_TIME=C.UTF-8
#> [4] LC_COLLATE=C.UTF-8 LC_MONETARY=C.UTF-8 LC_MESSAGES=C.UTF-8
#> [7] LC_PAPER=C.UTF-8 LC_NAME=C LC_ADDRESS=C
#> [10] LC_TELEPHONE=C LC_MEASUREMENT=C.UTF-8 LC_IDENTIFICATION=C
#>
#> time zone: UTC
#> tzcode source: system (glibc)
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] uwot_0.2.2 Matrix_1.7-0 xgboost_1.7.8.1 rpart_4.1.23
#> [5] earth_5.3.4 plotmo_3.6.4 plotrix_3.8-4 Formula_1.2-5
#> [9] mda_0.5-4 class_7.3-22 klaR_1.7-3 MASS_7.3-60.2
#> [13] discrim_1.0.1 embed_1.1.4 yardstick_1.3.1 workflowsets_1.1.0
#> [17] workflows_1.1.4 tune_1.2.1 tidyr_1.3.1 tibble_3.2.1
#> [21] rsample_1.2.1 recipes_1.1.0 purrr_1.0.2 parsnip_1.2.1
#> [25] modeldata_1.4.0 infer_1.0.7 ggplot2_3.5.1 dplyr_1.1.4
#> [29] dials_1.3.0 scales_1.3.0 broom_1.0.7 tidymodels_1.2.0
#>
#> loaded via a namespace (and not attached):
#> [1] conflicted_1.2.0 rlang_1.1.4 magrittr_2.0.3
#> [4] furrr_0.3.1 RcppAnnoy_0.0.22 compiler_4.4.1
#> [7] vctrs_0.6.5 combinat_0.0-8 stringr_1.5.1
#> [10] lhs_1.2.0 pkgconfig_2.0.3 fastmap_1.2.0
#> [13] backports_1.5.0 labeling_0.4.3 utf8_1.2.4
#> [16] promises_1.3.0 rmarkdown_2.28 prodlim_2024.06.25
#> [19] haven_2.5.4 xfun_0.48 reprex_2.1.1
#> [22] cachem_1.1.0 labelled_2.13.0 jsonlite_1.8.9
#> [25] highr_0.11 later_1.3.2 irlba_2.3.5.1
#> [28] parallel_4.4.1 prettyunits_1.2.0 R6_2.5.1
#> [31] stringi_1.8.4 parallelly_1.38.0 lubridate_1.9.3
#> [34] Rcpp_1.0.13 iterators_1.0.14 knitr_1.48
#> [37] future.apply_1.11.2 httpuv_1.6.15 splines_4.4.1
#> [40] nnet_7.3-19 timechange_0.3.0 tidyselect_1.2.1
#> [43] rstudioapi_0.16.0 yaml_2.3.10 timeDate_4041.110
#> [46] codetools_0.2-20 miniUI_0.1.1.1 curl_5.2.3
#> [49] listenv_0.9.1 lattice_0.22-6 shiny_1.9.1
#> [52] withr_3.0.1 evaluate_1.0.0 future_1.34.0
#> [55] survival_3.6-4 xml2_1.3.6 pillar_1.9.0
#> [58] foreach_1.5.2 generics_0.1.3 hms_1.1.3
#> [61] munsell_0.5.1 globals_0.16.3 xtable_1.8-4
#> [64] glue_1.8.0 tools_4.4.1 data.table_1.16.0
#> [67] gower_1.0.1 forcats_1.0.0 fs_1.6.4
#> [70] grid_4.4.1 ipred_0.9-15 colorspace_2.1-1
#> [73] cli_3.6.3 DiceDesign_1.10 fansi_1.0.6
#> [76] lava_1.8.0 gtable_0.3.5 GPfit_1.0-8
#> [79] digest_0.6.37 farver_2.1.2 memoise_2.0.1
#> [82] htmltools_0.5.8.1 questionr_0.7.8 lifecycle_1.0.4
#> [85] hardhat_1.4.0 mime_0.12 Created on 2024-10-08 with reprex v2.1.1 |
Ah, I see. A slightly more minimal reprex: library(tidymodels)
mtcars <- mtcars[rep(1:32, 10),]
# create a workflow ------------------------
rec <-
recipe(mpg ~ ., mtcars) %>%
step_pca(all_numeric_predictors(), num_comp = tune())
spec <- boost_tree(mode = "regression", mtry = tune())
wflow <- workflow(rec, spec)
folds <- vfold_cv(mtcars)
# attempt to tune ---------------------------
wflow_res <- tune_grid(wflow, folds)
#> Error in `check_parameters()`:
#> ! Some model parameters require finalization but there are recipe parameters that require tuning. Please use `extract_parameter_set_dials()` to set parameter ranges manually and supply the output to the `param_info` argument.
# finalize ----------------------------------
wflow_pset <- extract_parameter_set_dials(wflow)
wflow_pset_finalized <- finalize(wflow_pset, mtcars)
wflow_res <- tune_grid(wflow, folds, param_info = wflow_pset_finalized)
# analogous issue with workflow sets -------
wflow_set <- workflow_set(list(rec), list(spec))
wflow_set_res <- workflow_map(wflow_set, resamples = folds)
wflow_set_res$result[[1]][1]
#> [1] "Error in check_parameters(workflow, pset = pset, data = resamples$splits[[1]]$data, : \n Some model parameters require finalization but there are recipe parameters that require tuning. Please use `extract_parameter_set_dials()` to set parameter ranges manually and supply the output to the `param_info` argument.\n" Created on 2024-10-15 with reprex v2.1.1 This feels like it ought to be a |
Related to #157. |
I think it would be helpful to add a function to finalize all the workflows, so far something like this seems to work pretty well
The text was updated successfully, but these errors were encountered: