Skip to content

Commit

Permalink
Merge pull request #102 from mlr-org/fix_catboost
Browse files Browse the repository at this point in the history
Fix catboost, closes #100
  • Loading branch information
RaphaelS1 authored Aug 31, 2021
2 parents 327b556 + b68cc16 commit ff5c013
Show file tree
Hide file tree
Showing 16 changed files with 164 additions and 74 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/rcmdcheck.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ jobs:
pak::pkg_install("rcmdcheck")
shell: Rscript {0}

- name: Install CatBoost
- name: Install catboost
run: |
install.packages("remotes")
remotes::install_url('https://github.com/catboost/catboost/releases/download/v0.24.1/catboost-R-Linux-0.24.1.tgz', INSTALL_opts = c("--no-multiarch"))
remotes::install_url('https://github.com/catboost/catboost/releases/download/v0.26.1/catboost-R-Linux-0.26.1.tgz', INSTALL_opts = c("--no-multiarch"))
shell: Rscript {0}

- name: Install Python
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,4 @@ docs/

.ccache/
.vscode/settings.json
catboost_info/
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3extralearners
Title: Extra Learners For mlr3
Version: 0.5.2
Version: 0.5.3
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ export(LearnerSurvPenalized)
export(LearnerSurvRandomForestSRC)
export(LearnerSurvSVM)
export(create_learner)
export(install_catboost)
export(install_learners)
export(list_mlr3learners)
export(lrn)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# mlr3extralearners 0.5.3

* Fixed bugs in catboost for classification
* Removed factor feature types from catboost
* Added `install_catboost` to make installation from catboost simpler

# mlr3extralearners 0.5.2

