From 1e315e4bf04dd2368c92f88c0b18fc6f45fa2857 Mon Sep 17 00:00:00 2001 From: Egill Axfjord Fridgeirsson Date: Fri, 3 May 2024 11:47:26 -0400 Subject: [PATCH] Informative error message when no outcomes in early stopping set xgboost (#450) Closes #447 --- R/GradientBoostingMachine.R | 16 ++++++++------- tests/testthat/helper-functions.R | 31 +++++++++++++++++++++++++++- tests/testthat/test-rclassifier.R | 34 +++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 8 deletions(-) diff --git a/R/GradientBoostingMachine.R b/R/GradientBoostingMachine.R index 7035ee44f..fc2a48151 100644 --- a/R/GradientBoostingMachine.R +++ b/R/GradientBoostingMachine.R @@ -204,9 +204,13 @@ fitXgboost <- function( hyperParameters, settings ){ - - if(!is.null(hyperParameters$earlyStopRound)){ + set.seed(settings$seed) + if (!is.null(hyperParameters$earlyStopRound)) { trainInd <- sample(nrow(dataMatrix), nrow(dataMatrix)*0.9) + if (sum(labels$outcomeCount[-trainInd]) == 0) { + stop("No outcomes in early stopping set, either increase size of training + set or turn off early stopping") + } train <- xgboost::xgb.DMatrix( data = dataMatrix[trainInd,, drop = F], label = labels$outcomeCount[trainInd] @@ -215,9 +219,9 @@ fitXgboost <- function( data = dataMatrix[-trainInd,, drop = F], label = labels$outcomeCount[-trainInd] ) - watchlist <- list(train=train, test=test) + watchlist <- list(train = train, test = test) - } else{ + } else { train <- xgboost::xgb.DMatrix( data = dataMatrix, label = labels$outcomeCount @@ -225,10 +229,9 @@ fitXgboost <- function( watchlist <- list() } - outcomes <- sum(labels$outcomeCount>0) + outcomes <- sum(labels$outcomeCount > 0) N <- nrow(labels) outcomeProportion <- outcomes/N - set.seed(settings$seed) model <- xgboost::xgb.train( data = train, params = list( @@ -240,7 +243,6 @@ fitXgboost <- function( lambda = hyperParameters$lambda, alpha = hyperParameters$alpha, objective = "binary:logistic", - #eval.metric = "logloss" base_score = outcomeProportion, eval_metric = "auc" ), diff --git a/tests/testthat/helper-functions.R b/tests/testthat/helper-functions.R index 3a44091e8..7170cf2aa 100644 --- a/tests/testthat/helper-functions.R +++ b/tests/testthat/helper-functions.R @@ -31,4 +31,33 @@ createTinyPlpData <- function(plpData, plpResult, n= 20) { attributes(tinyPlpData)$metaData <- attributes(plpData)$metaData class(tinyPlpData) <- class(plpData) return(tinyPlpData) -} \ No newline at end of file +} + +createData <- function(observations, features, totalFeatures, + numCovs = FALSE, + outcomeRate = 0.5, + seed = 42) { + rowId <- rep(1:observations, each = features) + withr::with_seed(42, { + columnId <- sample(1:totalFeatures, observations * features, replace = TRUE) + }) + covariateValue <- rep(1, observations * features) + covariates <- data.frame(rowId = rowId, columnId = columnId, covariateValue = covariateValue) + if (numCovs) { + numRow <- 1:observations + numCol <- rep(totalFeatures + 1, observations) + withr::with_seed(seed, { + numVal <- runif(observations) + }) + numCovariates <- data.frame(rowId = as.integer(numRow), + columnId = as.integer(numCol), + covariateValue = numVal) + covariates <- rbind(covariates, numCovariates) + } + withr::with_seed(seed, { + labels <- as.numeric(sample(0:1, observations, replace = TRUE, prob = c(1 - outcomeRate, outcomeRate))) + }) + + data <- list(covariates = covariates, labels = labels) + return(data) +} diff --git a/tests/testthat/test-rclassifier.R b/tests/testthat/test-rclassifier.R index 5a0dadc31..72f6d024d 100644 --- a/tests/testthat/test-rclassifier.R +++ b/tests/testthat/test-rclassifier.R @@ -116,3 +116,37 @@ test_that("GBM working checks", { expect_equal(sum(abs(fitModel$covariateImportance$covariateValue))>0, TRUE) }) + + +test_that("GBM without outcomes in early stopping set errors", { + hyperParameters <- list( + ntrees = 10, + earlyStopRound = 2, + maxDepth = 3, + learnRate = 0.1, + minChildWeight = 1, + scalePosWeight = 1, + lambda = 1, + alpha = 0 + ) + observations <- 100 + features <- 10 + data <- createData(observations = observations, features = features, + totalFeatures = 10, + numCovs = FALSE, outcomeRate = 0.05) + dataMatrix <- Matrix::sparseMatrix( + i = data$covariates %>% dplyr::pull("rowId"), + j = data$covariates %>% dplyr::pull("columnId"), + x = data$covariates %>% dplyr::pull("covariateValue"), + dims = c(observations,features) + ) + labels <- data.frame(outcomeCount = data$labels) + settings <- list(seed = 42, threads = 2) + expect_error(fitXgboost(dataMatrix = dataMatrix, + labels = labels, + hyperParameters = hyperParameters, + settings = settings), + regexp = "* or turn off early stopping") + +}) +