From 556661f4ef42bc7931a06c140118b0f3620b71b4 Mon Sep 17 00:00:00 2001 From: Marc Becker <33069354+be-marc@users.noreply.github.com> Date: Thu, 6 Feb 2025 12:16:37 +0100 Subject: [PATCH 1/2] refactor: allow learner with tune token in benchmark() when param values overwrites them (#1251) * refactor: allow learner with tune token in benchmark() when param values overwrites them * ... * ... * ... --- R/assertions.R | 11 ++++++++--- R/benchmark.R | 5 +---- man/mlr_assertions.Rd | 6 +++++- tests/testthat/test_benchmark.R | 11 +++++++++++ 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/R/assertions.R b/R/assertions.R index 2e5d6c853..ec4f82059 100644 --- a/R/assertions.R +++ b/R/assertions.R @@ -119,8 +119,10 @@ assert_learners = function(learners, task = NULL, task_type = NULL, properties = # this does not check the validation task, as this is only possible once the validation set is known, # which happens during worker(), so it cannot be checked before that -assert_task_learner = function(task, learner, cols = NULL) { +assert_task_learner = function(task, learner, param_values = NULL, cols = NULL) { pars = learner$param_set$get_values(type = "only_token", check_required = FALSE) + # remove pars that are covered by param_values + pars = pars[names(pars) %nin% names(param_values)] if (length(pars) > 0) { stopf("%s cannot be trained with TuneToken present in hyperparameter: %s", learner$format(), str_collapse(names(pars))) } @@ -161,12 +163,15 @@ assert_task_learner = function(task, learner, cols = NULL) { } #' @export +#' @param param_values (`list()`)\cr +#' TuneToken are not allowed in the parameter set of the learner. +#' If the `param_values` overwrite the TuneToken, the assertion will pass. #' @rdname mlr_assertions -assert_learnable = function(task, learner) { +assert_learnable = function(task, learner, param_values = NULL) { if (task$task_type == "unsupervised") { stopf("%s cannot be trained with %s", learner$format(), task$format()) } - assert_task_learner(task, learner) + assert_task_learner(task, learner, param_values) } #' @export diff --git a/R/benchmark.R b/R/benchmark.R index 63b41f180..0ac8229e4 100644 --- a/R/benchmark.R +++ b/R/benchmark.R @@ -106,7 +106,6 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps if (length(learner_types) > 1) { stopf("Multiple learner types detected, but mixing types is not supported: %s", str_collapse(learner_types)) } - assert_task_learner(design$task[[1]], design$learner[[1]]) setDT(design) task = learner = resampling = NULL @@ -125,13 +124,11 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps # expand the design: add rows for each resampling iteration and param_values grid = pmap_dtr(design, function(task, learner, resampling, param_values) { - # learner = assert_learner(as_learner(learner, clone = TRUE)) - assert_learnable(task, learner) - iters = resampling$iters n_params = max(1L, length(param_values)) # insert constant values param_values = map(param_values, function(values) insert_named(learner$param_set$values, values)) + assert_learnable(task, learner, unlist(param_values, recursive = FALSE)) data.table( task = list(task), learner = list(learner), resampling = list(resampling), diff --git a/man/mlr_assertions.Rd b/man/mlr_assertions.Rd index 21a497c43..aff6e73b3 100644 --- a/man/mlr_assertions.Rd +++ b/man/mlr_assertions.Rd @@ -56,7 +56,7 @@ assert_learners( .var.name = vname(learners) ) -assert_learnable(task, learner) +assert_learnable(task, learner, param_values = NULL) assert_predictable(task, learner) @@ -124,6 +124,10 @@ Set of required task properties.} \item{learners}{(list of \link{Learner}).} +\item{param_values}{(\code{list()})\cr +TuneToken are not allowed in the parameter set of the learner. +If the \code{param_values} overwrite the TuneToken, the assertion will pass.} + \item{measure}{(\link{Measure}).} \item{prediction}{(\link{Prediction}).} diff --git a/tests/testthat/test_benchmark.R b/tests/testthat/test_benchmark.R index 8a14a6015..433253958 100644 --- a/tests/testthat/test_benchmark.R +++ b/tests/testthat/test_benchmark.R @@ -590,3 +590,14 @@ test_that("benchmark_grid only allows unique learner ids", { expect_error(benchmark_grid(task, list(learner, learner), resampling), "unique") }) +test_that("benchmark allows that param_values overwrites tune token", { + + learner = lrn("classif.rpart", cp = to_tune(0.01, 0.1)) + design = benchmark_grid(tsk("pima"), learner, rsmp("cv", folds = 3), param_values = list(list(list(cp = 0.01)))) + expect_benchmark_result(benchmark(design)) + + learner = lrn("classif.rpart", cp = to_tune(0.01, 0.1)) + design = benchmark_grid(tsk("pima"), learner, rsmp("cv", folds = 3)) + expect_error(benchmark(design), "cannot be trained with TuneToken present in hyperparameter") +}) + From ad8e233986ca157b2d15b01b158009e9fb4f0540 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Thu, 6 Feb 2025 15:39:49 +0100 Subject: [PATCH 2/2] fix(marshaling): extract internal tuning info before marshaling (#1257) * fix(marshaling): extract internal tuning info before marshaling Resolves Issue #1256 * cleanup previous commit * fix tests * ... * fix bug --- NEWS.md | 1 + R/worker.R | 42 ++++++++++++++++++++++++++--------- tests/testthat/test_Learner.R | 10 +++++++++ 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/NEWS.md b/NEWS.md index 89892f152..7fbedb09c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,7 @@ The option `mlr3.allow_utf8_names` is removed. * BREAKING CHANGE: `Learner$predict_types` is read-only now. * docs: Clear up behavior of `Learner$predict_type` after training. +* fix: Internal tuning and validation now works when the model requires marshaling (#1256) # mlr3 0.22.1 diff --git a/R/worker.R b/R/worker.R index b5a794329..a0b60bf62 100644 --- a/R/worker.R +++ b/R/worker.R @@ -18,11 +18,35 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL stopf("Learner '%s' on task '%s' returned NULL during internal %s()", learner$id, task$id, mode) } + + # In order to avoid unnecessary (un-)marshaling steps, + # we already extract the internal tuned values and validation scores here. + # They should only operate on the model and the param_vals so the + # information above should be enough. + # In the future, we might want to refactor this, so the extractors get directly + # called with the model and param_vals + learner$state$model = model + learner$state$param_vals = learner$param_set$values + + # Extract internal valid scores and tuned values if applicable. + internal_valid_scores = if (!is.null(get0("validate", learner)) && + exists(".extract_internal_valid_scores", get_private(learner))) { + get_private(learner)$.extract_internal_valid_scores() + } + + internal_tuned_values = if (exists(".extract_internal_tuned_values", get_private(learner))) { + get_private(learner)$.extract_internal_tuned_values() + } + if (learner$encapsulation[["train"]] == "callr") { model = marshal_model(model, inplace = TRUE) } - model + list( + model = model, + internal_valid_scores = internal_valid_scores, + internal_tuned_values = internal_tuned_values + ) } assert_choice(mode, c("train", "hotstart")) @@ -79,33 +103,31 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL train_time = result$elapsed learner$state = set_class(insert_named(learner$state, list( - model = result$result, + model = result$result$model, log = log, train_time = train_time, param_vals = learner$param_set$values, task_hash = task$hash, feature_names = task$feature_names, - validate = validate, + validate = get0("validate", learner), mlr3_version = mlr_reflections$package_version )), c("learner_state", "list")) # store the results of the internal tuning / internal validation in the learner's state # otherwise this information is only available with store_models = TRUE - if (!is.null(validate)) { - learner$state$internal_valid_scores = get_private(learner)$.extract_internal_valid_scores() + if (!is.null(result$result$internal_valid_scores)) { + learner$state$internal_valid_scores = result$result$internal_valid_scores learner$state$internal_valid_task_hash = task$internal_valid_task$hash } - if (exists(".extract_internal_tuned_values", get_private(learner))) { - learner$state$internal_tuned_values = get_private(learner)$.extract_internal_tuned_values() - } + learner$state$internal_tuned_values = result$result$internal_tuned_values - if (is.null(result$result)) { + if (is.null(result$result$model)) { lg$info("Learner '%s' on task '%s' failed to %s a model", learner$id, task$id, mode, learner = learner$clone(), messages = result$log$msg) } else { lg$debug("Learner '%s' on task '%s' succeeded to %s a model", - learner$id, task$id, mode, learner = learner$clone(), result = result$result, messages = result$log$msg) + learner$id, task$id, mode, learner = learner$clone(), result = result$result$model, messages = result$log$msg) } # fit fallback learner diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index 65349dbd7..0105e86d5 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -735,3 +735,13 @@ test_that("predict_newdata creates column info correctly", { expect_true("row_id" %in% learner$model$task_predict$col_info$id) }) + +test_that("marshaling and internal tuning", { + l = lrn("classif.debug", validate = 0.3, early_stopping = TRUE, iter = 100) + l$encapsulate("evaluate", lrn("classif.featureless")) + task = tsk("iris") + l$train(task) + expect_list(l$internal_tuned_values, types = "integer") + expect_list(l$internal_valid_scores, types = "numeric") + +})