* Fixed learner tests
Expand Down
50 changes: 50 additions & 0 deletions R/fn_install_catboost.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#' Helper function to install catboost
#' @title Install catboost
#' @param version `(character(1))`
#' Version to install, if `NULL` installs latest
#' @param os `(character(1))` \cr
#' Operating system to install on, if `NULL` automatically detected
#' @param install_required `(logical(1))` \cr
#' If `TRUE` (default) then installs required packages: {curl}, {jsonlite},
#' {devtools}
#' @param INSTALL_opts `(character())` \cr
#' Passed to [devtools::install_url]
#' @param ... `ANY` \cr
#' Other arguments passed to [devtools::install_url]
#' @export
install_catboost <- function(version = NULL, os = NULL,
install_required = TRUE,
INSTALL_opts = c("--no-multiarch",
"--no-test-load"), ...) {

if (is.null(version)) {

if (!requireNamespace("jsonlite", quietly = TRUE) && install_required) {
utils::install.packages("jsonlite", repos = "https://cloud.r-project.org")
}

if (!requireNamespace("curl", quietly = TRUE) && install_required) {
utils::install.packages("curl", repos = "https://cloud.r-project.org")
}

version <- jsonlite::fromJSON(
"https://api.github.com/repos/catboost/catboost/releases"
)$tag_name[1]
}

version <- gsub("v", "", version)

if (is.null(os)) {
os <- as.character(Sys.info()["sysname"])
}

url <- sprintf(
"https://github.com/catboost/catboost/releases/download/v%s/catboost-R-%s-%s.tgz",
version, os, version)

if (!requireNamespace("devtools", quietly = TRUE) && install_required) {
utils::install.packages("devtools", repos = "https://cloud.r-project.org")
}

devtools::install_url(url, INSTALL_opts = INSTALL_opts, ...)
}
57 changes: 22 additions & 35 deletions R/learner_catboost_classif_catboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
#' @templateVar id classif.catboost
#' @templateVar caller catboost.train
#'
#' @section Installation:
#' The easiest way to install catboost is with the helper function
#' [install_catboost].
#'
#' @section Custom mlr3 defaults:
#' - `logging_level`:
#' - Actual default: "Verbose"
Expand All @@ -25,11 +29,11 @@
#' - Reason for change: consistent with other mlr3 learners
#'
#' @references
#' CatBoost: unbiased boosting with categorical features.
#' catboost: unbiased boosting with categorical features.
#' Liudmila Prokhorenkova, Gleb Guse, Aleksandr Vorobev, Anna Veronika Dorogush and Andrey Gulin.
#' 2017. https://arxiv.org/abs/1706.09516.
#'
#' CatBoost: gradient boosting with categorical features support.
#' catboost: gradient boosting with categorical features support.
#' Anna Veronika Dorogush, Vasily Ershov and Andrey Gulin.
#' 2018. https://arxiv.org/abs/1810.11363.
#'
Expand Down Expand Up @@ -155,7 +159,7 @@ LearnerClassifCatboost = R6Class("LearnerClassifCatboost",
super$initialize(
id = "classif.catboost",
packages = "catboost",
feature_types = c("logical", "integer", "numeric", "factor", "ordered"),
feature_types = c("numeric", "factor", "ordered"),
predict_types = c("response", "prob"),
param_set = ps,
properties = c(
Expand All @@ -182,18 +186,14 @@ LearnerClassifCatboost = R6Class("LearnerClassifCatboost",

private = list(
.train = function(task) {
# integer/logical features must be converted to numerics explicitly

data = task$data(cols = task$feature_names)
to_numerics = task$feature_types$id[task$feature_types$type %in%
c("integer", "logical")]
if (length(to_numerics)) {
data[, (to_numerics) := lapply(.SD, as.numeric), .SDcols = to_numerics]
if (packageVersion('catboost') < '0.21') {
stop('catboost v0.21 or greater is required, update with install_catboost')
}

# target is encoded as integer values from 0
# if binary, the positive class is 1
is_binary = (length(task$class_names) == 2L)
is_binary = length(task$class_names) == 2L
label = if (is_binary) {
ifelse(task$data(cols = task$target_names)[[1L]] == task$positive,
yes = 1L,
Expand All @@ -204,7 +204,7 @@ LearnerClassifCatboost = R6Class("LearnerClassifCatboost",

# data must be a dataframe
learn_pool = mlr3misc::invoke(catboost::catboost.load_pool,
data = data.table::setDF(data),
data = task$data(cols = task$feature_names),
label = label,
weight = task$weights$weight,
thread_count = self$param_set$values$thread_count)
Expand All @@ -219,29 +219,16 @@ LearnerClassifCatboost = R6Class("LearnerClassifCatboost",
pars$loss_function_twoclass = NULL
pars$loss_function_multiclass = NULL

mlr3misc::invoke(catboost::catboost.train,
learn_pool = learn_pool,
test_pool = NULL,
params = pars)
catboost::catboost.train(learn_pool, NULL, pars)
},

.predict = function(task) {
# integer/logical features must be converted to numerics explicitly

data = task$data(cols = task$feature_names)
to_numerics = task$feature_types$id[task$feature_types$type %in%
c("integer", "logical")]
if (length(to_numerics)) {
data[, (to_numerics) := lapply(.SD, as.numeric), .SDcols = to_numerics]
}

# target was encoded as integer values based on the train_task
# to later revert this, again use the train_task
is_binary = (length(self$state$train_task$class_names) == 2L)
is_binary = (length(task$class_names) == 2L)

# data must be a dataframe
pool = mlr3misc::invoke(catboost::catboost.load_pool,
data = data.table::setDF(data),
data = task$data(cols = task$feature_names),
thread_count = self$param_set$values$thread_count)

prediction_type = if (self$predict_type == "response") {
Expand All @@ -257,20 +244,20 @@ LearnerClassifCatboost = R6Class("LearnerClassifCatboost",

if (self$predict_type == "response") {
response = if (is_binary) {
ifelse(preds == 1L,
yes = self$state$train_task$positive,
no = setdiff(
self$state$train_task$class_names,
self$state$train_task$positive))
ifelse(preds == 1L, yes = task$positive, no = task$negative)
} else {
self$state$train_task$class_names[preds + 1L]
task$class_names[preds + 1L]
}
list(response = response)
list(response = as.character(unname(response)))
} else {

if (is_binary && is.null(dim(preds))) {
preds = matrix(c(preds, 1 - preds), ncol = 2L, nrow = length(preds))
colnames(preds) = c(task$positive, task$negative)
} else {
colnames(preds) = self$state$train_task$class_names
}
colnames(preds) = self$state$train_task$class_names

list(prob = preds)
}
}
Expand Down
34 changes: 14 additions & 20 deletions R/learner_catboost_regr_catboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
#' @templateVar id regr.catboost
#' @templateVar caller catboost.train
#'
#' @section Installation:
#' The easiest way to install catboost is with the helper function
#' [install_catboost].
#'
#' @section Custom mlr3 defaults:
#' - `logging_level`:
#' - Actual default: "Verbose"
Expand All @@ -25,11 +29,11 @@
#' - Reason for change: consistent with other mlr3 learners
#'
#' @references
#' CatBoost: unbiased boosting with categorical features.
#' catboost: unbiased boosting with categorical features.
#' Liudmila Prokhorenkova, Gleb Guse, Aleksandr Vorobev, Anna Veronika Dorogush and Andrey Gulin.
#' 2017. https://arxiv.org/abs/1706.09516.
#'
#' CatBoost: gradient boosting with categorical features support.
#' catboost: gradient boosting with categorical features support.
#' Anna Veronika Dorogush, Vasily Ershov and Andrey Gulin.
#' 2018. https://arxiv.org/abs/1810.11363.
#'
Expand Down Expand Up @@ -149,7 +153,7 @@ LearnerRegrCatboost = R6Class("LearnerRegrCatboost",
super$initialize(
id = "regr.catboost",
packages = "catboost",
feature_types = c("logical", "integer", "numeric", "factor", "ordered"),
feature_types = c("numeric", "factor", "ordered"),
predict_types = "response",
param_set = ps,
properties = c(
Expand All @@ -176,17 +180,16 @@ LearnerRegrCatboost = R6Class("LearnerRegrCatboost",

private = list(
.train = function(task) {
# integer/logical features must be converted to numerics explicitly
data = task$data(cols = task$feature_names)
to_numerics = task$feature_types$id[task$feature_types$type %in%
c("integer", "logical")]
if (length(to_numerics)) {
data[, (to_numerics) := lapply(.SD, as.numeric), .SDcols = to_numerics]

if (packageVersion('catboost') < '0.21') {
stop('catboost v0.21 or greater is required, update with install_catboost')
}

self$state$feature_names = task$feature_names

# data must be a dataframe
learn_pool = mlr3misc::invoke(catboost::catboost.load_pool,
data = data.table::setDF(data),
data = task$data(cols = task$feature_names),
label = task$data(cols = task$target_names)[[1L]],
weight = task$weights$weight,
thread_count = self$param_set$values$thread_count)
Expand All @@ -198,18 +201,9 @@ LearnerRegrCatboost = R6Class("LearnerRegrCatboost",
},

.predict = function(task) {
# integer/logical features must be converted to numerics explicitly

data = task$data(cols = task$feature_names)
to_numerics = task$feature_types$id[task$feature_types$type %in%
c("integer", "logical")]
if (length(to_numerics)) {
data[, (to_numerics) := lapply(.SD, as.numeric), .SDcols = to_numerics]
}

# data must be a dataframe
pool = mlr3misc::invoke(catboost::catboost.load_pool,
data = data.table::setDF(data),
data = task$data(cols = self$state$feature_names),
thread_count = self$param_set$values$thread_count)

preds = mlr3misc::invoke(catboost::catboost.predict,
Expand Down
Binary file modified R/sysdata.rda
Binary file not shown.
5 changes: 3 additions & 2 deletions inst/paramtest/test_paramtest_catboost_classif_catboost.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
library(mlr3extralearners)
remotes::install_url('https://github.com/catboost/catboost/releases/download/v0.24.1/catboost-R-Linux-0.24.1.tgz', # nolint
INSTALL_opts = c("--no-multiarch"))
if (!requireNamespace("catboost", quietly = TRUE)) {
install_catboost("0.26.1")
}

test_that("classif.catboost_catboost.train", {
learner = lrn("classif.catboost")
Expand Down
5 changes: 3 additions & 2 deletions inst/paramtest/test_paramtest_catboost_regr_catboost.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
library(mlr3extralearners)
remotes::install_url('https://github.com/catboost/catboost/releases/download/v0.24.1/catboost-R-Linux-0.24.1.tgz', # nolint
INSTALL_opts = c("--no-multiarch"))
if (!requireNamespace("catboost", quietly = TRUE)) {
install_catboost("0.26.1")
}


test_that("regr.catboost_catboost.train", {
Expand Down
34 changes: 34 additions & 0 deletions man/install_catboost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 9 additions & 3 deletions man/mlr_learners_classif.catboost.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit ff5c013

Please sign in to comment.