Skip to content

Commit

Permalink
Merge pull request #3 from tdhock/variable_size_train
Browse files Browse the repository at this point in the history
ResamplingVariableSizeCV
  • Loading branch information
tdhock authored Dec 31, 2023
2 parents 8b23667 + e48be1c commit 0f33224
Show file tree
Hide file tree
Showing 12 changed files with 1,131 additions and 129 deletions.
5 changes: 4 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: mlr3resampling
Type: Package
Title: Resampling Algorithms for 'mlr3' Framework
Version: 2023.12.23
Version: 2023.12.28
Authors@R: c(
person("Toby", "Hocking",
email="[email protected]",
Expand Down Expand Up @@ -52,6 +52,9 @@ Description: A supervised learning algorithm inputs a train set,
For more information,
<https://tdhock.github.io/blog/2023/R-gen-new-subsets/>
describes the method in depth.
How many train samples are required to get accurate predictions on a
test set? Cross-validation can be used to answer this question, with
variable size train sets.
License: GPL-3
URL: https://github.com/tdhock/mlr3resampling
BugReports: https://github.com/tdhock/mlr3resampling/issues
Expand Down
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import(R6, checkmate, data.table, mlr3, mlr3misc, paradox)
export(ResamplingSameOtherCV, score)
export(ResamplingSameOtherCV, score, ResamplingVariableSizeTrainCV)

5 changes: 5 additions & 0 deletions NEWS
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
Changes in version 2023.12.28

- Rename Simulations vignette to ResamplingSameOtherCV.
- New ResamplingVariableSizeTrainCV class and vignette.

Changes in version 2023.12.23

- To get data set names in Simulations vignette, use task data names instead of learner$state$data_prototype.
Expand Down
83 changes: 83 additions & 0 deletions R/ResamplingBase.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
ResamplingBase = R6::R6Class(
"Resampling",
public = list(
id = NULL,
label = NULL,
param_set = NULL,
instance = NULL,
task_hash = NA_character_,
task_nrow = NA_integer_,
duplicated_ids = NULL,
man = NULL,
initialize = function(id, param_set = ps(), duplicated_ids = FALSE, label = NA_character_, man = NA_character_) {
self$id = checkmate::assert_string(id, min.chars = 1L)
self$label = checkmate::assert_string(label, na.ok = TRUE)
self$param_set = paradox::assert_param_set(param_set)
self$duplicated_ids = checkmate::assert_flag(duplicated_ids)
self$man = checkmate::assert_string(man, na.ok = TRUE)
},
format = function(...) {
sprintf("<%s>", class(self)[1L])
},
print = function(...) {
cat(
format(self),
if (is.null(self$label) || is.na(self$label))
"" else paste0(": ", self$label)
)
cat("\n* Iterations:", self$iters)
cat("\n* Instantiated:", self$is_instantiated)
cat("\n* Parameters:\n")
str(self$param_set$values)
},
help = function() {
self$man
},
train_set = function(i) {
self$instance$iteration.dt$train[[i]]
},
test_set = function(i) {
self$instance$iteration.dt$test[[i]]
}
),
active = list(
iters = function(rhs) {
nrow(self$instance$iteration.dt)
},
is_instantiated = function(rhs) {
!is.null(self$instance)
},
hash = function(rhs) {
if (!self$is_instantiated) {
return(NA_character_)
}
mlr3misc::calculate_hash(list(
class(self),
self$id,
self$param_set$values,
self$instance))
}
),
private = list(
.sample = function(ids, ...) {
data.table(
row_id = ids,
fold = sample(
seq(0, length(ids)-1) %%
as.integer(self$param_set$values$folds) + 1L
),
key = "fold"
)
},
.combine = function(instances) {
rbindlist(instances, use.names = TRUE)
},
deep_clone = function(name, value) {
switch(name,
"instance" = copy(value),
"param_set" = value$clone(deep = TRUE),
value
)
}
)
)
119 changes: 13 additions & 106 deletions R/ResamplingSameOther.R → R/ResamplingSameOtherCV.R
Original file line number Diff line number Diff line change
@@ -1,33 +1,17 @@
ResamplingSameOther = R6::R6Class(
"Resampling",
ResamplingSameOtherCV = R6::R6Class(
"ResamplingSameOtherCV",
inherit=ResamplingBase,
public = list(
id = NULL,
label = NULL,
param_set = NULL,
instance = NULL,
task_hash = NA_character_,
task_nrow = NA_integer_,
duplicated_ids = NULL,
man = NULL,
initialize = function(id, param_set = ps(), duplicated_ids = FALSE, label = NA_character_, man = NA_character_) {
self$id = checkmate::assert_string(id, min.chars = 1L)
self$label = checkmate::assert_string(label, na.ok = TRUE)
self$param_set = paradox::assert_param_set(param_set)
self$duplicated_ids = checkmate::assert_flag(duplicated_ids)
self$man = checkmate::assert_string(man, na.ok = TRUE)
},
format = function(...) {
sprintf("<%s>", class(self)[1L])
},
print = function(...) {
cat(format(self), if (is.null(self$label) || is.na(self$label)) "" else paste0(": ", self$label))
cat("\n* Iterations:", self$iters)
cat("\n* Instantiated:", self$is_instantiated)
cat("\n* Parameters:\n")
str(self$param_set$values)
},
help = function() {
self$man
initialize = function() {
ps = paradox::ps(
folds = paradox::p_int(2L, tags = "required")
)
ps$values = list(folds = 3L)
super$initialize(
id = "same_other_cv",
param_set = ps,
label = "Same versus Other Cross-Validation",
man = "ResamplingSameOtherCV")
},
instantiate = function(task) {
task = mlr3::assert_task(mlr3::as_task(task))
Expand Down Expand Up @@ -125,83 +109,6 @@ ResamplingSameOther = R6::R6Class(
self$task_hash = task$hash
self$task_nrow = task$nrow
invisible(self)
},
train_set = function(i) {
self$instance$iteration.dt$train[[i]]
},
test_set = function(i) {
self$instance$iteration.dt$test[[i]]
}
),
active = list(
is_instantiated = function(rhs) {
!is.null(self$instance)
},
hash = function(rhs) {
if (!self$is_instantiated) {
return(NA_character_)
}
mlr3misc::calculate_hash(list(class(self), self$id, self$param_set$values, self$instance))
}
)
)

ResamplingSameOtherCV = R6::R6Class(
"ResamplingSameOtherCV",
inherit = ResamplingSameOther,
public = list(
initialize = function() {
ps = paradox::ps(
folds = paradox::p_int(2L, tags = "required")
)
ps$values = list(folds = 3L)
super$initialize(
id = "same_other_cv",
param_set = ps,
label = "Cross-Validation",
man = "ResamplingSameOtherCV")
}
),
active = list(
iters = function(rhs) {
nrow(self$instance$iteration.dt)
}
),
private = list(
.sample = function(ids, ...) {
data.table(
row_id = ids,
fold = sample(
seq(0, length(ids)-1) %%
as.integer(self$param_set$values$folds) + 1L
),
key = "fold"
)
},
.combine = function(instances) {
rbindlist(instances, use.names = TRUE)
},
deep_clone = function(name, value) {
switch(name,
"instance" = copy(value),
"param_set" = value$clone(deep = TRUE),
value
)
}
)
)

score <- function(bench.result, ...){
algorithm <- learner_id <- NULL
## Above to avoid CRAN NOTE.
bench.score <- bench.result$score(...)
out.dt.list <- list()
for(score.i in 1:nrow(bench.score)){
bench.row <- bench.score[score.i]
it.dt <- bench.row$resampling[[1]]$instance$iteration.dt
out.dt.list[[score.i]] <- it.dt[
bench.row, on="iteration"
][, algorithm := sub(".*[.]", "", learner_id)]
}
rbindlist(out.dt.list)
}
80 changes: 80 additions & 0 deletions R/ResamplingVariableSizeTrainCV.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
ResamplingVariableSizeTrainCV = R6::R6Class(
"ResamplingVariableSizeTrainCV",
inherit=ResamplingBase,
public = list(
initialize = function() {
ps = paradox::ps(
folds = paradox::p_int(2L, tags = "required"),
min_train_data=paradox::p_int(1L, tags = "required"),
random_seeds=paradox::p_int(1L, tags = "required"),
train_sizes = paradox::p_int(2L, tags = "required"))
ps$values = list(
folds = 3L,
min_train_data=10L,
random_seeds=3L,
train_sizes=5L)
super$initialize(
id = "variable_size_train_cv",
param_set = ps,
label = "Cross-Validation with variable size train sets",
man = "ResamplingVariableSizeTrainCV")
},
instantiate = function(task) {
task = mlr3::assert_task(mlr3::as_task(task))
reserved.names <- c(
"row_id", "fold", "group", "display_row",
"train.groups", "test.fold", "test.group", "iteration",
"test", "train", "algorithm", "uhash", "nr", "task", "task_id",
"learner", "learner_id", "resampling", "resampling_id",
"prediction")
## bad.names <- group.name.vec[group.name.vec %in% reserved.names]
## if(length(bad.names)){
## first.bad <- bad.names[1]
## stop(sprintf("col with role group must not be named %s; please fix by renaming %s col", first.bad, first.bad))
## }
## orig.group.dt <- task$data(cols=group.name.vec)
strata <- if(is.null(task$strata)){
data.dt <- task$data()
data.table(N=nrow(data.dt), row_id=list(1:nrow(data.dt)))
}else task$strata
folds = private$.combine(
lapply(strata$row_id, private$.sample, task = task)
)[order(row_id)]
min_train_data <- self$param_set$values[["min_train_data"]]
if(task$nrow <= min_train_data){
stop(sprintf(
"task$nrow=%d but should be larger than min_train_data=%d",
task$nrow, min_train_data))
}
uniq.folds <- sort(unique(folds$fold))
iteration.dt.list <- list()
for(test.fold in uniq.folds){
is.set.fold <- list(
test=folds[["fold"]] == test.fold)
is.set.fold[["train"]] <- !is.set.fold[["test"]]
i.set.list <- lapply(is.set.fold, which)
max_train_data <- length(i.set.list$train)
log.range.data <- log(c(min_train_data, max_train_data))
seq.args <- c(as.list(log.range.data), list(l=self$param_set$values[["train_sizes"]]))
log.train.sizes <- do.call(seq, seq.args)
train.size.vec <- unique(as.integer(round(exp(log.train.sizes))))
for(seed in 1:self$param_set$values[["random_seeds"]]){
set.seed(seed)
ord.i.vec <- sample(i.set.list$train)
iteration.dt.list[[paste(test.fold, seed)]] <- data.table(
test.fold,
seed,
train_size=train.size.vec,
train=lapply(train.size.vec, function(last)ord.i.vec[1:last]),
test=list(i.set.list$test))
}
}
self$instance <- list(
iteration.dt=rbindlist(iteration.dt.list)[, iteration := .I][],
id.dt=folds)
self$task_hash = task$hash
self$task_nrow = task$nrow
invisible(self)
}
)
)
14 changes: 14 additions & 0 deletions R/score.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
score <- function(bench.result, ...){
algorithm <- learner_id <- NULL
## Above to avoid CRAN NOTE.
bench.score <- bench.result$score(...)
out.dt.list <- list()
for(score.i in 1:nrow(bench.score)){
bench.row <- bench.score[score.i]
it.dt <- bench.row$resampling[[1]]$instance$iteration.dt
out.dt.list[[score.i]] <- it.dt[
bench.row, on="iteration"
][, algorithm := sub(".*[.]", "", learner_id)]
}
rbindlist(out.dt.list)
}
32 changes: 15 additions & 17 deletions man/ResamplingSameOther.Rd → man/ResamplingSameOtherCV.Rd
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
\name{ResamplingSameOther}
\alias{ResamplingSameOther}
\name{ResamplingSameOtherCV}
\alias{ResamplingSameOtherCV}
\title{Resampling for comparing training on same or other groups}
\description{
\code{\link{ResamplingSameOther}} is the abstract base class for
\code{\link{ResamplingSameOtherCV}},
which defines how a task is partitioned for
resampling, for example in
\code{\link[mlr3:resample]{resample()}} or
\code{\link[mlr3:benchmark]{benchmark()}}.

Resampling objects can be instantiated on a
\code{\link[mlr3:Task]{Task}},
which should define at least one group variable.

After instantiation, sets can be accessed via
\verb{$train_set(i)} and
\verb{$test_set(i)}, respectively.
\code{\link{ResamplingSameOtherCV}}
defines how a task is partitioned for
resampling, for example in
\code{\link[mlr3:resample]{resample()}} or
\code{\link[mlr3:benchmark]{benchmark()}}.

Resampling objects can be instantiated on a
\code{\link[mlr3:Task]{Task}},
which should define at least one group variable.

After instantiation, sets can be accessed via
\verb{$train_set(i)} and
\verb{$test_set(i)}, respectively.
}
\details{
A supervised learning algorithm inputs a train set, and outputs a
Expand Down Expand Up @@ -50,7 +48,7 @@ each combination of the values of the stratification variables forms a stratum.
The grouping variable is assumed to be discrete,
and must be stored in the \link{Task} with column role \code{"group"}.

Then number of cross-validation folds K should be defined as the
The number of cross-validation folds K should be defined as the
\code{fold} parameter.

In each group, there will be about an equal number of observations
Expand Down
Loading

0 comments on commit 0f33224

Please sign in to comment.