Skip to content

Commit

Permalink
refactor train_lightgbm() to apply dataset arguments (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch authored May 13, 2024
1 parent f40d55a commit a91e80a
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 67 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

v0.2.1.9000 is a developmental version of the bonsai package.

* Enabled passing [Dataset Parameters](https://lightgbm.readthedocs.io/en/latest/Parameters.html#dataset-parameters) to the `"lightgbm"` engine. To pass an argument that would be usually passed as an element to the `param` argument in `lightgbm::lgb.Dataset()`, pass the argument directly through the ellipses in `set_engine()`, e.g. `boost_tree() %>% set_engine("lightgbm", linear_tree = TRUE)` (#77).

* Enabled case weights with the `"lightgbm"` engine (#72 by `@p-schaefer`).

* Fixed issues in metadata for the `"partykit"` engine for `rand_forest()` where some engine arguments were mistakenly protected (#74).
Expand Down
152 changes: 85 additions & 67 deletions R/lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,49 +55,42 @@ train_lightgbm <- function(x, y, weights = NULL, max_depth = -1, num_iterations

check_lightgbm_aliases(...)

args <- list(
param = list(
# bonsai should be able to differentiate between
# 1) main arguments to `lgb.train()` (as in `names(formals(lgb.train))` other
# than `params`),
# 2) main arguments to `lgb.Dataset()` (as in `names(formals(lgb.Dataset))`
# other than `params`), and
# 3) arguments to pass to `lgb.train(params)` OR `lgb.Dataset(params)`.
# arguments to the `params` argument of either function can be concatenated
# together and passed to both (#77).
args <-
list(
num_iterations = num_iterations,
learning_rate = learning_rate,
max_depth = max_depth,
feature_fraction_bynode = feature_fraction_bynode,
min_data_in_leaf = min_data_in_leaf,
min_gain_to_split = min_gain_to_split,
bagging_fraction = bagging_fraction
),
main = list(
bagging_fraction = bagging_fraction,
early_stopping_round = early_stopping_round,
...
)
)

args <- process_bagging(args)
args <- process_parallelism(args)
args <- process_objective_function(args, x, y)

args <- sort_args(args)

if (!is.numeric(y)) {
y <- as.numeric(y) - 1
}

args <- process_parallelism(args)

args <- process_bagging(args, ...)

args <- process_data(args, x, y, weights, validation, missing(validation),
early_stopping_round)
args <- process_data(args, x, y, weights, validation, missing(validation))

args <- sort_args(args)
compacted <- c(list(params = args$params), args$main_args_train)

if (!"verbose" %in% names(args$main)) {
args$main$verbose <- 1L
}

compacted <-
c(
list(param = args$param),
args$main[names(args$main) != "data"],
list(data = quote(args$main$data))
)

call <- parsnip::make_call(fun = "lgb.train", ns = "lightgbm", compacted)
call <- rlang::call2("lgb.train", !!!compacted, .ns = "lightgbm")

if (quiet) {
junk <- utils::capture.output(res <- rlang::eval_tidy(call, env = rlang::current_env()))
Expand Down Expand Up @@ -152,51 +145,46 @@ process_mtry <- function(feature_fraction_bynode, counts, x, is_missing) {

process_objective_function <- function(args, x, y) {
# set the "objective" param argument, clear it out from main args
if (!any(names(args$main) %in% c("objective"))) {
if (!any(names(args) %in% c("objective"))) {
if (is.numeric(y)) {
args$param$objective <- "regression"
args$objective <- "regression"
} else {
lvl <- levels(y)
lvls <- length(lvl)
if (lvls == 2) {
args$param$num_class <- 1
args$param$objective <- "binary"
args$num_class <- 1
args$objective <- "binary"
} else {
args$param$num_class <- lvls
args$param$objective <- "multiclass"
args$num_class <- lvls
args$objective <- "multiclass"
}
}
} else {
args$param$objective <- args$main$objective
}

args$main$objective <- NULL

args
}

# supply the number of threads as num_threads in params, clear out
# any other thread args that might be passed as main arguments
process_parallelism <- function(args) {
if (!is.null(args$main["num_threads"])) {
args$param$num_threads <- args$main[names(args$main) == "num_threads"]
args$main[names(args$main) == "num_threads"] <- NULL
if (!is.null(args["num_threads"])) {
args$num_threads <- args[names(args) == "num_threads"]
args[names(args) == "num_threads"] <- NULL
}

args
}

process_bagging <- function(args, ...) {
if (args$param$bagging_fraction != 1 &&
(!"bagging_freq" %in% names(list(...)))) {
args$param$bagging_freq <- 1
process_bagging <- function(args) {
if (args$bagging_fraction != 1 &&
(!"bagging_freq" %in% names(args))) {
args$bagging_freq <- 1
}

args
}

process_data <- function(args, x, y, weights, validation, missing_validation,
early_stopping_round) {
process_data <- function(args, x, y, weights, validation, missing_validation) {
# trn_index | val_index
# ----------------------------------
# needs_validation & missing_validation | 1:n 1:n
Expand All @@ -205,7 +193,12 @@ process_data <- function(args, x, y, weights, validation, missing_validation,
# !needs_validation & !missing_validation | sample(1:n, m) setdiff(trn_index, 1:n)

n <- nrow(x)
needs_validation <- !is.null(early_stopping_round)
needs_validation <- !is.null(args$params$early_stopping_round)
if (!needs_validation) {
# If early_stopping_round isn't set, clear it from arguments actually
# passed to LightGBM.
args$params$early_stopping_round <- NULL
}

if (missing_validation) {
trn_index <- 1:n
Expand All @@ -220,61 +213,86 @@ process_data <- function(args, x, y, weights, validation, missing_validation,
val_index <- setdiff(1:n, trn_index)
}

args$main$data <-
lightgbm::lgb.Dataset(
data = prepare_df_lgbm(x[trn_index, , drop = FALSE]),
label = y[trn_index],
categorical_feature = categorical_columns(x[trn_index, , drop = FALSE]),
params = list(feature_pre_filter = FALSE),
weight = weights[trn_index]
data_args <-
c(
list(
data = prepare_df_lgbm(x[trn_index, , drop = FALSE]),
label = y[trn_index],
categorical_feature = categorical_columns(x[trn_index, , drop = FALSE]),
params = c(list(feature_pre_filter = FALSE), args$params),
weight = weights[trn_index]
),
args$main_args_dataset
)

args$main_args_train$data <-
rlang::eval_bare(
rlang::call2("lgb.Dataset", !!!data_args, .ns = "lightgbm")
)

if (!is.null(val_index)) {
args$main$valids <-
list(validation =
lightgbm::lgb.Dataset(
valids_args <-
c(
list(
data = prepare_df_lgbm(x[val_index, , drop = FALSE]),
label = y[val_index],
categorical_feature = categorical_columns(x[val_index, , drop = FALSE]),
params = list(feature_pre_filter = FALSE),
params = list(feature_pre_filter = FALSE, args$params),
weight = weights[val_index]
),
args$main_args_dataset
)

args$main_args_train$valids <-
list(
validation =
rlang::eval_bare(
rlang::call2("lgb.Dataset", !!!valids_args, .ns = "lightgbm")
)
)
}

args
}

# transfers arguments between param and main arguments
# identifies supplied arguments as destined for `lgb.Dataset()`, `lgb.train()`,
# or the `params` argument to both of the above (#77).
sort_args <- function(args) {
# warn on arguments that won't be passed along
protected <- c("obj", "init_model", "colnames",
"categorical_feature", "callbacks", "reset_data")

if (any(names(args$main) %in% protected)) {
protected_args <- names(args$main[names(args$main) %in% protected])
if (any(names(args) %in% protected)) {
protected_args <- names(args[names(args) %in% protected])

rlang::warn(
glue::glue(
"The following argument(s) are guarded by bonsai and will not ",
"be passed to `lgb.train`: {glue::glue_collapse(protected_args, sep = ', ')}"
"be passed to LightGBM: {glue::glue_collapse(protected_args, sep = ', ')}"
)
)

args$main[protected_args] <- NULL
args[protected_args] <- NULL
}

# dots are deprecated in lgb.train -- pass to param instead
to_main <- c("nrounds", "eval", "verbose", "record", "eval_freq",
"early_stopping_round", "data", "valids")

args$param <- c(args$param, args$main[!names(args$main) %in% to_main])
main_args_dataset <- main_args(lightgbm::lgb.Dataset)
main_args_train <- main_args(lightgbm::lgb.train)

args$main[!names(args$main) %in% to_main] <- NULL
args <-
list(
main_args_dataset = args[names(args) %in% main_args_dataset],
main_args_train = args[names(args) %in% main_args_train],
params = args[!names(args) %in% c(main_args_dataset, main_args_train)]
)

args
}

main_args <- function(fn) {
res <- names(formals(fn))
res[res != "params"]
}

# in lightgbm <= 3.3.2, predict() for multiclass classification produced a single
# vector of length num_observations * num_classes, in row-major order
#
Expand Down
62 changes: 62 additions & 0 deletions tests/testthat/test-lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,68 @@ test_that("boost_tree with lightgbm",{
expect_equal(pars_preds_6_b, lgbm_preds_6)
})

test_that("bonsai applies dataset parameters (#77)", {
skip_if_not_installed("lightgbm")
skip_if_not_installed("modeldata")

suppressPackageStartupMessages({
library(lightgbm)
library(dplyr)
})

data("penguins", package = "modeldata")

penguins <- penguins[complete.cases(penguins),]

# regression -----------------------------------------------------------------
expect_error_free({
pars_fit_1 <-
boost_tree() %>%
set_engine("lightgbm", linear_tree = TRUE) %>%
set_mode("regression") %>%
fit(bill_length_mm ~ ., data = penguins)
})

expect_error_free({
pars_preds_1 <-
predict(pars_fit_1, penguins)
})

peng <-
penguins %>%
mutate(across(where(is.character), ~as.factor(.x))) %>%
mutate(across(where(is.factor), ~as.integer(.x) - 1))

peng_y <- peng$bill_length_mm

peng_m <- peng %>%
select(-bill_length_mm) %>%
as.matrix()

peng_x <-
lgb.Dataset(
data = peng_m,
label = peng_y,
params = list(feature_pre_filter = FALSE, linear_tree = TRUE),
categorical_feature = c(1L, 2L, 6L)
)

params_1 <- list(
objective = "regression"
)

lgbm_fit_1 <-
lightgbm::lgb.train(
data = peng_x,
params = params_1,
verbose = -1
)

lgbm_preds_1 <- predict(lgbm_fit_1, peng_m)

expect_equal(pars_preds_1$.pred, lgbm_preds_1)
expect_true(pars_fit_1$fit$params$linear_tree)
})

test_that("bonsai correctly determines objective when label is a factor", {
skip_if_not_installed("lightgbm")
Expand Down

0 comments on commit a91e80a

Please sign in to comment.