Skip to content

Commit 282d081

Browse files
authored
generalized random forests (#1299)
* initial definitions for classification and regression * enable quantile regression * documentation * testing update * snapshot updates * fix typo * air formatting * fix case weight entries * redoc * fix test * remove quantreg from suggests * check for quantreg install * fix errors * add grf to Config/Needs/website * changes based on user comments
1 parent 8590444 commit 282d081

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+3217
-292
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ Suggests:
7171
VignetteBuilder:
7272
knitr
7373
ByteCompile: true
74-
Config/Needs/website: brulee, C50, dbarts, earth, glmnet, keras, kernlab,
74+
Config/Needs/website: brulee, C50, dbarts, earth, glmnet, grf, keras, kernlab,
7575
kknn, LiblineaR, mgcv, nnet, parsnip, quantreg, randomForest, ranger,
7676
rpart, rstanarm, tidymodels/tidymodels, tidyverse/tidytemplate,
7777
rstudio/reticulate, xgboost, rmarkdown

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# parsnip (development version)
22

3+
* Enable generalized random forest (`grf`) models for classification, regression, and quantile regression modes. (#1288)
4+
35
* `surv_reg()` is now defunct and will error if called. Please use `survival_reg()` instead (#1206).
46

7+
58
# parsnip 1.3.3
69

710
* Bug fix in how tunable parameters were configured for brulee neural networks.

R/aaa_archive.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# no fmt
1+
# fmt: skip
22
model_info_table <-
33
tibble::tribble(
44
~model, ~mode, ~engine, ~pkg,
@@ -21,6 +21,7 @@ model_info_table <-
2121
"bag_tree", "classification", "rpart", "baguette",
2222
"bart", "classification", "dbarts", NA,
2323
"boost_tree", "classification", "C5.0", NA,
24+
"boost_tree", "classification", "catboost", "bonsai",
2425
"boost_tree", "classification", "h2o", "agua",
2526
"boost_tree", "classification", "h2o_gbm", "agua",
2627
"boost_tree", "classification", "lightgbm", "bonsai",
@@ -69,6 +70,7 @@ model_info_table <-
6970
"null_model", "classification", "parsnip", NA,
7071
"pls", "classification", "mixOmics", "plsmod",
7172
"rand_forest", "classification", "aorsf", "bonsai",
73+
"rand_forest", "classification", "grf", NA,
7274
"rand_forest", "classification", "h2o", "agua",
7375
"rand_forest", "classification", "partykit", "bonsai",
7476
"rand_forest", "classification", "randomForest", NA,
@@ -82,11 +84,13 @@ model_info_table <-
8284
"svm_rbf", "classification", "kernlab", NA,
8385
"svm_rbf", "classification", "liquidSVM", NA,
8486
"linear_reg", "quantile regression", "quantreg", NA,
87+
"rand_forest", "quantile regression", "grf", NA,
8588
"auto_ml", "regression", "h2o", "agua",
8689
"bag_mars", "regression", "earth", "baguette",
8790
"bag_mlp", "regression", "nnet", "baguette",
8891
"bag_tree", "regression", "rpart", "baguette",
8992
"bart", "regression", "dbarts", NA,
93+
"boost_tree", "regression", "catboost", "bonsai",
9094
"boost_tree", "regression", "h2o", "agua",
9195
"boost_tree", "regression", "h2o_gbm", "agua",
9296
"boost_tree", "regression", "lightgbm", "bonsai",
@@ -130,6 +134,7 @@ model_info_table <-
130134
"poisson_reg", "regression", "stan_glmer", "multilevelmod",
131135
"poisson_reg", "regression", "zeroinfl", "poissonreg",
132136
"rand_forest", "regression", "aorsf", "bonsai",
137+
"rand_forest", "regression", "grf", NA,
133138
"rand_forest", "regression", "h2o", "agua",
134139
"rand_forest", "regression", "partykit", "bonsai",
135140
"rand_forest", "regression", "randomForest", NA,
@@ -145,4 +150,3 @@ model_info_table <-
145150
"svm_rbf", "regression", "kernlab", NA,
146151
"svm_rbf", "regression", "liquidSVM", NA
147152
)
148-

R/augment.R

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,17 @@
8686
#'
8787
#' # ------------------------------------------------------------------------------
8888
#'
89-
#' # Quantile regression example
90-
#' qr_form <-
91-
#' linear_reg() |>
92-
#' set_engine("quantreg") |>
93-
#' set_mode("quantile regression", quantile_levels = c(0.25, 0.5, 0.75)) |>
94-
#' fit(mpg ~ ., data = car_trn)
95-
#'
96-
#' augment(qr_form, car_tst)
97-
#' augment(qr_form, car_tst[, -1])
89+
#' if (rlang::is_installed("quantreg")) {
90+
#' # Quantile regression example
91+
#' qr_form <-
92+
#' linear_reg() |>
93+
#' set_engine("quantreg") |>
94+
#' set_mode("quantile regression", quantile_levels = c(0.25, 0.5, 0.75)) |>
95+
#' fit(mpg ~ ., data = car_trn)
96+
#'
97+
#' augment(qr_form, car_tst)
98+
#' augment(qr_form, car_tst[, -1])
99+
#' }
98100
#'
99101
augment.model_fit <- function(x, new_data, eval_time = NULL, ...) {
100102
new_data <- tibble::new_tibble(new_data)

R/fit.R

Lines changed: 71 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,13 @@
109109
#' @export
110110
#' @export fit.model_spec
111111
fit.model_spec <-
112-
function(object,
113-
formula,
114-
data,
115-
case_weights = NULL,
116-
control = control_parsnip(),
117-
...
112+
function(
113+
object,
114+
formula,
115+
data,
116+
case_weights = NULL,
117+
control = control_parsnip(),
118+
...
118119
) {
119120
if (object$mode == "unknown") {
120121
cli::cli_abort(
@@ -135,7 +136,6 @@ fit.model_spec <-
135136
}
136137
check_formula(formula)
137138

138-
139139
if (is_sparse_matrix(data)) {
140140
data <- sparsevctrs::coerce_to_sparse_tibble(data, rlang::caller_env(0))
141141
}
@@ -153,12 +153,14 @@ fit.model_spec <-
153153
eng_vals <- possible_engines(object)
154154
object$engine <- eng_vals[1]
155155
if (control$verbosity > 0) {
156-
cli::cli_warn("Engine set to {.val {object$engine}}.")
156+
cli::cli_warn("Engine set to {.val {object$engine}}.")
157157
}
158158
}
159159

160160
if (all(c("x", "y") %in% names(dots))) {
161-
cli::cli_abort("{.fn fit.model_spec} is for the formula methods. Use {.fn fit_xy} instead.")
161+
cli::cli_abort(
162+
"{.fn fit.model_spec} is for the formula methods. Use {.fn fit_xy} instead."
163+
)
162164
}
163165
cl <- match.call(expand.dots = TRUE)
164166
# Create an environment with the evaluated argument objects. This will be
@@ -186,11 +188,12 @@ fit.model_spec <-
186188
fit_interface <-
187189
check_interface(eval_env$formula, eval_env$data, cl, object)
188190

189-
if (object$engine == "spark" && !inherits(eval_env$data, "tbl_spark"))
191+
if (object$engine == "spark" && !inherits(eval_env$data, "tbl_spark")) {
190192
cli::cli_abort(
191-
"spark objects can only be used with the formula interface to {.fn fit}
193+
"spark objects can only be used with the formula interface to {.fn fit}
192194
with a spark data object."
193-
)
195+
)
196+
}
194197

195198
# populate `method` with the details for this model type
196199
object <- add_methods(object, engine = object$engine)
@@ -208,51 +211,49 @@ fit.model_spec <-
208211
switch(
209212
interfaces,
210213
# homogeneous combinations:
211-
formula_formula =
212-
form_form(
213-
object = object,
214-
control = control,
215-
env = eval_env
216-
),
214+
formula_formula = form_form(
215+
object = object,
216+
control = control,
217+
env = eval_env
218+
),
217219

218220
# heterogenous combinations
219-
formula_matrix =
220-
form_xy(
221-
object = object,
222-
control = control,
223-
env = eval_env,
224-
target = object$method$fit$interface,
225-
...
226-
),
227-
formula_data.frame =
228-
form_xy(
229-
object = object,
230-
control = control,
231-
env = eval_env,
232-
target = object$method$fit$interface,
233-
...
234-
),
221+
formula_matrix = form_xy(
222+
object = object,
223+
control = control,
224+
env = eval_env,
225+
target = object$method$fit$interface,
226+
...
227+
),
228+
formula_data.frame = form_xy(
229+
object = object,
230+
control = control,
231+
env = eval_env,
232+
target = object$method$fit$interface,
233+
...
234+
),
235235

236236
cli::cli_abort("{.val {interfaces}} is unknown.")
237237
)
238238
res$censor_probs <- reverse_km(object, eval_env)
239239
model_classes <- class(res$fit)
240240
class(res) <- c(paste0("_", model_classes[1]), "model_fit")
241241
res
242-
}
242+
}
243243

244244
# ------------------------------------------------------------------------------
245245

246246
#' @rdname fit
247247
#' @export
248248
#' @export fit_xy.model_spec
249249
fit_xy.model_spec <-
250-
function(object,
251-
x,
252-
y,
253-
case_weights = NULL,
254-
control = control_parsnip(),
255-
...
250+
function(
251+
object,
252+
x,
253+
y,
254+
case_weights = NULL,
255+
control = control_parsnip(),
256+
...
256257
) {
257258
if (object$mode == "unknown") {
258259
cli::cli_abort(
@@ -329,32 +330,32 @@ fit_xy.model_spec <-
329330
switch(
330331
interfaces,
331332
# homogeneous combinations:
332-
matrix_matrix = , data.frame_matrix =
333-
xy_xy(
334-
object = object,
335-
env = eval_env,
336-
control = control,
337-
target = "matrix",
338-
...
339-
),
340-
341-
data.frame_data.frame = , matrix_data.frame =
342-
xy_xy(
343-
object = object,
344-
env = eval_env,
345-
control = control,
346-
target = "data.frame",
347-
...
348-
),
333+
matrix_matrix = ,
334+
data.frame_matrix = xy_xy(
335+
object = object,
336+
env = eval_env,
337+
control = control,
338+
target = "matrix",
339+
...
340+
),
341+
342+
data.frame_data.frame = ,
343+
matrix_data.frame = xy_xy(
344+
object = object,
345+
env = eval_env,
346+
control = control,
347+
target = "data.frame",
348+
...
349+
),
349350

350351
# heterogenous combinations
351-
matrix_formula = , data.frame_formula =
352-
xy_form(
353-
object = object,
354-
env = eval_env,
355-
control = control,
356-
...
357-
),
352+
matrix_formula = ,
353+
data.frame_formula = xy_form(
354+
object = object,
355+
env = eval_env,
356+
control = control,
357+
...
358+
),
358359
cli::cli_abort("{.val {interfaces}} is unknown.")
359360
)
360361
res$censor_probs <- reverse_km(object, eval_env)
@@ -368,7 +369,9 @@ fit_xy.model_spec <-
368369
eval_mod <- function(e, capture = FALSE, catch = FALSE, envir = NULL, ...) {
369370
if (capture) {
370371
if (catch) {
371-
junk <- capture.output(res <- try(eval_tidy(e, env = envir, ...), silent = TRUE))
372+
junk <- capture.output(
373+
res <- try(eval_tidy(e, env = envir, ...), silent = TRUE)
374+
)
372375
} else {
373376
junk <- capture.output(res <- eval_tidy(e, env = envir, ...))
374377
}
@@ -391,13 +394,13 @@ check_interface <- function(formula, data, cl, model, call = caller_env()) {
391394
# Determine the `fit()` interface
392395
form_interface <- !is.null(formula) & !is.null(data)
393396

394-
if (form_interface)
397+
if (form_interface) {
395398
return("formula")
399+
}
396400
cli::cli_abort("Error when checking the interface.", call = call)
397401
}
398402

399403
check_xy_interface <- function(x, y, cl, model, call = caller_env()) {
400-
401404
sparse_ok <- allow_sparse(model)
402405
sparse_x <- inherits(x, "dgCMatrix")
403406
if (!sparse_ok & sparse_x) {

R/rand_forest.R

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,17 @@
3434
#' @export
3535

3636
rand_forest <-
37-
function(mode = "unknown", engine = "ranger", mtry = NULL, trees = NULL, min_n = NULL) {
38-
37+
function(
38+
mode = "unknown",
39+
engine = "ranger",
40+
mtry = NULL,
41+
trees = NULL,
42+
min_n = NULL
43+
) {
3944
args <- list(
40-
mtry = enquo(mtry),
41-
trees = enquo(trees),
42-
min_n = enquo(min_n)
45+
mtry = enquo(mtry),
46+
trees = enquo(trees),
47+
min_n = enquo(min_n)
4348
)
4449

4550
new_model_spec(
@@ -60,15 +65,19 @@ rand_forest <-
6065
#' @rdname parsnip_update
6166
#' @export
6267
update.rand_forest <-
63-
function(object,
64-
parameters = NULL,
65-
mtry = NULL, trees = NULL, min_n = NULL,
66-
fresh = FALSE, ...) {
67-
68+
function(
69+
object,
70+
parameters = NULL,
71+
mtry = NULL,
72+
trees = NULL,
73+
min_n = NULL,
74+
fresh = FALSE,
75+
...
76+
) {
6877
args <- list(
69-
mtry = enquo(mtry),
70-
trees = enquo(trees),
71-
min_n = enquo(min_n)
78+
mtry = enquo(mtry),
79+
trees = enquo(trees),
80+
min_n = enquo(min_n)
7281
)
7382

7483
update_spec(
@@ -109,16 +118,17 @@ translate.rand_forest <- function(x, engine = x$engine, ...) {
109118

110119
# See "Details" in ?ml_random_forest_classifier. `feature_subset_strategy`
111120
# should be character even if it contains a number.
112-
if (any(names(arg_vals) == "feature_subset_strategy") &&
113-
isTRUE(is.numeric(quo_get_expr(arg_vals$feature_subset_strategy)))) {
121+
if (
122+
any(names(arg_vals) == "feature_subset_strategy") &&
123+
isTRUE(is.numeric(quo_get_expr(arg_vals$feature_subset_strategy)))
124+
) {
114125
arg_vals$feature_subset_strategy <-
115126
paste(quo_get_expr(arg_vals$feature_subset_strategy))
116127
}
117128
}
118129

119130
# add checks to error trap or change things for this method
120131
if (engine == "ranger") {
121-
122132
if (any(names(arg_vals) == "importance")) {
123133
if (isTRUE(is.logical(quo_get_expr(arg_vals$importance)))) {
124134
cli::cli_abort(
@@ -170,4 +180,3 @@ check_args.rand_forest <- function(object, call = rlang::caller_env()) {
170180
# move translate checks here?
171181
invisible(object)
172182
}
173-

0 commit comments

Comments
 (0)