From 3e40782b3780090f9db01b09975ea462fb07bcb8 Mon Sep 17 00:00:00 2001 From: bcjaeger Date: Fri, 22 Nov 2024 16:23:57 -0500 Subject: [PATCH] allow steps of >1 vs can run pretty slow when you have 500 predictors and drop one at a time. This should allow flexibility for data like that --- DESCRIPTION | 2 +- R/coerce_nans.R | 4 +- R/orsf_R6.R | 34 ++++++++++++---- R/orsf_data_prep.R | 5 ++- R/orsf_vs.R | 6 ++- man/orsf.Rd | 48 +++++++++------------- man/orsf_control_cph.Rd | 4 +- man/orsf_control_custom.Rd | 4 +- man/orsf_control_fast.Rd | 4 +- man/orsf_control_net.Rd | 4 +- man/orsf_pd_oob.Rd | 83 +++++++++++++++++++++++++++++++------- man/orsf_vs.Rd | 9 ++++- 12 files changed, 142 insertions(+), 65 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index f52ab031..fba54f91 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -48,4 +48,4 @@ Config/testthat/edition: 3 Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.2 diff --git a/R/coerce_nans.R b/R/coerce_nans.R index 79e9cec3..ac15ed58 100644 --- a/R/coerce_nans.R +++ b/R/coerce_nans.R @@ -1,13 +1,15 @@ +#' @noRd coerce_nans <- function(x, to){ UseMethod('coerce_nans') } +#' @noRd coerce_nans.list <- function(x, to){ lapply(x, coerce_nans, to = to) } - +#' @noRd coerce_nans.factor <- coerce_nans.integer <- coerce_nans.double <- diff --git a/R/orsf_R6.R b/R/orsf_R6.R index 0f9bf4b1..a395eec5 100644 --- a/R/orsf_R6.R +++ b/R/orsf_R6.R @@ -703,7 +703,9 @@ ObliqueForest <- R6::R6Class( # Variable selection # returns a data.table with variable selection info - select_variables = function(n_predictor_min, verbose_progress){ + select_variables = function(n_predictor_min, + n_predictor_drop, + verbose_progress){ public_state <- list(verbose_progress = self$verbose_progress, forest = self$forest, @@ -712,7 +714,9 @@ ObliqueForest <- R6::R6Class( object_trained <- self$trained out <- try( - private$select_variables_internal(n_predictor_min, verbose_progress) + private$select_variables_internal(n_predictor_min, + n_predictor_drop, + verbose_progress) ) private$restore_state(public_state, private_state = NULL) @@ -2928,9 +2932,11 @@ ObliqueForest <- R6::R6Class( }, - select_variables_internal = function(n_predictor_min, verbose_progress){ + select_variables_internal = function(n_predictor_min, + n_predictor_drop, + verbose_progress){ - n_predictors <- length(private$data_names$x_original) + n_predictors <- length(private$data_names$x_ref_code) # verbose progress on the forest should always be FALSE # because for orsf_vs, verbosity is coordinated in R @@ -2941,7 +2947,7 @@ ObliqueForest <- R6::R6Class( stat_value = rep(NA_real_, n_predictors), variables_included = vector(mode = 'list', length = n_predictors), predictors_included = vector(mode = 'list', length = n_predictors), - predictor_dropped = rep(NA_character_, n_predictors) + predictor_dropped = vector(mode = 'list', length = n_predictors) ) # if the forest was not trained prior to variable selection @@ -3045,9 +3051,21 @@ ObliqueForest <- R6::R6Class( cpp_args$mtry <- mtry_safe cpp_output <- do.call(orsf_cpp, args = cpp_args) - worst_index <- which.min(cpp_output$importance) - worst_predictor <- colnames(cpp_args$x)[worst_index] + n_drop <- min(n_predictor_drop, + n_predictors - n_predictor_min) + + if(n_drop > 0){ + + worst_index <- order(cpp_output$importance)[seq(n_drop)] + + worst_predictor <- colnames(cpp_args$x)[worst_index] + } else { + + worst_predictor <- NA_character_ + n_drop <- 1 + + } .variables_included <- with( variable_key, @@ -3062,7 +3080,7 @@ ObliqueForest <- R6::R6Class( predictor_dropped = worst_predictor)] cpp_args$x <- cpp_args$x[, -worst_index, drop = FALSE] - n_predictors <- n_predictors - 1 + n_predictors <- n_predictors - n_drop current_progress <- current_progress + 1 } diff --git a/R/orsf_data_prep.R b/R/orsf_data_prep.R index 1ce57200..59309d62 100644 --- a/R/orsf_data_prep.R +++ b/R/orsf_data_prep.R @@ -1,8 +1,9 @@ - +#' @noRd orsf_data_prep <- function(data, ...){ UseMethod('orsf_data_prep') } +#' @noRd orsf_data_prep.list <- function(data, ...){ lengths <- vapply(data, length, integer(1)) @@ -43,12 +44,14 @@ orsf_data_prep.list <- function(data, ...){ } +#' @noRd orsf_data_prep.recipe <- function(data, ...){ getElement(data, 'template') } +#' @noRd orsf_data_prep.data.frame <- function(data, ...){ data } diff --git a/R/orsf_vs.R b/R/orsf_vs.R index ac5236f5..814dfa0f 100644 --- a/R/orsf_vs.R +++ b/R/orsf_vs.R @@ -3,6 +3,7 @@ #' #' @inheritParams predict.ObliqueForest #' @param n_predictor_min (*integer*) the minimum number of predictors allowed +#' @param n_predictor_drop (*integer*) the number of predictors dropped at each step #' @param verbose_progress (*logical*) not implemented yet. Should progress be printed to the console? #' #' @return a [data.table][data.table::data.table-package] with four columns: @@ -38,6 +39,7 @@ orsf_vs <- function(object, n_predictor_min = 3, + n_predictor_drop = 1, verbose_progress = NULL){ check_arg_is(arg_value = object, @@ -74,7 +76,9 @@ orsf_vs <- function(object, arg_name = 'verbose_progress', expected_length = 1) - object$select_variables(n_predictor_min, verbose_progress) + object$select_variables(n_predictor_min, + n_predictor_drop, + verbose_progress) } diff --git a/man/orsf.Rd b/man/orsf.Rd index dc217bce..25be4e30 100644 --- a/man/orsf.Rd +++ b/man/orsf.Rd @@ -366,18 +366,6 @@ data that were not used to train it, i.e., testing data. library(magrittr) # for \%>\% }\if{html}{\out{}} -\if{html}{\out{
}}\preformatted{## -## Attaching package: 'magrittr' - -## The following object is masked from 'package:tidyr': -## -## extract - -## The following objects are masked from 'package:testthat': -## -## equals, is_less_than, not -}\if{html}{\out{
}} - \code{orsf()} is the entry-point of the \code{aorsf} package. It can be used to fit classification, regression, and survival forests. @@ -400,7 +388,7 @@ penguin_fit ## N trees: 5 ## N predictors total: 7 ## N predictors per node: 3 -## Average leaves per tree: 4.6 +## Average leaves per tree: 5.2 ## Min observations in leaf: 5 ## OOB stat value: 0.99 ## OOB stat type: AUC-ROC @@ -427,9 +415,9 @@ bill_fit ## N trees: 5 ## N predictors total: 7 ## N predictors per node: 3 -## Average leaves per tree: 51 +## Average leaves per tree: 47.4 ## Min observations in leaf: 5 -## OOB stat value: 0.70 +## OOB stat value: 0.71 ## OOB stat type: RSQ ## Variable importance: anova ## @@ -459,10 +447,10 @@ pbc_fit ## N trees: 5 ## N predictors total: 17 ## N predictors per node: 5 -## Average leaves per tree: 22.2 +## Average leaves per tree: 21.4 ## Min observations in leaf: 5 ## Min events in leaf: 1 -## OOB stat value: 0.78 +## OOB stat value: 0.80 ## OOB stat type: Harrell's C-index ## Variable importance: anova ## @@ -509,7 +497,7 @@ take to fit the forest before you commit to it: orsf_time_to_train() }\if{html}{\out{}} -\if{html}{\out{
}}\preformatted{## Time difference of 2.429678 secs +\if{html}{\out{
}}\preformatted{## Time difference of 2.534871 secs }\if{html}{\out{
}} \enumerate{ \item If fitting multiple forests, use the blueprint along with @@ -580,12 +568,12 @@ brier_scores \if{html}{\out{
}}\preformatted{## # A tibble: 6 x 4 ## .metric .estimator .eval_time .estimate ## -## 1 brier_survival standard 500 0.0597 -## 2 brier_survival standard 1000 0.0943 -## 3 brier_survival standard 1500 0.0883 -## 4 brier_survival standard 2000 0.102 -## 5 brier_survival standard 2500 0.137 -## 6 brier_survival standard 3000 0.153 +## 1 brier_survival standard 500 0.0384 +## 2 brier_survival standard 1000 0.0827 +## 3 brier_survival standard 1500 0.0932 +## 4 brier_survival standard 2000 0.105 +## 5 brier_survival standard 2500 0.151 +## 6 brier_survival standard 3000 0.238 }\if{html}{\out{
}} \if{html}{\out{
}}\preformatted{roc_scores <- test_pred \%>\% @@ -597,12 +585,12 @@ roc_scores \if{html}{\out{
}}\preformatted{## # A tibble: 6 x 4 ## .metric .estimator .eval_time .estimate ## -## 1 roc_auc_survival standard 500 0.957 -## 2 roc_auc_survival standard 1000 0.912 -## 3 roc_auc_survival standard 1500 0.935 -## 4 roc_auc_survival standard 2000 0.931 -## 5 roc_auc_survival standard 2500 0.907 -## 6 roc_auc_survival standard 3000 0.889 +## 1 roc_auc_survival standard 500 0.828 +## 2 roc_auc_survival standard 1000 0.864 +## 3 roc_auc_survival standard 1500 0.925 +## 4 roc_auc_survival standard 2000 0.934 +## 5 roc_auc_survival standard 2500 0.863 +## 6 roc_auc_survival standard 3000 0.767 }\if{html}{\out{
}} } } diff --git a/man/orsf_control_cph.Rd b/man/orsf_control_cph.Rd index 6f09a551..d1987f7f 100644 --- a/man/orsf_control_cph.Rd +++ b/man/orsf_control_cph.Rd @@ -52,9 +52,9 @@ Springer, New York, NY. DOI: 10.1007/978-1-4757-3294-8_3 } \seealso{ linear combination control functions +\code{\link{orsf_control}()}, \code{\link{orsf_control_custom}()}, \code{\link{orsf_control_fast}()}, -\code{\link{orsf_control_net}()}, -\code{\link{orsf_control}()} +\code{\link{orsf_control_net}()} } \concept{orsf_control} diff --git a/man/orsf_control_custom.Rd b/man/orsf_control_custom.Rd index bc74a5b7..0742a348 100644 --- a/man/orsf_control_custom.Rd +++ b/man/orsf_control_custom.Rd @@ -32,9 +32,9 @@ an input for the \code{control} argument of \link{orsf}. } \seealso{ linear combination control functions +\code{\link{orsf_control}()}, \code{\link{orsf_control_cph}()}, \code{\link{orsf_control_fast}()}, -\code{\link{orsf_control_net}()}, -\code{\link{orsf_control}()} +\code{\link{orsf_control_net}()} } \concept{orsf_control} diff --git a/man/orsf_control_fast.Rd b/man/orsf_control_fast.Rd index 2365bbaf..1041be61 100644 --- a/man/orsf_control_fast.Rd +++ b/man/orsf_control_fast.Rd @@ -38,9 +38,9 @@ on the scale of your data, which is why the default value is \code{TRUE}. } \seealso{ linear combination control functions +\code{\link{orsf_control}()}, \code{\link{orsf_control_cph}()}, \code{\link{orsf_control_custom}()}, -\code{\link{orsf_control_net}()}, -\code{\link{orsf_control}()} +\code{\link{orsf_control_net}()} } \concept{orsf_control} diff --git a/man/orsf_control_net.Rd b/man/orsf_control_net.Rd index 9892f6ac..517dad9d 100644 --- a/man/orsf_control_net.Rd +++ b/man/orsf_control_net.Rd @@ -40,9 +40,9 @@ coordinate descent." \emph{Journal of statistical software}, \emph{39}(5), 1. } \seealso{ linear combination control functions +\code{\link{orsf_control}()}, \code{\link{orsf_control_cph}()}, \code{\link{orsf_control_custom}()}, -\code{\link{orsf_control_fast}()}, -\code{\link{orsf_control}()} +\code{\link{orsf_control_fast}()} } \concept{orsf_control} diff --git a/man/orsf_pd_oob.Rd b/man/orsf_pd_oob.Rd index 7e550962..c732e57e 100644 --- a/man/orsf_pd_oob.Rd +++ b/man/orsf_pd_oob.Rd @@ -277,12 +277,47 @@ pd_new ## 3: Gentoo Biscoe 3200 42.81649 40.19221 42.55664 46.84035 ## 4: Adelie Dream 3200 40.16219 36.95895 40.34633 43.90681 ## 5: Chinstrap Dream 3200 46.21778 43.53954 45.90929 49.19173 -## --- +## 6: Gentoo Dream 3200 42.60465 39.89647 42.63520 46.28769 +## 7: Adelie Torgersen 3200 39.91652 36.80227 39.79806 43.68842 +## 8: Chinstrap Torgersen 3200 44.27807 41.95470 44.40742 46.68848 +## 9: Gentoo Torgersen 3200 42.09510 39.49863 41.80049 45.81833 +## 10: Adelie Biscoe 3550 40.77971 38.04027 40.59561 44.57505 +## 11: Chinstrap Biscoe 3550 45.81304 43.52102 45.73116 48.36366 +## 12: Gentoo Biscoe 3550 43.31233 40.77355 43.03077 47.22936 +## 13: Adelie Dream 3550 40.77741 38.07399 40.78175 44.37273 +## 14: Chinstrap Dream 3550 47.30926 44.80493 46.77540 50.47092 +## 15: Gentoo Dream 3550 43.26955 40.86119 43.16204 46.89190 +## 16: Adelie Torgersen 3550 40.25780 37.35251 40.07871 44.04576 +## 17: Chinstrap Torgersen 3550 44.77911 42.60161 44.81944 47.14986 +## 18: Gentoo Torgersen 3550 42.49520 39.95866 42.14160 46.26237 +## 19: Adelie Biscoe 3975 41.61744 38.94515 41.36634 45.38752 +## 20: Chinstrap Biscoe 3975 46.59363 44.59970 46.44923 49.11457 +## 21: Gentoo Biscoe 3975 44.07857 41.60792 43.74562 47.85109 +## 22: Adelie Dream 3975 41.50511 39.06187 41.24741 45.13027 +## 23: Chinstrap Dream 3975 48.14978 45.87390 47.54867 51.50683 +## 24: Gentoo Dream 3975 44.01928 41.70577 43.84099 47.50470 +## 25: Adelie Torgersen 3975 40.94764 38.12519 40.66759 44.73689 +## 26: Chinstrap Torgersen 3975 45.44820 43.49986 45.44036 47.63243 +## 27: Gentoo Torgersen 3975 43.13791 40.70628 42.70627 46.87306 +## 28: Adelie Biscoe 4700 42.93914 40.48463 42.44768 46.81756 +## 29: Chinstrap Biscoe 4700 47.18534 45.40866 47.07739 49.55747 +## 30: Gentoo Biscoe 4700 45.32541 43.08173 44.93498 49.23391 +## 31: Adelie Dream 4700 42.73806 40.44229 42.22226 46.49936 +## 32: Chinstrap Dream 4700 48.37354 46.34335 48.00781 51.18955 +## 33: Gentoo Dream 4700 45.09132 42.88328 44.79530 48.82180 +## 34: Adelie Torgersen 4700 42.09349 39.72074 41.56168 45.68838 +## 35: Chinstrap Torgersen 4700 46.17045 44.39042 46.09525 48.35127 +## 36: Gentoo Torgersen 4700 44.31621 42.18968 43.81773 47.98024 +## 37: Adelie Biscoe 5300 43.89769 41.43335 43.28504 48.10892 +## 38: Chinstrap Biscoe 5300 47.53721 45.66038 47.52770 49.88701 +## 39: Gentoo Biscoe 5300 46.16115 43.81722 45.59309 50.57469 +## 40: Adelie Dream 5300 43.59846 41.25825 43.24518 47.46193 ## 41: Chinstrap Dream 5300 48.48139 46.36282 48.25679 51.02996 ## 42: Gentoo Dream 5300 45.91819 43.62832 45.54110 49.91622 ## 43: Adelie Torgersen 5300 42.92879 40.66576 42.31072 46.76406 ## 44: Chinstrap Torgersen 5300 46.59576 44.80400 46.49196 49.03906 ## 45: Gentoo Torgersen 5300 45.11384 42.95190 44.51289 49.27629 +## species island body_mass_g mean lwr medn upr }\if{html}{\out{
}} By default, all combinations of all variables are used. However, you can @@ -385,19 +420,39 @@ Specify \code{pred_horizon} to get partial dependence at each value: pd_train }\if{html}{\out{
}} -\if{html}{\out{
}}\preformatted{## pred_horizon bili mean lwr medn upr -## -## 1: 500 0.55 0.0617199 0.000443399 0.00865419 0.5907104 -## 2: 1000 0.55 0.1418501 0.005793742 0.05572853 0.7360749 -## 3: 1500 0.55 0.2082505 0.013609478 0.09174558 0.8556319 -## 4: 2000 0.55 0.2679017 0.023047689 0.14574169 0.8910549 -## 5: 2500 0.55 0.3179617 0.063797305 0.20254500 0.9017710 -## --- -## 26: 1000 7.25 0.3264627 0.135343689 0.25956791 0.8884333 -## 27: 1500 7.25 0.4641265 0.218208755 0.38787435 0.9702903 -## 28: 2000 7.25 0.5511761 0.293367409 0.48427730 0.9812413 -## 29: 2500 7.25 0.6200238 0.371965247 0.56954399 0.9845058 -## 30: 3000 7.25 0.6803482 0.425128031 0.64642318 0.9888637 +\if{html}{\out{
}}\preformatted{## pred_horizon bili mean lwr medn upr +## +## 1: 500 0.55 0.06171990 0.000443399 0.008654190 0.5907104 +## 2: 1000 0.55 0.14185009 0.005793742 0.055728527 0.7360749 +## 3: 1500 0.55 0.20825053 0.013609478 0.091745579 0.8556319 +## 4: 2000 0.55 0.26790167 0.023047689 0.145741690 0.8910549 +## 5: 2500 0.55 0.31796166 0.063797305 0.202544999 0.9017710 +## 6: 3000 0.55 0.39108086 0.090852131 0.301804690 0.9234812 +## 7: 500 0.70 0.06240527 0.000443399 0.008934806 0.5980510 +## 8: 1000 0.70 0.14313570 0.006159694 0.056348007 0.7432448 +## 9: 1500 0.70 0.21012128 0.013717586 0.092461532 0.8597396 +## 10: 2000 0.70 0.27013021 0.023169510 0.146344595 0.8935664 +## 11: 2500 0.70 0.31880954 0.062506113 0.201979102 0.9068170 +## 12: 3000 0.70 0.39286323 0.089707173 0.308392927 0.9252028 +## 13: 500 1.50 0.06679162 0.001271788 0.011028398 0.6241228 +## 14: 1000 1.50 0.15727919 0.011478962 0.068332010 0.7678732 +## 15: 1500 1.50 0.23316655 0.028732095 0.117289745 0.8789647 +## 16: 2000 1.50 0.30139227 0.046792721 0.180096425 0.9144202 +## 17: 2500 1.50 0.35260943 0.084586675 0.238015966 0.9266065 +## 18: 3000 1.50 0.43512074 0.131110330 0.346025144 0.9438562 +## 19: 500 3.50 0.08638646 0.005208753 0.028239001 0.6740930 +## 20: 1000 3.50 0.22353655 0.051917978 0.139604845 0.8283986 +## 21: 1500 3.50 0.32700976 0.090198324 0.217982772 0.9371150 +## 22: 2000 3.50 0.41618105 0.144532860 0.311508093 0.9566091 +## 23: 2500 3.50 0.49248461 0.219511094 0.402095677 0.9636221 +## 24: 3000 3.50 0.56008108 0.263569896 0.503253258 0.9734948 +## 25: 500 7.25 0.12585007 0.022092057 0.063550987 0.7543806 +## 26: 1000 7.25 0.32646274 0.135343689 0.259567907 0.8884333 +## 27: 1500 7.25 0.46412653 0.218208755 0.387874346 0.9702903 +## 28: 2000 7.25 0.55117610 0.293367409 0.484277295 0.9812413 +## 29: 2500 7.25 0.62002385 0.371965247 0.569543990 0.9845058 +## 30: 3000 7.25 0.68034820 0.425128031 0.646423180 0.9888637 +## pred_horizon bili mean lwr medn upr }\if{html}{\out{
}} vector-valued \code{pred_horizon} input comes with minimal extra diff --git a/man/orsf_vs.Rd b/man/orsf_vs.Rd index c3db8cca..277e21d8 100644 --- a/man/orsf_vs.Rd +++ b/man/orsf_vs.Rd @@ -4,13 +4,20 @@ \alias{orsf_vs} \title{Variable selection} \usage{ -orsf_vs(object, n_predictor_min = 3, verbose_progress = NULL) +orsf_vs( + object, + n_predictor_min = 3, + n_predictor_drop = 1, + verbose_progress = NULL +) } \arguments{ \item{object}{(\emph{ObliqueForest}) a trained oblique random forest object (see \link{orsf}).} \item{n_predictor_min}{(\emph{integer}) the minimum number of predictors allowed} +\item{n_predictor_drop}{(\emph{integer}) the number of predictors dropped at each step} + \item{verbose_progress}{(\emph{logical}) not implemented yet. Should progress be printed to the console?} } \value{