Skip to content

Commit

Permalink
Informative error message when no outcomes in early stopping set xgbo…
Browse files Browse the repository at this point in the history
…ost (#450)

Closes #447
  • Loading branch information
egillax authored May 3, 2024
1 parent ca903a8 commit 1e315e4
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 8 deletions.
16 changes: 9 additions & 7 deletions R/GradientBoostingMachine.R
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,13 @@ fitXgboost <- function(
hyperParameters,
settings
){

if(!is.null(hyperParameters$earlyStopRound)){
set.seed(settings$seed)
if (!is.null(hyperParameters$earlyStopRound)) {
trainInd <- sample(nrow(dataMatrix), nrow(dataMatrix)*0.9)
if (sum(labels$outcomeCount[-trainInd]) == 0) {
stop("No outcomes in early stopping set, either increase size of training
set or turn off early stopping")
}
train <- xgboost::xgb.DMatrix(
data = dataMatrix[trainInd,, drop = F],
label = labels$outcomeCount[trainInd]
Expand All @@ -215,20 +219,19 @@ fitXgboost <- function(
data = dataMatrix[-trainInd,, drop = F],
label = labels$outcomeCount[-trainInd]
)
watchlist <- list(train=train, test=test)
watchlist <- list(train = train, test = test)

} else{
} else {
train <- xgboost::xgb.DMatrix(
data = dataMatrix,
label = labels$outcomeCount
)
watchlist <- list()
}

outcomes <- sum(labels$outcomeCount>0)
outcomes <- sum(labels$outcomeCount > 0)
N <- nrow(labels)
outcomeProportion <- outcomes/N
set.seed(settings$seed)
model <- xgboost::xgb.train(
data = train,
params = list(
Expand All @@ -240,7 +243,6 @@ fitXgboost <- function(
lambda = hyperParameters$lambda,
alpha = hyperParameters$alpha,
objective = "binary:logistic",
#eval.metric = "logloss"
base_score = outcomeProportion,
eval_metric = "auc"
),
Expand Down
31 changes: 30 additions & 1 deletion tests/testthat/helper-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,33 @@ createTinyPlpData <- function(plpData, plpResult, n= 20) {
attributes(tinyPlpData)$metaData <- attributes(plpData)$metaData
class(tinyPlpData) <- class(plpData)
return(tinyPlpData)
}
}

createData <- function(observations, features, totalFeatures,
numCovs = FALSE,
outcomeRate = 0.5,
seed = 42) {
rowId <- rep(1:observations, each = features)
withr::with_seed(42, {
columnId <- sample(1:totalFeatures, observations * features, replace = TRUE)
})
covariateValue <- rep(1, observations * features)
covariates <- data.frame(rowId = rowId, columnId = columnId, covariateValue = covariateValue)
if (numCovs) {
numRow <- 1:observations
numCol <- rep(totalFeatures + 1, observations)
withr::with_seed(seed, {
numVal <- runif(observations)
})
numCovariates <- data.frame(rowId = as.integer(numRow),
columnId = as.integer(numCol),
covariateValue = numVal)
covariates <- rbind(covariates, numCovariates)
}
withr::with_seed(seed, {
labels <- as.numeric(sample(0:1, observations, replace = TRUE, prob = c(1 - outcomeRate, outcomeRate)))
})

data <- list(covariates = covariates, labels = labels)
return(data)
}
34 changes: 34 additions & 0 deletions tests/testthat/test-rclassifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,37 @@ test_that("GBM working checks", {
expect_equal(sum(abs(fitModel$covariateImportance$covariateValue))>0, TRUE)

})


test_that("GBM without outcomes in early stopping set errors", {
hyperParameters <- list(
ntrees = 10,
earlyStopRound = 2,
maxDepth = 3,
learnRate = 0.1,
minChildWeight = 1,
scalePosWeight = 1,
lambda = 1,
alpha = 0
)
observations <- 100
features <- 10
data <- createData(observations = observations, features = features,
totalFeatures = 10,
numCovs = FALSE, outcomeRate = 0.05)
dataMatrix <- Matrix::sparseMatrix(
i = data$covariates %>% dplyr::pull("rowId"),
j = data$covariates %>% dplyr::pull("columnId"),
x = data$covariates %>% dplyr::pull("covariateValue"),
dims = c(observations,features)
)
labels <- data.frame(outcomeCount = data$labels)
settings <- list(seed = 42, threads = 2)
expect_error(fitXgboost(dataMatrix = dataMatrix,
labels = labels,
hyperParameters = hyperParameters,
settings = settings),
regexp = "* or turn off early stopping")

})

0 comments on commit 1e315e4

Please sign in to comment.