Skip to content

Commit

Permalink
Merge branch 'main' into callback
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Feb 10, 2025
2 parents 648ef79 + ad8e233 commit e44e39a
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 19 deletions.
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# mlr3 (development version)

* fix: the `$predict_newdata()` method of `Learner` now automatically conducts type conversions (#685)
* BREAKING_CHANGE: Predicting on a `task` with the wrong column information is now an error and not a warning.
* BREAKING_CHANGE: Predicting on a `task` with the wrong column information is now an error and not a warning.
* Column names with UTF-8 characters are now allowed by default.
The option `mlr3.allow_utf8_names` is removed.
* BREAKING CHANGE: `Learner$predict_types` is read-only now.
* docs: Clear up behavior of `Learner$predict_type` after training.
* feat: Add callbacks to `resample()` and `benchmark()`.
* fix: Internal tuning and validation now works when the model requires marshaling (#1256)

# mlr3 0.22.1

Expand Down
11 changes: 8 additions & 3 deletions R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@ assert_learners = function(learners, task = NULL, task_type = NULL, properties =

# this does not check the validation task, as this is only possible once the validation set is known,
# which happens during worker(), so it cannot be checked before that
assert_task_learner = function(task, learner, cols = NULL) {
assert_task_learner = function(task, learner, param_values = NULL, cols = NULL) {
pars = learner$param_set$get_values(type = "only_token", check_required = FALSE)
# remove pars that are covered by param_values
pars = pars[names(pars) %nin% names(param_values)]
if (length(pars) > 0) {
stopf("%s cannot be trained with TuneToken present in hyperparameter: %s", learner$format(), str_collapse(names(pars)))
}
Expand Down Expand Up @@ -161,12 +163,15 @@ assert_task_learner = function(task, learner, cols = NULL) {
}

#' @export
#' @param param_values (`list()`)\cr
#' TuneToken are not allowed in the parameter set of the learner.
#' If the `param_values` overwrite the TuneToken, the assertion will pass.
#' @rdname mlr_assertions
assert_learnable = function(task, learner) {
assert_learnable = function(task, learner, param_values = NULL) {
if (task$task_type == "unsupervised") {
stopf("%s cannot be trained with %s", learner$format(), task$format())
}
assert_task_learner(task, learner)
assert_task_learner(task, learner, param_values)
}

#' @export
Expand Down
5 changes: 1 addition & 4 deletions R/benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps
if (length(learner_types) > 1) {
stopf("Multiple learner types detected, but mixing types is not supported: %s", str_collapse(learner_types))
}
assert_task_learner(design$task[[1]], design$learner[[1]])

setDT(design)
task = learner = resampling = NULL
Expand All @@ -127,13 +126,11 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps

# expand the design: add rows for each resampling iteration and param_values
grid = pmap_dtr(design, function(task, learner, resampling, param_values) {
# learner = assert_learner(as_learner(learner, clone = TRUE))
assert_learnable(task, learner)

iters = resampling$iters
n_params = max(1L, length(param_values))
# insert constant values
param_values = map(param_values, function(values) insert_named(learner$param_set$values, values))
assert_learnable(task, learner, unlist(param_values, recursive = FALSE))

data.table(
task = list(task), learner = list(learner), resampling = list(resampling),
Expand Down
42 changes: 32 additions & 10 deletions R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,35 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL
stopf("Learner '%s' on task '%s' returned NULL during internal %s()", learner$id, task$id, mode)
}


# In order to avoid unnecessary (un-)marshaling steps,
# we already extract the internal tuned values and validation scores here.
# They should only operate on the model and the param_vals so the
# information above should be enough.
# In the future, we might want to refactor this, so the extractors get directly
# called with the model and param_vals
learner$state$model = model
learner$state$param_vals = learner$param_set$values

# Extract internal valid scores and tuned values if applicable.
internal_valid_scores = if (!is.null(get0("validate", learner)) &&
exists(".extract_internal_valid_scores", get_private(learner))) {
get_private(learner)$.extract_internal_valid_scores()
}

internal_tuned_values = if (exists(".extract_internal_tuned_values", get_private(learner))) {
get_private(learner)$.extract_internal_tuned_values()
}

if (learner$encapsulation[["train"]] == "callr") {
model = marshal_model(model, inplace = TRUE)
}

model
list(
model = model,
internal_valid_scores = internal_valid_scores,
internal_tuned_values = internal_tuned_values
)
}

assert_choice(mode, c("train", "hotstart"))
Expand Down Expand Up @@ -79,33 +103,31 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL
train_time = result$elapsed

learner$state = set_class(insert_named(learner$state, list(
model = result$result,
model = result$result$model,
log = log,
train_time = train_time,
param_vals = learner$param_set$values,
task_hash = task$hash,
feature_names = task$feature_names,
validate = validate,
validate = get0("validate", learner),
mlr3_version = mlr_reflections$package_version
)), c("learner_state", "list"))

# store the results of the internal tuning / internal validation in the learner's state
# otherwise this information is only available with store_models = TRUE
if (!is.null(validate)) {
learner$state$internal_valid_scores = get_private(learner)$.extract_internal_valid_scores()
if (!is.null(result$result$internal_valid_scores)) {
learner$state$internal_valid_scores = result$result$internal_valid_scores
learner$state$internal_valid_task_hash = task$internal_valid_task$hash
}

if (exists(".extract_internal_tuned_values", get_private(learner))) {
learner$state$internal_tuned_values = get_private(learner)$.extract_internal_tuned_values()
}
learner$state$internal_tuned_values = result$result$internal_tuned_values

if (is.null(result$result)) {
if (is.null(result$result$model)) {
lg$info("Learner '%s' on task '%s' failed to %s a model",
learner$id, task$id, mode, learner = learner$clone(), messages = result$log$msg)
} else {
lg$debug("Learner '%s' on task '%s' succeeded to %s a model",
learner$id, task$id, mode, learner = learner$clone(), result = result$result, messages = result$log$msg)
learner$id, task$id, mode, learner = learner$clone(), result = result$result$model, messages = result$log$msg)
}

# fit fallback learner
Expand Down
6 changes: 5 additions & 1 deletion man/mlr_assertions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions tests/testthat/test_Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -735,3 +735,13 @@ test_that("predict_newdata creates column info correctly", {
expect_true("row_id" %in% learner$model$task_predict$col_info$id)
})


test_that("marshaling and internal tuning", {
l = lrn("classif.debug", validate = 0.3, early_stopping = TRUE, iter = 100)
l$encapsulate("evaluate", lrn("classif.featureless"))
task = tsk("iris")
l$train(task)
expect_list(l$internal_tuned_values, types = "integer")
expect_list(l$internal_valid_scores, types = "numeric")

})
11 changes: 11 additions & 0 deletions tests/testthat/test_benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,14 @@ test_that("benchmark_grid only allows unique learner ids", {
expect_error(benchmark_grid(task, list(learner, learner), resampling), "unique")
})

test_that("benchmark allows that param_values overwrites tune token", {

learner = lrn("classif.rpart", cp = to_tune(0.01, 0.1))
design = benchmark_grid(tsk("pima"), learner, rsmp("cv", folds = 3), param_values = list(list(list(cp = 0.01))))
expect_benchmark_result(benchmark(design))

learner = lrn("classif.rpart", cp = to_tune(0.01, 0.1))
design = benchmark_grid(tsk("pima"), learner, rsmp("cv", folds = 3))
expect_error(benchmark(design), "cannot be trained with TuneToken present in hyperparameter")
})

0 comments on commit e44e39a

Please sign in to comment.