Skip to content

Commit

Permalink
implement rlang type checkers + cli conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Oct 24, 2024
1 parent 18bb633 commit 4d5c662
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 47 deletions.
55 changes: 26 additions & 29 deletions R/lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,18 @@ train_lightgbm <- function(x, y, weights = NULL, max_depth = -1, num_iterations
force(x)
force(y)

if (!is.logical(quiet)) {
rlang::abort("'quiet' should be a logical value.")
}
call <- call2("fit")

check_number_whole(max_depth, call = call)
check_number_whole(num_iterations, call = call)
check_number_decimal(learning_rate, call = call)
check_number_decimal(feature_fraction_bynode, call = call)
check_number_whole(min_data_in_leaf, call = call)
check_number_decimal(min_gain_to_split, call = call)
check_number_decimal(bagging_fraction, call = call)
check_number_decimal(early_stopping_round, allow_null = TRUE, call = call)
check_bool(counts, call = call)
check_bool(quiet, call = call)

feature_fraction_bynode <-
process_mtry(feature_fraction_bynode = feature_fraction_bynode,
Expand Down Expand Up @@ -101,38 +110,24 @@ train_lightgbm <- function(x, y, weights = NULL, max_depth = -1, num_iterations
res
}

process_mtry <- function(feature_fraction_bynode, counts, x, is_missing) {
if (!is.logical(counts)) {
rlang::abort("'counts' should be a logical value.")
}
process_mtry <- function(feature_fraction_bynode, counts, x, is_missing, call = call2("fit")) {
check_bool(counts, call = call)

ineq <- if (counts) {"greater"} else {"less"}
interp <- if (counts) {"count"} else {"proportion"}
opp <- if (!counts) {"count"} else {"proportion"}

if (rlang::is_call(feature_fraction_bynode)) {
if (rlang::call_name(feature_fraction_bynode) == "tune") {
rlang::abort(
glue::glue(
"The supplied `mtry` parameter is a call to `tune`. Did you forget ",
"to optimize hyperparameters with a tuning function like `tune::tune_grid`?"
),
call = NULL
)
}
}

if ((feature_fraction_bynode < 1 & counts) | (feature_fraction_bynode > 1 & !counts)) {
rlang::abort(
glue::glue(
"The supplied argument `mtry = {feature_fraction_bynode}` must be ",
"{ineq} than or equal to 1. \n\n`mtry` is currently being interpreted ",
"as a {interp} rather than a {opp}. Supply `counts = {!counts}` to ",
"`set_engine` to supply this argument as a {opp} rather than ",
# TODO: link to parsnip's lightgbm docs instead here
"a {interp}. \n\nSee `?train_lightgbm` for more details."
cli::cli_abort(
c(
"{.arg mtry} must be {ineq} than or equal to 1, not {feature_fraction_bynode}.",
"i" = "{.arg mtry} is currently being interpreted as a {interp}
rather than a {opp}.",
"i" = "Supply {.code counts = {!counts}} to {.fn set_engine} to supply
this argument as a {opp} rather than a {interp}.",
"i" = "See {.help train_lightgbm} for more details."
),
call = NULL
call = call
)
}

Expand Down Expand Up @@ -373,7 +368,9 @@ predict_lightgbm_regression_numeric <- function(object, new_data, ...) {
#' @rdname lightgbm_helpers
multi_predict._lgb.Booster <- function(object, new_data, type = NULL, trees = NULL, ...) {
if (any(names(rlang::enquos(...)) == "newdata")) {
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
cli::cli_abort(
"Did you mean to use {.code new_data} instead of {.code newdata}?"
)
}

trees <- sort(trees)
Expand Down
33 changes: 15 additions & 18 deletions tests/testthat/_snaps/lightgbm.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,17 @@

# bonsai handles mtry vs mtry_prop gracefully

The supplied argument `mtry = 0.5` must be greater than or equal to 1.

`mtry` is currently being interpreted as a count rather than a proportion. Supply `counts = FALSE` to `set_engine` to supply this argument as a proportion rather than a count.

See `?train_lightgbm` for more details.
`mtry` must be greater than or equal to 1, not 0.5.
i `mtry` is currently being interpreted as a count rather than a proportion.
i Supply `counts = FALSE` to `set_engine()` to supply this argument as a proportion rather than a count.
i See `?train_lightgbm()` for more details.

---

The supplied argument `mtry = 3` must be less than or equal to 1.

`mtry` is currently being interpreted as a proportion rather than a count. Supply `counts = TRUE` to `set_engine` to supply this argument as a count rather than a proportion.

See `?train_lightgbm` for more details.
`mtry` must be less than or equal to 1, not 3.
i `mtry` is currently being interpreted as a proportion rather than a count.
i Supply `counts = TRUE` to `set_engine()` to supply this argument as a count rather than a proportion.
i See `?train_lightgbm()` for more details.

---

Expand All @@ -47,21 +45,20 @@
Condition
Warning:
The argument `feature_fraction_bynode` cannot be manually modified and was removed.
Error:
! The supplied argument `mtry = 0.5` must be greater than or equal to 1.
`mtry` is currently being interpreted as a count rather than a proportion. Supply `counts = FALSE` to `set_engine` to supply this argument as a proportion rather than a count.
See `?train_lightgbm` for more details.
Error in `fit()`:
! `mtry` must be greater than or equal to 1, not 0.5.
i `mtry` is currently being interpreted as a count rather than a proportion.
i Supply `counts = FALSE` to `set_engine()` to supply this argument as a proportion rather than a count.
i See `?train_lightgbm()` for more details.

# tuning mtry vs mtry_prop

Code
boost_tree(mtry = tune::tune()) %>% set_engine("lightgbm") %>% set_mode(
"regression") %>% fit(bill_length_mm ~ ., data = penguins)
Condition
Error:
! The supplied `mtry` parameter is a call to `tune`. Did you forget to optimize hyperparameters with a tuning function like `tune::tune_grid`?
Error in `fit()`:
! `feature_fraction_bynode` must be a number, not a call.

# training wrapper warns on protected arguments

Expand Down

0 comments on commit 4d5c662

Please sign in to comment.