Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BART survival learner #290

Merged
merged 49 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
bd5f642
update doc links
bblodfon Aug 15, 2023
8fa59a2
add BART survival learner
bblodfon Aug 17, 2023
792b768
fix dnnsurv parameter test
bblodfon Aug 17, 2023
33b9269
add BART to 'Suggests'
bblodfon Aug 17, 2023
322f178
various small fixes
bblodfon Aug 17, 2023
33bc171
add more libraries to Suggests to run new BART example
bblodfon Aug 18, 2023
1871f0f
update doc
bblodfon Aug 18, 2023
989ff50
Update R/learner_BART_surv_bart.R
bblodfon Aug 21, 2023
4d5ea9d
Update R/learner_BART_surv_bart.R
bblodfon Aug 21, 2023
b82b1ed
Update R/learner_BART_surv_bart.R
bblodfon Aug 21, 2023
3722e3c
Update R/learner_BART_surv_bart.R
bblodfon Aug 21, 2023
387a41e
Update R/learner_BART_surv_bart.R
bblodfon Aug 21, 2023
f18d2d9
Update R/learner_BART_surv_bart.R
bblodfon Aug 21, 2023
c123fe1
Update R/learner_BART_surv_bart.R
bblodfon Aug 21, 2023
38dbed2
Update R/learner_BART_surv_bart.R
bblodfon Aug 21, 2023
6ab3570
Update R/learner_BART_surv_bart.R
bblodfon Aug 21, 2023
ffd61c6
Update R/learner_BART_surv_bart.R
bblodfon Aug 21, 2023
8af98e3
change K parameter type
bblodfon Aug 21, 2023
f795951
simplify and speed-up the creation of the survival matrix
bblodfon Aug 21, 2023
4673be7
add importance parameter, remove factor feature type
bblodfon Aug 21, 2023
b1af8e9
change tag for importance to train + fix small bug
bblodfon Aug 21, 2023
ef42564
update doc
bblodfon Aug 21, 2023
031065c
fix tests
bblodfon Aug 21, 2023
16330d8
change section name
bblodfon Aug 21, 2023
94b1742
remove BART example and extra libraries
bblodfon Aug 21, 2023
e3a3700
return model list slot and name refactoring
bblodfon Aug 21, 2023
d52c6fa
update doc
bblodfon Aug 22, 2023
c8041b4
store full posterior survival array (testing version)
bblodfon Sep 8, 2023
f162a31
update mlr3proba to 0.5.3 + refactoring
bblodfon Sep 11, 2023
1c8b201
Merge branch 'main' into main
sebffischer Sep 11, 2023
460b0e8
add which.curve parameter, defaults to 0.5 (median posterior)
bblodfon Sep 18, 2023
bdc09af
update doc
bblodfon Sep 19, 2023
25b9d8d
update BART test
bblodfon Sep 19, 2023
6e9bcff
remove code after checks (distr6 converts survival array correctly)
bblodfon Oct 5, 2023
21e5be9
better constraction of 'which.curve' parameter
bblodfon Oct 6, 2023
b7064ff
fix bug (which.curve was always NULL)
bblodfon Oct 6, 2023
8acb393
Merge branch 'main' into main
bblodfon Oct 16, 2023
ffbcdc0
Update R/learner_BART_surv_bart.R
bblodfon Oct 16, 2023
362c845
Update R/learner_BART_surv_bart.R
bblodfon Oct 16, 2023
1231d25
Update R/learner_BART_surv_bart.R
bblodfon Oct 16, 2023
7e30f2a
remove new parameter (to be corrected in another PR)
bblodfon Oct 16, 2023
eeff74f
changes after code review
bblodfon Oct 16, 2023
2c0d0a3
remove delayedAssign + add more review suggestions
bblodfon Oct 16, 2023
594ed58
explain better 'varcount.mean'
bblodfon Oct 16, 2023
4204c2d
small update of BART doc
bblodfon Oct 16, 2023
1e66dc9
add more doc for 'which.curve'
bblodfon Oct 16, 2023
83d09b9
small style changes
bblodfon Oct 16, 2023
3b4db43
fix hanging indent
bblodfon Oct 17, 2023
b3699a9
add no lint
bblodfon Oct 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Suggests:
aorsf (>= 0.0.5),
actuar,
apcluster,
BART (>= 2.9.4),
C50,
coin,
CoxBoost,
Expand Down Expand Up @@ -66,7 +67,7 @@ Suggests:
mgcv,
mlr3cluster,
mlr3learners (>= 0.4.2),
mlr3proba,
mlr3proba (>= 0.5.3),
mlr3pipelines,
mvtnorm,
nnet,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ export(LearnerSurvGAMBoost)
export(LearnerSurvGBM)
export(LearnerSurvGLMBoost)
export(LearnerSurvGlmnet)
export(LearnerSurvLearnerSurvBART)
export(LearnerSurvLogisticHazard)
export(LearnerSurvMBoost)
export(LearnerSurvNelson)
Expand Down
218 changes: 218 additions & 0 deletions R/learner_BART_surv_bart.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
#' @title Survival Bayesian Additive Regression Trees Learner
#' @author bblodfon
#' @name mlr_learners_surv.bart
#'
#' @description
#' Fits a Bayesian Additive Regression Trees (BART) learner to right-censored
#' survival data. Calls [BART::mc.surv.bart()] from \CRANpkg{BART}.
#'
#' @details
#' Two types of prediction are returned for this learner:
#' 1. `distr`: a 3d survival array with observations as 1st dimension, time
#' points as 2nd and the posterior draws as 3rd dimension.
#' 2. `crank`: the expected mortality using [mlr3proba::.surv_return]. The parameter
#' `which.curve` decides which posterior draw (3rd dimension) will be used for the
#' calculation of the expected mortality. Note that the median posterior is
#' by default used for the calculation of survival measures that require a `distr`
#' prediction, see more info on [PredictionSurv][mlr3proba::PredictionSurv].
#'
#' @section Custom mlr3 defaults:
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
#' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}.
#'
#' @section Custom mlr3 parameters:
#' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is
#' initialized to `TRUE`.
#' - `importance` allows to choose the type of importance. Default is `count`,
#' see documentation of method `$importance()` for more details.
#' - `which.curve` allows to choose which posterior draw will be used for the
#' calculation of the `crank` prediction. If between (0,1) it is taken as the
#' quantile of the curves otherwise if greater than 1 it is taken as the curve
#' index, can also be 'mean'. By default the **median posterior** is used,
#' i.e. `which.curve` is 0.5.
#'
#' @templateVar id surv.bart
#' @template learner
#'
#' @references
#' `r format_bib("sparapani2021nonparametric", "chipman2010bart")`
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
#'
#' @template seealso_learner
#' @template example
#' @export
LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART",
inherit = mlr3proba::LearnerSurv,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(
K = p_dbl(default = NULL, special_vals = list(NULL), lower = 1, tags = c("train", "predict")),
events = p_uty(default = NULL, tags = c("train", "predict")),
ztimes = p_uty(default = NULL, tags = c("train", "predict")),
zdelta = p_uty(default = NULL, tags = c("train", "predict")),
sparse = p_lgl(default = FALSE, tags = "train"),
theta = p_dbl(default = 0, tags = "train"),
omega = p_dbl(default = 1, tags = "train"),
a = p_dbl(default = 0.5, lower = 0.5, upper = 1, tags = "train"),
b = p_dbl(default = 1L, tags = "train"),
augment = p_lgl(default = FALSE, tags = "train"),
rho = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"),
usequants = p_lgl(default = FALSE, tags = "train"),
rm.const = p_lgl(default = TRUE, tags = "train"),
type = p_fct(levels = c("pbart", "lbart"), default = "pbart", tags = "train"),
ntype = p_int(lower = 1, upper = 3, tags = "train"),
k = p_dbl(default = 2.0, lower = 0, tags = "train"),
power = p_dbl(default = 2.0, lower = 0, tags = "train"),
base = p_dbl(default = 0.95, lower = 0, upper = 1, tags = "train"),
offset = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"),
ntree = p_int(default = 50L, lower = 1L, tags = "train"),
numcut = p_int(default = 100L, lower = 1L, tags = "train"),
ndpost = p_int(default = 1000L, lower = 1L, tags = "train"),
nskip = p_int(default = 250L, lower = 0L, tags = "train"),
keepevery = p_int(default = 10L, lower = 1L, tags = "train"),
printevery = p_int(default = 100L, lower = 1L, tags = "train"),
seed = p_int(default = 99L, tags = "train"),
mc.cores = p_int(default = 2L, lower = 1L, tags = c("train", "predict")),
nice = p_int(default = 19L, lower = 0L, upper = 19L, tags = c("train", "predict")),
openmp = p_lgl(default = TRUE, tags = "predict"),
quiet = p_lgl(default = TRUE, tags = "predict"),
importance = p_fct(default = "count", levels = c("count", "prob"), tags = "train"),
which.curve = p_dbl(lower = 0L, special_vals = list("mean"), tags = "predict")
)

