diff --git a/DESCRIPTION b/DESCRIPTION
index 40f960a..bb6ccbc 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -1,7 +1,7 @@
Package: mlr3resampling
Type: Package
Title: Resampling Algorithms for 'mlr3' Framework
-Version: 2023.12.23
+Version: 2023.12.28
Authors@R: c(
person("Toby", "Hocking",
email="toby.hocking@r-project.org",
@@ -52,6 +52,9 @@ Description: A supervised learning algorithm inputs a train set,
For more information,
describes the method in depth.
+ How many train samples are required to get accurate predictions on a
+ test set? Cross-validation can be used to answer this question, with
+ variable size train sets.
License: GPL-3
URL: https://github.com/tdhock/mlr3resampling
BugReports: https://github.com/tdhock/mlr3resampling/issues
diff --git a/NAMESPACE b/NAMESPACE
index e49d1ce..97f7836 100644
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -1,3 +1,3 @@
import(R6, checkmate, data.table, mlr3, mlr3misc, paradox)
-export(ResamplingSameOtherCV, score)
+export(ResamplingSameOtherCV, score, ResamplingVariableSizeTrainCV)
diff --git a/NEWS b/NEWS
index 1bae8b4..884315f 100644
--- a/NEWS
+++ b/NEWS
@@ -1,3 +1,8 @@
+Changes in version 2023.12.28
+
+- Rename Simulations vignette to ResamplingSameOtherCV.
+- New ResamplingVariableSizeTrainCV class and vignette.
+
Changes in version 2023.12.23
- To get data set names in Simulations vignette, use task data names instead of learner$state$data_prototype.
diff --git a/R/ResamplingBase.R b/R/ResamplingBase.R
new file mode 100644
index 0000000..4fdbbbc
--- /dev/null
+++ b/R/ResamplingBase.R
@@ -0,0 +1,83 @@
+ResamplingBase = R6::R6Class(
+ "Resampling",
+ public = list(
+ id = NULL,
+ label = NULL,
+ param_set = NULL,
+ instance = NULL,
+ task_hash = NA_character_,
+ task_nrow = NA_integer_,
+ duplicated_ids = NULL,
+ man = NULL,
+ initialize = function(id, param_set = ps(), duplicated_ids = FALSE, label = NA_character_, man = NA_character_) {
+ self$id = checkmate::assert_string(id, min.chars = 1L)
+ self$label = checkmate::assert_string(label, na.ok = TRUE)
+ self$param_set = paradox::assert_param_set(param_set)
+ self$duplicated_ids = checkmate::assert_flag(duplicated_ids)
+ self$man = checkmate::assert_string(man, na.ok = TRUE)
+ },
+ format = function(...) {
+ sprintf("<%s>", class(self)[1L])
+ },
+ print = function(...) {
+ cat(
+ format(self),
+ if (is.null(self$label) || is.na(self$label))
+ "" else paste0(": ", self$label)
+ )
+ cat("\n* Iterations:", self$iters)
+ cat("\n* Instantiated:", self$is_instantiated)
+ cat("\n* Parameters:\n")
+ str(self$param_set$values)
+ },
+ help = function() {
+ self$man
+ },
+ train_set = function(i) {
+ self$instance$iteration.dt$train[[i]]
+ },
+ test_set = function(i) {
+ self$instance$iteration.dt$test[[i]]
+ }
+ ),
+ active = list(
+ iters = function(rhs) {
+ nrow(self$instance$iteration.dt)
+ },
+ is_instantiated = function(rhs) {
+ !is.null(self$instance)
+ },
+ hash = function(rhs) {
+ if (!self$is_instantiated) {
+ return(NA_character_)
+ }
+ mlr3misc::calculate_hash(list(
+ class(self),
+ self$id,
+ self$param_set$values,
+ self$instance))
+ }
+ ),
+ private = list(
+ .sample = function(ids, ...) {
+ data.table(
+ row_id = ids,
+ fold = sample(
+ seq(0, length(ids)-1) %%
+ as.integer(self$param_set$values$folds) + 1L
+ ),
+ key = "fold"
+ )
+ },
+ .combine = function(instances) {
+ rbindlist(instances, use.names = TRUE)
+ },
+ deep_clone = function(name, value) {
+ switch(name,
+ "instance" = copy(value),
+ "param_set" = value$clone(deep = TRUE),
+ value
+ )
+ }
+ )
+)
diff --git a/R/ResamplingSameOther.R b/R/ResamplingSameOtherCV.R
similarity index 60%
rename from R/ResamplingSameOther.R
rename to R/ResamplingSameOtherCV.R
index 50ccb82..36c15cd 100644
--- a/R/ResamplingSameOther.R
+++ b/R/ResamplingSameOtherCV.R
@@ -1,33 +1,17 @@
-ResamplingSameOther = R6::R6Class(
- "Resampling",
+ResamplingSameOtherCV = R6::R6Class(
+ "ResamplingSameOtherCV",
+ inherit=ResamplingBase,
public = list(
- id = NULL,
- label = NULL,
- param_set = NULL,
- instance = NULL,
- task_hash = NA_character_,
- task_nrow = NA_integer_,
- duplicated_ids = NULL,
- man = NULL,
- initialize = function(id, param_set = ps(), duplicated_ids = FALSE, label = NA_character_, man = NA_character_) {
- self$id = checkmate::assert_string(id, min.chars = 1L)
- self$label = checkmate::assert_string(label, na.ok = TRUE)
- self$param_set = paradox::assert_param_set(param_set)
- self$duplicated_ids = checkmate::assert_flag(duplicated_ids)
- self$man = checkmate::assert_string(man, na.ok = TRUE)
- },
- format = function(...) {
- sprintf("<%s>", class(self)[1L])
- },
- print = function(...) {
- cat(format(self), if (is.null(self$label) || is.na(self$label)) "" else paste0(": ", self$label))
- cat("\n* Iterations:", self$iters)
- cat("\n* Instantiated:", self$is_instantiated)
- cat("\n* Parameters:\n")
- str(self$param_set$values)
- },
- help = function() {
- self$man
+ initialize = function() {
+ ps = paradox::ps(
+ folds = paradox::p_int(2L, tags = "required")
+ )
+ ps$values = list(folds = 3L)
+ super$initialize(
+ id = "same_other_cv",
+ param_set = ps,
+ label = "Same versus Other Cross-Validation",
+ man = "ResamplingSameOtherCV")
},
instantiate = function(task) {
task = mlr3::assert_task(mlr3::as_task(task))
@@ -125,83 +109,6 @@ ResamplingSameOther = R6::R6Class(
self$task_hash = task$hash
self$task_nrow = task$nrow
invisible(self)
- },
- train_set = function(i) {
- self$instance$iteration.dt$train[[i]]
- },
- test_set = function(i) {
- self$instance$iteration.dt$test[[i]]
- }
- ),
- active = list(
- is_instantiated = function(rhs) {
- !is.null(self$instance)
- },
- hash = function(rhs) {
- if (!self$is_instantiated) {
- return(NA_character_)
- }
- mlr3misc::calculate_hash(list(class(self), self$id, self$param_set$values, self$instance))
- }
- )
-)
-
-ResamplingSameOtherCV = R6::R6Class(
- "ResamplingSameOtherCV",
- inherit = ResamplingSameOther,
- public = list(
- initialize = function() {
- ps = paradox::ps(
- folds = paradox::p_int(2L, tags = "required")
- )
- ps$values = list(folds = 3L)
- super$initialize(
- id = "same_other_cv",
- param_set = ps,
- label = "Cross-Validation",
- man = "ResamplingSameOtherCV")
- }
- ),
- active = list(
- iters = function(rhs) {
- nrow(self$instance$iteration.dt)
- }
- ),
- private = list(
- .sample = function(ids, ...) {
- data.table(
- row_id = ids,
- fold = sample(
- seq(0, length(ids)-1) %%
- as.integer(self$param_set$values$folds) + 1L
- ),
- key = "fold"
- )
- },
- .combine = function(instances) {
- rbindlist(instances, use.names = TRUE)
- },
- deep_clone = function(name, value) {
- switch(name,
- "instance" = copy(value),
- "param_set" = value$clone(deep = TRUE),
- value
- )
}
)
)
-
-score <- function(bench.result, ...){
- algorithm <- learner_id <- NULL
- ## Above to avoid CRAN NOTE.
- bench.score <- bench.result$score(...)
- out.dt.list <- list()
- for(score.i in 1:nrow(bench.score)){
- bench.row <- bench.score[score.i]
- it.dt <- bench.row$resampling[[1]]$instance$iteration.dt
- out.dt.list[[score.i]] <- it.dt[
- bench.row, on="iteration"
- ][, algorithm := sub(".*[.]", "", learner_id)]
- }
- rbindlist(out.dt.list)
-}
diff --git a/R/ResamplingVariableSizeTrainCV.R b/R/ResamplingVariableSizeTrainCV.R
new file mode 100644
index 0000000..3cd52d2
--- /dev/null
+++ b/R/ResamplingVariableSizeTrainCV.R
@@ -0,0 +1,80 @@
+ResamplingVariableSizeTrainCV = R6::R6Class(
+ "ResamplingVariableSizeTrainCV",
+ inherit=ResamplingBase,
+ public = list(
+ initialize = function() {
+ ps = paradox::ps(
+ folds = paradox::p_int(2L, tags = "required"),
+ min_train_data=paradox::p_int(1L, tags = "required"),
+ random_seeds=paradox::p_int(1L, tags = "required"),
+ train_sizes = paradox::p_int(2L, tags = "required"))
+ ps$values = list(
+ folds = 3L,
+ min_train_data=10L,
+ random_seeds=3L,
+ train_sizes=5L)
+ super$initialize(
+ id = "variable_size_train_cv",
+ param_set = ps,
+ label = "Cross-Validation with variable size train sets",
+ man = "ResamplingVariableSizeTrainCV")
+ },
+ instantiate = function(task) {
+ task = mlr3::assert_task(mlr3::as_task(task))
+ reserved.names <- c(
+ "row_id", "fold", "group", "display_row",
+ "train.groups", "test.fold", "test.group", "iteration",
+ "test", "train", "algorithm", "uhash", "nr", "task", "task_id",
+ "learner", "learner_id", "resampling", "resampling_id",
+ "prediction")
+ ## bad.names <- group.name.vec[group.name.vec %in% reserved.names]
+ ## if(length(bad.names)){
+ ## first.bad <- bad.names[1]
+ ## stop(sprintf("col with role group must not be named %s; please fix by renaming %s col", first.bad, first.bad))
+ ## }
+ ## orig.group.dt <- task$data(cols=group.name.vec)
+ strata <- if(is.null(task$strata)){
+ data.dt <- task$data()
+ data.table(N=nrow(data.dt), row_id=list(1:nrow(data.dt)))
+ }else task$strata
+ folds = private$.combine(
+ lapply(strata$row_id, private$.sample, task = task)
+ )[order(row_id)]
+ min_train_data <- self$param_set$values[["min_train_data"]]
+ if(task$nrow <= min_train_data){
+ stop(sprintf(
+ "task$nrow=%d but should be larger than min_train_data=%d",
+ task$nrow, min_train_data))
+ }
+ uniq.folds <- sort(unique(folds$fold))
+ iteration.dt.list <- list()
+ for(test.fold in uniq.folds){
+ is.set.fold <- list(
+ test=folds[["fold"]] == test.fold)
+ is.set.fold[["train"]] <- !is.set.fold[["test"]]
+ i.set.list <- lapply(is.set.fold, which)
+ max_train_data <- length(i.set.list$train)
+ log.range.data <- log(c(min_train_data, max_train_data))
+ seq.args <- c(as.list(log.range.data), list(l=self$param_set$values[["train_sizes"]]))
+ log.train.sizes <- do.call(seq, seq.args)
+ train.size.vec <- unique(as.integer(round(exp(log.train.sizes))))
+ for(seed in 1:self$param_set$values[["random_seeds"]]){
+ set.seed(seed)
+ ord.i.vec <- sample(i.set.list$train)
+ iteration.dt.list[[paste(test.fold, seed)]] <- data.table(
+ test.fold,
+ seed,
+ train_size=train.size.vec,
+ train=lapply(train.size.vec, function(last)ord.i.vec[1:last]),
+ test=list(i.set.list$test))
+ }
+ }
+ self$instance <- list(
+ iteration.dt=rbindlist(iteration.dt.list)[, iteration := .I][],
+ id.dt=folds)
+ self$task_hash = task$hash
+ self$task_nrow = task$nrow
+ invisible(self)
+ }
+ )
+)
diff --git a/R/score.R b/R/score.R
new file mode 100644
index 0000000..c739e91
--- /dev/null
+++ b/R/score.R
@@ -0,0 +1,14 @@
+score <- function(bench.result, ...){
+ algorithm <- learner_id <- NULL
+ ## Above to avoid CRAN NOTE.
+ bench.score <- bench.result$score(...)
+ out.dt.list <- list()
+ for(score.i in 1:nrow(bench.score)){
+ bench.row <- bench.score[score.i]
+ it.dt <- bench.row$resampling[[1]]$instance$iteration.dt
+ out.dt.list[[score.i]] <- it.dt[
+ bench.row, on="iteration"
+ ][, algorithm := sub(".*[.]", "", learner_id)]
+ }
+ rbindlist(out.dt.list)
+}
diff --git a/man/ResamplingSameOther.Rd b/man/ResamplingSameOtherCV.Rd
similarity index 89%
rename from man/ResamplingSameOther.Rd
rename to man/ResamplingSameOtherCV.Rd
index 1450f84..192e543 100644
--- a/man/ResamplingSameOther.Rd
+++ b/man/ResamplingSameOtherCV.Rd
@@ -1,22 +1,20 @@
-\name{ResamplingSameOther}
-\alias{ResamplingSameOther}
+\name{ResamplingSameOtherCV}
\alias{ResamplingSameOtherCV}
\title{Resampling for comparing training on same or other groups}
\description{
-\code{\link{ResamplingSameOther}} is the abstract base class for
-\code{\link{ResamplingSameOtherCV}},
-which defines how a task is partitioned for
-resampling, for example in
-\code{\link[mlr3:resample]{resample()}} or
-\code{\link[mlr3:benchmark]{benchmark()}}.
-
-Resampling objects can be instantiated on a
-\code{\link[mlr3:Task]{Task}},
-which should define at least one group variable.
-
-After instantiation, sets can be accessed via
-\verb{$train_set(i)} and
-\verb{$test_set(i)}, respectively.
+ \code{\link{ResamplingSameOtherCV}}
+ defines how a task is partitioned for
+ resampling, for example in
+ \code{\link[mlr3:resample]{resample()}} or
+ \code{\link[mlr3:benchmark]{benchmark()}}.
+
+ Resampling objects can be instantiated on a
+ \code{\link[mlr3:Task]{Task}},
+ which should define at least one group variable.
+
+ After instantiation, sets can be accessed via
+ \verb{$train_set(i)} and
+ \verb{$test_set(i)}, respectively.
}
\details{
A supervised learning algorithm inputs a train set, and outputs a
@@ -50,7 +48,7 @@ each combination of the values of the stratification variables forms a stratum.
The grouping variable is assumed to be discrete,
and must be stored in the \link{Task} with column role \code{"group"}.
-Then number of cross-validation folds K should be defined as the
+The number of cross-validation folds K should be defined as the
\code{fold} parameter.
In each group, there will be about an equal number of observations
diff --git a/man/ResamplingVariableSizeTrainCV.Rd b/man/ResamplingVariableSizeTrainCV.Rd
new file mode 100644
index 0000000..4f56325
--- /dev/null
+++ b/man/ResamplingVariableSizeTrainCV.Rd
@@ -0,0 +1,155 @@
+\name{ResamplingVariableSizeTrainCV}
+\alias{ResamplingVariableSizeTrainCV}
+\title{Resampling for comparing training on same or other groups}
+\description{
+ \code{\link{ResamplingVariableSizeTrainCV}}
+ defines how a task is partitioned for
+ resampling, for example in
+ \code{\link[mlr3:resample]{resample()}} or
+ \code{\link[mlr3:benchmark]{benchmark()}}.
+
+ Resampling objects can be instantiated on a
+ \code{\link[mlr3:Task]{Task}}.
+
+ After instantiation, sets can be accessed via
+ \verb{$train_set(i)} and
+ \verb{$test_set(i)}, respectively.
+}
+\details{
+ A supervised learning algorithm inputs a train set, and outputs a
+ prediction function, which can be used on a test set.
+ How many train samples are required to get accurate predictions on a
+ test set? Cross-validation can be used to answer this question, with
+ variable size train sets.
+}
+\section{Stratification}{
+ \code{\link{ResamplingVariableSizeTrainCV}} supports stratified sampling.
+ The stratification variables are assumed to be discrete,
+ and must be stored in the \link{Task} with column role \code{"stratum"}.
+ In case of multiple stratification variables,
+ each combination of the values of the stratification variables forms a stratum.
+}
+
+\section{Grouping}{
+ \code{\link{ResamplingVariableSizeTrainCV}}
+ does not support grouping of observations.
+}
+
+\section{Hyper-parameters}{
+
+ The number of cross-validation folds should be defined as the
+ \code{fold} parameter.
+
+ For each fold ID, the corresponding observations are considered the
+ test set, and a variable number of other observations are considered
+ the train set.
+
+ The \code{random_seeds} parameter controls the number of random
+ orderings of the train set that are considered.
+
+ For each random order of the train set, the \code{min_train_data}
+ parameter controls the smallest train set size considered.
+
+ To determine the other train set sizes, we use an equally spaced grid
+ on the log scale, from \code{min_train_data} to the largest train set
+ size (all data not in test set). The
+ number of train set sizes in this grid is determined by the
+ \code{train_sizes} parameter.
+
+}
+
+\examples{
+(var_sizes <- mlr3resampling::ResamplingVariableSizeTrainCV$new())
+}
+\concept{Resampling}
+\section{Methods}{
+\subsection{Public methods}{
+\itemize{
+\item \href{#method-Resampling-new}{\code{Resampling$new()}}
+\item \href{#method-Resampling-train_set}{\code{Resampling$train_set()}}
+\item \href{#method-Resampling-test_set}{\code{Resampling$test_set()}}
+}
+}
+
+\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-Resampling-new}{}}}
+\subsection{Method \code{new()}}{
+Creates a new instance of this \link[R6:R6Class]{R6} class.
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{Resampling$new(
+ id,
+ param_set = ps(),
+ duplicated_ids = FALSE,
+ label = NA_character_,
+ man = NA_character_
+)}\if{html}{\out{
}}
+}
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{id}}{(\code{character(1)})\cr
+Identifier for the new instance.}
+
+\item{\code{param_set}}{(\link[paradox:ParamSet]{paradox::ParamSet})\cr
+Set of hyperparameters.}
+
+\item{\code{duplicated_ids}}{(\code{logical(1)})\cr
+Set to \code{TRUE} if this resampling strategy may have duplicated row ids in a single training set or test set.
+
+Note that this object is typically constructed via a derived classes, e.g. \link{ResamplingCV} or \link{ResamplingHoldout}.}
+
+\item{\code{label}}{(\code{character(1)})\cr
+Label for the new instance.}
+
+\item{\code{man}}{(\code{character(1)})\cr
+String in the format \verb{[pkg]::[topic]} pointing to a manual page for this object.
+The referenced help package can be opened via method \verb{$help()}.}
+}
+\if{html}{\out{
}}
+}
+}
+
+\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-Resampling-train_set}{}}}
+\subsection{Method \code{train_set()}}{
+Returns the row ids of the i-th training set.
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{Resampling$train_set(i)}\if{html}{\out{
}}
+}
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{i}}{(\code{integer(1)})\cr
+Iteration.}
+}
+\if{html}{\out{
}}
+}
+\subsection{Returns}{
+(\code{integer()}) of row ids.
+}
+}
+
+\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-Resampling-test_set}{}}}
+\subsection{Method \code{test_set()}}{
+Returns the row ids of the i-th test set.
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{Resampling$test_set(i)}\if{html}{\out{
}}
+}
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{i}}{(\code{integer(1)})\cr
+Iteration.}
+}
+\if{html}{\out{
}}
+}
+\subsection{Returns}{
+(\code{integer()}) of row ids.
+}
+}
+
+}
diff --git a/tests/testthat/test-CRAN.R b/tests/testthat/test-CRAN.R
index dcb86fb..d6fcdc3 100644
--- a/tests/testthat/test-CRAN.R
+++ b/tests/testthat/test-CRAN.R
@@ -80,3 +80,41 @@ test_that("error for group named test", {
same_other$instantiate(itask)
}, "col with role group must not be named test; please fix by renaming test col")
})
+
+test_that("error for 10 data", {
+ size_cv <- mlr3resampling::ResamplingVariableSizeTrainCV$new()
+ i10.dt <- data.table(iris)[1:10]
+ i10.task <- mlr3::TaskClassif$new("i10", i10.dt, target="Species")
+ expect_error({
+ size_cv$instantiate(i10.task)
+ },
+ "task$nrow=10 but should be larger than min_train_data=10",
+ fixed=TRUE)
+})
+
+test_that("train set max size 67 for 100 data", {
+ size_cv <- mlr3resampling::ResamplingVariableSizeTrainCV$new()
+ i100.dt <- data.table(iris)[1:100]
+ i100.task <- mlr3::TaskClassif$new("i10", i100.dt, target="Species")
+ size_cv$instantiate(i100.task)
+ inst <- size_cv$instance
+ computed.counts <- inst$id.dt[, .(rows=.N), keyby=fold]
+ expected.counts <- data.table(
+ fold=1:3,
+ rows=as.integer(c(34,33,33)),
+ key="fold")
+ expect_equal(computed.counts, expected.counts)
+ l.train <- sapply(inst$iteration.dt$train, length)
+ expect_equal(l.train, inst$iteration.dt$train_size)
+ expect_equal(max(l.train), 67)
+})
+
+test_that("test fold 1 for iteration 1", {
+ set.seed(1)
+ size_cv <- mlr3resampling::ResamplingVariableSizeTrainCV$new()
+ i100.dt <- data.table(iris)[1:100]
+ i100.task <- mlr3::TaskClassif$new("i10", i100.dt, target="Species")
+ size_cv$instantiate(i100.task)
+ inst <- size_cv$instance
+ expect_equal(inst$iteration.dt$test.fold[1], 1)
+})
diff --git a/vignettes/Simulations.Rmd b/vignettes/ResamplingSameOtherCV.Rmd
similarity index 98%
rename from vignettes/Simulations.Rmd
rename to vignettes/ResamplingSameOtherCV.Rmd
index 97a032d..b91d518 100644
--- a/vignettes/Simulations.Rmd
+++ b/vignettes/ResamplingSameOtherCV.Rmd
@@ -1,8 +1,8 @@
---
-title: "Simulations"
+title: "Comparing training on same or other groups"
author: "Toby Dylan Hocking"
vignette: >
- %\VignetteIndexEntry{Simulations}
+ %\VignetteIndexEntry{Comparing training on same or other groups}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---
@@ -332,7 +332,7 @@ if(require(animint2)){
"Split number / cross-validation iteration")+
scale_y_continuous(
"Row number"),
- source="https://github.com/tdhock/mlr3resampling/blob/d7cbe06fdb562ff8c1c8b4ad0bc055b65bd7b2c6/vignettes/Simulations.Rmd")
+ source="https://github.com/tdhock/mlr3resampling/blob/main/vignettes/ResamplingSameOtherCV.Rmd")
viz
}
if(FALSE){
@@ -700,7 +700,7 @@ if(require(animint2)){
"Split number / cross-validation iteration")+
scale_y_continuous(
"Row number"),
- source="https://github.com/tdhock/mlr3resampling/blob/942048163960cfeb90483224f52aa8315b3f50f5/vignettes/Simulations.Rmd")
+ source="https://github.com/tdhock/mlr3resampling/blob/main/vignettes/ResamplingSameOtherCV.Rmd")
viz
}
if(FALSE){
diff --git a/vignettes/ResamplingVariableSizeTrainCV.Rmd b/vignettes/ResamplingVariableSizeTrainCV.Rmd
new file mode 100644
index 0000000..88b5a34
--- /dev/null
+++ b/vignettes/ResamplingVariableSizeTrainCV.Rmd
@@ -0,0 +1,719 @@
+---
+title: "Comparing train set sizes"
+author: "Toby Dylan Hocking"
+vignette: >
+ %\VignetteIndexEntry{Comparing train set sizes}
+ %\VignetteEngine{knitr::rmarkdown}
+ %\VignetteEncoding{UTF-8}
+---
+
+
+
+```{r setup, include = FALSE}
+knitr::opts_chunk$set(
+ fig.width=6,
+ fig.height=6)
+data.table::setDTthreads(1)
+## output: rmarkdown::html_vignette above creates html where figures are limited to 700px wide.
+## Above CSS from https://stackoverflow.com/questions/34906002/increase-width-of-entire-html-rmarkdown-output main-container is for html_document, body is for html_vignette
+knitr::opts_chunk$set(
+ collapse = TRUE,
+ comment = "#>"
+)
+```
+
+The goal of this vignette is to explain how to
+`ResamplingVariableSizeTrainCV`, which can be used to determine how
+many train data are necessary to provide accurate predictions on a
+given test set.
+
+## Simulated regression problems
+
+The code below creates data for simulated regression problems. First
+we define a vector of input values,
+
+```{r}
+N <- 300
+abs.x <- 10
+set.seed(1)
+x.vec <- runif(N, -abs.x, abs.x)
+str(x.vec)
+```
+
+Below we define a list of two true regression functions (tasks in mlr3
+terminology) for our simulated data,
+
+```{r}
+reg.pattern.list <- list(
+ sin=sin,
+ constant=function(x)0)
+```
+
+The constant function represents a regression problem which can be
+solved by always predicting the mean value of outputs (featureless is
+the best possible learning algorithm). The sin function will be used
+to generate data with a non-linear pattern that will need to be
+learned. Below we use a for loop over these two functions/tasks, to
+simulate the data which will be used as input to the learning
+algorithms:x
+
+```{r}
+library(data.table)
+reg.task.list <- list()
+reg.data.list <- list()
+for(task_id in names(reg.pattern.list)){
+ f <- reg.pattern.list[[task_id]]
+ task.dt <- data.table(
+ x=x.vec,
+ y = f(x.vec)+rnorm(N,sd=0.5))
+ reg.data.list[[task_id]] <- data.table(task_id, task.dt)
+ reg.task.list[[task_id]] <- mlr3::TaskRegr$new(
+ task_id, task.dt, target="y"
+ )
+}
+(reg.data <- rbindlist(reg.data.list))
+```
+
+In the table above, the input is x, and the output is y. Below we
+visualize these data, with one task in each facet/panel:
+
+```{r}
+if(require(animint2)){
+ ggplot()+
+ geom_point(aes(
+ x, y),
+ data=reg.data)+
+ facet_grid(task_id ~ ., labeller=label_both)
+}
+```
+
+In the plot above we can see two different simulated data sets
+(constant and sin). Note that the code above used the `animint2`
+package, which provides interactive extensions to the static graphics
+of the `ggplot2` package (see below section Interactive data viz).
+
+### Visualizing instance table
+
+In the code below, we define a K-fold cross-validation experiment,
+with K=3 folds.
+
+```{r}
+reg_size_cv <- mlr3resampling::ResamplingVariableSizeTrainCV$new()
+reg_size_cv$param_set$values$train_sizes <- 6
+reg_size_cv
+```
+
+In the output above we can see the parameters of the resampling
+object, all of which should be integer scalars:
+
+* `folds` is the number of cross-validation folds.
+* `min_train_data` is the minimum number of train data to consider.
+* `random_seeds` is the number of random seeds, each of which
+ determines a different random ordering of the train data. The random
+ ordering determines which data are included in small train set
+ sizes.
+* `train_sizes` is the number of train set sizes, evenly spaced on a
+ log scale, from `min_train_data` to the max number of train data
+ (determined by `folds`).
+
+Below we instantiate the resampling on one of the tasks:
+
+```{r}
+reg_size_cv$instantiate(reg.task.list[["sin"]])
+reg_size_cv$instance
+```
+
+Above we see the instance, which need not be examined by the user, but
+for informational purposes, it contains the following data:
+
+* `iteration.dt` has one row for each train/test split,
+* `id.dt` has one row for each data point.
+
+### Benchmark: computing test error
+
+In the code below, we define two learners to compare,
+
+```{r}
+(reg.learner.list <- list(
+ if(requireNamespace("rpart"))mlr3::LearnerRegrRpart$new(),
+ mlr3::LearnerRegrFeatureless$new()))
+```
+
+The code above defines
+
+* `regr.rpart`: Regression Tree learning algorithm, which should be
+ able to learn the non-linear pattern in the sin data (if there are
+ enough data in the train set).
+* `regr.featureless`: Featureless Regression learning algorithm, which
+ should be optimal for the constant data, and can be used as a
+ baseline in the sin data. When the rpart learner gets smaller
+ prediction error rates than featureless, then we know that it has
+ learned some non-trivial relationship between inputs and outputs.
+
+In the code below, we define the benchmark grid, which is all
+combinations of tasks (constant and sin), learners (rpart and
+featureless), and the one resampling method.
+
+```{r}
+(reg.bench.grid <- mlr3::benchmark_grid(
+ reg.task.list,
+ reg.learner.list,
+ reg_size_cv))
+```
+
+In the code below, we execute the benchmark experiment (optionally in parallel
+using the multisession future plan).
+
+```{r}
+if(FALSE){
+ if(require(future))plan("multisession")
+}
+(reg.bench.result <- mlr3::benchmark(
+ reg.bench.grid, store_models = TRUE))
+```
+
+The code below computes the test error for each split, and visualizes
+the information stored in the first row of the result:
+
+```{r}
+reg.bench.score <- mlr3resampling::score(reg.bench.result)
+reg.bench.score[1]
+```
+
+The output above contains all of the results related to a particular
+train/test split. In particular for our purposes, the interesting
+columns are:
+
+* `test.fold` is the cross-validation fold ID.
+* `seed` is the random seed used to determine the train set order.
+* `train_size` is the number of data in the train set.
+* `train` and `test` are vectors of row numbers assigned to each set.
+* `iteration` is an ID for the train/test split, for a particular
+ learning algorithm and task. It is the row number of `iteration.dt`
+ (see instance above), which has one row for each unique combination
+ of `test.fold`, `seed`, and `train_size`.
+* `learner` is the mlr3 learner object, which can be used to compute
+ predictions on new data (including a grid of inputs, to show
+ predictions in the visualization below).
+* `regr.mse` is the mean squared error on the test set.
+* `algorithm` is the name of the learning algorithm (same as
+ `learner_id` but without `regr.` prefix).
+
+The code below visualizes the resulting test accuracy numbers.
+
+```{r}
+train_size_vec <- unique(reg.bench.score$train_size)
+if(require(animint2)){
+ ggplot()+
+ scale_x_log10(
+ breaks=train_size_vec)+
+ scale_y_log10()+
+ geom_line(aes(
+ train_size, regr.mse,
+ group=paste(algorithm, seed),
+ color=algorithm),
+ shape=1,
+ data=reg.bench.score)+
+ geom_point(aes(
+ train_size, regr.mse, color=algorithm),
+ shape=1,
+ data=reg.bench.score)+
+ facet_grid(
+ test.fold~task_id,
+ labeller=label_both,
+ scales="free")
+}
+```
+
+Above we plot the test error for each fold and train set size.
+There is a different panel for each task and test fold.
+Each line represents a random seed (ordering of data in train set),
+and each dot represents a specific train set size.
+So the plot above shows that some variation in test error, for a given test fold,
+is due to the random ordering of the train data.
+
+Below we summarize each train set size, by taking the mean and standard deviation over each random seed.
+
+```{r}
+reg.mean.dt <- dcast(
+ reg.bench.score,
+ task_id + train_size + test.fold + algorithm ~ .,
+ list(mean, sd),
+ value.var="regr.mse")
+if(require(animint2)){
+ ggplot()+
+ scale_x_log10(
+ breaks=train_size_vec)+
+ scale_y_log10()+
+ geom_ribbon(aes(
+ train_size,
+ ymin=regr.mse_mean-regr.mse_sd,
+ ymax=regr.mse_mean+regr.mse_sd,
+ fill=algorithm),
+ alpha=0.5,
+ data=reg.mean.dt)+
+ geom_line(aes(
+ train_size, regr.mse_mean, color=algorithm),
+ shape=1,
+ data=reg.mean.dt)+
+ facet_grid(
+ test.fold~task_id,
+ labeller=label_both,
+ scales="free")
+}
+```
+
+The plot above shows a line for the mean,
+and a ribbon for the standard deviation,
+over the three random seeds.
+It is clear from the plot above that
+
+* in constant task, the featureless always has smaller or equal
+ prediction error rates than rpart, which indicates that rpart
+ sometimes overfits for large sample sizes.
+* in sin task, more than 30 samples are required for rpart to be more
+ accurate than featureless, which indicates it has learned a
+ non-trivial relationship between input and output.
+
+### Interactive data viz
+
+The code below can be used to create an interactive data visualization
+which allows exploring how different functions are learned during
+different splits.
+
+```{r ResamplingVariableSizeTrainCVAnimintRegression}
+grid.dt <- data.table(x=seq(-abs.x, abs.x, l=101), y=0)
+grid.task <- mlr3::TaskRegr$new("grid", grid.dt, target="y")
+pred.dt.list <- list()
+point.dt.list <- list()
+for(score.i in 1:nrow(reg.bench.score)){
+ reg.bench.row <- reg.bench.score[score.i]
+ task.dt <- data.table(
+ reg.bench.row$task[[1]]$data(),
+ reg.bench.row$resampling[[1]]$instance$id.dt)
+ set.ids <- data.table(
+ set.name=c("test","train")
+ )[
+ , data.table(row_id=reg.bench.row[[set.name]][[1]])
+ , by=set.name]
+ i.points <- set.ids[
+ task.dt, on="row_id"
+ ][
+ is.na(set.name), set.name := "unused"
+ ]
+ point.dt.list[[score.i]] <- data.table(
+ reg.bench.row[, .(task_id, iteration)],
+ i.points)
+ i.learner <- reg.bench.row$learner[[1]]
+ pred.dt.list[[score.i]] <- data.table(
+ reg.bench.row[, .(
+ task_id, iteration, algorithm
+ )],
+ as.data.table(
+ i.learner$predict(grid.task)
+ )[, .(x=grid.dt$x, y=response)]
+ )
+}
+(pred.dt <- rbindlist(pred.dt.list))
+(point.dt <- rbindlist(point.dt.list))
+set.colors <- c(
+ train="#1B9E77",
+ test="#D95F02",
+ unused="white")
+algo.colors <- c(
+ featureless="blue",
+ rpart="red")
+if(require(animint2)){
+ viz <- animint(
+ title="Variable size train set, regression",
+ pred=ggplot()+
+ ggtitle("Predictions for selected train/test split")+
+ theme_animint(height=400)+
+ scale_fill_manual(values=set.colors)+
+ geom_point(aes(
+ x, y, fill=set.name),
+ showSelected="iteration",
+ size=3,
+ shape=21,
+ data=point.dt)+
+ scale_size_manual(values=c(
+ featureless=3,
+ rpart=2))+
+ scale_color_manual(values=algo.colors)+
+ geom_line(aes(
+ x, y,
+ color=algorithm,
+ size=algorithm,
+ group=paste(algorithm, iteration)),
+ showSelected="iteration",
+ data=pred.dt)+
+ facet_grid(
+ task_id ~ .,
+ labeller=label_both),
+ err=ggplot()+
+ ggtitle("Test error for each split")+
+ theme_animint(width=500)+
+ theme(
+ panel.margin=grid::unit(1, "lines"),
+ legend.position="none")+
+ scale_y_log10(
+ "Mean squared error on test set")+
+ scale_color_manual(values=algo.colors)+
+ scale_x_log10(
+ "Train set size",
+ breaks=train_size_vec)+
+ geom_line(aes(
+ train_size, regr.mse,
+ group=paste(algorithm, seed),
+ color=algorithm),
+ clickSelects="seed",
+ alpha_off=0.2,
+ showSelected="algorithm",
+ size=4,
+ data=reg.bench.score)+
+ facet_grid(
+ test.fold~task_id,
+ labeller=label_both,
+ scales="free")+
+ geom_point(aes(
+ train_size, regr.mse,
+ color=algorithm),
+ size=5,
+ stroke=3,
+ fill="black",
+ fill_off=NA,
+ showSelected=c("algorithm","seed"),
+ clickSelects="iteration",
+ data=reg.bench.score),
+ source="https://github.com/tdhock/mlr3resampling/blob/main/vignettes/Simulations.Rmd")
+ viz
+}
+if(FALSE){
+ animint2pages(viz, "2023-12-26-train-sizes-regression")
+}
+```
+
+If you are viewing this in an installed package or on CRAN,
+then there will be no data viz on this page,
+but you can view it on:
+
+
+The interactive data viz consists of two plots:
+
+* The first plot shows the data, with each point colored according to
+ the set it was assigned, in the currently selected
+ split/iteration. The red/blue lines additionally show the learned
+ prediction functions for the currently selected split/iteration.
+* The second plot shows the test error rates, as a function of train
+ set size. Clicking a line selects the corresponding random seed,
+ which makes the corresponding points on that line appear. Clicking a
+ point selects the corresponding iteration (seed, test fold, and train
+ set size).
+
+## Simulated classification problems
+
+Whereas in the section above, we focused on regression (output is a real number),
+in this section we simulate a binary classification problem (output if a factor with two levels).
+
+```{r}
+class.N <- 300
+class.abs.x <- 1
+rclass <- function(){
+ runif(class.N, -class.abs.x, class.abs.x)
+}
+library(data.table)
+set.seed(1)
+class.x.dt <- data.table(x1=rclass(), x2=rclass())
+class.fun.list <- list(
+ constant=function(...)0.5,
+ xor=function(x1, x2)xor(x1>0, x2>0))
+class.data.list <- list()
+class.task.list <- list()
+for(task_id in names(class.fun.list)){
+ class.fun <- class.fun.list[[task_id]]
+ y <- factor(ifelse(
+ class.x.dt[, class.fun(x1, x2)+rnorm(class.N, sd=0.5)]>0.5,
+ "spam", "not"))
+ task.dt <- data.table(class.x.dt, y)
+ class.task.list[[task_id]] <- mlr3::TaskClassif$new(
+ task_id, task.dt, target="y"
+ )$set_col_roles("y",c("target","stratum"))
+ class.data.list[[task_id]] <- data.table(task_id, task.dt)
+}
+(class.data <- rbindlist(class.data.list))
+```
+
+The simulated data table above consists of two input features (`x1`
+and `x2`) along with an output/label to predict (`y`). Below we count
+the number of times each label appears in each task:
+
+```{r}
+class.data[, .(count=.N), by=.(task_id, y)]
+```
+
+The table above shows that the `spam` label is the minority class
+(`not` is majority, so that will be the prediction of the featureless
+baseline). Below we visualize the data in the feature space:
+
+```{r}
+if(require(animint2)){
+ ggplot()+
+ geom_point(aes(
+ x1, x2, color=y),
+ shape=1,
+ data=class.data)+
+ facet_grid(. ~ task_id, labeller=label_both)+
+ coord_equal()
+}
+```
+
+The plot above shows how the output `y` is related to the two inputs `x1` and
+`x2`, for the two tasks.
+
+* For the constant task, the two inputs are not related to the output.
+* For the xor task, the spam label is associated with either `x1` or
+ `x2` being negative (but not both).
+
+In the mlr3 code below, we define a list of learners, our resampling
+method, and a benchmark grid:
+
+```{r}
+class.learner.list <- list(
+ if(requireNamespace("rpart"))mlr3::LearnerClassifRpart$new(),
+ mlr3::LearnerClassifFeatureless$new())
+size_cv <- mlr3resampling::ResamplingVariableSizeTrainCV$new()
+(class.bench.grid <- mlr3::benchmark_grid(
+ class.task.list,
+ class.learner.list,
+ size_cv))
+```
+
+Below we run the learning algorithm for each of the train/test splits
+defined by our benchmark grid:
+
+```{r}
+if(FALSE){
+ if(require(future))plan("multisession")
+}
+(class.bench.result <- mlr3::benchmark(
+ class.bench.grid, store_models = TRUE))
+```
+
+Below we compute scores (test error) for each resampling iteration,
+and show the first row of the result.
+
+```{r}
+class.bench.score <- mlr3resampling::score(class.bench.result)
+class.bench.score[1]
+```
+
+The output above has columns which are very similar to the regression
+example in the previous section. The main difference is the
+`classif.ce` column, which is the classification error on the test
+set.
+
+Finally we plot the test error values below.
+
+```{r}
+if(require(animint2)){
+ ggplot()+
+ geom_line(aes(
+ train_size, classif.ce,
+ group=paste(algorithm, seed),
+ color=algorithm),
+ shape=1,
+ data=class.bench.score)+
+ geom_point(aes(
+ train_size, classif.ce, color=algorithm),
+ shape=1,
+ data=class.bench.score)+
+ facet_grid(
+ task_id ~ test.fold,
+ labeller=label_both,
+ scales="free")+
+ scale_x_log10()
+}
+```
+
+It is clear from the plot above that
+
+* in constant task, rpart does not have significantly lower error
+ rates than featureless, which is expected, because the best
+ prediction function is constant (predict the most frequent class, no
+ relationship between inputs and output).
+* in xor task, more than 30 samples are required for rpart to be more
+ accurate than featureless, which indicates it has learned a
+ non-trivial relationship between inputs and output.
+
+Exercise for the reader: compute and plot mean and SD for these
+classification tasks, similar to the plot for the regression tasks in
+the previous section.
+
+### Interactive visualization of data, test error, and splits
+
+The code below can be used to create an interactive data visualization
+which allows exploring how different functions are learned during
+different splits.
+
+```{r ResamplingVariableSizeTrainCVAnimintClassification}
+class.grid.vec <- seq(-class.abs.x, class.abs.x, l=21)
+class.grid.dt <- CJ(x1=class.grid.vec, x2=class.grid.vec)
+class.pred.dt.list <- list()
+class.point.dt.list <- list()
+for(score.i in 1:nrow(class.bench.score)){
+ class.bench.row <- class.bench.score[score.i]
+ task.dt <- data.table(
+ class.bench.row$task[[1]]$data(),
+ class.bench.row$resampling[[1]]$instance$id.dt)
+ set.ids <- data.table(
+ set.name=c("test","train")
+ )[
+ , data.table(row_id=class.bench.row[[set.name]][[1]])
+ , by=set.name]
+ i.points <- set.ids[
+ task.dt, on="row_id"
+ ][
+ is.na(set.name), set.name := "unused"
+ ][]
+ class.point.dt.list[[score.i]] <- data.table(
+ class.bench.row[, .(task_id, iteration)],
+ i.points)
+ if(class.bench.row$algorithm!="featureless"){
+ i.learner <- class.bench.row$learner[[1]]
+ i.learner$predict_type <- "prob"
+ i.task <- class.bench.row$task[[1]]
+ grid.class.task <- mlr3::TaskClassif$new(
+ "grid", class.grid.dt[, label:=factor(NA,levels(task.dt$y))], target="label")
+ pred.grid <- as.data.table(
+ i.learner$predict(grid.class.task)
+ )[, data.table(class.grid.dt, prob.spam)]
+ pred.wide <- dcast(pred.grid, x1 ~ x2, value.var="prob.spam")
+ prob.mat <- as.matrix(pred.wide[,-1])
+ if(length(table(prob.mat))>1){
+ contour.list <- contourLines(
+ class.grid.vec, class.grid.vec, prob.mat, levels=0.5)
+ class.pred.dt.list[[score.i]] <- data.table(
+ class.bench.row[, .(
+ task_id, iteration, algorithm
+ )],
+ data.table(contour.i=seq_along(contour.list))[, {
+ do.call(data.table, contour.list[[contour.i]])[, .(level, x1=x, x2=y)]
+ }, by=contour.i]
+ )
+ }
+ }
+}
+(class.pred.dt <- rbindlist(class.pred.dt.list))
+(class.point.dt <- rbindlist(class.point.dt.list))
+
+set.colors <- c(
+ train="#1B9E77",
+ test="#D95F02",
+ unused="white")
+algo.colors <- c(
+ featureless="blue",
+ rpart="red")
+if(require(animint2)){
+ viz <- animint(
+ title="Variable size train sets, classification",
+ pred=ggplot()+
+ ggtitle("Predictions for selected train/test split")+
+ theme(panel.margin=grid::unit(1, "lines"))+
+ theme_animint(width=600)+
+ coord_equal()+
+ scale_fill_manual(values=set.colors)+
+ scale_color_manual(values=c(spam="black","not spam"="white"))+
+ geom_point(aes(
+ x1, x2, color=y, fill=set.name),
+ showSelected="iteration",
+ size=3,
+ stroke=2,
+ shape=21,
+ data=class.point.dt)+
+ geom_path(aes(
+ x1, x2,
+ group=paste(algorithm, iteration, contour.i)),
+ showSelected=c("iteration","algorithm"),
+ color=algo.colors[["rpart"]],
+ data=class.pred.dt)+
+ facet_grid(
+ . ~ task_id,
+ labeller=label_both,
+ space="free",
+ scales="free"),
+ err=ggplot()+
+ ggtitle("Test error for each split")+
+ theme_animint(height=400)+
+ theme(panel.margin=grid::unit(1, "lines"))+
+ scale_y_continuous(
+ "Classification error on test set")+
+ scale_color_manual(values=algo.colors)+
+ scale_x_log10(
+ "Train set size")+
+ geom_line(aes(
+ train_size, classif.ce,
+ group=paste(algorithm, seed),
+ color=algorithm),
+ clickSelects="seed",
+ alpha_off=0.2,
+ showSelected="algorithm",
+ size=4,
+ data=class.bench.score)+
+ facet_grid(
+ test.fold~task_id,
+ labeller=label_both,
+ scales="free")+
+ geom_point(aes(
+ train_size, classif.ce,
+ color=algorithm),
+ size=5,
+ stroke=3,
+ fill="black",
+ fill_off=NA,
+ showSelected=c("algorithm","seed"),
+ clickSelects="iteration",
+ data=class.bench.score),
+ source="https://github.com/tdhock/mlr3resampling/blob/main/vignettes/ResamplingVariableSizeTrainCV.Rmd")
+ viz
+}
+if(FALSE){
+ animint2pages(viz, "2023-12-27-train-sizes-classification")
+}
+```
+
+If you are viewing this in an installed package or on CRAN,
+then there will be no data viz on this page,
+but you can view it on:
+
+
+The interactive data viz consists of two plots
+
+* The first plot shows the data, with each point colored according to
+ its label/y value (black outline for spam, white outline for not),
+ and the set it was assigned (fill color) in the currently selected
+ split/iteration. The red lines additionally show the learned
+ decision boundary for rpart, given the currently selected
+ split/iteration. For constant, the ideal decision boundary is none
+ (always predict the most frequent class), and for xor, the ideal
+ decision boundary looks like a plus sign.
+* The second plot shows the test error rates, as a function of train
+ set size. Clicking a line selects the corresponding random seed,
+ which makes the corresponding points on that line appear. Clicking a
+ point selects the corresponding iteration (seed, test fold, and train
+ set size).
+
+## Conclusion
+
+In this vignette we have shown how to use mlr3resampling for comparing
+test error of models trained on different sized train sets.
+
+## Session info
+
+```{r}
+sessionInfo()
+```