diff --git a/DESCRIPTION b/DESCRIPTION index 4b81ae2..f89ff21 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -49,7 +49,7 @@ Description: A supervised learning algorithm inputs a train set, test accuracy for each group; other is usually somewhat less accurate than same; other can be just as bad as featureless baseline when the groups have different patterns). - For more information, + For more information, describes the method in depth. How many train samples are required to get accurate predictions on a diff --git a/R/ResamplingSameOtherCV.R b/R/ResamplingSameOtherCV.R index 2d3eb07..3fd1b49 100644 --- a/R/ResamplingSameOtherCV.R +++ b/R/ResamplingSameOtherCV.R @@ -24,7 +24,7 @@ ResamplingSameOtherCV = R6::R6Class( } reserved.names <- c( "row_id", "fold", "subset", "display_row", - "train.subsets", "test.fold", "test.subset", "iteration", + "train.subsets", "test.fold", "test.subset", "iteration", "test", "train", "algorithm", "uhash", "nr", "task", "task_id", "learner", "learner_id", "resampling", "resampling_id", "prediction") @@ -106,7 +106,7 @@ ResamplingSameOtherCV = R6::R6Class( rows="fold", display_row=min(display_row), display_end=max(display_row) - ), by=.(subset, fold)]) + ), by=.(subset, fold)]) self$instance <- list( iteration.dt=iteration.dt, id.dt=id.fold.subsets[order(row_id)], diff --git a/R/ResamplingVariableSizeTrainCV.R b/R/ResamplingVariableSizeTrainCV.R index 1e5fc72..a82ab3b 100644 --- a/R/ResamplingVariableSizeTrainCV.R +++ b/R/ResamplingVariableSizeTrainCV.R @@ -22,8 +22,7 @@ ResamplingVariableSizeTrainCV = R6::R6Class( instantiate = function(task) { task = mlr3::assert_task(mlr3::as_task(task)) strata <- if(is.null(task$strata)){ - data.dt <- task$data() - data.table(N=nrow(data.dt), row_id=list(1:nrow(data.dt))) + data.table(N=task$nrow, row_id=list(seq_len(task$nrow))) }else task$strata strata.list <- lapply(strata$row_id, private$.sample, task = task) folds = private$.combine(strata.list)[order(row_id)] diff --git a/R/zzz.R b/R/zzz.R new file mode 100644 index 0000000..95a4f26 --- /dev/null +++ b/R/zzz.R @@ -0,0 +1,16 @@ +register_mlr3 = function() { + mlr_resamplings = utils::getFromNamespace("mlr_resamplings", ns = "mlr3") + mlr_resamplings$add("same_other_sizes_cv", ResamplingSameOtherSizesCV) +} + +.onLoad = function(libname, pkgname) { # nolint + # Configure Logger: + assign("lg", lgr::get_logger("mlr3"), envir = parent.env(environment())) + if (Sys.getenv("IN_PKGDOWN") == "true") { + lg$set_threshold("warn") # nolint + } + + mlr3misc::register_namespace_callback(pkgname, "mlr3", register_mlr3) +} + +mlr3misc::leanify_package()