From ad8e233986ca157b2d15b01b158009e9fb4f0540 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Thu, 6 Feb 2025 15:39:49 +0100 Subject: [PATCH] fix(marshaling): extract internal tuning info before marshaling (#1257) * fix(marshaling): extract internal tuning info before marshaling Resolves Issue #1256 * cleanup previous commit * fix tests * ... * fix bug --- NEWS.md | 1 + R/worker.R | 42 ++++++++++++++++++++++++++--------- tests/testthat/test_Learner.R | 10 +++++++++ 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/NEWS.md b/NEWS.md index 89892f152..7fbedb09c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,7 @@ The option `mlr3.allow_utf8_names` is removed. * BREAKING CHANGE: `Learner$predict_types` is read-only now. * docs: Clear up behavior of `Learner$predict_type` after training. +* fix: Internal tuning and validation now works when the model requires marshaling (#1256) # mlr3 0.22.1 diff --git a/R/worker.R b/R/worker.R index b5a794329..a0b60bf62 100644 --- a/R/worker.R +++ b/R/worker.R @@ -18,11 +18,35 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL stopf("Learner '%s' on task '%s' returned NULL during internal %s()", learner$id, task$id, mode) } + + # In order to avoid unnecessary (un-)marshaling steps, + # we already extract the internal tuned values and validation scores here. + # They should only operate on the model and the param_vals so the + # information above should be enough. + # In the future, we might want to refactor this, so the extractors get directly + # called with the model and param_vals + learner$state$model = model + learner$state$param_vals = learner$param_set$values + + # Extract internal valid scores and tuned values if applicable. + internal_valid_scores = if (!is.null(get0("validate", learner)) && + exists(".extract_internal_valid_scores", get_private(learner))) { + get_private(learner)$.extract_internal_valid_scores() + } + + internal_tuned_values = if (exists(".extract_internal_tuned_values", get_private(learner))) { + get_private(learner)$.extract_internal_tuned_values() + } + if (learner$encapsulation[["train"]] == "callr") { model = marshal_model(model, inplace = TRUE) } - model + list( + model = model, + internal_valid_scores = internal_valid_scores, + internal_tuned_values = internal_tuned_values + ) } assert_choice(mode, c("train", "hotstart")) @@ -79,33 +103,31 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL train_time = result$elapsed learner$state = set_class(insert_named(learner$state, list( - model = result$result, + model = result$result$model, log = log, train_time = train_time, param_vals = learner$param_set$values, task_hash = task$hash, feature_names = task$feature_names, - validate = validate, + validate = get0("validate", learner), mlr3_version = mlr_reflections$package_version )), c("learner_state", "list")) # store the results of the internal tuning / internal validation in the learner's state # otherwise this information is only available with store_models = TRUE - if (!is.null(validate)) { - learner$state$internal_valid_scores = get_private(learner)$.extract_internal_valid_scores() + if (!is.null(result$result$internal_valid_scores)) { + learner$state$internal_valid_scores = result$result$internal_valid_scores learner$state$internal_valid_task_hash = task$internal_valid_task$hash } - if (exists(".extract_internal_tuned_values", get_private(learner))) { - learner$state$internal_tuned_values = get_private(learner)$.extract_internal_tuned_values() - } + learner$state$internal_tuned_values = result$result$internal_tuned_values - if (is.null(result$result)) { + if (is.null(result$result$model)) { lg$info("Learner '%s' on task '%s' failed to %s a model", learner$id, task$id, mode, learner = learner$clone(), messages = result$log$msg) } else { lg$debug("Learner '%s' on task '%s' succeeded to %s a model", - learner$id, task$id, mode, learner = learner$clone(), result = result$result, messages = result$log$msg) + learner$id, task$id, mode, learner = learner$clone(), result = result$result$model, messages = result$log$msg) } # fit fallback learner diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index 65349dbd7..0105e86d5 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -735,3 +735,13 @@ test_that("predict_newdata creates column info correctly", { expect_true("row_id" %in% learner$model$task_predict$col_info$id) }) + +test_that("marshaling and internal tuning", { + l = lrn("classif.debug", validate = 0.3, early_stopping = TRUE, iter = 100) + l$encapsulate("evaluate", lrn("classif.featureless")) + task = tsk("iris") + l$train(task) + expect_list(l$internal_tuned_values, types = "integer") + expect_list(l$internal_valid_scores, types = "numeric") + +})