# custom defaults
param_set$values = list(mc.cores = 1, quiet = TRUE, importance = "count",
which.curve = 0.5) # 0.5 quantile => median posterior

super$initialize(
id = "surv.bart",
packages = "BART",
feature_types = c("logical", "integer", "numeric"),
predict_types = c("crank", "distr"),
param_set = param_set,
properties = c("importance", "missings"),
man = "mlr3extralearners::mlr_learners_surv.bart",
label = "Bayesian Additive Regression Trees"
)
},

#' @description
#' Two types of importance scores are supported based on the value
#' of the parameter `importance`:
#' 1. `prob`: The mean selection probability of each feature in the trees,
#' extracted from the slot `varprob.mean`.
#' If `sparse = FALSE` (default), this is a fixed constant.
#' Recommended to use this option when `sparse = TRUE`.
#' 2. `count`: The mean observed count of each feature in the trees (average
#' number of times the feature was used in a tree decision rule across all
#' posterior draws), extracted from the slot `varcount.mean`.
#' This is the default importance scores.
#'
#' In both cases, higher values signify more important variables.
#'
#' @return Named `numeric()`.
importance = function() {
if (is.null(self$model$model)) {
stopf("No model stored")
}

pars = self$param_set$get_values(tags = "train")

if (pars$importance == "prob") {
sort(self$model$model$varprob.mean[-1], decreasing = TRUE)
} else {
sort(self$model$model$varcount.mean[-1], decreasing = TRUE)
}
}
),

