Skip to content

Commit

Permalink
refactor: allow learner with tune token in benchmark() when param val…
Browse files Browse the repository at this point in the history
…ues overwrites them (#1251)

* refactor: allow learner with tune token in benchmark() when param values overwrites them

* ...

* ...

* ...
  • Loading branch information
be-marc authored Feb 6, 2025
1 parent 8e80e11 commit 556661f
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 8 deletions.
11 changes: 8 additions & 3 deletions R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@ assert_learners = function(learners, task = NULL, task_type = NULL, properties =

# this does not check the validation task, as this is only possible once the validation set is known,
# which happens during worker(), so it cannot be checked before that
assert_task_learner = function(task, learner, cols = NULL) {
assert_task_learner = function(task, learner, param_values = NULL, cols = NULL) {
pars = learner$param_set$get_values(type = "only_token", check_required = FALSE)
# remove pars that are covered by param_values
pars = pars[names(pars) %nin% names(param_values)]
if (length(pars) > 0) {
stopf("%s cannot be trained with TuneToken present in hyperparameter: %s", learner$format(), str_collapse(names(pars)))
}
Expand Down Expand Up @@ -161,12 +163,15 @@ assert_task_learner = function(task, learner, cols = NULL) {
}

#' @export
#' @param param_values (`list()`)\cr
#' TuneToken are not allowed in the parameter set of the learner.
#' If the `param_values` overwrite the TuneToken, the assertion will pass.
#' @rdname mlr_assertions
assert_learnable = function(task, learner) {
assert_learnable = function(task, learner, param_values = NULL) {
if (task$task_type == "unsupervised") {
stopf("%s cannot be trained with %s", learner$format(), task$format())
}
assert_task_learner(task, learner)
assert_task_learner(task, learner, param_values)
}

#' @export
Expand Down
5 changes: 1 addition & 4 deletions R/benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps
if (length(learner_types) > 1) {
stopf("Multiple learner types detected, but mixing types is not supported: %s", str_collapse(learner_types))
}
assert_task_learner(design$task[[1]], design$learner[[1]])

setDT(design)
task = learner = resampling = NULL
Expand All @@ -125,13 +124,11 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps

# expand the design: add rows for each resampling iteration and param_values
grid = pmap_dtr(design, function(task, learner, resampling, param_values) {
# learner = assert_learner(as_learner(learner, clone = TRUE))
assert_learnable(task, learner)

iters = resampling$iters
n_params = max(1L, length(param_values))
# insert constant values
param_values = map(param_values, function(values) insert_named(learner$param_set$values, values))
assert_learnable(task, learner, unlist(param_values, recursive = FALSE))

data.table(
task = list(task), learner = list(learner), resampling = list(resampling),
Expand Down
6 changes: 5 additions & 1 deletion man/mlr_assertions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions tests/testthat/test_benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,14 @@ test_that("benchmark_grid only allows unique learner ids", {
expect_error(benchmark_grid(task, list(learner, learner), resampling), "unique")
})

test_that("benchmark allows that param_values overwrites tune token", {

learner = lrn("classif.rpart", cp = to_tune(0.01, 0.1))
design = benchmark_grid(tsk("pima"), learner, rsmp("cv", folds = 3), param_values = list(list(list(cp = 0.01))))
expect_benchmark_result(benchmark(design))

learner = lrn("classif.rpart", cp = to_tune(0.01, 0.1))
design = benchmark_grid(tsk("pima"), learner, rsmp("cv", folds = 3))
expect_error(benchmark(design), "cannot be trained with TuneToken present in hyperparameter")
})

0 comments on commit 556661f

Please sign in to comment.