Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aorsf classification and regression engines #78

Merged
merged 22 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Suggests:
covr,
knitr,
lightgbm,
aorsf (>= 0.1.3),
modeldata,
partykit,
rmarkdown,
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

v0.2.1.9000 is a developmental version of the bonsai package.

* Introduced support for accelerated oblique random forests for the `"classification"` and `"regression"` modes using the new [`"aorsf"` engine](https://github.com/ropensci/aorsf) (#78 by `@bcjaeger`).

* Enabled passing [Dataset Parameters](https://lightgbm.readthedocs.io/en/latest/Parameters.html#dataset-parameters) to the `"lightgbm"` engine. To pass an argument that would be usually passed as an element to the `param` argument in `lightgbm::lgb.Dataset()`, pass the argument directly through the ellipses in `set_engine()`, e.g. `boost_tree() %>% set_engine("lightgbm", linear_tree = TRUE)` (#77).

* Enabled case weights with the `"lightgbm"` engine (#72 by `@p-schaefer`).
Expand Down
225 changes: 225 additions & 0 deletions R/aorsf_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# nocov start

make_rand_forest_aorsf <- function(){
parsnip::set_model_engine("rand_forest", "classification", "aorsf")
parsnip::set_model_engine("rand_forest", "regression", "aorsf")
parsnip::set_dependency("rand_forest", "aorsf", "aorsf", mode = "classification")
parsnip::set_dependency("rand_forest", "aorsf", "aorsf", mode = "regression")

parsnip::set_model_arg(
model = "rand_forest",
eng = "aorsf",
parsnip = "mtry",
original = "mtry",
func = list(pkg = "dials", fun = "mtry"),
has_submodel = FALSE
)

parsnip::set_model_arg(
model = "rand_forest",
eng = "aorsf",
parsnip = "trees",
original = "n_tree",
func = list(pkg = "dials", fun = "trees"),
has_submodel = FALSE
)

parsnip::set_model_arg(
model = "rand_forest",
eng = "aorsf",
parsnip = "min_n",
original = "leaf_min_obs",
func = list(pkg = "dials", fun = "min_n"),
has_submodel = FALSE
)

parsnip::set_model_arg(
model = "rand_forest",
eng = "aorsf",
parsnip = "mtry",
original = "mtry",
func = list(pkg = "dials", fun = "mtry"),
has_submodel = FALSE
)

parsnip::set_fit(
model = "rand_forest",
eng = "aorsf",
mode = "classification",
value = list(
interface = "formula",
protect = c("formula", "data", "weights"),
func = c(pkg = "aorsf", fun = "orsf"),
defaults =
list(
n_thread = 1,
verbose_progress = FALSE
)
)
)

parsnip::set_encoding(
model = "rand_forest",
eng = "aorsf",
mode = "classification",
options = list(
predictor_indicators = "none",
compute_intercept = FALSE,
remove_intercept = FALSE,
allow_sparse_x = FALSE
)
)

parsnip::set_fit(
model = "rand_forest",
eng = "aorsf",
mode = "regression",
value = list(
interface = "formula",
protect = c("formula", "data", "weights"),
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
func = c(pkg = "aorsf", fun = "orsf"),
defaults =
list(
n_thread = 1,
verbose_progress = FALSE
)
)
)

parsnip::set_encoding(
model = "rand_forest",
eng = "aorsf",
mode = "regression",
options = list(
predictor_indicators = "none",
compute_intercept = FALSE,
remove_intercept = FALSE,
allow_sparse_x = FALSE
)
)

parsnip::set_pred(
model = "rand_forest",
eng = "aorsf",
mode = "classification",
type = "class",
value = list(
pre = NULL,
# makes prob preds consistent with class ones.
# note: the class predict method in aorsf uses the standard 'each tree
# gets one vote' approach, which is usually consistent with probability
# but not all the time. I opted to make predicted probability totally
# consistent with predicted class in the parsnip bindings for aorsf b/c
# I think it's really confusing when predicted probs do not align with
# predicted classes. I'm fine with this in aorsf but in bonsai I want
# to minimize confusion (#78).
post = function(results, object){

missings <- apply(results, 1, function(x) any(is.na(x)))

if(!any(missings)) {
return(colnames(results)[apply(results, 1, which.max)])
}

obs <- which(!missings)

out <- rep(NA_character_, nrow(results))
out[obs] <- colnames(results)[apply(results[obs, ], 1, which.max)]
out

},
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
new_data = quote(new_data),
pred_type = "prob",
verbose_progress = FALSE,
na_action = 'pass'
)
)
)

parsnip::set_pred(
model = "rand_forest",
eng = "aorsf",
mode = "classification",
type = "prob",
value = list(
pre = NULL,
post = function(x, object) {
as_tibble(x)
},
bcjaeger marked this conversation as resolved.
Show resolved Hide resolved
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
new_data = quote(new_data),
pred_type = 'prob',
verbose_progress = FALSE,
na_action = 'pass'
)
)
)

parsnip::set_pred(
model = "rand_forest",
eng = "aorsf",
mode = "classification",
type = "raw",
value = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
new_data = quote(new_data),
verbose_progress = FALSE,
na_action = 'pass'
)
)
)

parsnip::set_pred(
model = "rand_forest",
eng = "aorsf",
mode = "regression",
type = "numeric",
value = list(
pre = NULL,
post = as.numeric,
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
new_data = quote(new_data),
pred_type = "mean",
verbose_progress = FALSE,
na_action = 'pass'
)
)
)

parsnip::set_pred(
model = "rand_forest",
eng = "aorsf",
mode = "regression",
type = "raw",
value = list(
pre = NULL,
post = as.numeric,
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
new_data = quote(new_data),
pred_type = "mean",
verbose_progress = FALSE,
na_action = 'pass'
)
)
)
}

# nocov end
2 changes: 2 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

make_decision_tree_partykit()
make_rand_forest_partykit()

make_rand_forest_aorsf()
}


Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ following table:
| decision_tree | partykit | classification |
| rand_forest | partykit | regression |
| rand_forest | partykit | classification |
| rand_forest | aorsf | classification |
| rand_forest | aorsf | regression |

## Code of Conduct

Expand Down
Loading
Loading