diff --git a/DESCRIPTION b/DESCRIPTION index 53a9a06f2..6aba52c38 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -62,7 +62,7 @@ Imports: digest, lgr, mlr3 (>= 0.20.0), - mlr3misc (>= 0.9.0), + mlr3misc (>= 0.16.0), paradox, R6, withr diff --git a/R/po.R b/R/po.R index e8c6315a3..7c23c44cf 100644 --- a/R/po.R +++ b/R/po.R @@ -83,7 +83,7 @@ po.PipeOp = function(.obj, ...) { #' @export po.character = function(.obj, ...) { - dictionary_sugar_inc_get(dict = mlr_pipeops, .key = .obj, ...) + dictionary_sugar_inc_get(dict = mlr_pipeops, .key = .obj, ..., .dicts_suggest = list("ppl()" = mlr_graphs)) } #' @export @@ -111,7 +111,7 @@ pos.NULL = function(.objs, ...) { #' @export pos.character = function(.objs, ...) { - dictionary_sugar_inc_mget(dict = mlr_pipeops, .keys = .objs, ...) + dictionary_sugar_inc_mget(dict = mlr_pipeops, .keys = .objs, ..., .dicts_suggest = list("ppls()" = mlr_graphs)) } #' @export diff --git a/R/ppl.R b/R/ppl.R index 1548d25ce..374c4f17d 100644 --- a/R/ppl.R +++ b/R/ppl.R @@ -23,12 +23,12 @@ #' gr = ppl("bagging", graph = po(lrn("regr.rpart")), #' averager = po("regravg", collect_multiplicity = TRUE)) ppl = function(.key, ...) { - dictionary_sugar_get(dict = mlr_graphs, .key = .key, ...) + dictionary_sugar_get(dict = mlr_graphs, .key = .key, ..., .dicts_suggest = list("po()" = mlr_pipeops)) } #' @export #' @rdname ppl ppls = function(.keys, ...) { if (missing(.keys)) return(mlr_graphs) - map(.x = .keys, .f = dictionary_sugar_get, dict = mlr_graphs, ...) + map(.x = .keys, .f = dictionary_sugar_get, dict = mlr_graphs, ..., .dicts_suggest = list("pos()" = mlr_pipeops)) } diff --git a/tests/testthat/test_po.R b/tests/testthat/test_po.R index f52ea538a..eaff5c07f 100644 --- a/tests/testthat/test_po.R +++ b/tests/testthat/test_po.R @@ -217,3 +217,11 @@ test_that("Incrementing ids works", { xs = pos(c("pca_1", "pca_2")) assert_true(all(names(xs) == c("pca_1", "pca_2"))) }) + +test_that("po - dictionary suggest works", { + + # test that correct dictionary is checked against + expect_error(po("robustify"), "ppl\\(\\): 'robustify'") + expect_error(pos("robustify"), "ppls\\(\\): 'robustify'") + +}) diff --git a/tests/testthat/test_ppl.R b/tests/testthat/test_ppl.R index 700eff726..0c1771dc9 100644 --- a/tests/testthat/test_ppl.R +++ b/tests/testthat/test_ppl.R @@ -26,7 +26,7 @@ test_that("mlr_graphs access works", { }) -test_that("mlr_pipeops multi-access works", { +test_that("mlr_graphs multi-access works", { expect_equal( ppls("robustify"), @@ -73,3 +73,11 @@ test_that("mlr3book authors don't sleepwalk through life", { bmr = benchmark(benchmark_grid(tasks, learners, rsmp("cv", folds = 2))) }) + +test_that("ppl - dictionary suggest works", { + + # test that correct dictionary is checked against + expect_error(ppl("adas"), "po\\(\\): 'adas'") + expect_error(ppls("adas"), "pos\\(\\): 'adas'") + +})