private = list(
.train = function(task) {
pars = self$param_set$get_values(tags = "train")
pars$importance = NULL # not used in the train function

x.train = as.data.frame(task$data(cols = task$feature_names)) # nolint
truth = task$truth()
times = truth[, 1]
delta = truth[, 2] # delta => status

list(
model = invoke(
BART::mc.surv.bart,
x.train = x.train,
times = times,
delta = delta,
.args = pars
),
# need these for predict
x.train = x.train,
times = times,
delta = delta
)
},

.predict = function(task) {
# get parameters with tag "predict"
pars = self$param_set$get_values(tags = "predict")

# get newdata and ensure same ordering in train and predict
x.test = as.data.frame(ordered_features(task, self)) # nolint

# subset parameters to use in `surv.pre.bart`
pars_pre = pars[names(pars) %in% c("K", "events", "ztimes", "zdelta")]

# transform data to be suitable for BART survival analysis (needs train data)
trans_data = invoke(
BART::surv.pre.bart,
times = self$model$times,
delta = self$model$delta,
x.train = self$model$x.train,
x.test = x.test,
.args = pars_pre
)

# subset parameters to use in `predict`
pars_pred = pars[names(pars) %in% c("mc.cores", "nice")]

pred_fun = function() {
invoke(
predict,
self$model$model,
newdata = trans_data$tx.test,
.args = pars_pred
)
}

# don't print C++ generated info during prediction
if (pars$quiet) {
utils::capture.output({
pred = pred_fun()
})
} else {
pred = pred_fun()
}

# Number of test observations
N = task$nrow
# Number of unique times
K = pred$K
times = pred$times
# Number of posterior draws
M = nrow(pred$surv.test)

# Convert full posterior survival matrix to 3D survival array
# See page 34-35 in Sparapani (2021) for more details
surv_array = aperm(
array(pred$surv.test, dim = c(M, K, N), dimnames = list(NULL, times, NULL)),
c(3, 2, 1)
)

# distr => 3d survival array
# crank => expected mortality
mlr3proba::.surv_return(times = times, surv = surv_array,
which.curve = pars$which.curve)
}
)
)

.extralrns_dict$add("surv.bart", LearnerSurvLearnerSurvBART)
Loading