Skip to content

Commit

Permalink
Merge branch 'master' into xgboost_args
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jan 7, 2025
2 parents 46456cf + a475327 commit a9ed114
Show file tree
Hide file tree
Showing 35 changed files with 498 additions and 102 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/r_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ jobs:
uses: actions/cache@v4
with:
path: ${{ env.R_LIBS_USER }}
key: ${{ runner.os }}-r-${{ matrix.r }}-7-${{ hashFiles('R-package/DESCRIPTION') }}
restore-keys: ${{ runner.os }}-r-${{ matrix.r }}-7-${{ hashFiles('R-package/DESCRIPTION') }}
key: ${{ runner.os }}-r-${{ matrix.r }}-8-${{ hashFiles('R-package/DESCRIPTION') }}
restore-keys: ${{ runner.os }}-r-${{ matrix.r }}-8-${{ hashFiles('R-package/DESCRIPTION') }}
- uses: actions/setup-python@v5
with:
python-version: "3.10"
Expand Down
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,8 @@ if(GOOGLE_TEST)
configure_file(
${xgboost_SOURCE_DIR}/tests/cli/machine.conf.in
${xgboost_BINARY_DIR}/tests/cli/machine.conf
@ONLY)
@ONLY
NEWLINE_STYLE UNIX)
if(BUILD_DEPRECATED_CLI)
add_test(
NAME TestXGBoostCLI
Expand Down
165 changes: 131 additions & 34 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -462,9 +462,8 @@ NULL
#' could in theory change again in the future, so XGBoost's serializers should be
#' preferred for long-term storage.
#'
#' Furthermore, note that using the package `qs` for serialization will require
#' version 0.26 or higher of said package, and will have the same compatibility
#' restrictions as R serializers.
#' Furthermore, note that model objects from XGBoost might not be serializable with third-party
#' R packages like `qs` or `qs2`.
#'
#' @details
#' Use [xgb.save()] to save the XGBoost model as a stand-alone file. You may opt into
Expand Down Expand Up @@ -527,32 +526,73 @@ NULL
#' @name a-compatibility-note-for-saveRDS-save
NULL

#' @name xgboost-options
#' @title XGBoost Options
#' @description XGBoost offers an \link[base:options]{option setting} for controlling the behavior
#' of deprecated and removed function arguments.
#'
#' Some of the arguments in functions like [xgb.train()] or [predict.xgb.Booster()] been renamed
#' from how they were in previous versions, or have been removed.
#'
#' In order to make the transition to newer XGBoost versions easier, some of these parameters are
#' still accepted but issue a warning when using them. \bold{Note that these warnings will become
#' errors in the future!!} - this is just a temporary workaround to make the transition easier.
#'
#' One can optionally use 'strict mode' to turn these warnings into errors, in order to ensure
#' that code calling xgboost will still work once those are removed in future releases.
#'
#' Currently, the only supported option is `xgboost.strict_mode`, which can be set to `TRUE` or
#' `FALSE` (default).
#' @examples
#' options("xgboost.strict_mode" = FALSE)
#' options("xgboost.strict_mode" = TRUE)
NULL

# Lookup table for the deprecated parameters bookkeeping
deprecated_train_params <- list(
'print.every.n' = 'print_every_n',
'early.stop.round' = 'early_stopping_rounds',
'training.data' = 'data',
'dtrain' = 'data',
'watchlist' = 'evals',
'feval' = 'custom_metric'
renamed = list(
'print.every.n' = 'print_every_n',
'early.stop.round' = 'early_stopping_rounds',
'training.data' = 'data',
'dtrain' = 'data',
'watchlist' = 'evals',
'feval' = 'custom_metric'
),
removed = character()
)
deprecated_dttree_params <- list(
'n_first_tree' = 'trees'
renamed = list('n_first_tree' = 'trees'),
removed = c("feature_names", "text")
)
deprecated_plot_params <- list(
'plot.height' = 'plot_height',
'plot.width' = 'plot_width'
deprecated_plotimp_params <- list(
renamed = list(
'plot.height' = 'plot_height',
'plot.width' = 'plot_width'
),
removed = character()
)
deprecated_multitrees_params <- c(
deprecated_plot_params,
list('features.keep' = 'features_keep')
deprecated_multitrees_params <- list(
renamed = c(
deprecated_plotimp_params$renamed,
list('features.keep' = 'features_keep')
),
removed = "feature_names"
)
deprecated_dump_params <- list(
'with.stats' = 'with_stats'
renamed = list('with.stats' = 'with_stats'),
removed = character()
)
deprecated_plottree_params <- c(
deprecated_plot_params,
deprecated_dump_params
renamed = list(
deprecated_plotimp_params$renamed,
deprecated_dump_params$renamed,
list('trees' = 'tree_idx')
),
removed = c("show_node_id", "feature_names")
)
deprecated_predict_params <- list(
renamed = list("ntreelimit" = "iterationrange"),
removed = "reshape"
)

# Checks the dot-parameters for deprecated names
Expand All @@ -570,42 +610,99 @@ check.deprecation <- function(
if (length(params) == 0) {
return(NULL)
}
error_on_deprecated <- getOption("xgboost.strict_mode", default = FALSE)
throw_err_or_depr_msg <- function(...) {
if (error_on_deprecated) {
stop(...)
} else {
warning(..., " This warning will become an error in a future version.")
}
}

if (is.null(names(params)) || min(nchar(names(params))) == 0L) {
stop("Passed invalid positional arguments")
throw_err_or_depr_msg("Passed invalid positional arguments")
}
all_match <- pmatch(names(params), names(deprecated_list))
list_renamed <- deprecated_list$renamed
list_removed <- deprecated_list$removed
has_params_arg <- list_renamed[[1L]] == deprecated_train_params$renamed[[1L]]
all_match <- pmatch(names(params), names(list_renamed))
# throw error on unrecognized parameters
if (!allow_unrecognized && anyNA(all_match)) {

names_unrecognized <- names(params)[is.na(all_match)]
# make it informative if they match something that goes under 'params'
if (deprecated_list[[1L]] == deprecated_train_params[[1L]]) {
if (has_params_arg) {
names_params <- formalArgs(xgb.params)
names_params <- c(names_params, gsub("_", ".", names_params, fixed = TRUE))
names_under_params <- intersect(names_unrecognized, names_params)
if (length(names_under_params)) {
stop(
"Passed invalid function arguments: ",
paste(head(names_under_params), collapse = ", "),
". These should be passed as a list to argument 'params'."
)
if (error_on_deprecated) {
stop(
"Passed invalid function arguments: ",
paste(head(names_under_params), collapse = ", "),
". These should be passed as a list to argument 'params'."
)
} else {
warning(
"Passed invalid function arguments: ",
paste(head(names_under_params), collapse = ", "),
". These should be passed as a list to argument 'params'.",
" Conversion from argument to 'params' entry will be done automatically, but this ",
"behavior will become an error in a future version."
)
if (any(names_under_params %in% names(env[["params"]]))) {
repeteated_params <- intersect(names_under_params, names(env[["params"]]))
stop(
"Passed entries as both function argument(s) and as elements under 'params': ",
paste(head(repeteated_params), collapse = ", ")
)
} else {
env[["params"]] <- c(env[["params"]], params[names_under_params])
}
}
names_unrecognized <- setdiff(names_unrecognized, names_under_params)
}
}

# check for parameters that were removed from a previous version
names_removed <- intersect(names_unrecognized, list_removed)
if (length(names_removed)) {
throw_err_or_depr_msg(
"Parameter(s) have been removed from this function: ",
paste(names_removed, collapse = ", "), "."
)
names_unrecognized <- setdiff(names_unrecognized, list_removed)
}

# otherwise throw a generic error
stop(
"Passed unrecognized parameters: ",
paste(head(names_unrecognized), collapse = ", ")
)
if (length(names_unrecognized)) {
throw_err_or_depr_msg(
"Passed unrecognized parameters: ",
paste(head(names_unrecognized), collapse = ", ")
)
}

} else {

names_removed <- intersect(names(params)[is.na(all_match)], list_removed)
if (length(names_removed)) {
throw_err_or_depr_msg(
"Parameter(s) have been removed from this function: ",
paste(names_removed, collapse = ", "), "."
)
}

}

matched_params <- deprecated_list[all_match[!is.na(all_match)]]
matched_params <- list_renamed[all_match[!is.na(all_match)]]
idx_orig <- seq_along(params)[!is.na(all_match)]
function_args_passed <- names(as.list(fn_call))[-1L]
for (idx in seq_along(matched_params)) {
match_old <- names(matched_params)[[idx]]
match_new <- matched_params[[idx]]
warning(
throw_err_or_depr_msg(
"Parameter '", match_old, "' has been renamed to '",
match_new, "' and will be removed in a future version."
match_new, "'."
)
if (match_new %in% function_args_passed) {
stop("Passed both '", match_new, "' and '", match_old, "'.")
Expand Down
4 changes: 1 addition & 3 deletions R-package/R/xgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,7 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
training = FALSE, iterationrange = NULL, strict_shape = FALSE, avoid_transpose = FALSE,
validate_features = FALSE, base_margin = NULL, ...) {
if (NROW(list(...))) {
warning("Passed unused prediction arguments: ", paste(names(list(...)), collapse = ", "), ".")
}
check.deprecation(deprecated_predict_params, match.call(), ..., allow_unrecognized = TRUE)
if (validate_features) {
newdata <- validate.features(object, newdata)
}
Expand Down
33 changes: 28 additions & 5 deletions R-package/R/xgb.DMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,9 @@ xgb.QuantileDMatrix <- function(
)
data_iterator <- .single.data.iterator(iterator_env)

env_keep_alive <- new.env()
env_keep_alive$keepalive <- NULL

# Note: the ProxyDMatrix has its finalizer assigned in the R externalptr
# object, but that finalizer will only be called once the object is
# garbage-collected, which doesn't happen immediately after it goes out
Expand All @@ -363,9 +366,10 @@ xgb.QuantileDMatrix <- function(
.Call(XGDMatrixFree_R, proxy_handle)
})
iterator_next <- function() {
return(xgb.ProxyDMatrix(proxy_handle, data_iterator))
return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive))
}
iterator_reset <- function() {
env_keep_alive$keepalive <- NULL
return(data_iterator$f_reset(iterator_env))
}
calling_env <- environment()
Expand Down Expand Up @@ -553,7 +557,8 @@ xgb.DataBatch <- function(
}

# This is only for internal usage, class is not exposed to the user.
xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) {
xgb.ProxyDMatrix <- function(proxy_handle, data_iterator, env_keep_alive) {
env_keep_alive$keepalive <- NULL
lst <- data_iterator$f_next(data_iterator$env)
if (is.null(lst)) {
return(0L)
Expand All @@ -566,13 +571,19 @@ xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) {
stop("Either one of 'group' or 'qid' should be NULL")
}
if (is.data.frame(lst$data)) {
tmp <- .process.df.for.dmatrix(lst$data, lst$feature_types)
data <- lst$data
lst$data <- NULL
tmp <- .process.df.for.dmatrix(data, lst$feature_types)
lst$feature_types <- tmp$feature_types
data <- NULL
env_keep_alive$keepalive <- tmp
.Call(XGProxyDMatrixSetDataColumnar_R, proxy_handle, tmp$lst)
} else if (is.matrix(lst$data)) {
env_keep_alive$keepalive <- lst
.Call(XGProxyDMatrixSetDataDense_R, proxy_handle, lst$data)
} else if (inherits(lst$data, "dgRMatrix")) {
tmp <- list(p = lst$data@p, j = lst$data@j, x = lst$data@x, ncol = ncol(lst$data))
env_keep_alive$keepalive <- tmp
.Call(XGProxyDMatrixSetDataCSR_R, proxy_handle, tmp)
} else {
stop("'data' has unsupported type.")
Expand Down Expand Up @@ -712,14 +723,23 @@ xgb.ExtMemDMatrix <- function(
cache_prefix <- path.expand(cache_prefix)
nthread <- as.integer(NVL(nthread, -1L))

# The purpose of this environment is to keep data alive (protected from the
# garbage collector) after setting the data in the proxy dmatrix. The data
# held here (under name 'keepalive') should be unset (leaving it unprotected
# for garbage collection) before the start of each data iteration batch and
# during each iterator reset.
env_keep_alive <- new.env()
env_keep_alive$keepalive <- NULL

proxy_handle <- .make.proxy.handle()
on.exit({
.Call(XGDMatrixFree_R, proxy_handle)
})
iterator_next <- function() {
return(xgb.ProxyDMatrix(proxy_handle, data_iterator))
return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive))
}
iterator_reset <- function() {
env_keep_alive$keepalive <- NULL
return(data_iterator$f_reset(data_iterator$env))
}
calling_env <- environment()
Expand Down Expand Up @@ -779,14 +799,17 @@ xgb.QuantileDMatrix.from_iterator <- function( # nolint

nthread <- as.integer(NVL(nthread, -1L))

env_keep_alive <- new.env()
env_keep_alive$keepalive <- NULL
proxy_handle <- .make.proxy.handle()
on.exit({
.Call(XGDMatrixFree_R, proxy_handle)
})
iterator_next <- function() {
return(xgb.ProxyDMatrix(proxy_handle, data_iterator))
return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive))
}
iterator_reset <- function() {
env_keep_alive$keepalive <- NULL
return(data_iterator$f_reset(data_iterator$env))
}
calling_env <- environment()
Expand Down
2 changes: 1 addition & 1 deletion R-package/R/xgb.plot.importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
#' @export
xgb.plot.importance <- function(importance_matrix = NULL, top_n = NULL, measure = NULL,
rel_to_first = FALSE, left_margin = 10, cex = NULL, plot = TRUE, ...) {
check.deprecation(deprecated_plot_params, match.call(), ..., allow_unrecognized = TRUE)
check.deprecation(deprecated_plotimp_params, match.call(), ..., allow_unrecognized = TRUE)
if (!is.data.table(importance_matrix)) {
stop("importance_matrix: must be a data.table")
}
Expand Down
19 changes: 13 additions & 6 deletions R-package/R/xgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,18 @@
#' [xgb.save()] (but are kept when using R serializers like [saveRDS()]).
#' @param ... Not used.
#'
#' Some arguments are currently deprecated or have been renamed. If a deprecated argument
#' is passed, will throw a warning and use its current equivalent.
#' Some arguments that were part of this function in previous XGBoost versions are currently
#' deprecated or have been renamed. If a deprecated or renamed argument is passed, will throw
#' a warning (by default) and use its current equivalent instead. This warning will become an
#' error if using the \link[=xgboost-options]{'strict mode' option}.
#'
#' If some additional argument is passed that is neither a current function argument nor
#' a deprecated argument, an error will be thrown.
#' a deprecated or renamed argument, a warning or error will be thrown depending on the
#' 'strict mode' option.
#'
#' \bold{Important:} `...` will be removed in a future version, and all the current
#' deprecation warnings will become errors. Please use only arguments that form part of
#' the function signature.
#' @return An object of class `xgb.Booster`.
#' @details
#' Compared to [xgboost()], the `xgb.train()` interface supports advanced features such as
Expand Down Expand Up @@ -453,10 +460,10 @@ xgb.train <- function(params = xgb.params(), data, nrounds, evals = list(),
#' @param seed Random number seed. If not specified, will take a random seed through R's own RNG engine.
#' @param booster (default= `"gbtree"`)
#' Which booster to use. Can be `"gbtree"`, `"gblinear"` or `"dart"`; `"gbtree"` and `"dart"` use tree based models while `"gblinear"` uses linear functions.
#' @param eta,learning_rate (two aliases for the same parameter) (default=0.3)
#' @param eta,learning_rate (two aliases for the same parameter)
#' Step size shrinkage used in update to prevent overfitting. After each boosting step, we can directly get the weights of new features, and `eta` shrinks the feature weights to make the boosting process more conservative.
#'
#' range: \eqn{[0,1]}
#' - range: \eqn{[0,1]}
#' - default value: 0.3 for tree-based boosters, 0.5 for linear booster.
#'
#' Note: should only pass one of `eta` or `learning_rate`. Both refer to the same parameter and there's thus no difference between one or the other.
#' @param gamma,min_split_loss (two aliases for the same parameter) (for Tree Booster) (default=0, alias: `gamma`)
Expand Down
Loading

0 comments on commit a9ed114

Please sign in to comment.