Skip to content

Commit

Permalink
fix(marshaling): extract internal tuning info before marshaling (#1257)
Browse files Browse the repository at this point in the history
* fix(marshaling): extract internal tuning info before marshaling

Resolves Issue #1256

* cleanup previous commit

* fix tests

* ...

* fix bug
  • Loading branch information
sebffischer authored Feb 6, 2025
1 parent 556661f commit ad8e233
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 10 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
The option `mlr3.allow_utf8_names` is removed.
* BREAKING CHANGE: `Learner$predict_types` is read-only now.
* docs: Clear up behavior of `Learner$predict_type` after training.
* fix: Internal tuning and validation now works when the model requires marshaling (#1256)

# mlr3 0.22.1

Expand Down
42 changes: 32 additions & 10 deletions R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,35 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL
stopf("Learner '%s' on task '%s' returned NULL during internal %s()", learner$id, task$id, mode)
}


# In order to avoid unnecessary (un-)marshaling steps,
# we already extract the internal tuned values and validation scores here.
# They should only operate on the model and the param_vals so the
# information above should be enough.
# In the future, we might want to refactor this, so the extractors get directly
# called with the model and param_vals
learner$state$model = model
learner$state$param_vals = learner$param_set$values

# Extract internal valid scores and tuned values if applicable.
internal_valid_scores = if (!is.null(get0("validate", learner)) &&
exists(".extract_internal_valid_scores", get_private(learner))) {
get_private(learner)$.extract_internal_valid_scores()
}

internal_tuned_values = if (exists(".extract_internal_tuned_values", get_private(learner))) {
get_private(learner)$.extract_internal_tuned_values()
}

if (learner$encapsulation[["train"]] == "callr") {
model = marshal_model(model, inplace = TRUE)
}

model
list(
model = model,
internal_valid_scores = internal_valid_scores,
internal_tuned_values = internal_tuned_values
)
}

assert_choice(mode, c("train", "hotstart"))
Expand Down Expand Up @@ -79,33 +103,31 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL
train_time = result$elapsed

learner$state = set_class(insert_named(learner$state, list(
model = result$result,
model = result$result$model,
log = log,
train_time = train_time,
param_vals = learner$param_set$values,
task_hash = task$hash,
feature_names = task$feature_names,
validate = validate,
validate = get0("validate", learner),
mlr3_version = mlr_reflections$package_version
)), c("learner_state", "list"))

# store the results of the internal tuning / internal validation in the learner's state
# otherwise this information is only available with store_models = TRUE
if (!is.null(validate)) {
learner$state$internal_valid_scores = get_private(learner)$.extract_internal_valid_scores()
if (!is.null(result$result$internal_valid_scores)) {
learner$state$internal_valid_scores = result$result$internal_valid_scores
learner$state$internal_valid_task_hash = task$internal_valid_task$hash
}

if (exists(".extract_internal_tuned_values", get_private(learner))) {
learner$state$internal_tuned_values = get_private(learner)$.extract_internal_tuned_values()
}
learner$state$internal_tuned_values = result$result$internal_tuned_values

if (is.null(result$result)) {
if (is.null(result$result$model)) {
lg$info("Learner '%s' on task '%s' failed to %s a model",
learner$id, task$id, mode, learner = learner$clone(), messages = result$log$msg)
} else {
lg$debug("Learner '%s' on task '%s' succeeded to %s a model",
learner$id, task$id, mode, learner = learner$clone(), result = result$result, messages = result$log$msg)
learner$id, task$id, mode, learner = learner$clone(), result = result$result$model, messages = result$log$msg)
}

# fit fallback learner
Expand Down
10 changes: 10 additions & 0 deletions tests/testthat/test_Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -735,3 +735,13 @@ test_that("predict_newdata creates column info correctly", {
expect_true("row_id" %in% learner$model$task_predict$col_info$id)
})


test_that("marshaling and internal tuning", {
l = lrn("classif.debug", validate = 0.3, early_stopping = TRUE, iter = 100)
l$encapsulate("evaluate", lrn("classif.featureless"))
task = tsk("iris")
l$train(task)
expect_list(l$internal_tuned_values, types = "integer")
expect_list(l$internal_valid_scores, types = "numeric")

})

0 comments on commit ad8e233

Please sign in to comment.