From 3cd9119123dfc56525df6bdef2cc0fe466ecfcb8 Mon Sep 17 00:00:00 2001 From: Toby Dylan Hocking Date: Sun, 21 Jan 2024 12:51:43 -0700 Subject: [PATCH 1/4] test fails for train strata counts --- tests/testthat/test-CRAN.R | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/testthat/test-CRAN.R b/tests/testthat/test-CRAN.R index d6fcdc3..dc3c006 100644 --- a/tests/testthat/test-CRAN.R +++ b/tests/testthat/test-CRAN.R @@ -1,3 +1,5 @@ +library(testthat) +library(data.table) test_that("resampling error if no group", { itask <- mlr3::TaskClassif$new("iris", iris, target="Species") same_other <- mlr3resampling::ResamplingSameOtherCV$new() @@ -92,6 +94,30 @@ test_that("error for 10 data", { fixed=TRUE) }) +test_that("strata respected in all sizes", { + size_cv <- mlr3resampling::ResamplingVariableSizeTrainCV$new() + size_cv$param_set$values$min_train_data <- 5 + size_cv$param_set$values$folds <- 5 + N <- 100 + imbalance <- 4 + strat.vec <- ifelse((1:imbalance) Date: Tue, 23 Jan 2024 19:08:52 +0000 Subject: [PATCH 2/4] VariableTrainSize respects strata for all train sets --- DESCRIPTION | 2 +- NEWS | 4 ++ R/ResamplingVariableSizeTrainCV.R | 75 +++++++++++++++++----------- man/ResamplingVariableSizeTrainCV.Rd | 3 +- tests/testthat/test-CRAN.R | 23 +++++++-- 5 files changed, 70 insertions(+), 37 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 5e60caf..b07bf74 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Package: mlr3resampling Type: Package Title: Resampling Algorithms for 'mlr3' Framework -Version: 2024.1.8 +Version: 2024.1.23 Authors@R: c( person("Toby", "Hocking", email="toby.hocking@r-project.org", diff --git a/NEWS b/NEWS index 85fe607..f63534f 100644 --- a/NEWS +++ b/NEWS @@ -1,3 +1,7 @@ +Changes in version 2024.1.23 + +- ResamplingVariableSizeTrainCV outputs train sets which respect strata. + Changes in version 2024.1.8 - Rename Simulations vignette to ResamplingSameOtherCV. diff --git a/R/ResamplingVariableSizeTrainCV.R b/R/ResamplingVariableSizeTrainCV.R index f781f69..1e5fc72 100644 --- a/R/ResamplingVariableSizeTrainCV.R +++ b/R/ResamplingVariableSizeTrainCV.R @@ -21,55 +21,70 @@ ResamplingVariableSizeTrainCV = R6::R6Class( }, instantiate = function(task) { task = mlr3::assert_task(mlr3::as_task(task)) - reserved.names <- c( - "row_id", "fold", "group", "display_row", - "train.groups", "test.fold", "test.group", "iteration", - "test", "train", "algorithm", "uhash", "nr", "task", "task_id", - "learner", "learner_id", "resampling", "resampling_id", - "prediction") - ## bad.names <- group.name.vec[group.name.vec %in% reserved.names] - ## if(length(bad.names)){ - ## first.bad <- bad.names[1] - ## stop(sprintf("col with role group must not be named %s; please fix by renaming %s col", first.bad, first.bad)) - ## } - ## orig.group.dt <- task$data(cols=group.name.vec) strata <- if(is.null(task$strata)){ data.dt <- task$data() data.table(N=nrow(data.dt), row_id=list(1:nrow(data.dt))) }else task$strata - sample.list <- lapply(strata$row_id, private$.sample, task = task) - folds = private$.combine(sample.list)[order(row_id)] + strata.list <- lapply(strata$row_id, private$.sample, task = task) + folds = private$.combine(strata.list)[order(row_id)] + max.train.vec <- sapply(strata.list, nrow) + small.strat.i <- which.min(max.train.vec) min_train_data <- self$param_set$values[["min_train_data"]] - if(task$nrow <= min_train_data){ - stop(sprintf( - "task$nrow=%d but should be larger than min_train_data=%d", - task$nrow, min_train_data)) - } uniq.folds <- sort(unique(folds$fold)) iteration.dt.list <- list() for(test.fold in uniq.folds){ - is.set.fold <- list( - test=folds[["fold"]] == test.fold) - is.set.fold[["train"]] <- !is.set.fold[["test"]] - i.set.list <- lapply(is.set.fold, which) - max_train_data <- length(i.set.list$train) + train.strata.list <- lapply(strata.list, function(DT)DT[fold != test.fold]) + max_train_data <- nrow(train.strata.list[[small.strat.i]]) + if(max_train_data <= min_train_data){ + stop(sprintf( + "max_train_data=%d (in smallest stratum) but should be larger than min_train_data=%d, please fix by decreasing min_train_data", + max_train_data, min_train_data)) + } log.range.data <- log(c(min_train_data, max_train_data)) seq.args <- c(as.list(log.range.data), list(l=self$param_set$values[["train_sizes"]])) log.train.sizes <- do.call(seq, seq.args) - train.size.vec <- unique(as.integer(round(exp(log.train.sizes)))) + train.size.vec <- as.integer(round(exp(log.train.sizes))) + size.tab <- table(train.size.vec) + if(any(size.tab>1)){ + stop("train sizes not unique, please decrease train_sizes") + } for(seed in 1:self$param_set$values[["random_seeds"]]){ set.seed(seed) - ord.i.vec <- sample(i.set.list$train) + train.seed.list <- lapply(train.strata.list, function(DT)DT[sample(.N)][, `:=`( + row_seed = .I, + prop = .I/.N + )][]) + test.index.vec <- do.call(c, lapply( + strata.list, function(DT)DT[fold == test.fold, row_id])) + train.prop.dt <- train.seed.list[[small.strat.i]][train.size.vec, data.table(prop)] + train.i.list <- lapply(train.seed.list, function(DT)DT[ + train.prop.dt, + .(train.i=lapply(row_seed, function(last)DT$row_id[1:last])), + on="prop", + roll="nearest"]) + train.index.list <- list() + for(train.size.i in seq_along(train.size.vec)){ + strata.index.list <- lapply(train.i.list, function(DT)DT[["train.i"]][[train.size.i]]) + train.index.list[[train.size.i]] <- do.call(c, strata.index.list) + } iteration.dt.list[[paste(test.fold, seed)]] <- data.table( test.fold, seed, - train_size=train.size.vec, - train=lapply(train.size.vec, function(last)ord.i.vec[1:last]), - test=list(i.set.list$test)) + small_stratum_size=train.size.vec, + train_size_i=seq_along(train.size.vec), + train_size=sapply(train.index.list, length), + train=train.index.list, + test=list(test.index.vec)) } } self$instance <- list( - iteration.dt=rbindlist(iteration.dt.list)[, iteration := .I][], + iteration.dt=rbindlist( + iteration.dt.list + )[ + , iteration := .I + ][ + , train_min_size := min(train_size), by=train_size_i + ][], id.dt=folds) self$task_hash = task$hash self$task_nrow = task$nrow diff --git a/man/ResamplingVariableSizeTrainCV.Rd b/man/ResamplingVariableSizeTrainCV.Rd index 4f56325..997a992 100644 --- a/man/ResamplingVariableSizeTrainCV.Rd +++ b/man/ResamplingVariableSizeTrainCV.Rd @@ -48,7 +48,8 @@ orderings of the train set that are considered. For each random order of the train set, the \code{min_train_data} - parameter controls the smallest train set size considered. + parameter controls the size of the smallest stratum in the smallest + train set considered. To determine the other train set sizes, we use an equally spaced grid on the log scale, from \code{min_train_data} to the largest train set diff --git a/tests/testthat/test-CRAN.R b/tests/testthat/test-CRAN.R index dc3c006..dd3790e 100644 --- a/tests/testthat/test-CRAN.R +++ b/tests/testthat/test-CRAN.R @@ -83,15 +83,28 @@ test_that("error for group named test", { }, "col with role group must not be named test; please fix by renaming test col") }) -test_that("error for 10 data", { +test_that("errors and result for 10 train data in small stratum", { size_cv <- mlr3resampling::ResamplingVariableSizeTrainCV$new() - i10.dt <- data.table(iris)[1:10] - i10.task <- mlr3::TaskClassif$new("i10", i10.dt, target="Species") + size_cv$param_set$values$folds <- 2 + i10.dt <- data.table(iris)[1:70] + i10.task <- mlr3::TaskClassif$new( + "i10", i10.dt, target="Species" + )$set_col_roles("Species",c("target","stratum")) expect_error({ size_cv$instantiate(i10.task) }, - "task$nrow=10 but should be larger than min_train_data=10", + "max_train_data=10 (in smallest stratum) but should be larger than min_train_data=10, please fix by decreasing min_train_data", fixed=TRUE) + size_cv$param_set$values$min_train_data <- 9 + expect_error({ + size_cv$instantiate(i10.task) + }, + "train sizes not unique, please decrease train_sizes", + fixed=TRUE) + size_cv$param_set$values$train_sizes <- 2 + size_cv$instantiate(i10.task) + size.tab <- table(size_cv$instance$iteration.dt[["small_stratum_size"]]) + expect_identical(names(size.tab), c("9","10")) }) test_that("strata respected in all sizes", { @@ -114,7 +127,7 @@ test_that("strata respected in all sizes", { min.row <- min.dt[min.i] train.i <- min.row$train[[1]] strat.tab <- table(istrat.dt[train.i, strat]) - expect_identical(strat.tab, smallest.size.tab) + expect_equal(strat.tab, smallest.size.tab) } }) From 545eea8dfbf549556ed97c0658623f660e126433 Mon Sep 17 00:00:00 2001 From: Toby Dylan Hocking Date: Tue, 23 Jan 2024 16:32:40 -0700 Subject: [PATCH 3/4] lgr off --- vignettes/ResamplingSameOtherCV.Rmd | 4 +++- vignettes/ResamplingVariableSizeTrainCV.Rmd | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/vignettes/ResamplingSameOtherCV.Rmd b/vignettes/ResamplingSameOtherCV.Rmd index b91d518..87582a4 100644 --- a/vignettes/ResamplingSameOtherCV.Rmd +++ b/vignettes/ResamplingSameOtherCV.Rmd @@ -156,9 +156,10 @@ In the code below, we execute the benchmark experiment (in parallel using the multisession future plan). ```{r} -if(FALSE){ +if(FALSE){#for CRAN. if(require(future))plan("multisession") } +lgr::get_logger("mlr3")$set_threshold("warn") (reg.bench.result <- mlr3::benchmark( reg.bench.grid, store_models = TRUE)) ``` @@ -495,6 +496,7 @@ iteration can be parallelized by declaring a future plan. if(FALSE){ if(require(future))plan("multisession") } +lgr::get_logger("mlr3")$set_threshold("warn") (class.bench.result <- mlr3::benchmark( class.bench.grid, store_models = TRUE)) ``` diff --git a/vignettes/ResamplingVariableSizeTrainCV.Rmd b/vignettes/ResamplingVariableSizeTrainCV.Rmd index 9d4c66b..77d673e 100644 --- a/vignettes/ResamplingVariableSizeTrainCV.Rmd +++ b/vignettes/ResamplingVariableSizeTrainCV.Rmd @@ -173,6 +173,7 @@ using the multisession future plan). if(FALSE){ if(require(future))plan("multisession") } +lgr::get_logger("mlr3")$set_threshold("warn") (reg.bench.result <- mlr3::benchmark( reg.bench.grid, store_models = TRUE)) ``` @@ -500,6 +501,7 @@ defined by our benchmark grid: if(FALSE){ if(require(future))plan("multisession") } +lgr::get_logger("mlr3")$set_threshold("warn") (class.bench.result <- mlr3::benchmark( class.bench.grid, store_models = TRUE)) ``` From d95e44aebc334a1ca1d38c43972580cb257ddae0 Mon Sep 17 00:00:00 2001 From: Toby Dylan Hocking Date: Tue, 23 Jan 2024 16:38:35 -0700 Subject: [PATCH 4/4] Suggest lgr --- DESCRIPTION | 1 + vignettes/ResamplingSameOtherCV.Rmd | 4 ++-- vignettes/ResamplingVariableSizeTrainCV.Rmd | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index b07bf74..02369db 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -67,6 +67,7 @@ Imports: mlr3misc Suggests: animint2, + lgr, future, testthat, knitr, diff --git a/vignettes/ResamplingSameOtherCV.Rmd b/vignettes/ResamplingSameOtherCV.Rmd index 87582a4..6d017c2 100644 --- a/vignettes/ResamplingSameOtherCV.Rmd +++ b/vignettes/ResamplingSameOtherCV.Rmd @@ -159,7 +159,7 @@ using the multisession future plan). if(FALSE){#for CRAN. if(require(future))plan("multisession") } -lgr::get_logger("mlr3")$set_threshold("warn") +if(require(lgr))get_logger("mlr3")$set_threshold("warn") (reg.bench.result <- mlr3::benchmark( reg.bench.grid, store_models = TRUE)) ``` @@ -496,7 +496,7 @@ iteration can be parallelized by declaring a future plan. if(FALSE){ if(require(future))plan("multisession") } -lgr::get_logger("mlr3")$set_threshold("warn") +if(require(lgr))get_logger("mlr3")$set_threshold("warn") (class.bench.result <- mlr3::benchmark( class.bench.grid, store_models = TRUE)) ``` diff --git a/vignettes/ResamplingVariableSizeTrainCV.Rmd b/vignettes/ResamplingVariableSizeTrainCV.Rmd index 77d673e..51389d3 100644 --- a/vignettes/ResamplingVariableSizeTrainCV.Rmd +++ b/vignettes/ResamplingVariableSizeTrainCV.Rmd @@ -173,7 +173,7 @@ using the multisession future plan). if(FALSE){ if(require(future))plan("multisession") } -lgr::get_logger("mlr3")$set_threshold("warn") +if(require(lgr))get_logger("mlr3")$set_threshold("warn") (reg.bench.result <- mlr3::benchmark( reg.bench.grid, store_models = TRUE)) ``` @@ -501,7 +501,7 @@ defined by our benchmark grid: if(FALSE){ if(require(future))plan("multisession") } -lgr::get_logger("mlr3")$set_threshold("warn") +if(require(lgr))get_logger("mlr3")$set_threshold("warn") (class.bench.result <- mlr3::benchmark( class.bench.grid, store_models = TRUE)) ```