diff --git a/.Rbuildignore b/.Rbuildignore
index f8d8128c..03781c23 100644
--- a/.Rbuildignore
+++ b/.Rbuildignore
@@ -16,3 +16,4 @@
^Rmd/
lastMiKTeXException
^\.zenodo\.json$
+^scratch\.R$
diff --git a/NEWS.md b/NEWS.md
index 496cb610..8d6a405a 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -1,4 +1,12 @@
-# aorsf 0.0.8
+# aorsf 0.1.0 (unreleased)
+
+* Re-worked `aorsf`'s C++, code following the design of `ranger`, to set it up for classification and regression trees.
+
+* Allowed multi-threading to be performed in `orsf()`, `predict.orsf_fit()`, and functions in the `orsf_vi()` and `orsf_pd()` family.
+
+* Allowed for sampling without replacement and sampling a specific fraction of observations in `orsf()`
+
+* Included Harrell's C-statistic as an option for assessing goodness of splits while growing trees.
* Fixed an issue where an uninformative error message would occur when `pred_horizon` was > max(time) for `orsf_summarize_uni`. Thanks to @JyHao1 and @DustinMLong for finding this!
diff --git a/R/RcppExports.R b/R/RcppExports.R
index 3791a30d..2501f394 100644
--- a/R/RcppExports.R
+++ b/R/RcppExports.R
@@ -1,67 +1,27 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
-std_setdiff <- function(x, y) {
- .Call(`_aorsf_std_setdiff`, x, y)
+coxph_fit_exported <- function(x_node, y_node, w_node, method, cph_eps, cph_iter_max) {
+ .Call(`_aorsf_coxph_fit_exported`, x_node, y_node, w_node, method, cph_eps, cph_iter_max)
}
-x_node_scale_exported <- function(x_, w_) {
- .Call(`_aorsf_x_node_scale_exported`, x_, w_)
+compute_cstat_exported_vec <- function(y, w, p, pred_is_risklike) {
+ .Call(`_aorsf_compute_cstat_exported_vec`, y, w, p, pred_is_risklike)
}
-leaf_kaplan_testthat <- function(y, w) {
- .Call(`_aorsf_leaf_kaplan_testthat`, y, w)
+compute_cstat_exported_uvec <- function(y, w, g, pred_is_risklike) {
+ .Call(`_aorsf_compute_cstat_exported_uvec`, y, w, g, pred_is_risklike)
}
-newtraph_cph_testthat <- function(x_in, y_in, w_in, method, cph_eps_, iter_max) {
- .Call(`_aorsf_newtraph_cph_testthat`, x_in, y_in, w_in, method, cph_eps_, iter_max)
+compute_logrank_exported <- function(y, w, g) {
+ .Call(`_aorsf_compute_logrank_exported`, y, w, g)
}
-lrt_multi_testthat <- function(y_node_, w_node_, XB_, n_split_, leaf_min_events_, leaf_min_obs_) {
- .Call(`_aorsf_lrt_multi_testthat`, y_node_, w_node_, XB_, n_split_, leaf_min_events_, leaf_min_obs_)
+cph_scale <- function(x, w) {
+ .Call(`_aorsf_cph_scale`, x, w)
}
-oobag_c_harrell_testthat <- function(y_mat, s_vec) {
- .Call(`_aorsf_oobag_c_harrell_testthat`, y_mat, s_vec)
-}
-
-ostree_pred_leaf_testthat <- function(tree, x_pred_) {
- .Call(`_aorsf_ostree_pred_leaf_testthat`, tree, x_pred_)
-}
-
-orsf_fit <- function(x, y, weights, n_tree, n_split_, mtry_, leaf_min_events_, leaf_min_obs_, split_min_events_, split_min_obs_, split_min_stat_, cph_method_, cph_eps_, cph_iter_max_, cph_do_scale_, net_alpha_, net_df_target_, oobag_pred_, oobag_pred_type_, oobag_pred_horizon_, oobag_eval_every_, oobag_importance_, oobag_importance_type_, tree_seeds, max_retry_, f_beta, type_beta_, f_oobag_eval, type_oobag_eval_, verbose_progress) {
- .Call(`_aorsf_orsf_fit`, x, y, weights, n_tree, n_split_, mtry_, leaf_min_events_, leaf_min_obs_, split_min_events_, split_min_obs_, split_min_stat_, cph_method_, cph_eps_, cph_iter_max_, cph_do_scale_, net_alpha_, net_df_target_, oobag_pred_, oobag_pred_type_, oobag_pred_horizon_, oobag_eval_every_, oobag_importance_, oobag_importance_type_, tree_seeds, max_retry_, f_beta, type_beta_, f_oobag_eval, type_oobag_eval_, verbose_progress)
-}
-
-orsf_oob_negate_vi <- function(x, y, forest, last_eval_stat, time_pred_, f_oobag_eval, pred_type_, type_oobag_eval_) {
- .Call(`_aorsf_orsf_oob_negate_vi`, x, y, forest, last_eval_stat, time_pred_, f_oobag_eval, pred_type_, type_oobag_eval_)
-}
-
-orsf_oob_permute_vi <- function(x, y, forest, last_eval_stat, time_pred_, f_oobag_eval, pred_type_, type_oobag_eval_) {
- .Call(`_aorsf_orsf_oob_permute_vi`, x, y, forest, last_eval_stat, time_pred_, f_oobag_eval, pred_type_, type_oobag_eval_)
-}
-
-orsf_pred_uni <- function(forest, x_new, time_dbl, pred_type) {
- .Call(`_aorsf_orsf_pred_uni`, forest, x_new, time_dbl, pred_type)
-}
-
-orsf_pred_multi <- function(forest, x_new, time_vec, pred_type) {
- .Call(`_aorsf_orsf_pred_multi`, forest, x_new, time_vec, pred_type)
-}
-
-pd_new_smry <- function(forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type) {
- .Call(`_aorsf_pd_new_smry`, forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type)
-}
-
-pd_oob_smry <- function(forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type) {
- .Call(`_aorsf_pd_oob_smry`, forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type)
-}
-
-pd_new_ice <- function(forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type) {
- .Call(`_aorsf_pd_new_ice`, forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type)
-}
-
-pd_oob_ice <- function(forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type) {
- .Call(`_aorsf_pd_oob_ice`, forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type)
+orsf_cpp <- function(x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, sample_with_replacement, sample_fraction, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_x_vals, pd_x_cols, pd_probs, n_thread, write_forest, run_forest, verbosity) {
+ .Call(`_aorsf_orsf_cpp`, x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, sample_with_replacement, sample_fraction, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_x_vals, pd_x_cols, pd_probs, n_thread, write_forest, run_forest, verbosity)
}
diff --git a/R/check.R b/R/check.R
index bb6bdbaf..67bd202f 100644
--- a/R/check.R
+++ b/R/check.R
@@ -610,9 +610,13 @@ check_orsf_inputs <- function(data = NULL,
n_tree = NULL,
n_split = NULL,
n_retry = NULL,
+ n_thread = NULL,
mtry = NULL,
+ sample_with_replacement = NULL,
+ sample_fraction = NULL,
leaf_min_events = NULL,
leaf_min_obs = NULL,
+ split_rule = NULL,
split_min_events = NULL,
split_min_obs = NULL,
split_min_stat = NULL,
@@ -796,6 +800,25 @@ check_orsf_inputs <- function(data = NULL,
}
+ if(!is.null(n_thread)){
+
+ check_arg_type(arg_value = n_thread,
+ arg_name = 'n_thread',
+ expected_type = 'numeric')
+
+ check_arg_is_integer(arg_name = 'n_thread',
+ arg_value = n_thread)
+
+ check_arg_gteq(arg_name = 'n_thread',
+ arg_value = n_thread,
+ bound = 0)
+
+ check_arg_length(arg_name = 'n_thread',
+ arg_value = n_thread,
+ expected_length = 1)
+
+ }
+
if(!is.null(mtry)){
check_arg_type(arg_value = mtry,
@@ -815,6 +838,38 @@ check_orsf_inputs <- function(data = NULL,
}
+ if(!is.null(sample_with_replacement)){
+
+ check_arg_type(arg_value = sample_with_replacement,
+ arg_name = 'sample_with_replacement',
+ expected_type = 'logical')
+
+ check_arg_length(arg_name = 'sample_with_replacement',
+ arg_value = sample_with_replacement,
+ expected_length = 1)
+
+ }
+
+ if(!is.null(sample_fraction)){
+
+ check_arg_type(arg_value = sample_fraction,
+ arg_name = 'sample_fraction',
+ expected_type = 'numeric')
+
+ check_arg_gt(arg_value = sample_fraction,
+ arg_name = 'sample_fraction',
+ bound = 0)
+
+ check_arg_lteq(arg_value = sample_fraction,
+ arg_name = 'sample_fraction',
+ bound = 1)
+
+ check_arg_length(arg_value = sample_fraction,
+ arg_name = 'sample_fraction',
+ expected_length = 1)
+
+ }
+
if(!is.null(leaf_min_events)){
check_arg_type(arg_value = leaf_min_events,
@@ -852,6 +907,22 @@ check_orsf_inputs <- function(data = NULL,
}
+ if(!is.null(split_rule)){
+
+ check_arg_type(arg_value = split_rule,
+ arg_name = 'split_rule',
+ expected_type = 'character')
+
+ check_arg_length(arg_value = split_rule,
+ arg_name = 'split_rule',
+ expected_length = 1)
+
+ check_arg_is_valid(arg_value = split_rule,
+ arg_name = 'split_rule',
+ valid_options = c("logrank", "cstat"))
+
+ }
+
if(!is.null(split_min_events)){
check_arg_type(arg_value = split_min_events,
@@ -916,19 +987,14 @@ check_orsf_inputs <- function(data = NULL,
arg_name = 'oobag_pred_type',
expected_length = 1)
- if(oobag_pred_type == 'mort') stop(
- "Out-of-bag mortality predictions aren't supported yet. ",
- " Sorry for the inconvenience - we plan on including this option",
- " in a future update.",
- call. = FALSE
- )
-
check_arg_is_valid(arg_value = oobag_pred_type,
arg_name = 'oobag_pred_type',
valid_options = c("none",
"surv",
"risk",
- "chf"))
+ "chf",
+ "mort",
+ "leaf"))
}
@@ -938,13 +1004,17 @@ check_orsf_inputs <- function(data = NULL,
arg_name = 'oobag_pred_horizon',
expected_type = 'numeric')
- check_arg_length(arg_value = oobag_pred_horizon,
- arg_name = 'oobag_pred_horizon',
- expected_length = 1)
+ # check_arg_length(arg_value = oobag_pred_horizon,
+ # arg_name = 'oobag_pred_horizon',
+ # expected_length = 1)
- check_arg_gteq(arg_value = oobag_pred_horizon,
- arg_name = 'oobag_pred_horizon',
- bound = 0)
+ for(i in seq_along(oobag_pred_horizon)){
+
+ check_arg_gteq(arg_value = oobag_pred_horizon[i],
+ arg_name = 'oobag_pred_horizon',
+ bound = 0)
+
+ }
}
@@ -1000,7 +1070,7 @@ check_orsf_inputs <- function(data = NULL,
check_arg_is_integer(tree_seeds, arg_name = 'tree_seeds')
- if(length(tree_seeds) != n_tree){
+ if(length(tree_seeds) > 1 && length(tree_seeds) != n_tree){
stop('tree_seeds should have length <', n_tree,
"> (the number of trees) but instead has length <",
@@ -1550,7 +1620,8 @@ check_predict <- function(object,
valid_options = c("risk",
"surv",
"chf",
- "mort"))
+ "mort",
+ "leaf"))
}
@@ -1615,8 +1686,8 @@ check_oobag_fun <- function(oobag_fun){
oobag_fun_args <- names(formals(oobag_fun))
- if(length(oobag_fun_args) != 2) stop(
- "oobag_fun should have 2 input arguments but instead has ",
+ if(length(oobag_fun_args) != 3) stop(
+ "oobag_fun should have 3 input arguments but instead has ",
length(oobag_fun_args),
call. = FALSE
)
@@ -1627,8 +1698,14 @@ check_oobag_fun <- function(oobag_fun){
call. = FALSE
)
- if(oobag_fun_args[2] != 's_vec') stop(
- "the second input argument of oobag_fun should be named 's_vec' ",
+ if(oobag_fun_args[2] != 'w_vec') stop(
+ "the second input argument of oobag_fun should be named 'w_vec' ",
+ "but is instead named '", oobag_fun_args[1], "'",
+ call. = FALSE
+ )
+
+ if(oobag_fun_args[3] != 's_vec') stop(
+ "the third input argument of oobag_fun should be named 's_vec' ",
"but is instead named '", oobag_fun_args[2], "'",
call. = FALSE
)
@@ -1637,9 +1714,12 @@ check_oobag_fun <- function(oobag_fun){
test_status <- rep(c(0,1), each = 50)
.y_mat <- cbind(time = test_time, status = test_status)
+ .w_vec <- rep(1, times = 100)
.s_vec <- seq(0.9, 0.1, length.out = 100)
- test_output <- try(oobag_fun(y_mat = .y_mat, s_vec = .s_vec),
+ test_output <- try(oobag_fun(y_mat = .y_mat,
+ w_vec = .w_vec,
+ s_vec = .s_vec),
silent = FALSE)
if(is_error(test_output)){
@@ -1649,8 +1729,9 @@ check_oobag_fun <- function(oobag_fun){
"test_time <- seq(from = 1, to = 5, length.out = 100)\n",
"test_status <- rep(c(0,1), each = 50)\n\n",
"y_mat <- cbind(time = test_time, status = test_status)\n",
+ "w_vec <- rep(1, times = 100)\n",
"s_vec <- seq(0.9, 0.1, length.out = 100)\n\n",
- "test_output <- oobag_fun(y_mat = y_mat, s_vec = s_vec)\n\n",
+ "test_output <- oobag_fun(y_mat = y_mat, w_vec = w_vec, s_vec = s_vec)\n\n",
"test_output should be a numeric value of length 1",
call. = FALSE)
diff --git a/R/compute_mean_leaves.R b/R/compute_mean_leaves.R
new file mode 100644
index 00000000..08ce45f6
--- /dev/null
+++ b/R/compute_mean_leaves.R
@@ -0,0 +1,15 @@
+
+
+compute_mean_leaves <- function(forest){
+
+ if(is.null(forest$leaf_summary)){
+ return(0)
+ }
+
+ collapse::fmean(
+ vapply(forest$leaf_summary,
+ function(leaf_smry) sum(leaf_smry != 0),
+ FUN.VALUE = integer(1))
+ )
+
+}
diff --git a/R/misc.R b/R/misc.R
index 149a0e34..d735910b 100644
--- a/R/misc.R
+++ b/R/misc.R
@@ -104,6 +104,55 @@ paste_collapse <- function(x, sep=', ', last = ' or '){
}
+#' Find cut-point boundaries (R version)
+#'
+#' Used to test the cpp version for finding cutpoints
+#'
+#' @param y_node outcome matrix
+#' @param w_node weight vector
+#' @param XB linear combination of predictors
+#' @param xb_uni unique values in XB
+#' @param leaf_min_events min no. of events in a leaf
+#' @param leaf_min_obs min no. of observations in a leaf
+#'
+#' @noRd
+#'
+#' @return data.frame with description of valid cutpoints
+cp_find_bounds_R <- function(y_node,
+ w_node,
+ XB,
+ xb_uni,
+ leaf_min_events,
+ leaf_min_obs){
+
+ status = y_node[, 'status']
+
+ cp_stats <-
+ sapply(
+ X = xb_uni,
+ FUN = function(x){
+ c(
+ cp = x,
+ e_right = sum(status[XB > x]),
+ e_left = sum(status[XB <= x]),
+ n_right = sum(XB > x),
+ n_left = sum(XB <= x)
+ )
+ }
+ )
+
+ cp_stats <- as.data.frame(t(cp_stats))
+
+ cp_stats$valid_cp = with(
+ cp_stats,
+ e_right >= leaf_min_events & e_left >= leaf_min_events &
+ n_right >= leaf_min_obs & n_left >= leaf_min_obs
+ )
+
+ cp_stats
+
+}
+
has_units <- function(x){
inherits(x, 'units')
}
diff --git a/R/oobag_c_harrell.R b/R/oobag_c_harrell.R
deleted file mode 100644
index 3e22ca23..00000000
--- a/R/oobag_c_harrell.R
+++ /dev/null
@@ -1,61 +0,0 @@
-
-
-#' Harrell's C-statistic
-#'
-#' This function is for testing and internal use.
-#'
-#' @param y_mat outcome matrix
-#' @param s_vec vector of predicted survival
-#'
-#' @return the C-statistic
-#'
-#' @noRd
-#'
-
-oobag_c_harrell <- function(y_mat, s_vec){
-
- sorted <- order(y_mat[, 1], -y_mat[, 2])
-
- y_mat <- y_mat[sorted, ]
- s_vec <- s_vec[sorted]
-
- time = y_mat[, 1]
- status = y_mat[, 2]
- events = which(status == 1)
-
- k = nrow(y_mat)
-
- total <- 0
- concordant <- 0
-
- for(i in events){
-
- if(i+1 <= k){
-
- for(j in seq(i+1, k)){
-
- if(time[j] > time[i]){
-
- total <- total + 1
-
- if(s_vec[j] > s_vec[i]){
-
- concordant <- concordant + 1
-
- } else if (s_vec[j] == s_vec[i]){
-
- concordant <- concordant + 0.5
-
- }
-
- }
-
- }
-
- }
-
- }
-
- concordant / total
-
-}
diff --git a/R/oobag_c_survival.R b/R/oobag_c_survival.R
new file mode 100644
index 00000000..d5f819c6
--- /dev/null
+++ b/R/oobag_c_survival.R
@@ -0,0 +1,41 @@
+
+
+#' Harrell's C-statistic
+#'
+#' This function is for testing and internal use.
+#'
+#' @param y_mat outcome matrix
+#' @param s_vec vector of predicted survival
+#'
+#' @return the C-statistic
+#'
+#' @noRd
+#'
+
+oobag_c_survival <- function(y_mat, w_vec, s_vec){
+
+ data <- as.data.frame(cbind(y_mat, s_vec))
+ names(data) = c("time", "status", "x")
+
+ survival::concordance(
+ survival::Surv(time, status) ~ x,
+ data = data,
+ weights = w_vec
+ )$concordance
+
+}
+
+oobag_c_risk <- function(y_mat, w_vec, s_vec){
+
+ data <- as.data.frame(cbind(y_mat, s_vec))
+ names(data) = c("time", "status", "x")
+
+ 1 - survival::concordance(
+ survival::Surv(time, status) ~ x,
+ data = data,
+ weights = w_vec
+ )$concordance
+
+}
+
+
diff --git a/R/orsf.R b/R/orsf.R
index 473482d2..d501a96c 100644
--- a/R/orsf.R
+++ b/R/orsf.R
@@ -85,16 +85,35 @@
#' of randomly selected predictors, up to `n_retry` times. Default is
#' `n_retry = 3`. Set `n_retry = 0` to prevent any retries.
#'
+#' @param n_thread `r roxy_n_thread_header("growing trees, computing predictions, and computing importance")`
+#'
#' @param mtry (_integer_) Number of predictors randomly included as candidates
#' for splitting a node. The default is the smallest integer greater than
#' the square root of the number of total predictors, i.e.,
#' `mtry = ceiling(sqrt(number of predictors))`
#'
+#' @param sample_with_replacement (_logical_) If `TRUE` (the default),
+#' observations are sampled with replacement when an in-bag sample
+#' is created for a decision tree. If `FALSE`, observations are
+#' sampled without replacement and each tree will have an in-bag sample
+#' containing `sample_fraction`% of the original sample.
+#'
+#' @param sample_fraction (_double_) the proportion of observations that
+#' each trees' in-bag sample will contain, relative to the number of
+#' rows in `data`. Only used if `sample_with_replacement` is `FALSE`.
+#' Default value is 0.632.
+#'
#' @param leaf_min_events (_integer_) minimum number of events in a
#' leaf node. Default is `leaf_min_events = 1`
#'
#' @param leaf_min_obs (_integer_) minimum number of observations in a
-#' leaf node. Default is `leaf_min_obs = 5`
+#' leaf node. Default is `leaf_min_obs = 5`.
+#'
+#' @param split_rule (_character_) how to assess the quality of a potential
+#' splitting rule for a node. Valid options are
+#'
+#' - 'logrank' : a log-rank test statistic.
+#' - 'cstat' : Harrell's concordance statistic.
#'
#' @param split_min_events (_integer_) minimum number of events required
#' in a node to consider splitting it. Default is `split_min_events = 5`
@@ -103,19 +122,21 @@
#' in a node to consider splitting it. Default is `split_min_obs = 10`.
#'
#' @param split_min_stat (double) minimum test statistic required to split
-#' a node. Default is 3.841459 for the log-rank test, which is roughly
-#' a p-value of 0.05
+#' a node. Default is 3.841459 if `split_rule = 'logrank'` and 0.50 if
+#' `split_rule = 'cstat'`. If no splits are found with a statistic
+#' exceeding `split_min_stat`, the given node either becomes a leaf or
+#' a retry occurs (up to `n_retry` retries).
#'
#' @param oobag_pred_type (_character_) The type of out-of-bag predictions
#' to compute while fitting the ensemble. Valid options are
#'
#' - 'none' : don't compute out-of-bag predictions
-#' - 'risk' : predict the probability of having an event at or before `oobag_pred_horizon`.
+#' - 'risk' : probability of event occurring at or before `oobag_pred_horizon`.
#' - 'surv' : 1 - risk.
-#' - 'chf' : predict cumulative hazard function
-#'
-#' Mortality ('mort')is not implemented for out of bag predictions yet, but it
-#' will be in a future update.
+#' - 'chf' : cumulative hazard function at `oobag_pred_horizon`.
+#' - 'mort' : mortality, i.e., the number of events expected if all
+#' observations in the training data were identical to a
+#' given observation.
#'
#' @param oobag_pred_horizon (_numeric_) A numeric value indicating what time
#' should be used for out-of-bag predictions. Default is the median
@@ -172,7 +193,7 @@
#' to the output will be the imputed version of `data`.
#'
#' @param verbose_progress (_logical_) if `TRUE`, progress messages are
-#' printed in the console.
+#' printed in the console. If `FALSE` (the default), nothing is printed.
#'
#' @param ... `r roxy_dots()`
#'
@@ -239,6 +260,13 @@
#' importance or permutation importance, but it will not have any role
#' for ANOVA importance.
#'
+#' **n_thread**:
+#'
+#' If an R function must be called from C++ (i.e., user-supplied function to
+#' compute out-of-bag error or identify linear combinations of variables),
+#' `n_thread` will automatically be set to 1 because attempting to run R
+#' functions in multiple threads will cause the R session to crash.
+#'
#' @section What is an oblique decision tree?:
#'
#' Decision trees are developed by splitting a set of training data into two
@@ -312,12 +340,18 @@ orsf <- function(data,
n_tree = 500,
n_split = 5,
n_retry = 3,
+ n_thread = 1,
mtry = NULL,
+ sample_with_replacement = TRUE,
+ sample_fraction = 0.632,
leaf_min_events = 1,
leaf_min_obs = 5,
+ split_rule = 'logrank',
split_min_events = 5,
split_min_obs = 10,
- split_min_stat = 3.841459,
+ split_min_stat = switch(split_rule,
+ "logrank" = 3.841459,
+ "cstat" = 0.50),
oobag_pred_type = 'surv',
oobag_pred_horizon = NULL,
oobag_eval_every = n_tree,
@@ -346,9 +380,13 @@ orsf <- function(data,
n_tree = n_tree,
n_split = n_split,
n_retry = n_retry,
+ n_thread = n_thread,
mtry = mtry,
+ sample_with_replacement = sample_with_replacement,
+ sample_fraction = sample_fraction,
leaf_min_events = leaf_min_events,
leaf_min_obs = leaf_min_obs,
+ split_rule = split_rule,
split_min_events = split_min_events,
split_min_obs = split_min_obs,
split_min_stat = split_min_stat,
@@ -360,8 +398,22 @@ orsf <- function(data,
attach_data = attach_data
)
+ #TODO: more polish
+ if(split_rule == "cstat" && split_min_stat >= 1){
+ stop("If split_rule is 'cstat', split_min_stat must be < 1",
+ call. = FALSE)
+ }
+
oobag_pred <- oobag_pred_type != 'none'
+ if(sample_fraction == 1 && oobag_pred){
+ stop(
+ "cannot compute out-of-bag predictions if no samples are out-of-bag.",
+ "To resolve this, set sample_fraction < 1 or oobag_pred_type = 'none'.",
+ call. = FALSE
+ )
+ }
+
orsf_type <- attr(control, 'type')
switch(
@@ -408,19 +460,34 @@ orsf <- function(data,
)
+ if(importance %in% c("permute", "negate") && !oobag_pred){
+ # oobag_pred <- TRUE # Should I add a warning?
+ oobag_pred_type <- 'surv'
+ }
+
if(is.null(oobag_fun)){
f_oobag_eval <- function(x) x
- type_oobag_eval <- 'H'
+ type_oobag_eval <- if(oobag_pred) 'cstat' else 'none'
} else {
check_oobag_fun(oobag_fun)
f_oobag_eval <- oobag_fun
- type_oobag_eval <- 'U'
+ type_oobag_eval <- 'user'
+
+ if(oobag_pred_type == 'leaf'){
+ stop("a user-supplied oobag function cannot be",
+ "applied when oobag_pred_type = 'leaf'",
+ call. = FALSE)
+ }
}
+ # can't evaluate the oobag predictions if they aren't aggregated
+ if(oobag_pred_type == 'leaf') type_oobag_eval <- 'none'
+
+
cph_method <- control_cph$cph_method
cph_eps <- control_cph$cph_eps
cph_iter_max <- control_cph$cph_iter_max
@@ -428,11 +495,6 @@ orsf <- function(data,
net_alpha <- control_net$net_alpha
net_df_target <- control_net$net_df_target
- if(importance %in% c("permute", "negate") && !oobag_pred){
- oobag_pred <- TRUE # Should I add a warning?
- oobag_pred_type <- 'surv'
- }
-
formula_terms <- suppressWarnings(stats::terms(formula, data=data))
@@ -652,15 +714,15 @@ orsf <- function(data,
if(!is.null(oobag_pred_horizon)){
- if(oobag_pred_horizon <= 0)
+ if(any(oobag_pred_horizon <= 0))
stop("Out of bag prediction horizon (oobag_pred_horizon) must be > 0",
call. = FALSE)
} else {
- # tell orsf.cpp to make its own oobag_pred_horizon by setting this to 0
- oobag_pred_horizon <- 0
+ # use training data to provide sensible default
+ oobag_pred_horizon <- stats::median(y[, 1])
}
@@ -668,108 +730,135 @@ orsf <- function(data,
collapse::radixorder(y[, 1], # order this way for risk sets
-y[, 2]) # order this way for oob C-statistic.
+ if(is.null(weights)) weights <- rep(1, nrow(x))
+
x_sort <- x[sorted, , drop = FALSE]
y_sort <- y[sorted, , drop = FALSE]
+ w_sort <- weights[sorted]
+
+ if(length(tree_seeds) == 1 && n_tree > 1){
+ set.seed(tree_seeds)
+ tree_seeds <- sample(x = n_tree*10, size = n_tree, replace = FALSE)
+ } else if(is.null(tree_seeds)){
+ tree_seeds <- sample(x = n_tree*10, size = n_tree, replace = FALSE)
+ }
- if(is.null(weights)) weights <- double()
- if(is.null(tree_seeds)) tree_seeds <- vector(mode = 'integer', length = 0L)
-
- orsf_out <- orsf_fit(
- x = x_sort,
- y = y_sort,
- weights = if(length(weights) > 0) weights[sorted] else weights,
- n_tree = if(no_fit) 0 else n_tree,
- n_split_ = n_split,
- mtry_ = mtry,
- leaf_min_events_ = leaf_min_events,
- leaf_min_obs_ = leaf_min_obs,
- split_min_events_ = split_min_events,
- split_min_obs_ = split_min_obs,
- split_min_stat_ = split_min_stat,
- cph_method_ = switch(tolower(cph_method),
- 'breslow' = 0,
- 'efron' = 1),
- cph_eps_ = cph_eps,
- cph_iter_max_ = cph_iter_max,
- cph_do_scale_ = cph_do_scale,
- net_alpha_ = net_alpha,
- net_df_target_ = net_df_target,
- oobag_pred_ = oobag_pred,
- oobag_pred_type_ = switch(oobag_pred_type,
- "none" = "N",
- "surv" = "S",
- "risk" = "R",
- "chf" = "H"),
- oobag_pred_horizon_ = oobag_pred_horizon,
- oobag_eval_every_ = oobag_eval_every,
- oobag_importance_ = importance %in% c("negate", "permute"),
- oobag_importance_type_ = switch(importance,
- "none" = "O",
- "anova" = "A",
- "negate" = "N",
- "permute" = "P"),
- #' @srrstats {G2.4a} *converting to integer in case R does that thing where it assumes the integer values you gave it are supposed to be doubles*
- tree_seeds = as.integer(tree_seeds),
- max_retry_ = n_retry,
- f_beta = f_beta,
- type_beta_ = switch(orsf_type,
- 'fast' = 'C',
- 'cph' = 'C',
- 'net' = 'N',
- 'custom' = 'U'),
- f_oobag_eval = f_oobag_eval,
- type_oobag_eval_ = type_oobag_eval,
- verbose_progress = verbose_progress
- )
+
+ vi_max_pvalue = 0.01
+ tree_type_R = 3
+
+ orsf_out <- orsf_cpp(x = x_sort,
+ y = y_sort,
+ w = w_sort,
+ tree_type_R = tree_type_R,
+ tree_seeds = as.integer(tree_seeds),
+ loaded_forest = list(),
+ n_tree = n_tree,
+ mtry = mtry,
+ sample_with_replacement = sample_with_replacement,
+ sample_fraction = sample_fraction,
+ vi_type_R = switch(importance,
+ "none" = 0,
+ "negate" = 1,
+ "permute" = 2,
+ "anova" = 3),
+ vi_max_pvalue = vi_max_pvalue,
+ lincomb_R_function = f_beta,
+ oobag_R_function = f_oobag_eval,
+ leaf_min_events = leaf_min_events,
+ leaf_min_obs = leaf_min_obs,
+ split_rule_R = switch(split_rule,
+ "logrank" = 1,
+ "cstat" = 2),
+ split_min_events = split_min_events,
+ split_min_obs = split_min_obs,
+ split_min_stat = split_min_stat,
+ split_max_cuts = n_split,
+ split_max_retry = n_retry,
+ lincomb_type_R = switch(orsf_type,
+ 'fast' = 1,
+ 'cph' = 1,
+ 'random' = 2,
+ 'net' = 3,
+ 'custom' = 4),
+ lincomb_eps = cph_eps,
+ lincomb_iter_max = cph_iter_max,
+ lincomb_scale = cph_do_scale,
+ lincomb_alpha = net_alpha,
+ lincomb_df_target = net_df_target,
+ lincomb_ties_method = switch(tolower(cph_method),
+ 'breslow' = 0,
+ 'efron' = 1),
+ pred_type_R = switch(oobag_pred_type,
+ "none" = 0,
+ "risk" = 1,
+ "surv" = 2,
+ "chf" = 3,
+ "mort" = 4,
+ "leaf" = 8),
+ pred_mode = FALSE,
+ pred_aggregate = oobag_pred_type != 'leaf',
+ pred_horizon = oobag_pred_horizon,
+ oobag = oobag_pred,
+ oobag_eval_type_R = switch(type_oobag_eval,
+ 'none' = 0,
+ 'cstat' = 1,
+ 'user' = 2),
+ oobag_eval_every = oobag_eval_every,
+ pd_type_R = 0,
+ pd_x_vals = list(matrix(0, ncol=1, nrow=1)),
+ pd_x_cols = list(matrix(1L, ncol=1, nrow=1)),
+ pd_probs = c(0),
+ n_thread = n_thread,
+ write_forest = TRUE,
+ run_forest = !no_fit,
+ verbosity = as.integer(verbose_progress))
# if someone says no_fit and also says don't attach the data,
# give them a warning but also do the right thing for them.
orsf_out$data <- if(attach_data) data else NULL
-
- if(importance != 'none'){
+ if(importance != 'none' && !no_fit){
rownames(orsf_out$importance) <- colnames(x)
orsf_out$importance <-
rev(orsf_out$importance[order(orsf_out$importance), , drop=TRUE])
}
- if(oobag_pred){
-
+ if(oobag_pred && !no_fit){
# put the oob predictions into the same order as the training data.
unsorted <- collapse::radixorder(sorted)
- # clear labels for oobag evaluation type
+ # makes labels for oobag evaluation type
orsf_out$eval_oobag$stat_type <-
- switch(EXPR = orsf_out$eval_oobag$stat_type,
- 'H' = "Harrell's C-statistic",
- 'U' = "User-specified function")
+ switch(EXPR = as.character(orsf_out$eval_oobag$stat_type),
+ "0" = "None",
+ "1" = "Harrell's C-statistic",
+ "2" = "User-specified function")
+
+ if(oobag_pred_type == 'leaf'){
+ all_rows <- seq(nrow(data))
+ for(i in seq(n_tree)){
+ rows_inbag <- setdiff(all_rows, orsf_out$forest$rows_oobag[[i]]+1)
+ orsf_out$pred_oobag[rows_inbag, i] <- NA
+ }
+ }
#' @srrstats {G2.10} *drop = FALSE for type consistency*
orsf_out$pred_oobag <- orsf_out$pred_oobag[unsorted, , drop = FALSE]
- } else {
+ orsf_out$pred_oobag[is.nan(orsf_out$pred_oobag)] <- NA_real_
- if(oobag_pred_horizon == 0)
- # this would get added by orsf_fit if oobag_pred was TRUE
- orsf_out$pred_horizon <- stats::median(y[, 1])
- else
- orsf_out$pred_horizon <- oobag_pred_horizon
- }
- n_leaves_mean <- 0
- if(!no_fit) {
- n_leaves_mean <-
- collapse::fmean(
- vapply(orsf_out$forest,
- function(t) nrow(t$leaf_node_index),
- FUN.VALUE = integer(1))
- )
}
+ orsf_out$pred_horizon <- oobag_pred_horizon
+
+ n_leaves_mean <- compute_mean_leaves(orsf_out$forest)
+
attr(orsf_out, 'control') <- control
attr(orsf_out, 'mtry') <- mtry
attr(orsf_out, 'n_obs') <- nrow(y_sort)
@@ -811,13 +900,19 @@ orsf <- function(data,
attr(orsf_out, 'oobag_fun') <- oobag_fun
attr(orsf_out, 'oobag_pred_type') <- oobag_pred_type
attr(orsf_out, 'oobag_eval_every') <- oobag_eval_every
+ attr(orsf_out, 'oobag_pred_horizon') <- oobag_pred_horizon
attr(orsf_out, 'importance') <- importance
attr(orsf_out, 'importance_values') <- orsf_out$importance
attr(orsf_out, 'group_factors') <- group_factors
attr(orsf_out, 'weights_user') <- weights
attr(orsf_out, 'verbose_progress') <- verbose_progress
-
- attr(orsf_out, 'tree_seeds') <- if(is.null(tree_seeds)) c() else tree_seeds
+ attr(orsf_out, 'vi_max_pvalue') <- vi_max_pvalue
+ attr(orsf_out, 'split_rule') <- split_rule
+ attr(orsf_out, 'n_thread') <- n_thread
+ attr(orsf_out, 'tree_type') <- tree_type_R
+ attr(orsf_out, 'tree_seeds') <- tree_seeds
+ attr(orsf_out, 'sample_with_replacement') <- sample_with_replacement
+ attr(orsf_out, 'sample_fraction') <- sample_fraction
#' @srrstats {ML5.0a} *orsf output has its own class*
class(orsf_out) <- "orsf_fit"
@@ -1011,75 +1106,98 @@ orsf_train_ <- function(object,
x <- prep_x_from_orsf(object)
}
+ if(is.null(n_tree)){
+ n_tree <- get_n_tree(object)
+ }
+
if(is.null(sorted)){
sorted <-
collapse::radixorder(y[, 1], # order this way for risk sets
-y[, 2]) # order this way for oob C-statistic.
}
+ weights <- get_weights_user(object)
x_sort <- x[sorted, ]
y_sort <- y[sorted, ]
-
- if(is.null(n_tree)) n_tree <- get_n_tree(object)
+ w_sort <- weights[sorted]
oobag_eval_every <- min(n_tree, get_oobag_eval_every(object))
- weights <- get_weights_user(object)
-
- orsf_out <- orsf_fit(
- x = x_sort,
- y = y_sort,
- weights = if(length(weights) > 0) weights[sorted] else weights,
- n_tree = n_tree,
- n_split_ = get_n_split(object),
- mtry_ = get_mtry(object),
- leaf_min_events_ = get_leaf_min_events(object),
- leaf_min_obs_ = get_leaf_min_obs(object),
- split_min_events_ = get_split_min_events(object),
- split_min_obs_ = get_split_min_obs(object),
- split_min_stat_ = get_split_min_stat(object),
- cph_method_ = switch(tolower(get_cph_method(object)),
- 'breslow' = 0,
- 'efron' = 1),
- cph_eps_ = get_cph_eps(object), #
- cph_iter_max_ = get_cph_iter_max(object),
- cph_do_scale_ = get_cph_do_scale(object),
- net_alpha_ = get_net_alpha(object),
- net_df_target_ = get_net_df_target(object),
- oobag_pred_ = get_oobag_pred(object),
- oobag_pred_type_ = switch(get_oobag_pred_type(object),
- "none" = "N",
- "surv" = "S",
- "risk" = "R",
- "chf" = "H"),
- oobag_pred_horizon_ = object$pred_horizon,
- oobag_eval_every_ = oobag_eval_every,
- oobag_importance_ = get_importance(object) %in% c("negate", "permute"),
- oobag_importance_type_ = switch(get_importance(object),
- "none" = "O",
- "anova" = "A",
- "negate" = "N",
- "permute" = "P"),
- tree_seeds = as.integer(get_tree_seeds(object)),
- max_retry_ = get_n_retry(object),
- f_beta = get_f_beta(object),
- type_beta_ = switch(get_orsf_type(object),
- 'fast' = 'C',
- 'cph' = 'C',
- 'net' = 'N',
- 'custom' = 'U'),
- f_oobag_eval = get_f_oobag_eval(object),
- type_oobag_eval_ = get_type_oobag_eval(object),
- verbose_progress = get_verbose_progress(object)
- )
+ orsf_out <- orsf_cpp(x = x_sort,
+ y = y_sort,
+ w = w_sort,
+ tree_type_R = 3,
+ tree_seeds = get_tree_seeds(object),
+ loaded_forest = list(),
+ n_tree = n_tree,
+ mtry = get_mtry(object),
+ sample_with_replacement = get_sample_with_replacement(object),
+ sample_fraction = get_sample_fraction(object),
+ vi_type_R = switch(get_importance(object),
+ "none" = 0,
+ "negate" = 1,
+ "permute" = 2,
+ "anova" = 3),
+ vi_max_pvalue = get_vi_max_pvalue(object),
+ lincomb_R_function = get_f_beta(object),
+ oobag_R_function = get_f_oobag_eval(object),
+ leaf_min_events = get_leaf_min_events(object),
+ leaf_min_obs = get_leaf_min_obs(object),
+ split_rule_R = switch(get_split_rule(object),
+ "logrank" = 1,
+ "cstat" = 2),
+ split_min_events = get_split_min_events(object),
+ split_min_obs = get_split_min_obs(object),
+ split_min_stat = get_split_min_stat(object),
+ split_max_cuts = get_n_split(object),
+ split_max_retry = get_n_retry(object),
+ lincomb_type_R = switch(get_orsf_type(object),
+ 'fast' = 1,
+ 'cph' = 1,
+ 'random' = 2,
+ 'net' = 3,
+ 'custom' = 4),
+ lincomb_eps = get_cph_eps(object),
+ lincomb_iter_max = get_cph_iter_max(object),
+ lincomb_scale = get_cph_do_scale(object),
+ lincomb_alpha = get_net_alpha(object),
+ lincomb_df_target = get_net_df_target(object),
+ lincomb_ties_method = switch(
+ tolower(get_cph_method(object)),
+ 'breslow' = 0,
+ 'efron' = 1
+ ),
+ pred_type_R = switch(get_oobag_pred_type(object),
+ "none" = 0,
+ "risk" = 1,
+ "surv" = 2,
+ "chf" = 3,
+ "mort" = 4),
+ pred_mode = FALSE,
+ pred_aggregate = TRUE,
+ pred_horizon = get_oobag_pred_horizon(object),
+ oobag = get_oobag_pred(object),
+ oobag_eval_type_R = switch(get_type_oobag_eval(object),
+ 'none' = 0,
+ 'cstat' = 1,
+ 'user' = 2),
+ oobag_eval_every = oobag_eval_every,
+ pd_type_R = 0,
+ pd_x_vals = list(matrix(0, ncol=1, nrow=1)),
+ pd_x_cols = list(matrix(1L, ncol=1, nrow=1)),
+ pd_probs = c(0),
+ n_thread = get_n_thread(object),
+ write_forest = TRUE,
+ run_forest = TRUE,
+ verbosity = get_verbose_progress(object))
- object$forest <- orsf_out$forest
object$pred_oobag <- orsf_out$pred_oobag
- object$pred_horizon <- orsf_out$pred_horizon
object$eval_oobag <- orsf_out$eval_oobag
+ object$forest <- orsf_out$forest
object$importance <- orsf_out$importance
+ object$pred_horizon <- get_oobag_pred_horizon(object)
if(get_importance(object) != 'none'){
@@ -1107,22 +1225,19 @@ orsf_train_ <- function(object,
# clear labels for oobag evaluation type
object$eval_oobag$stat_type <-
- switch(EXPR = object$eval_oobag$stat_type,
- 'H' = "Harrell's C-statistic",
- 'U' = "User-specified function")
+ switch(EXPR = as.character(object$eval_oobag$stat_type),
+ "0" = "None",
+ "1" = "Harrell's C-statistic",
+ "2" = "User-specified function")
object$pred_oobag <- object$pred_oobag[unsorted, , drop = FALSE]
}
- attr(object, "n_leaves_mean") <-
- mean(vapply(orsf_out$forest,
- function(t) nrow(t$leaf_node_index),
- FUN.VALUE = integer(1)))
+ attr(object, "n_leaves_mean") <- compute_mean_leaves(orsf_out$forest)
attr(object, 'trained') <- TRUE
-
object
}
diff --git a/R/orsf_attr.R b/R/orsf_attr.R
index 4cab8b5f..8c9e8a9d 100644
--- a/R/orsf_attr.R
+++ b/R/orsf_attr.R
@@ -57,6 +57,13 @@ get_tree_seeds <- function(object) attr(object, 'tree_seeds')
get_weights_user <- function(object) attr(object, 'weights_user')
get_event_times <- function(object) attr(object, 'event_times')
get_verbose_progress <- function(object) attr(object, 'verbose_progress')
+get_vi_max_pvalue <- function(object) attr(object, 'vi_max_pvalue')
+get_split_rule <- function(object) attr(object, 'split_rule')
+get_n_thread <- function(object) attr(object, 'n_thread')
+get_tree_type <- function(object) attr(object, 'tree_type')
+get_sample_with_replacement <- function(object) attr(object, 'sample_with_replacement')
+get_sample_fraction <- function(object) attr(object, 'sample_fraction')
+
#' ORSF status
#'
@@ -78,7 +85,7 @@ is_trained <- function(object) attr(object, 'trained')
#'
#' @noRd
#'
-contains_oobag <- function(object) {!is_empty(object$pred_oobag)}
+contains_oobag <- function(object) {!is_empty(object$eval_oobag$stat_values)}
#' Determine whether object has variable importance estimates
#'
diff --git a/R/orsf_pd.R b/R/orsf_pd.R
index 9a261951..011c77d3 100644
--- a/R/orsf_pd.R
+++ b/R/orsf_pd.R
@@ -60,6 +60,8 @@
#' percentile in the object's training data. If `FALSE`, these checks are
#' skipped.
#'
+#' @param n_thread `r roxy_n_thread_header("computing predictions")`
+#'
#' @param ... `r roxy_dots()`
#'
#' @return a [data.table][data.table::data.table-package] containing
@@ -86,6 +88,7 @@ orsf_pd_oob <- function(object,
prob_values = c(0.025, 0.50, 0.975),
prob_labels = c('lwr', 'medn', 'upr'),
boundary_checks = TRUE,
+ n_thread = 1,
...){
check_dots(list(...), orsf_pd_oob)
@@ -99,6 +102,7 @@ orsf_pd_oob <- function(object,
prob_values = prob_values,
prob_labels = prob_labels,
boundary_checks = boundary_checks,
+ n_thread = n_thread,
oobag = TRUE,
type_output = 'smry')
@@ -114,6 +118,7 @@ orsf_pd_inb <- function(object,
prob_values = c(0.025, 0.50, 0.975),
prob_labels = c('lwr', 'medn', 'upr'),
boundary_checks = TRUE,
+ n_thread = 1,
...){
check_dots(list(...), orsf_pd_inb)
@@ -132,6 +137,7 @@ orsf_pd_inb <- function(object,
prob_values = prob_values,
prob_labels = prob_labels,
boundary_checks = boundary_checks,
+ n_thread = n_thread,
oobag = FALSE,
type_output = 'smry')
@@ -149,6 +155,7 @@ orsf_pd_new <- function(object,
prob_values = c(0.025, 0.50, 0.975),
prob_labels = c('lwr', 'medn', 'upr'),
boundary_checks = TRUE,
+ n_thread = 1,
...){
check_dots(list(...), orsf_pd_new)
@@ -163,6 +170,7 @@ orsf_pd_new <- function(object,
prob_values = prob_values,
prob_labels = prob_labels,
boundary_checks = boundary_checks,
+ n_thread = n_thread,
oobag = FALSE,
type_output = 'smry')
@@ -192,6 +200,7 @@ orsf_ice_oob <- function(object,
pred_type = 'risk',
expand_grid = TRUE,
boundary_checks = TRUE,
+ n_thread = 1,
...){
check_dots(list(...), orsf_ice_oob)
@@ -203,6 +212,7 @@ orsf_ice_oob <- function(object,
pred_type = pred_type,
expand_grid = expand_grid,
boundary_checks = boundary_checks,
+ n_thread = n_thread,
oobag = TRUE,
type_output = 'ice')
@@ -216,6 +226,7 @@ orsf_ice_inb <- function(object,
pred_type = 'risk',
expand_grid = TRUE,
boundary_checks = TRUE,
+ n_thread = 1,
...){
check_dots(list(...), orsf_ice_oob)
@@ -232,6 +243,7 @@ orsf_ice_inb <- function(object,
pred_type = pred_type,
expand_grid = expand_grid,
boundary_checks = boundary_checks,
+ n_thread = n_thread,
oobag = FALSE,
type_output = 'ice')
@@ -247,6 +259,7 @@ orsf_ice_new <- function(object,
na_action = 'fail',
expand_grid = TRUE,
boundary_checks = TRUE,
+ n_thread = 1,
...){
check_dots(list(...), orsf_ice_new)
@@ -259,6 +272,7 @@ orsf_ice_new <- function(object,
na_action = na_action,
expand_grid = expand_grid,
boundary_checks = boundary_checks,
+ n_thread = n_thread,
oobag = FALSE,
type_output = 'ice')
@@ -290,12 +304,16 @@ orsf_pred_dependence <- function(object,
expand_grid,
prob_values = NULL,
prob_labels = NULL,
+ boundary_checks,
+ n_thread,
oobag,
- type_output,
- boundary_checks){
+ type_output){
pred_horizon <- infer_pred_horizon(object, pred_horizon)
+ # make a visible binding for CRAN
+ id_variable = NULL
+
if(is.null(prob_values)) prob_values <- c(0.025, 0.50, 0.975)
if(is.null(prob_labels)) prob_labels <- c('lwr', 'medn', 'upr')
@@ -310,13 +328,6 @@ orsf_pred_dependence <- function(object,
pred_horizon = pred_horizon,
na_action = na_action)
- if(pred_type == 'mort') stop(
- "mortality predictions aren't supported in partial dependence functions",
- " yet. Sorry for the inconvenience - we plan on including this option",
- " in a future update.",
- call. = FALSE
- )
-
if(oobag && is.null(object$data))
stop("no data were found in object. ",
"did you use attach_data = FALSE when ",
@@ -329,8 +340,6 @@ orsf_pred_dependence <- function(object,
pred_type, " predictions.", call. = FALSE)
}
- type_input <- if(expand_grid) 'grid' else 'loop'
-
names_x_data <- intersect(get_names_x(object), names(pd_data))
cc <- which(stats::complete.cases(select_cols(pd_data, names_x_data)))
@@ -345,7 +354,6 @@ orsf_pred_dependence <- function(object,
x_new <- prep_x_from_orsf(object, data = pd_data[cc, ])
-
# the values in pred_spec need to be centered & scaled to match x_new,
# which is also centered and scaled
means <- get_means(object)
@@ -355,272 +363,221 @@ orsf_pred_dependence <- function(object,
pred_spec[[i]] <- (pred_spec[[i]] - means[i]) / standard_deviations[i]
}
- if(is.data.frame(pred_spec)) type_input <- 'grid'
+ pred_type_R <- switch(pred_type,
+ "risk" = 1,
+ "surv" = 2,
+ "chf" = 3,
+ "mort" = 4)
+ fi <- get_fctr_info(object)
- pd_fun_structure <- switch(type_input,
- 'grid' = pd_grid,
- 'loop' = pd_loop)
+ if(expand_grid){
- pd_fun_predict <- switch(paste(oobag, type_output, sep = "_"),
- "TRUE_ice" = pd_oob_ice,
- "TRUE_smry" = pd_oob_smry,
- "FALSE_ice" = pd_new_ice,
- "FALSE_smry" = pd_new_smry)
+ if(!is.data.frame(pred_spec))
+ pred_spec <- expand.grid(pred_spec, stringsAsFactors = TRUE)
- pred_type_cpp <- switch(
- pred_type,
- "risk" = "R",
- "surv" = "S",
- "chf" = "H",
- "mort" = "M"
- )
+ for(i in seq_along(fi$cols)){
- out_list <- lapply(
+ ii <- fi$cols[i]
- X = pred_horizon,
+ if(is.character(pred_spec[[ii]]) && !fi$ordr[i]){
- FUN = function(.pred_horizon){
+ pred_spec[[ii]] <- factor(pred_spec[[ii]], levels = fi$lvls[[ii]])
- pd_fun_structure(object,
- x_new,
- pred_spec,
- .pred_horizon,
- pd_fun_predict,
- type_output,
- prob_values,
- prob_labels,
- oobag,
- pred_type_cpp)
+ }
}
- )
+ check_new_data_fctrs(new_data = pred_spec,
+ names_x = get_names_x(object),
+ fi_ref = fi,
+ label_new = "pred_spec")
- names(out_list) <- as.character(pred_horizon)
+ pred_spec_new <- ref_code(x_data = pred_spec,
+ fi = get_fctr_info(object),
+ names_x_data = names(pred_spec))
- out <- rbindlist(l = out_list,
- fill = TRUE,
- idcol = 'pred_horizon')
+ x_cols <- list(match(names(pred_spec_new), colnames(x_new)) - 1)
- out[, pred_horizon := as.numeric(pred_horizon)]
+ pred_spec_new <- list(as.matrix(pred_spec_new))
- # put data back into original scale
- for(j in intersect(names(means), names(pred_spec))){
+ pd_bind <- list(pred_spec)
- if(j %in% names(out)){
+ } else {
- var_index <- collapse::seq_row(out)
- var_value <- (out[[j]] * standard_deviations[j]) + means[j]
- var_name <- j
+ pred_spec_new <- pd_bind <- x_cols <- list()
- } else {
+ for(i in seq_along(pred_spec)){
- var_index <- out$variable %==% j
- var_value <- (out$value[var_index] * standard_deviations[j]) + means[j]
- var_name <- 'value'
+ pred_spec_new[[i]] <- as.data.frame(pred_spec[i])
+ pd_name <- names(pred_spec)[i]
- }
+ pd_bind[[i]] <- data.frame(
+ variable = pd_name,
+ value = rep(NA_real_, length(pred_spec[[i]])),
+ level = rep(NA_character_, length(pred_spec[[i]]))
+ )
- set(out, i = var_index, j = var_name, value = var_value)
+ if(pd_name %in% fi$cols) {
- }
+ pd_bind[[i]]$level <- as.character(pred_spec[[i]])
- # silent print after modify in place
- out[]
-
- out
-
-}
-
-
-#' grid working function in orsf_pd family
-#'
-#' This function expands pred_spec into a grid with all combos of inputs,
-#' and computes partial dependence for each one.
-#'
-#' @inheritParams orsf_pred_dependence
-#' @param x_new the x-matrix used to compute partial dependence
-#' @param pd_fun_predict which cpp function to use.
-#'
-#' @return a `data.table` containing summarized partial dependence
-#' values if using `orsf_pd_summery` or individual conditional
-#' expectation (ICE) partial dependence if using `orsf_ice`.
-#'
-#' @noRd
-
-pd_grid <- function(object,
- x_new,
- pred_spec,
- pred_horizon,
- pd_fun_predict,
- type_output,
- prob_values,
- prob_labels,
- oobag,
- pred_type_cpp){
-
- if(!is.data.frame(pred_spec))
- pred_spec <- expand.grid(pred_spec, stringsAsFactors = TRUE)
-
- fi_ref <- get_fctr_info(object)
+ pred_spec_new[[i]] <- ref_code(pred_spec_new[[i]],
+ fi = fi,
+ names_x_data = pd_name)
- for(i in seq_along(fi_ref$cols)){
+ } else {
- ii <- fi_ref$cols[i]
+ pd_bind[[i]]$value <- pred_spec[[i]]
- if(is.character(pred_spec[[ii]]) && !fi_ref$ordr[i]){
+ }
- pred_spec[[ii]] <- factor(pred_spec[[ii]],
- levels = fi_ref$lvls[[ii]])
+ x_cols[[i]] <- match(names(pred_spec_new[[i]]), colnames(x_new))-1
+ pred_spec_new[[i]] <- as.matrix(pred_spec_new[[i]])
}
}
- check_new_data_fctrs(new_data = pred_spec,
- names_x = get_names_x(object),
- fi_ref = fi_ref,
- label_new = "pred_spec")
-
- pred_spec_new <- ref_code(x_data = pred_spec,
- fi = get_fctr_info(object),
- names_x_data = names(pred_spec))
-
- x_cols <- match(names(pred_spec_new), colnames(x_new))
-
- pd_vals <- pd_fun_predict(forest = object$forest,
- x_new_ = x_new,
- x_cols_ = x_cols-1,
- x_vals_ = as_matrix(pred_spec_new),
- probs_ = prob_values,
- time_dbl = pred_horizon,
- pred_type = pred_type_cpp)
-
-
- if(type_output == 'smry'){
-
- rownames(pd_vals) <- c('mean', prob_labels)
- output <- cbind(pred_spec, t(pd_vals))
- .names <- names(output)
-
- }
-
- if(type_output == 'ice'){
-
- colnames(pd_vals) <- c('id_variable', 'pred')
- pred_spec$id_variable <- seq(nrow(pred_spec))
- output <- merge(pred_spec, pd_vals, by = 'id_variable')
- output$id_row <- rep(seq(nrow(x_new)), pred_horizon = nrow(pred_spec))
+ orsf_out <- orsf_cpp(x = x_new,
+ y = matrix(1, ncol=2),
+ w = rep(1, nrow(x_new)),
+ tree_type_R = get_tree_type(object),
+ tree_seeds = get_tree_seeds(object),
+ loaded_forest = object$forest,
+ n_tree = get_n_tree(object),
+ mtry = get_mtry(object),
+ sample_with_replacement = get_sample_with_replacement(object),
+ sample_fraction = get_sample_fraction(object),
+ vi_type_R = 0,
+ vi_max_pvalue = get_vi_max_pvalue(object),
+ lincomb_R_function = get_f_beta(object),
+ oobag_R_function = get_f_oobag_eval(object),
+ leaf_min_events = get_leaf_min_events(object),
+ leaf_min_obs = get_leaf_min_obs(object),
+ split_rule_R = switch(get_split_rule(object),
+ "logrank" = 1,
+ "cstat" = 2),
+ split_min_events = get_split_min_events(object),
+ split_min_obs = get_split_min_obs(object),
+ split_min_stat = get_split_min_stat(object),
+ split_max_cuts = get_n_split(object),
+ split_max_retry = get_n_retry(object),
+ lincomb_type_R = switch(get_orsf_type(object),
+ 'fast' = 1,
+ 'cph' = 1,
+ 'random' = 2,
+ 'net' = 3,
+ 'custom' = 4),
+ lincomb_eps = get_cph_eps(object),
+ lincomb_iter_max = get_cph_iter_max(object),
+ lincomb_scale = get_cph_do_scale(object),
+ lincomb_alpha = get_net_alpha(object),
+ lincomb_df_target = get_net_df_target(object),
+ lincomb_ties_method = switch(
+ tolower(get_cph_method(object)),
+ 'breslow' = 0,
+ 'efron' = 1
+ ),
+ pred_type_R = pred_type_R,
+ pred_mode = FALSE,
+ pred_aggregate = TRUE,
+ pred_horizon = pred_horizon,
+ oobag = oobag,
+ oobag_eval_type_R = 0,
+ oobag_eval_every = get_n_tree(object),
+ pd_type_R = switch(type_output,
+ "smry" = 1L,
+ "ice" = 2L),
+ pd_x_vals = pred_spec_new,
+ pd_x_cols = x_cols,
+ pd_probs = prob_values,
+ n_thread = n_thread,
+ write_forest = FALSE,
+ run_forest = TRUE,
+ verbosity = 0)
- ids <- c('id_variable', 'id_row')
- .names <- c(ids, setdiff(names(output), ids))
+ pd_vals <- orsf_out$pd_values
- }
+ for(i in seq_along(pd_vals)){
- as.data.table(output[, .names])
+ pd_bind[[i]]$id_variable <- seq(nrow(pd_bind[[i]]))
-}
-
-#' loop working function in orsf_pd family
-#'
-#' This function loops through the items in pred_spec one by one,
-#' computing partial dependence for each one separately.
-#'
-#' @inheritParams orsf_pd_
-#' @param x_new the x-matrix used to compute partial dependence
-#' @param pd_fun_predict which cpp function to use.
-#'
-#' @return a `data.table` containing summarized partial dependence
-#' values if using `orsf_pd_summery` or individual conditional
-#' expectation (ICE) partial dependence if using `orsf_ice`.
-#'
-#' @noRd
+ for(j in seq_along(pd_vals[[i]])){
-pd_loop <- function(object,
- x_new,
- pred_spec,
- pred_horizon,
- pd_fun_predict,
- type_output,
- prob_values,
- prob_labels,
- oobag,
- pred_type_cpp){
+ pd_vals[[i]][[j]] <- matrix(pd_vals[[i]][[j]],
+ nrow=length(pred_horizon),
+ byrow = T)
- fi <- get_fctr_info(object)
+ rownames(pd_vals[[i]][[j]]) <- pred_horizon
- output <- vector(mode = 'list', length = length(pred_spec))
+ if(type_output=='smry')
+ colnames(pd_vals[[i]][[j]]) <- c('mean', prob_labels)
+ else
+ colnames(pd_vals[[i]][[j]]) <- c(paste(1:nrow(x_new)))
- for(i in seq_along(pred_spec)){
+ pd_vals[[i]][[j]] <- as.data.table(pd_vals[[i]][[j]],
+ keep.rownames = 'pred_horizon')
- pd_new <- as.data.frame(pred_spec[i])
- pd_name <- names(pred_spec)[i]
+ if(type_output == 'ice')
+ pd_vals[[i]][[j]] <- melt(data = pd_vals[[i]][[j]],
+ id.vars = 'pred_horizon',
+ variable.name = 'id_row',
+ value.name = 'pred')
- pd_bind <- data.frame(variable = pd_name,
- value = rep(NA_real_, length(pred_spec[[i]])),
- level = rep(NA_character_, length(pred_spec[[i]])))
+ }
- if(pd_name %in% fi$cols) {
+ pd_vals[[i]] <- rbindlist(pd_vals[[i]], idcol = 'id_variable')
- pd_bind$level <- as.character(pred_spec[[i]])
+ pd_vals[[i]] <- merge(pd_vals[[i]],
+ as.data.table(pd_bind[[i]]),
+ by = 'id_variable')
- pd_new <- ref_code(pd_new,
- fi = fi,
- names_x_data = pd_name)
- } else {
+ }
- pd_bind$value <- pred_spec[[i]]
- }
+ out <- rbindlist(pd_vals)
- x_cols <- match(names(pd_new), colnames(x_new))
+ ids <- c('id_variable', if(type_output == 'ice') 'id_row')
- x_vals <- x_new[, x_cols]
+ mid <- setdiff(names(out), c(ids, 'mean', prob_labels, 'pred'))
+ end <- setdiff(names(out), c(ids, mid))
- pd_vals <- pd_fun_predict(forest = object$forest,
- x_new_ = x_new,
- x_cols_ = x_cols-1,
- x_vals_ = as.matrix(pd_new),
- probs_ = prob_values,
- time_dbl = pred_horizon,
- pred_type = pred_type_cpp)
+ setcolorder(out, neworder = c(ids, mid, end))
+ out[, pred_horizon := as.numeric(pred_horizon)]
- # pd_fun_predict modifies x_new by reference, so reset it.
- x_new[, x_cols] <- x_vals
+ # not needed for summary
+ if(type_output == 'smry')
+ out[, id_variable := NULL]
- if(type_output == 'smry'){
+ # put data back into original scale
+ for(j in intersect(names(means), names(pred_spec))){
- rownames(pd_vals) <- c('mean', prob_labels)
- output[[i]] <- cbind(pd_bind, t(pd_vals))
+ if(j %in% names(out)){
- }
+ var_index <- collapse::seq_row(out)
+ var_value <- (out[[j]] * standard_deviations[j]) + means[j]
+ var_name <- j
- if(type_output == 'ice'){
+ } else {
- colnames(pd_vals) <- c('id_variable', 'pred')
- pd_bind$id_variable <- seq(nrow(pd_bind))
- output[[i]] <- merge(pd_bind, pd_vals, by = 'id_variable')
- output[[i]]$id_row <- seq(nrow(output[[i]]))
+ var_index <- out$variable %==% j
+ var_value <- (out$value[var_index] * standard_deviations[j]) + means[j]
+ var_name <- 'value'
}
- }
-
- output <- rbindlist(output)
-
- if(type_output == 'ice'){
-
- ids <- c('id_variable', 'id_row')
- .names <- c(ids, setdiff(names(output), ids))
- setcolorder(output, neworder = .names)
+ set(out, i = var_index, j = var_name, value = var_value)
}
+ # silent print after modify in place
+ out[]
- output
+ out
}
+
diff --git a/R/orsf_predict.R b/R/orsf_predict.R
index ab784490..6cbae26f 100644
--- a/R/orsf_predict.R
+++ b/R/orsf_predict.R
@@ -45,6 +45,18 @@
#' observed time in `object`'s training data. If `FALSE`, these checks
#' are skipped.
#'
+#' @param n_thread `r roxy_n_thread_header("computing predictions")`
+#'
+#' @param pred_aggregate (_logical_) If `TRUE` (the default), predictions
+#' will be aggregated over all trees by taking the mean. If `FALSE`, the
+#' returned output will contain one row per observation and one column
+#' for each tree. If the length of `pred_horizon` is two or more and
+#' `pred_aggregate` is `FALSE`, then the result will be a list of such
+#' matrices, with the i'th item in the list corresponding to the i'th
+#' value of `pred_horizon`.
+#'
+#' @inheritParams orsf
+#'
#' @param ... `r roxy_dots()`
#'
#' @return a `matrix` of predictions. Column `j` of the matrix corresponds
@@ -79,6 +91,9 @@ predict.orsf_fit <- function(object,
pred_type = 'risk',
na_action = 'fail',
boundary_checks = TRUE,
+ n_thread = 1,
+ verbose_progress = FALSE,
+ pred_aggregate = TRUE,
...){
# catch any arguments that didn't match and got relegated to ...
@@ -87,6 +102,21 @@ predict.orsf_fit <- function(object,
names_x_data <- intersect(get_names_x(object), names(new_data))
+ if(pred_type %in% c('leaf', 'mort') && !is.null(pred_horizon)){
+
+ extra_text <- if(length(pred_horizon)>1){
+ " Predictions at each value of pred_horizon will be identical."
+ } else {
+ ""
+ }
+
+ warning("pred_horizon does not impact predictions",
+ " when pred_type is '", pred_type, "'.",
+ extra_text, call. = FALSE)
+ # avoid copies of predictions and copies of this warning.
+ pred_horizon <- pred_horizon[1]
+ }
+
pred_horizon <- infer_pred_horizon(object, pred_horizon)
check_predict(object = object,
@@ -96,6 +126,27 @@ predict.orsf_fit <- function(object,
na_action = na_action,
boundary_checks = boundary_checks)
+ if(length(pred_horizon) > 1 && !pred_aggregate){
+
+ results <- lapply(
+ X = pred_horizon,
+ FUN = function(t){
+ predict.orsf_fit(object = object,
+ new_data = new_data,
+ pred_horizon = t,
+ pred_type = pred_type,
+ na_action = na_action,
+ boundary_checks = boundary_checks,
+ n_thread = n_thread,
+ verbose_progress = verbose_progress,
+ pred_aggregate = pred_aggregate)
+ }
+ )
+
+ return(simplify2array(results))
+
+ }
+
pred_horizon_order <- order(pred_horizon)
pred_horizon_ordered <- pred_horizon[pred_horizon_order]
@@ -123,33 +174,73 @@ predict.orsf_fit <- function(object,
x_new <- prep_x_from_orsf(object, data = new_data[cc, ])
- # x_new <- as.matrix(
- # ref_code(x_data = new_data[cc, ],
- # fi = get_fctr_info(object),
- # names_x_data = names_x_data)
- # )
-
- pred_type_cpp <- switch(
- pred_type,
- "risk" = "R",
- "surv" = "S",
- "chf" = "H",
- "mort" = "M"
- )
-
- out_values <-
- if(pred_type_cpp == "M"){
- orsf_pred_mort(object, x_new)
- } else if (length(pred_horizon) == 1L) {
- orsf_pred_uni(object$forest, x_new, pred_horizon_ordered, pred_type_cpp)
- } else {
- orsf_pred_multi(object$forest, x_new, pred_horizon_ordered, pred_type_cpp)
- }
+ orsf_out <- orsf_cpp(x = x_new,
+ y = matrix(1, ncol=2),
+ w = rep(1, nrow(x_new)),
+ tree_type_R = get_tree_type(object),
+ tree_seeds = get_tree_seeds(object),
+ loaded_forest = object$forest,
+ n_tree = get_n_tree(object),
+ mtry = get_mtry(object),
+ sample_with_replacement = get_sample_with_replacement(object),
+ sample_fraction = get_sample_fraction(object),
+ vi_type_R = 0,
+ vi_max_pvalue = get_vi_max_pvalue(object),
+ lincomb_R_function = get_f_beta(object),
+ oobag_R_function = get_f_oobag_eval(object),
+ leaf_min_events = get_leaf_min_events(object),
+ leaf_min_obs = get_leaf_min_obs(object),
+ split_rule_R = switch(get_split_rule(object),
+ "logrank" = 1,
+ "cstat" = 2),
+ split_min_events = get_split_min_events(object),
+ split_min_obs = get_split_min_obs(object),
+ split_min_stat = get_split_min_stat(object),
+ split_max_cuts = get_n_split(object),
+ split_max_retry = get_n_retry(object),
+ lincomb_type_R = switch(get_orsf_type(object),
+ 'fast' = 1,
+ 'cph' = 1,
+ 'random' = 2,
+ 'net' = 3,
+ 'custom' = 4),
+ lincomb_eps = get_cph_eps(object),
+ lincomb_iter_max = get_cph_iter_max(object),
+ lincomb_scale = get_cph_do_scale(object),
+ lincomb_alpha = get_net_alpha(object),
+ lincomb_df_target = get_net_df_target(object),
+ lincomb_ties_method = switch(
+ tolower(get_cph_method(object)),
+ 'breslow' = 0,
+ 'efron' = 1
+ ),
+ pred_type_R = switch(pred_type,
+ "risk" = 1,
+ "surv" = 2,
+ "chf" = 3,
+ "mort" = 4,
+ "leaf" = 8),
+ pred_mode = TRUE,
+ pred_aggregate = pred_aggregate,
+ pred_horizon = pred_horizon_ordered,
+ oobag = FALSE,
+ oobag_eval_type_R = 0,
+ oobag_eval_every = get_n_tree(object),
+ pd_type_R = 0,
+ pd_x_vals = list(matrix(0, ncol=1, nrow=1)),
+ pd_x_cols = list(matrix(1L, ncol=1, nrow=1)),
+ pd_probs = c(0),
+ n_thread = n_thread,
+ write_forest = FALSE,
+ run_forest = TRUE,
+ verbosity = as.integer(verbose_progress))
+
+ out_values <- orsf_out$pred_new
if(na_action == "pass"){
out <- matrix(nrow = nrow(new_data),
- ncol = length(pred_horizon))
+ ncol = ncol(out_values))
out[cc, ] <- out_values
@@ -159,20 +250,10 @@ predict.orsf_fit <- function(object,
}
+ if(pred_type == "leaf" || !pred_aggregate) return(out)
+
# output in the same order as pred_horizon
out[, order(pred_horizon_order), drop = FALSE]
}
-orsf_pred_mort <- function(object, x_new){
-
- pred_mat <- orsf_pred_multi(object$forest,
- x_new = x_new,
- time_vec = get_event_times(object),
- pred_type = 'H')
-
- matrix(apply(pred_mat, MARGIN = 1, FUN = sum), ncol = 1)
-
-}
-
-
diff --git a/R/orsf_scale_cph.R b/R/orsf_scale_cph.R
index c858c32f..958b023c 100644
--- a/R/orsf_scale_cph.R
+++ b/R/orsf_scale_cph.R
@@ -78,7 +78,7 @@ orsf_scale_cph <- function(x_mat, w_vec = NULL){
call. = FALSE)
# pass x[, ] instead of x to prevent x from being modified in place.
- output <- x_node_scale_exported(x_mat[, ], w_vec)
+ output <- cph_scale(x_mat[, ], w_vec)
colnames(output$x_scaled) <- colnames(x_mat)
colnames(output$x_transforms) <- c("mean", "scale")
diff --git a/R/orsf_vi.R b/R/orsf_vi.R
index 60b8398a..d9afae1d 100644
--- a/R/orsf_vi.R
+++ b/R/orsf_vi.R
@@ -83,6 +83,8 @@ orsf_vi <- function(object,
group_factors = TRUE,
importance = NULL,
oobag_fun = NULL,
+ n_thread = 1,
+ verbose_progress = FALSE,
...){
check_dots(list(...), .f = orsf_vi)
@@ -110,30 +112,60 @@ orsf_vi <- function(object,
orsf_vi_(object,
group_factors = group_factors,
type_vi = type_vi,
- oobag_fun = oobag_fun)
+ oobag_fun = oobag_fun,
+ n_thread = n_thread,
+ verbose_progress = verbose_progress)
}
#' @rdname orsf_vi
#' @export
-orsf_vi_negate <- function(object, group_factors = TRUE, oobag_fun = NULL, ...){
- check_dots(list(...), .f = orsf_vi_negate)
- orsf_vi_(object, group_factors, type_vi = 'negate', oobag_fun = oobag_fun)
-}
+orsf_vi_negate <-
+ function(object,
+ group_factors = TRUE,
+ oobag_fun = NULL,
+ n_thread = 1,
+ verbose_progress = FALSE,
+ ...) {
+ check_dots(list(...), .f = orsf_vi_negate)
+ orsf_vi_(object,
+ group_factors,
+ type_vi = 'negate',
+ oobag_fun = oobag_fun,
+ n_thread = n_thread,
+ verbose_progress = verbose_progress)
+ }
#' @rdname orsf_vi
#' @export
-orsf_vi_permute <- function(object, group_factors = TRUE, oobag_fun = NULL, ...){
- check_dots(list(...), .f = orsf_vi_permute)
- orsf_vi_(object, group_factors, type_vi = 'permute', oobag_fun = oobag_fun)
-}
+orsf_vi_permute <-
+ function(object,
+ group_factors = TRUE,
+ oobag_fun = NULL,
+ n_thread = 1,
+ verbose_progress = FALSE,
+ ...) {
+ check_dots(list(...), .f = orsf_vi_permute)
+ orsf_vi_(object,
+ group_factors,
+ type_vi = 'permute',
+ oobag_fun = oobag_fun,
+ n_thread = n_thread,
+ verbose_progress = verbose_progress)
+ }
#' @rdname orsf_vi
#' @export
-orsf_vi_anova <- function(object, group_factors = TRUE, ...){
+orsf_vi_anova <- function(object,
+ group_factors = TRUE,
+ ...) {
check_dots(list(...), .f = orsf_vi_anova)
- orsf_vi_(object, group_factors, type_vi = 'anova', oobag_fun = NULL)
+ orsf_vi_(object,
+ group_factors,
+ type_vi = 'anova',
+ oobag_fun = NULL,
+ verbose_progress = FALSE)
}
#' Variable importance working function
@@ -143,7 +175,12 @@ orsf_vi_anova <- function(object, group_factors = TRUE, ...){
#'
#' @noRd
#'
-orsf_vi_ <- function(object, group_factors, type_vi, oobag_fun = NULL){
+orsf_vi_ <- function(object,
+ group_factors,
+ type_vi,
+ oobag_fun,
+ n_thread,
+ verbose_progress){
#' @srrstats {G2.8} *As part of initial pre-processing, run checks on inputs to ensure that all other sub-functions receive inputs of a single defined class or type.*
@@ -156,10 +193,14 @@ orsf_vi_ <- function(object, group_factors, type_vi, oobag_fun = NULL){
" orsf object with importance = 'anova'",
call. = FALSE)
- out <- switch(type_vi,
- 'anova' = as.matrix(get_importance_values(object)),
- 'negate' = orsf_vi_oobag_(object, type_vi, oobag_fun),
- 'permute' = orsf_vi_oobag_(object, type_vi, oobag_fun))
+ out <- switch(
+ type_vi,
+ 'anova' = as.matrix(get_importance_values(object)),
+ 'negate' = orsf_vi_oobag_(object, type_vi, oobag_fun,
+ n_thread, verbose_progress),
+ 'permute' = orsf_vi_oobag_(object, type_vi, oobag_fun,
+ n_thread, verbose_progress)
+ )
if(group_factors) {
@@ -209,15 +250,20 @@ orsf_vi_ <- function(object, group_factors, type_vi, oobag_fun = NULL){
#'
#' @noRd
#'
-orsf_vi_oobag_ <- function(object, type_vi, oobag_fun){
-
- if(!contains_oobag(object)){
- stop("cannot compute ",
- switch(type_vi, 'negate' = 'negation', 'permute' = 'permutation'),
- " importance if the orsf_fit object does not have out-of-bag error",
- " (see oobag_pred in ?orsf).",
- call. = FALSE)
- }
+orsf_vi_oobag_ <- function(object,
+ type_vi,
+ oobag_fun,
+ n_thread,
+ verbose_progress){
+
+ # can remove this b/c prediction accuracy is now computed at tree level
+ # if(!contains_oobag(object)){
+ # stop("cannot compute ",
+ # switch(type_vi, 'negate' = 'negation', 'permute' = 'permutation'),
+ # " importance if the orsf_fit object does not have out-of-bag error",
+ # " (see oobag_pred in ?orsf).",
+ # call. = FALSE)
+ # }
if(contains_vi(object) &&
is.null(oobag_fun) &&
@@ -234,13 +280,13 @@ orsf_vi_oobag_ <- function(object, type_vi, oobag_fun){
if(is.null(oobag_fun)){
f_oobag_eval <- function(x) x
- type_oobag_eval <- 'H'
+ type_oobag_eval <- 'cstat'
} else {
check_oobag_fun(oobag_fun)
f_oobag_eval <- oobag_fun
- type_oobag_eval <- 'U'
+ type_oobag_eval <- 'user'
}
@@ -250,43 +296,70 @@ orsf_vi_oobag_ <- function(object, type_vi, oobag_fun){
# Put data in the same order that it was in when object was fit
sorted <- order(y[, 1], -y[, 2])
-
- if(is.null(oobag_fun)) {
-
- last_eval_stat <-
- last_value(object$eval_oobag$stat_values[, 1, drop=TRUE])
-
- } else {
-
- last_eval_stat <-
- f_oobag_eval(y_mat = y, s_vec = object$pred_oobag)
-
- }
-
- f_oobag_vi <- switch(
- type_vi,
- 'negate' = orsf_oob_negate_vi,
- 'permute' = orsf_oob_permute_vi
- )
-
- pred_type <- switch(
- get_oobag_pred_type(object),
- "surv" = "S",
- "risk" = "R",
- "chf" = "H"
- )
-
- out <- f_oobag_vi(x = x[sorted, ],
- y = y[sorted, ],
- forest = object$forest,
- last_eval_stat = last_eval_stat,
- time_pred_ = object$pred_horizon,
- f_oobag_eval = f_oobag_eval,
- pred_type_ = pred_type,
- type_oobag_eval_ = type_oobag_eval)
-
+ pred_type <- 'mort'
+
+ orsf_out <- orsf_cpp(x = x[sorted, , drop = FALSE],
+ y = y[sorted, , drop = FALSE],
+ w = get_weights_user(object),
+ tree_type_R = get_tree_type(object),
+ tree_seeds = get_tree_seeds(object),
+ loaded_forest = object$forest,
+ n_tree = get_n_tree(object),
+ mtry = get_mtry(object),
+ sample_with_replacement = get_sample_with_replacement(object),
+ sample_fraction = get_sample_fraction(object),
+ vi_type_R = switch(type_vi,
+ 'negate' = 1,
+ 'permute' = 2),
+ vi_max_pvalue = get_vi_max_pvalue(object),
+ lincomb_R_function = get_f_beta(object),
+ oobag_R_function = f_oobag_eval,
+ leaf_min_events = get_leaf_min_events(object),
+ leaf_min_obs = get_leaf_min_obs(object),
+ split_rule_R = switch(get_split_rule(object),
+ "logrank" = 1,
+ "cstat" = 2),
+ split_min_events = get_split_min_events(object),
+ split_min_obs = get_split_min_obs(object),
+ split_min_stat = get_split_min_stat(object),
+ split_max_cuts = get_n_split(object),
+ split_max_retry = get_n_retry(object),
+ lincomb_type_R = switch(get_orsf_type(object),
+ 'fast' = 1,
+ 'cph' = 1,
+ 'random' = 2,
+ 'net' = 3,
+ 'custom' = 4),
+ lincomb_eps = get_cph_eps(object),
+ lincomb_iter_max = get_cph_iter_max(object),
+ lincomb_scale = get_cph_do_scale(object),
+ lincomb_alpha = get_net_alpha(object),
+ lincomb_df_target = get_net_df_target(object),
+ lincomb_ties_method = switch(
+ tolower(get_cph_method(object)),
+ 'breslow' = 0,
+ 'efron' = 1
+ ),
+ pred_type_R = 4,
+ pred_mode = FALSE,
+ pred_aggregate = TRUE,
+ pred_horizon = get_oobag_pred_horizon(object),
+ oobag = FALSE,
+ oobag_eval_type_R = switch(type_oobag_eval,
+ 'cstat' = 1,
+ 'user' = 2),
+ oobag_eval_every = get_n_tree(object),
+ pd_type_R = 0,
+ pd_x_vals = list(matrix(0, ncol=1, nrow=1)),
+ pd_x_cols = list(matrix(1L, ncol=1, nrow=1)),
+ pd_probs = c(0),
+ n_thread = n_thread,
+ write_forest = FALSE,
+ run_forest = TRUE,
+ verbosity = as.integer(verbose_progress))
+
+ out <- orsf_out$importance
rownames(out) <- colnames(x)
-
out
}
diff --git a/R/penalized_cph.R b/R/penalized_cph.R
index 11d98956..733d09c5 100644
--- a/R/penalized_cph.R
+++ b/R/penalized_cph.R
@@ -29,6 +29,8 @@ penalized_cph <- function(x_node,
alpha,
df_target){
+ colnames(y_node) <- c('time', 'status')
+
suppressWarnings(
fit <- try(
glmnet::glmnet(x = x_node,
diff --git a/R/roxy.R b/R/roxy.R
index 8fb1d5a4..509fcac4 100644
--- a/R/roxy.R
+++ b/R/roxy.R
@@ -14,6 +14,20 @@ roxy_data_allowed <- function(){
)
}
+# multi-threading ---------------------------------------------------------
+
+roxy_n_thread_header <- function(action){
+ paste("(_integer_) number of threads to use while ",
+ action, ". Default is one thread. ",
+ "To use the maximum number of threads that ",
+ "your system provides for concurrent execution, ",
+ "set `n_thread = 0`.", sep = "")
+}
+
+roxy_n_thread_details <- function(){
+ "(_integer_) number of threads to use. Default is one thread."
+}
+
# importance --------------------------------------------------------------
roxy_importance_header <- function(){
diff --git a/Rmd/orsf_examples.Rmd b/Rmd/orsf_examples.Rmd
index 4c79596a..0ff1fdc2 100644
--- a/Rmd/orsf_examples.Rmd
+++ b/Rmd/orsf_examples.Rmd
@@ -107,7 +107,7 @@ Let's make two customized functions to identify linear combinations of predictor
# estimate two principal components.
pca <- stats::prcomp(x_node, rank. = 2)
# use the second principal component to split the node
- pca$rotation[, 2L, drop = FALSE]
+ pca$rotation[, 1L, drop = FALSE]
}
@@ -163,11 +163,11 @@ sc$Brier$score[order(-IPA), .(model, times, IPA)]
From inspection,
-- the PCA approach has the highest discrimination, showing that you can do very well with just a two line custom function.
+- the `glmnet` approach has the highest discrimination and index of prediction accuracy.
-- the accelerated ORSF has the highest index of prediction accuracy
+- the accelerated ORSF is a close second.
-- the random coefficients generally don't do that well.
+- the random coefficients don't do that well, but they aren't bad.
## tidymodels
@@ -279,8 +279,8 @@ Score(
From inspection,
- `aorsf` obtained slightly higher discrimination (AUC)
+
- `aorsf` obtained higher index of prediction accuracy (IPA)
-- Way to go, `aorsf`
## mlr3 pipelines
@@ -443,9 +443,9 @@ tbl_data <-
structure(
list(
learner_id = c("surv.aorsf", "surv.ranger", "surv.rfsrc"),
- surv.graf = c(0.151447953930207, 0.166799975594481, 0.15586242346754),
- surv.cindex = c(0.729057822587355, 0.706476104709337, 0.714969112063354),
- time_train = c(0.344528301886968, 2.53641509434031, 0.782641509433885)
+ surv.graf = c(0.151771237677512, 0.166032273495838, 0.155174775571719),
+ surv.cindex = c(0.733123595064337, 0.71210747198625, 0.723016206784682),
+ time_train = c(1.41181818181788, 1.95254545454584, 0.744727272727191)
),
row.names = c(NA, -3L),
class = c("tbl_df", "tbl", "data.frame")
@@ -453,20 +453,11 @@ tbl_data <-
tbl_data
-# knitr::kable(tbl_data,
-# col.names = c('Learner',
-# 'Brier score',
-# 'C-index',
-# 'Time to train')) %>%
-# kableExtra::kable_styling()
-
```
From inspection,
-- `aorsf` appears to have a higher expected value for 'surv.cindex' (higher is better)
-- `aorsf` appears to have a lower expected value for 'surv.graf' (lower is better)
-- `aorsf` has the lowest training time.
+- `aorsf` has a higher expected value for 'surv.cindex' (higher is better)
-the lower training time for `aorsf` is likely due to the fact that there are many unique event times in the benchmark tasks. `ranger` and `rfsrc` create grids of time points based on each unique event time in each leaf of each decision tree, whereas `aorsf` also uses a grid but restricts it to the unique event times among observations in the current leaf.
+- `aorsf` has a lower expected value for 'surv.graf' (lower is better)
diff --git a/man/orsf.Rd b/man/orsf.Rd
index efa3e7b8..cf7cbcc9 100644
--- a/man/orsf.Rd
+++ b/man/orsf.Rd
@@ -13,12 +13,16 @@ orsf(
n_tree = 500,
n_split = 5,
n_retry = 3,
+ n_thread = 1,
mtry = NULL,
+ sample_with_replacement = TRUE,
+ sample_fraction = 0.632,
leaf_min_events = 1,
leaf_min_obs = 5,
+ split_rule = "logrank",
split_min_events = 5,
split_min_obs = 10,
- split_min_stat = 3.841459,
+ split_min_stat = switch(split_rule, logrank = 3.841459, cstat = 0.5),
oobag_pred_type = "surv",
oobag_pred_horizon = NULL,
oobag_eval_every = n_tree,
@@ -79,16 +83,36 @@ will try again with a new linear combination based on a different set
of randomly selected predictors, up to \code{n_retry} times. Default is
\code{n_retry = 3}. Set \code{n_retry = 0} to prevent any retries.}
+\item{n_thread}{(\emph{integer}) number of threads to use while growing trees, computing predictions, and computing importance. Default is one thread. To use the maximum number of threads that your system provides for concurrent execution, set \code{n_thread = 0}.}
+
\item{mtry}{(\emph{integer}) Number of predictors randomly included as candidates
for splitting a node. The default is the smallest integer greater than
the square root of the number of total predictors, i.e.,
\verb{mtry = ceiling(sqrt(number of predictors))}}
+\item{sample_with_replacement}{(\emph{logical}) If \code{TRUE} (the default),
+observations are sampled with replacement when an in-bag sample
+is created for a decision tree. If \code{FALSE}, observations are
+sampled without replacement and each tree will have an in-bag sample
+containing \code{sample_fraction}\% of the original sample.}
+
+\item{sample_fraction}{(\emph{double}) the proportion of observations that
+each trees' in-bag sample will contain, relative to the number of
+rows in \code{data}. Only used if \code{sample_with_replacement} is \code{FALSE}.
+Default value is 0.632.}
+
\item{leaf_min_events}{(\emph{integer}) minimum number of events in a
leaf node. Default is \code{leaf_min_events = 1}}
\item{leaf_min_obs}{(\emph{integer}) minimum number of observations in a
-leaf node. Default is \code{leaf_min_obs = 5}}
+leaf node. Default is \code{leaf_min_obs = 5}.}
+
+\item{split_rule}{(\emph{character}) how to assess the quality of a potential
+splitting rule for a node. Valid options are
+\itemize{
+\item 'logrank' : a log-rank test statistic.
+\item 'cstat' : Harrell's concordance statistic.
+}}
\item{split_min_events}{(\emph{integer}) minimum number of events required
in a node to consider splitting it. Default is \code{split_min_events = 5}}
@@ -97,20 +121,22 @@ in a node to consider splitting it. Default is \code{split_min_events = 5}}
in a node to consider splitting it. Default is \code{split_min_obs = 10}.}
\item{split_min_stat}{(double) minimum test statistic required to split
-a node. Default is 3.841459 for the log-rank test, which is roughly
-a p-value of 0.05}
+a node. Default is 3.841459 if \code{split_rule = 'logrank'} and 0.50 if
+\code{split_rule = 'cstat'}. If no splits are found with a statistic
+exceeding \code{split_min_stat}, the given node either becomes a leaf or
+a retry occurs (up to \code{n_retry} retries).}
\item{oobag_pred_type}{(\emph{character}) The type of out-of-bag predictions
to compute while fitting the ensemble. Valid options are
\itemize{
\item 'none' : don't compute out-of-bag predictions
-\item 'risk' : predict the probability of having an event at or before \code{oobag_pred_horizon}.
+\item 'risk' : probability of event occurring at or before \code{oobag_pred_horizon}.
\item 'surv' : 1 - risk.
-\item 'chf' : predict cumulative hazard function
-}
-
-Mortality ('mort')is not implemented for out of bag predictions yet, but it
-will be in a future update.}
+\item 'chf' : cumulative hazard function at \code{oobag_pred_horizon}.
+\item 'mort' : mortality, i.e., the number of events expected if all
+observations in the training data were identical to a
+given observation.
+}}
\item{oobag_pred_horizon}{(\emph{numeric}) A numeric value indicating what time
should be used for out-of-bag predictions. Default is the median
@@ -172,7 +198,7 @@ to the output will be the imputed version of \code{data}.
}}
\item{verbose_progress}{(\emph{logical}) if \code{TRUE}, progress messages are
-printed in the console.}
+printed in the console. If \code{FALSE} (the default), nothing is printed.}
\item{...}{Further arguments passed to or from other methods (not currently used).}
@@ -235,6 +261,13 @@ occur when using \link{orsf_control_net}.
If \code{oobag_fun} is specified, it will be used in to compute negation
importance or permutation importance, but it will not have any role
for ANOVA importance.
+
+\strong{n_thread}:
+
+If an R function must be called from C++ (i.e., user-supplied function to
+compute out-of-bag error or identify linear combinations of variables),
+\code{n_thread} will automatically be set to 1 because attempting to run R
+functions in multiple threads will cause the R session to crash.
}
\section{What is an oblique decision tree?}{
@@ -326,7 +359,7 @@ printing \code{fit} provides quick descriptive summaries:
## N trees: 500
## N predictors total: 17
## N predictors per node: 5
-## Average leaves per tree: 24
+## Average leaves per tree: 25
## Min observations in leaf: 5
## Min events in leaf: 1
## OOB stat value: 0.84
@@ -403,7 +436,7 @@ predictors.
# estimate two principal components.
pca <- stats::prcomp(x_node, rank. = 2)
# use the second principal component to split the node
- pca$rotation[, 2L, drop = FALSE]
+ pca$rotation[, 1L, drop = FALSE]
\}
}\if{html}{\out{}}
@@ -447,11 +480,11 @@ The AUC values, from highest to lowest:
}\if{html}{\out{}}
\if{html}{\out{
}}\preformatted{## model times AUC se lower upper
-## 1: net 1788 0.9107925 0.02116880 0.8693024 0.9522826
-## 2: accel 1788 0.9106308 0.02178112 0.8679406 0.9533210
-## 3: cph 1788 0.9072690 0.02120139 0.8657150 0.9488229
-## 4: pca 1788 0.8915619 0.02335399 0.8457889 0.9373349
-## 5: rando 1788 0.8900944 0.02228487 0.8464168 0.9337719
+## 1: accel 1788 0.9095660 0.02113628 0.8681397 0.9509924
+## 2: net 1788 0.9093490 0.02158187 0.8670493 0.9516487
+## 3: cph 1788 0.9066412 0.02196233 0.8635958 0.9496866
+## 4: rando 1788 0.9013929 0.02194349 0.8583845 0.9444014
+## 5: pca 1788 0.9001017 0.02233370 0.8563284 0.9438749
}\if{html}{\out{
}}
And the indices of prediction accuracy:
@@ -460,20 +493,20 @@ And the indices of prediction accuracy:
}\if{html}{\out{}}
\if{html}{\out{}}\preformatted{## model times IPA
-## 1: accel 1788 0.4891448
-## 2: cph 1788 0.4687734
-## 3: net 1788 0.4652211
-## 4: rando 1788 0.4011573
-## 5: pca 1788 0.3845911
+## 1: accel 1788 0.4812191
+## 2: net 1788 0.4810210
+## 3: cph 1788 0.4735707
+## 4: pca 1788 0.4408537
+## 5: rando 1788 0.4240110
## 6: Null model 1788 0.0000000
}\if{html}{\out{
}}
From inspection,
\itemize{
-\item the PCA approach has the highest discrimination, showing that you can
-do very well with just a two line custom function.
-\item the accelerated ORSF has the highest index of prediction accuracy
-\item the random coefficients generally don’t do that well.
+\item the \code{glmnet} approach has the highest discrimination and index of
+prediction accuracy.
+\item the accelerated ORSF is a close second.
+\item the random coefficients don’t do that well, but they aren’t bad.
}
}
@@ -577,29 +610,29 @@ glimpse(results)
\if{html}{\out{}}\preformatted{## Rows: 276
## Columns: 23
-## $ id 2, 16, 27, 66, 79, 97, 107, 116, 136, 137, 158, 189, 193, ~
-## $ trt d_penicill_main, placebo, placebo, d_penicill_main, d_peni~
-## $ age 56.44627, 40.44353, 54.43943, 46.45311, 46.51608, 71.89322~
-## $ sex f, f, f, m, f, m, f, f, f, f, f, f, f, f, f, f, f, f, f, f~
-## $ ascites 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0~
-## $ hepato 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1~
-## $ spiders 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1~
-## $ edema 0, 0, 0.5, 0, 0, 0.5, 0, 0.5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0~
-## $ bili 1.1, 0.7, 21.6, 1.4, 0.8, 2.0, 0.6, 3.0, 0.8, 1.1, 3.4, 1.~
-## $ chol 302, 204, 175, 427, 315, 420, 212, 458, 263, 399, 450, 360~
-## $ albumin 4.14, 3.66, 3.31, 3.70, 4.24, 3.26, 4.03, 3.63, 3.35, 3.60~
-## $ copper 54, 28, 221, 105, 13, 62, 10, 74, 27, 79, 32, 52, 267, 76,~
-## $ alk.phos 7394.8, 685.0, 3697.4, 1909.0, 1637.0, 3196.0, 648.0, 1588~
-## $ ast 113.52, 72.85, 101.91, 182.90, 170.50, 77.50, 71.30, 106.9~
-## $ trig 88, 58, 168, 171, 70, 91, 77, 382, 69, 152, 118, 164, 157,~
-## $ platelet 221, 198, 80, 123, 426, 344, 316, 438, 206, 344, 313, 256,~
-## $ protime 10.6, 10.8, 12.0, 11.0, 10.9, 11.4, 17.1, 9.9, 9.8, 10.1, ~
-## $ stage 3, 3, 4, 3, 3, 3, 1, 3, 2, 2, 2, 3, 4, 4, 2, 2, 3, 3, 4, 4~
-## $ time 4500, 3672, 77, 4191, 3707, 611, 3388, 3336, 3098, 2990, 2~
-## $ status 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0~
-## $ pred_aorsf 0.21650571, 0.01569191, 0.93095617, 0.36737089, 0.12868206~
-## $ pred_rfsrc 0.15202784, 0.01104486, 0.81913559, 0.20173550, 0.13806608~
-## $ pred_ranger 0.11418963, 0.02130315, 0.77073269, 0.22130305, 0.18419972~
+## $ id 17, 23, 34, 43, 50, 51, 61, 71, 78, 80, 92, 97, 100, 121, ~
+## $ trt placebo, placebo, d_penicill_main, d_penicill_main, d_peni~
+## $ age 52.18344, 55.96715, 52.06023, 48.87064, 53.50856, 52.08761~
+## $ sex f, f, f, f, f, f, m, f, f, m, f, m, m, m, f, m, f, f, f, f~
+## $ ascites 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0~
+## $ hepato 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0~
+## $ spiders 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0~
+## $ edema 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0.5, 0, 1, 0, 0, 0, 0, 0,~
+## $ bili 2.7, 17.4, 0.8, 1.1, 1.1, 0.8, 0.6, 1.2, 6.3, 7.2, 1.4, 2.~
+## $ chol 274, 395, 364, 361, 257, 276, 216, 258, 436, 247, 206, 420~
+## $ albumin 3.15, 2.94, 3.70, 3.64, 3.36, 3.60, 3.94, 3.57, 3.02, 3.72~
+## $ copper 159, 558, 37, 36, 43, 54, 28, 79, 75, 269, 36, 62, 145, 73~
+## $ alk.phos 1533.0, 6064.8, 1840.0, 5430.2, 1080.0, 4332.0, 601.0, 220~
+## $ ast 117.80, 227.04, 170.50, 67.08, 106.95, 99.33, 60.45, 120.9~
+## $ trig 128, 191, 64, 89, 73, 143, 188, 76, 104, 91, 70, 91, 122, ~
+## $ platelet 224, 214, 273, 203, 128, 273, 211, 410, 236, 360, 145, 344~
+## $ protime 10.5, 11.7, 10.5, 10.6, 10.6, 10.6, 13.0, 11.5, 10.6, 11.2~
+## $ stage 4, 4, 2, 2, 4, 2, 1, 4, 4, 4, 4, 3, 4, 4, 4, 4, 2, 2, 2, 3~
+## $ time 769, 264, 3933, 4556, 2598, 3853, 4256, 4196, 1690, 890, 3~
+## $ status 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0~
+## $ pred_aorsf 0.45334129, 0.95144543, 0.05125555, 0.04889358, 0.10217167~
+## $ pred_rfsrc 0.399822742, 0.849260628, 0.078051369, 0.062322537, 0.0593~
+## $ pred_ranger 0.37288105, 0.78251828, 0.04888381, 0.03342850, 0.05038629~
}\if{html}{\out{
}}
And finish by aggregating the predictions and computing performance in
@@ -625,16 +658,16 @@ counts.
## Results by model:
##
## model times AUC lower upper
-## 1: aorsf 1826 90.1 85.7 94.6
-## 2: rfsrc 1826 89.4 85.0 93.7
-## 3: ranger 1826 90.1 85.9 94.3
+## 1: aorsf 1826 90.0 85.4 94.7
+## 2: rfsrc 1826 89.9 85.4 94.4
+## 3: ranger 1826 89.5 85.0 94.1
##
## Results of model comparisons:
##
## times model reference delta.AUC lower upper p
-## 1: 1826 rfsrc aorsf -0.7 -2.3 0.8 0.4
-## 2: 1826 ranger aorsf -0.0 -1.7 1.6 1.0
-## 3: 1826 ranger rfsrc 0.7 -0.4 1.8 0.2
+## 1: 1826 rfsrc aorsf -0.1 -1.5 1.2 0.8
+## 2: 1826 ranger aorsf -0.5 -2.0 1.0 0.5
+## 3: 1826 ranger rfsrc -0.4 -1.5 0.8 0.5
##
## NOTE: Values are multiplied by 100 and given in \%.
@@ -648,19 +681,19 @@ counts.
##
## model times Brier lower upper IPA
## 1: Null model 1826.25 20.5 18.1 22.9 0.0
-## 2: aorsf 1826.25 11.1 8.8 13.4 45.8
-## 3: rfsrc 1826.25 12.0 9.8 14.1 41.6
-## 4: ranger 1826.25 11.8 9.7 13.9 42.5
+## 2: aorsf 1826.25 11.1 8.9 13.3 45.8
+## 3: rfsrc 1826.25 11.8 9.6 13.9 42.5
+## 4: ranger 1826.25 11.8 9.7 13.9 42.2
##
## Results of model comparisons:
##
## times model reference delta.Brier lower upper p
-## 1: 1826.25 aorsf Null model -9.4 -12.1 -6.6 2.423961e-11
-## 2: 1826.25 rfsrc Null model -8.5 -10.8 -6.2 2.104905e-13
-## 3: 1826.25 ranger Null model -8.7 -11.0 -6.4 1.802417e-13
-## 4: 1826.25 rfsrc aorsf 0.9 -0.0 1.7 5.277607e-02
-## 5: 1826.25 ranger aorsf 0.7 -0.1 1.5 1.008730e-01
-## 6: 1826.25 ranger rfsrc -0.2 -0.7 0.3 4.550782e-01
+## 1: 1826.25 aorsf Null model -9.4 -12.1 -6.6 1.836147e-11
+## 2: 1826.25 rfsrc Null model -8.7 -10.9 -6.5 2.460068e-14
+## 3: 1826.25 ranger Null model -8.6 -11.0 -6.3 3.215459e-13
+## 4: 1826.25 rfsrc aorsf 0.7 -0.2 1.5 1.176276e-01
+## 5: 1826.25 ranger aorsf 0.7 -0.0 1.5 5.782500e-02
+## 6: 1826.25 ranger rfsrc 0.1 -0.5 0.6 8.143879e-01
##
## NOTE: Values are multiplied by 100 and given in \%.
@@ -672,7 +705,6 @@ From inspection,
\itemize{
\item \code{aorsf} obtained slightly higher discrimination (AUC)
\item \code{aorsf} obtained higher index of prediction accuracy (IPA)
-\item Way to go, \code{aorsf}
}
}
@@ -823,25 +855,17 @@ Let’s look at the overall results:
\if{html}{\out{}}\preformatted{## # A tibble: 3 x 4
## learner_id surv.graf surv.cindex time_train
##
-## 1 surv.aorsf 0.151 0.729 0.345
-## 2 surv.ranger 0.167 0.706 2.54
-## 3 surv.rfsrc 0.156 0.715 0.783
+## 1 surv.aorsf 0.152 0.733 1.41
+## 2 surv.ranger 0.166 0.712 1.95
+## 3 surv.rfsrc 0.155 0.723 0.745
}\if{html}{\out{
}}
From inspection,
\itemize{
-\item \code{aorsf} appears to have a higher expected value for ‘surv.cindex’
-(higher is better)
-\item \code{aorsf} appears to have a lower expected value for ‘surv.graf’ (lower
-is better)
-\item \code{aorsf} has the lowest training time.
+\item \code{aorsf} has a higher expected value for ‘surv.cindex’ (higher is
+better)
+\item \code{aorsf} has a lower expected value for ‘surv.graf’ (lower is better)
}
-
-the lower training time for \code{aorsf} is likely due to the fact that there
-are many unique event times in the benchmark tasks. \code{ranger} and \code{rfsrc}
-create grids of time points based on each unique event time in each leaf
-of each decision tree, whereas \code{aorsf} also uses a grid but restricts it
-to the unique event times among observations in the current leaf.
}
}
diff --git a/man/orsf_control_custom.Rd b/man/orsf_control_custom.Rd
index 79d415dc..f4888c7a 100644
--- a/man/orsf_control_custom.Rd
+++ b/man/orsf_control_custom.Rd
@@ -67,10 +67,10 @@ fit_rando
## N trees: 500
## N predictors total: 17
## N predictors per node: 5
-## Average leaves per tree: 23
+## Average leaves per tree: 20
## Min observations in leaf: 5
## Min events in leaf: 1
-## OOB stat value: 0.82
+## OOB stat value: 0.84
## OOB stat type: Harrell's C-statistic
## Variable importance: anova
##
@@ -110,7 +110,7 @@ prediction accuracy based on out-of-bag predictions:
\if{html}{\out{}}\preformatted{library(riskRegression)
}\if{html}{\out{
}}
-\if{html}{\out{}}\preformatted{## riskRegression version 2022.11.28
+\if{html}{\out{
}}\preformatted{## riskRegression version 2023.09.08
}\if{html}{\out{
}}
\if{html}{\out{
}}\preformatted{library(survival)
@@ -135,15 +135,15 @@ The PCA ORSF does quite well! (higher IPA is better)
##
## model times Brier lower upper IPA
## 1: Null model 1788 20.479 18.090 22.868 0.000
-## 2: rando 1788 12.381 10.175 14.588 39.541
-## 3: pca 1788 12.496 10.476 14.515 38.983
+## 2: rando 1788 11.784 9.689 13.878 42.460
+## 3: pca 1788 12.685 10.694 14.675 38.061
##
## Results of model comparisons:
##
## times model reference delta.Brier lower upper p
-## 1: 1788 rando Null model -8.098 -10.392 -5.804 4.558033e-12
-## 2: 1788 pca Null model -7.983 -9.888 -6.078 2.142713e-16
-## 3: 1788 pca rando 0.114 -0.703 0.932 7.838255e-01
+## 1: 1788 rando Null model -8.695 -10.811 -6.580 7.854170e-16
+## 2: 1788 pca Null model -7.794 -9.475 -6.114 9.797721e-20
+## 3: 1788 pca rando 0.901 0.174 1.629 1.519521e-02
##
## NOTE: Values are multiplied by 100 and given in \%.
diff --git a/man/orsf_ice_oob.Rd b/man/orsf_ice_oob.Rd
index a9047c18..4948204d 100644
--- a/man/orsf_ice_oob.Rd
+++ b/man/orsf_ice_oob.Rd
@@ -13,6 +13,7 @@ orsf_ice_oob(
pred_type = "risk",
expand_grid = TRUE,
boundary_checks = TRUE,
+ n_thread = 1,
...
)
@@ -23,6 +24,7 @@ orsf_ice_inb(
pred_type = "risk",
expand_grid = TRUE,
boundary_checks = TRUE,
+ n_thread = 1,
...
)
@@ -35,6 +37,7 @@ orsf_ice_new(
na_action = "fail",
expand_grid = TRUE,
boundary_checks = TRUE,
+ n_thread = 1,
...
)
}
@@ -80,6 +83,8 @@ to make sure the requested values are between the 10th and 90th
percentile in the object's training data. If \code{FALSE}, these checks are
skipped.}
+\item{n_thread}{(\emph{integer}) number of threads to use while computing predictions. Default is one thread. To use the maximum number of threads that your system provides for concurrent execution, set \code{n_thread = 0}.}
+
\item{...}{Further arguments passed to or from other methods (not currently used).}
\item{new_data}{a \link{data.frame}, \link[tibble:tibble-package]{tibble}, or \link[data.table:data.table]{data.table} to compute predictions in.}
@@ -146,18 +151,18 @@ ice_oob <- orsf_ice_oob(fit, pred_spec, boundary_checks = FALSE)
ice_oob
}\if{html}{\out{
}}
-\if{html}{\out{
}}\preformatted{## pred_horizon id_variable id_row bili pred
-## 1: 1788 1 1 1 0.8935318
-## 2: 1788 1 2 1 0.1025087
-## 3: 1788 1 3 1 0.6959198
-## 4: 1788 1 4 1 0.3465760
-## 5: 1788 1 5 1 0.1105536
+\if{html}{\out{
}}\preformatted{## id_variable id_row pred_horizon bili pred
+## 1: 1 1 1788 1 0.9011797
+## 2: 1 2 1788 1 0.1096207
+## 3: 1 3 1788 1 0.7646444
+## 4: 1 4 1788 1 0.3531060
+## 5: 1 5 1788 1 0.1228441
## ---
-## 6896: 1788 25 272 10 0.4409361
-## 6897: 1788 25 273 10 0.4493052
-## 6898: 1788 25 274 10 0.4696659
-## 6899: 1788 25 275 10 0.3892409
-## 6900: 1788 25 276 10 0.4565133
+## 6896: 25 272 1788 10 0.3089586
+## 6897: 25 273 1788 10 0.4005430
+## 6898: 25 274 1788 10 0.4933945
+## 6899: 25 275 1788 10 0.3134373
+## 6900: 25 276 1788 10 0.5002014
}\if{html}{\out{
}}
Much more detailed examples are given in the
diff --git a/man/orsf_pd_oob.Rd b/man/orsf_pd_oob.Rd
index 9592449d..bf900bd0 100644
--- a/man/orsf_pd_oob.Rd
+++ b/man/orsf_pd_oob.Rd
@@ -15,6 +15,7 @@ orsf_pd_oob(
prob_values = c(0.025, 0.5, 0.975),
prob_labels = c("lwr", "medn", "upr"),
boundary_checks = TRUE,
+ n_thread = 1,
...
)
@@ -27,6 +28,7 @@ orsf_pd_inb(
prob_values = c(0.025, 0.5, 0.975),
prob_labels = c("lwr", "medn", "upr"),
boundary_checks = TRUE,
+ n_thread = 1,
...
)
@@ -41,6 +43,7 @@ orsf_pd_new(
prob_values = c(0.025, 0.5, 0.975),
prob_labels = c("lwr", "medn", "upr"),
boundary_checks = TRUE,
+ n_thread = 1,
...
)
}
@@ -98,6 +101,8 @@ to make sure the requested values are between the 10th and 90th
percentile in the object's training data. If \code{FALSE}, these checks are
skipped.}
+\item{n_thread}{(\emph{integer}) number of threads to use while computing predictions. Default is one thread. To use the maximum number of threads that your system provides for concurrent execution, set \code{n_thread = 0}.}
+
\item{...}{Further arguments passed to or from other methods (not currently used).}
\item{new_data}{a \link{data.frame}, \link[tibble:tibble-package]{tibble}, or \link[data.table:data.table]{data.table} to compute predictions in.}
@@ -155,12 +160,12 @@ You can compute partial dependence and ICE three ways with \code{aorsf}:
pd_train
}\if{html}{\out{
}}
-\if{html}{\out{
}}\preformatted{## pred_horizon bili mean lwr medn upr
-## 1: 1826.25 1 0.2054232 0.01599366 0.0929227 0.8077278
-## 2: 1826.25 2 0.2369077 0.02549869 0.1268457 0.8227315
-## 3: 1826.25 3 0.2808514 0.05027265 0.1720280 0.8457834
-## 4: 1826.25 4 0.3428065 0.09758988 0.2545869 0.8575243
-## 5: 1826.25 5 0.3992909 0.16392752 0.3232681 0.8634269
+\if{html}{\out{
}}\preformatted{## pred_horizon bili mean lwr medn upr
+## 1: 1826.25 1 0.2151663 0.02028479 0.09634648 0.7997269
+## 2: 1826.25 2 0.2576618 0.03766695 0.15497447 0.8211875
+## 3: 1826.25 3 0.2998484 0.06436773 0.20771324 0.8425637
+## 4: 1826.25 4 0.3390664 0.08427149 0.25401067 0.8589590
+## 5: 1826.25 5 0.3699045 0.10650098 0.28284427 0.8689855
}\if{html}{\out{
}}
\item using out-of-bag predictions for the training data
@@ -170,11 +175,11 @@ pd_train
}\if{html}{\out{
}}
\if{html}{\out{
}}\preformatted{## pred_horizon bili mean lwr medn upr
-## 1: 1826.25 1 0.2068300 0.01479443 0.08824123 0.8053317
-## 2: 1826.25 2 0.2377046 0.02469718 0.12623031 0.8258154
-## 3: 1826.25 3 0.2810546 0.04080813 0.18721220 0.8484846
-## 4: 1826.25 4 0.3417839 0.09076851 0.24968438 0.8611884
-## 5: 1826.25 5 0.3979925 0.16098228 0.32147532 0.8554402
+## 1: 1826.25 1 0.2145044 0.01835000 0.09619052 0.7980629
+## 2: 1826.25 2 0.2566241 0.03535358 0.14185734 0.8173143
+## 3: 1826.25 3 0.2984693 0.05900059 0.20515477 0.8334243
+## 4: 1826.25 4 0.3383547 0.07887323 0.24347513 0.8469769
+## 5: 1826.25 5 0.3696260 0.10450534 0.28065473 0.8523756
}\if{html}{\out{
}}
\item using predictions for a new set of data
@@ -186,11 +191,11 @@ pd_test
}\if{html}{\out{
}}
\if{html}{\out{}}\preformatted{## pred_horizon bili mean lwr medn upr
-## 1: 1826.25 1 0.2510900 0.01631318 0.1872414 0.8162621
-## 2: 1826.25 2 0.2807327 0.02903956 0.2269297 0.8332956
-## 3: 1826.25 3 0.3247386 0.05860235 0.2841853 0.8481825
-## 4: 1826.25 4 0.3850799 0.10741224 0.3405760 0.8588955
-## 5: 1826.25 5 0.4394952 0.17572657 0.4050864 0.8657886
+## 1: 1826.25 1 0.2542230 0.02901386 0.1943767 0.8143912
+## 2: 1826.25 2 0.2955726 0.05037316 0.2474559 0.8317684
+## 3: 1826.25 3 0.3388434 0.07453896 0.3010898 0.8488622
+## 4: 1826.25 4 0.3800254 0.10565022 0.3516805 0.8592057
+## 5: 1826.25 5 0.4124587 0.12292465 0.3915066 0.8690074
}\if{html}{\out{
}}
\item in-bag partial dependence indicates relationships that the model has
learned during training. This is helpful if your goal is to interpret
diff --git a/man/orsf_vi.Rd b/man/orsf_vi.Rd
index 33019467..5312dc31 100644
--- a/man/orsf_vi.Rd
+++ b/man/orsf_vi.Rd
@@ -7,11 +7,33 @@
\alias{orsf_vi_anova}
\title{ORSF variable importance}
\usage{
-orsf_vi(object, group_factors = TRUE, importance = NULL, oobag_fun = NULL, ...)
-
-orsf_vi_negate(object, group_factors = TRUE, oobag_fun = NULL, ...)
-
-orsf_vi_permute(object, group_factors = TRUE, oobag_fun = NULL, ...)
+orsf_vi(
+ object,
+ group_factors = TRUE,
+ importance = NULL,
+ oobag_fun = NULL,
+ n_thread = 1,
+ verbose_progress = FALSE,
+ ...
+)
+
+orsf_vi_negate(
+ object,
+ group_factors = TRUE,
+ oobag_fun = NULL,
+ n_thread = 1,
+ verbose_progress = FALSE,
+ ...
+)
+
+orsf_vi_permute(
+ object,
+ group_factors = TRUE,
+ oobag_fun = NULL,
+ n_thread = 1,
+ verbose_progress = FALSE,
+ ...
+)
orsf_vi_anova(object, group_factors = TRUE, ...)
}
@@ -49,6 +71,11 @@ importance is estimated.
For more details, see the out-of-bag
\href{https://docs.ropensci.org/aorsf/articles/oobag.html}{vignette}.}
+\item{n_thread}{(\emph{integer}) number of threads to use while computing predictions. Default is one thread. To use the maximum number of threads that your system provides for concurrent execution, set \code{n_thread = 0}.}
+
+\item{verbose_progress}{(\emph{logical}) if \code{TRUE}, progress messages are
+printed in the console. If \code{FALSE} (the default), nothing is printed.}
+
\item{...}{Further arguments passed to or from other methods (not currently used).}
}
\value{
@@ -129,12 +156,12 @@ the ‘raw’ variable importance values can be accessed from the fit object
\if{html}{\out{}}\preformatted{attr(fit, 'importance_values')
}\if{html}{\out{
}}
-\if{html}{\out{}}\preformatted{## edema_1 ascites_1 bili copper age albumin
-## 0.41468531 0.34547820 0.27357335 0.19702602 0.17831563 0.17231851
-## edema_0.5 protime chol stage sex_f spiders_1
-## 0.16100917 0.15265823 0.14529486 0.13818084 0.13186813 0.12881052
-## ast hepato_1 alk.phos trig platelet trt_placebo
-## 0.12509496 0.11370348 0.10024752 0.09878683 0.08006941 0.06398488
+\if{html}{\out{
}}\preformatted{## ascites_1 edema_1 bili albumin copper edema_0.5
+## 0.44146501 0.43190921 0.29391304 0.22145499 0.22120519 0.20110957
+## age protime chol spiders_1 stage sex_f
+## 0.19980193 0.19329637 0.17777778 0.17772293 0.16048729 0.15926709
+## hepato_1 ast trig alk.phos platelet trt_placebo
+## 0.15816481 0.15734785 0.13200993 0.12433796 0.11844461 0.09404636
}\if{html}{\out{
}}
these are ‘raw’ because values for factors have not been aggregated into
@@ -153,24 +180,24 @@ To get aggregated values across all levels of each factor,
\if{html}{\out{
}}\preformatted{fit$importance
}\if{html}{\out{
}}
-\if{html}{\out{
}}\preformatted{## ascites bili edema copper age albumin protime
-## 0.34547820 0.27357335 0.26368761 0.19702602 0.17831563 0.17231851 0.15265823
-## chol stage sex spiders ast hepato alk.phos
-## 0.14529486 0.13818084 0.13186813 0.12881052 0.12509496 0.11370348 0.10024752
-## trig platelet trt
-## 0.09878683 0.08006941 0.06398488
+\if{html}{\out{
}}\preformatted{## ascites edema bili albumin copper age protime
+## 0.44146501 0.29452847 0.29391304 0.22145499 0.22120519 0.19980193 0.19329637
+## chol spiders stage sex hepato ast trig
+## 0.17777778 0.17772293 0.16048729 0.15926709 0.15816481 0.15734785 0.13200993
+## alk.phos platelet trt
+## 0.12433796 0.11844461 0.09404636
}\if{html}{\out{
}}
\item use \code{orsf_vi()} with group_factors set to \code{TRUE} (the default)
\if{html}{\out{
}}\preformatted{orsf_vi(fit)
}\if{html}{\out{
}}
-\if{html}{\out{
}}\preformatted{## ascites bili edema copper age albumin protime
-## 0.34547820 0.27357335 0.26368761 0.19702602 0.17831563 0.17231851 0.15265823
-## chol stage sex spiders ast hepato alk.phos
-## 0.14529486 0.13818084 0.13186813 0.12881052 0.12509496 0.11370348 0.10024752
-## trig platelet trt
-## 0.09878683 0.08006941 0.06398488
+\if{html}{\out{
}}\preformatted{## ascites edema bili albumin copper age protime
+## 0.44146501 0.29452847 0.29391304 0.22145499 0.22120519 0.19980193 0.19329637
+## chol spiders stage sex hepato ast trig
+## 0.17777778 0.17772293 0.16048729 0.15926709 0.15816481 0.15734785 0.13200993
+## alk.phos platelet trt
+## 0.12433796 0.11844461 0.09404636
}\if{html}{\out{
}}
}
@@ -193,27 +220,27 @@ You can fit an ORSF without VI, then add VI later
orsf_vi_negate(fit_no_vi)
}\if{html}{\out{
}}
-\if{html}{\out{
}}\preformatted{## bili copper age protime albumin
-## 0.0717336945 0.0288601792 0.0253698687 0.0110960617 0.0100020838
-## chol ascites spiders ast stage
-## 0.0075015628 0.0060950198 0.0045321942 0.0044280058 0.0025526151
-## edema sex hepato platelet alk.phos
-## 0.0024856369 0.0015628256 0.0004688477 0.0003646593 -0.0007293186
-## trig trt
-## -0.0020316733 -0.0061471140
+\if{html}{\out{
}}\preformatted{## bili copper sex protime stage
+## 0.1139657923 0.0498712200 0.0355366377 0.0283554322 0.0263792287
+## albumin age ascites chol ast
+## 0.0231636378 0.0195791833 0.0175120075 0.0148252414 0.0104918262
+## edema spiders hepato trt trig
+## 0.0084871358 0.0070608860 0.0067054788 0.0052040792 0.0030363455
+## alk.phos platelet
+## 0.0029918139 -0.0003309069
}\if{html}{\out{
}}
\if{html}{\out{
}}\preformatted{orsf_vi_permute(fit_no_vi)
}\if{html}{\out{
}}
-\if{html}{\out{
}}\preformatted{## age bili copper albumin chol
-## 1.109606e-02 1.083559e-02 7.032715e-03 5.157324e-03 4.636383e-03
-## protime ascites spiders ast platelet
-## 4.011252e-03 3.854970e-03 2.396333e-03 1.146072e-03 5.209419e-04
-## alk.phos edema sex hepato trig
-## 2.083767e-04 1.959734e-04 5.209419e-05 -4.688477e-04 -1.719108e-03
-## trt
-## -3.698687e-03
+\if{html}{\out{
}}\preformatted{## bili copper albumin protime ascites
+## 0.0538801986 0.0235904126 0.0144632299 0.0142037786 0.0123519716
+## stage age edema hepato chol
+## 0.0120377993 0.0110782938 0.0055307145 0.0052409958 0.0047839166
+## ast spiders sex trig alk.phos
+## 0.0042115620 0.0039660651 0.0028902730 0.0021803920 0.0018880548
+## platelet trt
+## 0.0005279898 -0.0024330707
}\if{html}{\out{
}}
}
@@ -229,14 +256,14 @@ fit an ORSF and compute vi at the same time
orsf_vi_permute(fit_permute_vi)
}\if{html}{\out{
}}
-\if{html}{\out{
}}\preformatted{## bili age copper stage ascites
-## 0.0114086268 0.0094811419 0.0055219837 0.0043238175 0.0032298395
-## albumin hepato protime ast edema
-## 0.0031256512 0.0030214628 0.0029172744 0.0021358616 0.0019051588
-## spiders chol alk.phos platelet trt
-## 0.0017712023 0.0013023547 0.0008335070 -0.0009376954 -0.0016149198
-## sex trig
-## -0.0020837675 -0.0022921442
+\if{html}{\out{
}}\preformatted{## bili copper age albumin ascites
+## 0.0513074950 0.0217622790 0.0131467379 0.0121683721 0.0120025410
+## stage protime chol edema ast
+## 0.0112281635 0.0108887695 0.0064301068 0.0061316531 0.0055392320
+## spiders hepato sex alk.phos trig
+## 0.0046819086 0.0026387295 0.0026066599 0.0017043328 0.0012899918
+## platelet trt
+## 0.0007224274 -0.0005790547
}\if{html}{\out{
}}
You can still get negation VI from this fit, but it needs to be computed
@@ -244,14 +271,12 @@ You can still get negation VI from this fit, but it needs to be computed
\if{html}{\out{
}}\preformatted{orsf_vi_negate(fit_permute_vi)
}\if{html}{\out{
}}
-\if{html}{\out{
}}\preformatted{## bili copper age protime albumin
-## 0.0773598666 0.0272452594 0.0258387164 0.0115649094 0.0084392582
-## sex chol ast ascites stage
-## 0.0081787872 0.0074494686 0.0060429256 0.0058866431 0.0043238175
-## hepato edema spiders platelet trig
-## 0.0040112523 0.0027684339 0.0026047093 0.0005730360 0.0002083767
-## trt alk.phos
-## -0.0003125651 -0.0016149198
+\if{html}{\out{
}}\preformatted{## bili copper sex stage age protime
+## 0.1106715167 0.0456031656 0.0306666098 0.0304383573 0.0252136203 0.0224838590
+## albumin ascites chol ast edema trt
+## 0.0212630703 0.0168893963 0.0134174671 0.0132075752 0.0099681058 0.0088378768
+## spiders hepato trig alk.phos platelet
+## 0.0078776082 0.0062877323 0.0043076141 0.0030432581 0.0005571111
}\if{html}{\out{
}}
}
}
diff --git a/man/predict.orsf_fit.Rd b/man/predict.orsf_fit.Rd
index 460bb69e..bafe78ca 100644
--- a/man/predict.orsf_fit.Rd
+++ b/man/predict.orsf_fit.Rd
@@ -11,6 +11,9 @@
pred_type = "risk",
na_action = "fail",
boundary_checks = TRUE,
+ n_thread = 1,
+ verbose_progress = FALSE,
+ pred_aggregate = TRUE,
...
)
}
@@ -51,6 +54,19 @@ checked to make sure the requested values are less than the maximum
observed time in \code{object}'s training data. If \code{FALSE}, these checks
are skipped.}
+\item{n_thread}{(\emph{integer}) number of threads to use while computing predictions. Default is one thread. To use the maximum number of threads that your system provides for concurrent execution, set \code{n_thread = 0}.}
+
+\item{verbose_progress}{(\emph{logical}) if \code{TRUE}, progress messages are
+printed in the console. If \code{FALSE} (the default), nothing is printed.}
+
+\item{pred_aggregate}{(\emph{logical}) If \code{TRUE} (the default), predictions
+will be aggregated over all trees by taking the mean. If \code{FALSE}, the
+returned output will contain one row per observation and one column
+for each tree. If the length of \code{pred_horizon} is two or more and
+\code{pred_aggregate} is \code{FALSE}, then the result will be a list of such
+matrices, with the i'th item in the list corresponding to the i'th
+value of \code{pred_horizon}.}
+
\item{...}{Further arguments passed to or from other methods (not currently used).}
}
\value{
@@ -104,11 +120,11 @@ predict(fit,
}\if{html}{\out{
}}
\if{html}{\out{
}}\preformatted{## [,1] [,2] [,3]
-## [1,] 0.48792661 0.75620281 0.90618133
-## [2,] 0.04293829 0.09112952 0.18602887
-## [3,] 0.12147573 0.27784498 0.41600114
-## [4,] 0.01136075 0.03401092 0.08236831
-## [5,] 0.01294947 0.02070625 0.05645823
+## [1,] 0.49679905 0.77309053 0.90830168
+## [2,] 0.03363621 0.08527972 0.17061414
+## [3,] 0.15129784 0.30402666 0.43747212
+## [4,] 0.01152480 0.02950914 0.07068198
+## [5,] 0.01035341 0.01942262 0.05117679
}\if{html}{\out{
}}
\if{html}{\out{
}}\preformatted{# predicted survival, i.e., 1 - risk
@@ -119,11 +135,11 @@ predict(fit,
}\if{html}{\out{
}}
\if{html}{\out{
}}\preformatted{## [,1] [,2] [,3]
-## [1,] 0.5120734 0.2437972 0.09381867
-## [2,] 0.9570617 0.9088705 0.81397113
-## [3,] 0.8785243 0.7221550 0.58399886
-## [4,] 0.9886393 0.9659891 0.91763169
-## [5,] 0.9870505 0.9792937 0.94354177
+## [1,] 0.5032009 0.2269095 0.09169832
+## [2,] 0.9663638 0.9147203 0.82938586
+## [3,] 0.8487022 0.6959733 0.56252788
+## [4,] 0.9884752 0.9704909 0.92931802
+## [5,] 0.9896466 0.9805774 0.94882321
}\if{html}{\out{
}}
\if{html}{\out{
}}\preformatted{# predicted cumulative hazard function
@@ -135,11 +151,11 @@ predict(fit,
}\if{html}{\out{
}}
\if{html}{\out{
}}\preformatted{## [,1] [,2] [,3]
-## [1,] 0.68107429 1.28607479 1.70338193
-## [2,] 0.04519460 0.10911618 0.24871482
-## [3,] 0.14686474 0.41252079 0.69005048
-## [4,] 0.01149952 0.03951923 0.10628942
-## [5,] 0.01338978 0.02214232 0.06644605
+## [1,] 0.74442414 1.39538511 1.78344589
+## [2,] 0.03473938 0.10418984 0.24047328
+## [3,] 0.19732086 0.47015754 0.73629459
+## [4,] 0.01169147 0.03223257 0.09564168
+## [5,] 0.01072007 0.02240040 0.06464319
}\if{html}{\out{
}}
Predict mortality, defined as the number of events in the forest’s
@@ -152,12 +168,12 @@ prediction horizon
pred_type = 'mort')
}\if{html}{\out{
}}
-\if{html}{\out{
}}\preformatted{## [,1]
-## [1,] 68.394152
-## [2,] 12.299344
-## [3,] 28.208251
-## [4,] 6.475339
-## [5,] 4.247305
+\if{html}{\out{
}}\preformatted{## [,1]
+## [1,] 83.08611
+## [2,] 27.48146
+## [3,] 43.52432
+## [4,] 15.20281
+## [5,] 10.56334
}\if{html}{\out{
}}
}
diff --git a/src/Coxph.cpp b/src/Coxph.cpp
new file mode 100644
index 00000000..ed7c5ac6
--- /dev/null
+++ b/src/Coxph.cpp
@@ -0,0 +1,682 @@
+/*-----------------------------------------------------------------------------
+ This file is part of aorsf.
+ Author: Byron C Jaeger
+ aorsf may be modified and distributed under the terms of the MIT license.
+#----------------------------------------------------------------------------*/
+
+#include
+#include "globals.h"
+#include "Coxph.h"
+#include "utility.h"
+
+ using namespace arma;
+ using namespace Rcpp;
+
+ namespace aorsf {
+
+ void cholesky_decomp(mat& vmat){
+
+ double eps_chol = 0;
+ double toler = 1e-8;
+ double pivot1, pivot2;
+ uword n_vars = vmat.n_cols;
+ uword i, j, k;
+
+ for(i = 0; i < n_vars; i++){
+
+ if(vmat.at(i,i) > eps_chol) eps_chol = vmat.at(i,i);
+
+ // copy upper right values to bottom left
+ for(j = (i+1); j eps_chol) {
+
+ for(j = (i+1); j < n_vars; j++){
+
+ pivot2 = vmat.at(j,i) / pivot1;
+ vmat.at(j,i) = pivot2;
+ vmat.at(j,j) -= pivot2*pivot2*pivot1;
+
+ for(k = (j+1); k < n_vars; k++){
+
+ vmat.at(k, j) -= pivot2 * vmat.at(k, i);
+
+ }
+
+ }
+
+ } else {
+
+ vmat.at(i, i) = 0;
+
+ }
+
+ }
+
+ }
+
+
+ void cholesky_solve(mat& vmat,
+ vec& u){
+
+ uword n_vars = vmat.n_cols;
+ uword i, j;
+ double temp;
+
+ for (i = 0; i < n_vars; i++) {
+
+ temp = u[i];
+
+ for (j = 0; j < i; j++){
+
+ temp -= u[j] * vmat.at(i, j);
+ u[i] = temp;
+
+ }
+
+ }
+
+
+ for (i = n_vars; i >= 1; i--){
+
+ if (vmat.at(i-1, i-1) == 0){
+
+ u[i-1] = 0;
+
+ } else {
+
+ temp = u[i-1] / vmat.at(i-1, i-1);
+
+ for (j = i; j < n_vars; j++){
+ temp -= u[j] * vmat.at(j, i-1);
+ }
+
+ u[i-1] = temp;
+
+ }
+
+ }
+
+ }
+
+ void cholesky_invert(mat& vmat){
+
+ uword n_vars = vmat.n_cols;
+ uword i, j, k;
+ double temp;
+
+ for (i=0; i0) {
+
+ // take full advantage of the cholesky's diagonal of 1's
+ vmat.at(i,i) = 1.0 / vmat.at(i,i);
+
+ for (j=(i+1); j 0)
+ scales.at(i) = w_node_sum / scales.at(i);
+ else
+ scales.at(i) = 1.0; // rare case of constant covariate;
+
+ x_node.col(i) *= scales.at(i);
+
+ }
+
+ }
+
+ beta_current.zeros(n_vars);
+ beta_new.zeros(n_vars);
+
+ // these are filled with initial values later
+ Risk.set_size(x_node.n_rows);
+ u.set_size(n_vars);
+ a.set_size(n_vars);
+ a2.set_size(n_vars);
+ vmat.set_size(n_vars, n_vars);
+ cmat.set_size(n_vars, n_vars);
+ cmat2.set_size(n_vars, n_vars);
+
+ halving = 0;
+
+ // do the initial iteration
+ denom = 0;
+ loglik = 0;
+ n_risk = 0;
+
+ person = x_node.n_rows - 1;
+
+ u.fill(0);
+ a.fill(0);
+ a2.fill(0);
+ vmat.fill(0);
+ cmat.fill(0);
+ cmat2.fill(0);
+
+
+ // the outer loop needs to be broken when a condition occurs in
+ // the inner loop - set up a bool to break the outer loop
+ break_loop = false;
+
+ // xb = 0.0;
+
+ for( ; ; ){
+
+ temp2 = y_node.at(person, 0); // time of event for current person
+ n_events = 0 ; // number of deaths at this time point
+ weight_events = 0 ; // sum of w_node for the deaths
+ denom_events = 0 ; // sum of weighted risks for the deaths
+
+ // walk through this set of tied times
+ while(y_node.at(person, 0) == temp2){
+
+ n_risk++;
+
+ risk = w_node.at(person);
+
+ if (y_node.at(person, 1) == 0) {
+
+ denom += risk;
+
+ /* a contains weighted sums of x, cmat sums of squares */
+
+ for (i=0; i 0) {
+
+ if (ties_method == 0 || n_events == 1) { // Breslow
+
+ denom += denom_events;
+ loglik -= denom_events * log(denom);
+
+ for (i=0; i 1 && stat_best < R_PosInf){
+
+ for(iter = 1; iter < iter_max; iter++){
+
+ // if(VERBOSITY > 1){
+ //
+ // Rcout << "--------- Newt-Raph algo; iter " << iter;
+ // Rcout << " ---------" << std::endl;
+ // Rcout << "beta: " << beta_new.t();
+ // Rcout << "loglik: " << stat_best;
+ // Rcout << std::endl;
+ // Rcout << "------------------------------------------";
+ // Rcout << std::endl << std::endl << std::endl;
+ //
+ // }
+
+ // do the next iteration
+
+ denom = 0;
+ loglik = 0;
+ n_risk = 0;
+
+ person = x_node.n_rows - 1;
+
+ u.fill(0);
+ a.fill(0);
+ a2.fill(0);
+ vmat.fill(0);
+ cmat.fill(0);
+ cmat2.fill(0);
+
+ // this loop has a strange break condition to accomodate
+ // the restriction that a uvec (uword) cannot be < 0
+
+ break_loop = false;
+
+ XB = x_node * beta_new;
+ Risk = exp(XB) % w_node;
+
+
+ for( ; ; ){
+
+ temp2 = y_node.at(person, 0); // time of event for current person
+ n_events = 0 ; // number of deaths at this time point
+ weight_events = 0 ; // sum of w_node for the deaths
+ denom_events = 0 ; // sum of weighted risks for the deaths
+
+ // walk through this set of tied times
+ while(y_node.at(person, 0) == temp2){
+
+ n_risk++;
+
+ xb = XB.at(person);
+ risk = Risk.at(person);
+
+ // xb = 0;
+ //
+ // for(i = 0; i < n_vars; i++){
+ // xb += beta.at(i) * x_node.at(person, i);
+ // }
+
+ w_node_person = w_node.at(person);
+
+ // risk = exp(xb) * w_node_person;
+
+ if (y_node.at(person, 1) == 0) {
+
+ denom += risk;
+
+ /* a contains weighted sums of x, cmat sums of squares */
+
+ for (i=0; i 0) {
+
+ if (ties_method == 0 || n_events == 1) { // Breslow
+
+ denom += denom_events;
+ loglik -= weight_events * log(denom);
+
+ for (i=0; i
+#include "globals.h"
+
+
+ namespace aorsf {
+
+ // cholesky decomposition
+ //
+ // @description this function is copied from the survival package and
+ // translated into arma.
+ //
+ // @param vmat matrix with covariance estimates
+ // @param n_vars the number of predictors used in the current node
+ //
+ // prepares vmat for cholesky_solve()
+
+
+ void cholesky_decomp(arma::mat& vmat);
+
+ // solve cholesky decomposition
+ //
+ // @description this function is copied from the survival package and
+ // translated into arma. Prepares u, the vector used to update beta.
+ //
+ // @param vmat matrix with covariance estimates
+ // @param n_vars the number of predictors used in the current node
+ //
+ //
+ void cholesky_solve(arma::mat& vmat,
+ arma::vec& u);
+
+ // invert the cholesky in the lower triangle
+ //
+ // @description this function is copied from the survival package and
+ // translated into arma. Inverts vmat
+ //
+ // @param vmat matrix with covariance estimates
+ // @param n_vars the number of predictors used in the current node
+ //
+
+ void cholesky_invert(arma::mat& vmat);
+
+ // run the newton raphson procedure
+ //
+ // @description identify a linear combination of predictors.
+ // This function is copied from the survival package and
+ // translated into arma with light modifications for efficiency.
+ // The procedure works with the partial likelihood function
+ // of the Cox model. All inputs are described above
+ // in newtraph_cph_iter()
+ //
+ arma::mat coxph_fit(arma::mat& x_node,
+ arma::mat& y_node,
+ arma::vec& w_node,
+ bool do_scale,
+ int ties_method,
+ double epsilon,
+ arma::uword iter_max);
+
+ }
+
+#endif /* COXPH_H */
+
diff --git a/src/Data.h b/src/Data.h
new file mode 100644
index 00000000..c561be3c
--- /dev/null
+++ b/src/Data.h
@@ -0,0 +1,136 @@
+/*-----------------------------------------------------------------------------
+ This file is part of aorsf.
+ Author: Byron C Jaeger
+ aorsf may be modified and distributed under the terms of the MIT license.
+#----------------------------------------------------------------------------*/
+
+#ifndef DATA_H_
+#define DATA_H_
+
+#include
+#include "globals.h"
+
+ using namespace arma;
+ using namespace Rcpp;
+
+ namespace aorsf {
+
+ class Data {
+
+ public:
+
+ Data() = default;
+
+ Data(arma::mat& x,
+ arma::mat& y,
+ arma::vec& w) {
+
+ this->x = x;
+ this->y = y;
+ this->w = w;
+
+ this->n_rows = x.n_rows;
+ this->n_cols = x.n_cols;
+ this->has_weights = !w.empty();
+ this->saved_values.resize(n_cols);
+
+ }
+
+ Data(const Data&) = delete;
+ Data& operator=(const Data&) = delete;
+
+ arma::uword get_n_rows() {
+ return(n_rows);
+ }
+
+ arma::uword get_n_cols() {
+ return(n_cols);
+ }
+
+ arma::mat& get_x(){
+ return(x);
+ }
+
+ arma::mat& get_y(){
+ return(y);
+ }
+
+ arma::vec& get_w(){
+ return(w);
+ }
+
+ arma::mat x_rows(arma::uvec& row_indices) {
+ return(x.rows(row_indices));
+ }
+
+ arma::mat x_cols(arma::uvec& column_indices) {
+ return(x.cols(column_indices));
+ }
+
+ arma::mat y_rows(arma::uvec& row_indices) {
+ return(y.rows(row_indices));
+ }
+
+ arma::mat y_cols(arma::uvec& column_indices) {
+ return(y.cols(column_indices));
+ }
+
+ arma::mat x_submat(arma::uvec& row_indices,
+ arma::uvec& column_indices){
+ return(x.submat(row_indices, column_indices));
+ }
+
+ arma::mat y_submat(arma::uvec& row_indices,
+ arma::uvec& column_indices){
+ return(y.submat(row_indices, column_indices));
+ }
+
+ arma::vec w_subvec(arma::uvec& indices){
+ return(w(indices));
+ }
+
+ void permute_col(arma::uword j, std::mt19937_64& rng){
+
+ arma::vec x_j = x.unsafe_col(j);
+ // make and store a copy
+ this->saved_values[j] = arma::vec(x_j.begin(), x_j.size(), true);
+ // shuffle the vector in-place
+ std::shuffle(x_j.begin(), x_j.end(), rng);
+
+ }
+
+ void save_col(arma::uword j){
+ saved_values[j] = x.col(j);
+ }
+
+ void restore_col(arma::uword j){
+ x.col(j) = saved_values[j];
+ }
+
+ void fill_col(double value, uword j){
+ x.col(j).fill(value);
+ }
+
+
+ // member variables
+
+ arma::uword n_cols;
+ arma::uword n_rows;
+ arma::vec w;
+
+ // for multi-column ops (e.g., partial dependence)
+ std::vector saved_values;
+
+ bool has_weights;
+
+ private:
+
+ arma::mat x;
+ arma::mat y;
+
+ };
+
+
+ } // namespace aorsf
+
+#endif /* DATA_H_ */
diff --git a/src/Forest.cpp b/src/Forest.cpp
new file mode 100644
index 00000000..0c5a76e6
--- /dev/null
+++ b/src/Forest.cpp
@@ -0,0 +1,847 @@
+// Forest.cpp
+
+#include
+#include "Forest.h"
+#include "Tree.h"
+
+using namespace arma;
+using namespace Rcpp;
+
+namespace aorsf {
+
+Forest::Forest(){ }
+
+void Forest::init(std::unique_ptr input_data,
+ Rcpp::IntegerVector& tree_seeds,
+ arma::uword n_tree,
+ arma::uword mtry,
+ bool sample_with_replacement,
+ double sample_fraction,
+ bool grow_mode,
+ VariableImportance vi_type,
+ double vi_max_pvalue,
+ // leaves
+ double leaf_min_obs,
+ // node splitting
+ SplitRule split_rule,
+ double split_min_obs,
+ double split_min_stat,
+ arma::uword split_max_cuts,
+ arma::uword split_max_retry,
+ // linear combinations
+ LinearCombo lincomb_type,
+ double lincomb_eps,
+ arma::uword lincomb_iter_max,
+ bool lincomb_scale,
+ double lincomb_alpha,
+ arma::uword lincomb_df_target,
+ arma::uword lincomb_ties_method,
+ RObject lincomb_R_function,
+ // predictions
+ PredType pred_type,
+ bool pred_mode,
+ bool pred_aggregate,
+ bool oobag_pred,
+ EvalType oobag_eval_type,
+ arma::uword oobag_eval_every,
+ Rcpp::RObject oobag_R_function,
+ uint n_thread,
+ int verbosity){
+
+ this->data = std::move(input_data);
+ this->tree_seeds = tree_seeds;
+ this->n_tree = n_tree;
+ this->mtry = mtry;
+ this->sample_with_replacement = sample_with_replacement;
+ this->sample_fraction = sample_fraction;
+ this->grow_mode = grow_mode;
+ this->vi_type = vi_type;
+ this->vi_max_pvalue = vi_max_pvalue;
+ this->leaf_min_obs = leaf_min_obs;
+ this->split_rule = split_rule;
+ this->split_min_obs = split_min_obs;
+ this->split_min_stat = split_min_stat;
+ this->split_max_cuts = split_max_cuts;
+ this->split_max_retry = split_max_retry;
+ this->lincomb_type = lincomb_type; this->lincomb_eps = lincomb_eps;
+ this->lincomb_iter_max = lincomb_iter_max;
+ this->lincomb_scale = lincomb_scale;
+ this->lincomb_alpha = lincomb_alpha;
+ this->lincomb_df_target = lincomb_df_target;
+ this->lincomb_ties_method = lincomb_ties_method;
+ this->lincomb_R_function = lincomb_R_function;
+ this->pred_type = pred_type;
+ this->pred_mode = pred_mode;
+ this->pred_aggregate = pred_aggregate;
+ this->oobag_pred = oobag_pred;
+ this->oobag_eval_type = oobag_eval_type;
+ this->oobag_eval_every = oobag_eval_every;
+ this->oobag_R_function = oobag_R_function;
+ this->n_thread = n_thread;
+ this->verbosity = verbosity;
+
+ if(vi_type != VI_NONE){
+ vi_numer.zeros(data->get_n_cols());
+ if(vi_type == VI_ANOVA){
+ vi_denom.zeros(data->get_n_cols());
+ }
+ }
+
+ // oobag denominator tracks the number of times an obs is oobag
+ oobag_denom.zeros(data->get_n_rows());
+
+ if(verbosity > 1){
+
+ Rcout << "------------ input data dimensions ------------" << std::endl;
+ Rcout << "N observations total: " << data->get_n_rows() << std::endl;
+ Rcout << "N columns total: " << data->get_n_cols() << std::endl;
+ Rcout << "-----------------------------------------------";
+ Rcout << std::endl;
+ Rcout << std::endl;
+
+ }
+
+}
+
+void Forest::run(bool oobag){
+
+ if (grow_mode) { // if the forest hasn't been grown
+
+ // plant first
+ plant();
+ // initialize
+ init_trees();
+ // grow
+ grow();
+
+ // compute + evaluate out-of-bag predictions if oobag is true
+ if(oobag){
+ this->pred_values = predict(oobag);
+ }
+
+ } else { // if the forest was already grown
+ // initialize trees
+ init_trees();
+ }
+
+ // if using a grown forest for prediction
+ if(pred_mode){
+ this->pred_values = predict(oobag);
+ }
+
+ // if using a grown forest for variable importance
+ if(vi_type == VI_PERMUTE || vi_type == VI_NEGATE){
+ compute_oobag_vi();
+ }
+
+ // if using a grown forest for partial dependence
+ if(pd_type == PD_SUMMARY || pd_type == PD_ICE){
+ this->pd_values = compute_dependence(oobag);
+ }
+
+}
+
+void Forest::init_trees(){
+
+ for(uword i = 0; i < n_tree; ++i){
+
+ trees[i]->init(data.get(),
+ tree_seeds[i],
+ mtry,
+ sample_with_replacement,
+ sample_fraction,
+ pred_type,
+ leaf_min_obs,
+ vi_type,
+ vi_max_pvalue,
+ split_rule,
+ split_min_obs,
+ split_min_stat,
+ split_max_cuts,
+ split_max_retry,
+ lincomb_type,
+ lincomb_eps,
+ lincomb_iter_max,
+ lincomb_scale,
+ lincomb_alpha,
+ lincomb_df_target,
+ lincomb_ties_method,
+ lincomb_R_function,
+ oobag_R_function,
+ oobag_eval_type,
+ verbosity);
+
+ }
+
+}
+
+void Forest::grow() {
+
+ // Create thread ranges
+ equalSplit(thread_ranges, 0, n_tree - 1, n_thread);
+
+ // reset progress to 0
+ progress = 0;
+
+ if(n_thread == 1){
+ // ensure safe usage of R functions and glmnet
+ // by growing trees in a single thread.
+ grow_single_thread(&vi_numer, &vi_denom);
+ return;
+ }
+
+ // catch interrupts from threads
+ aborted = false;
+ aborted_threads = 0;
+
+ // containers
+ std::vector threads;
+ std::vector vi_numer_threads(n_thread);
+ std::vector vi_denom_threads(n_thread);
+
+ // reserve memory
+ threads.reserve(n_thread);
+
+ // begin multi-thread grow
+ for (uint i = 0; i < n_thread; ++i) {
+
+ vi_numer_threads[i].zeros(data->n_cols);
+ if(vi_type == VI_ANOVA) vi_denom_threads[i].zeros(data->n_cols);
+
+ threads.emplace_back(&Forest::grow_multi_thread, this, i,
+ &(vi_numer_threads[i]),
+ &(vi_denom_threads[i]));
+ }
+
+ if(verbosity == 1){
+ show_progress("Growing trees", n_tree);
+ }
+
+ // end multi-thread grow
+ for (auto &thread : threads) {
+ thread.join();
+ }
+
+ if (aborted_threads > 0) {
+ throw std::runtime_error("User interrupt.");
+ }
+
+ if(vi_type == VI_ANOVA){
+
+ for(uint i = 0; i < n_thread; ++i){
+ vi_numer += vi_numer_threads[i];
+ vi_denom += vi_denom_threads[i];
+ }
+
+ }
+
+}
+
+void Forest::grow_single_thread(vec* vi_numer_ptr,
+ uvec* vi_denom_ptr){
+
+
+ using std::chrono::steady_clock;
+ using std::chrono::duration_cast;
+ using std::chrono::seconds;
+
+ steady_clock::time_point start_time = steady_clock::now();
+ steady_clock::time_point last_time = steady_clock::now();
+ size_t max_progress = n_tree;
+
+ for (uint i = 0; i < n_tree; ++i) {
+
+ if(verbosity > 1){
+ Rcout << "------------ Growing tree " << i << " --------------";
+ Rcout << std::endl;
+ Rcout << std::endl;
+ }
+
+ trees[i]->grow(vi_numer_ptr, vi_denom_ptr);
+
+ ++progress;
+
+ if(verbosity == 1){
+
+ seconds elapsed_time = duration_cast(steady_clock::now() - last_time);
+
+ if ((progress > 0 && elapsed_time.count() > STATUS_INTERVAL) ||
+ (progress == max_progress)) {
+
+ double relative_progress = (double) progress / (double) max_progress;
+ seconds time_from_start = duration_cast(steady_clock::now() - start_time);
+ uint remaining_time = (1 / relative_progress - 1) * time_from_start.count();
+
+ Rcout << "Growing trees: ";
+ Rcout << round(100 * relative_progress) << "%. ";
+
+ if(progress < max_progress){
+ Rcout << "~ time remaining: ";
+ Rcout << beautifyTime(remaining_time) << ".";
+ }
+
+ Rcout << std::endl;
+
+ last_time = steady_clock::now();
+
+ }
+
+ }
+
+ Rcpp::checkUserInterrupt();
+
+ }
+
+}
+
+
+void Forest::grow_multi_thread(uint thread_idx,
+ vec* vi_numer_ptr,
+ uvec* vi_denom_ptr) {
+
+
+ if (thread_ranges.size() > thread_idx + 1) {
+
+ for (uint i = thread_ranges[thread_idx]; i < thread_ranges[thread_idx + 1]; ++i) {
+
+ trees[i]->grow(vi_numer_ptr, vi_denom_ptr);
+
+ // Check for user interrupt
+ if (aborted) {
+ std::unique_lock lock(mutex);
+ ++aborted_threads;
+ condition_variable.notify_one();
+ return;
+ }
+
+ // Increase progress by 1 tree
+ std::unique_lock lock(mutex);
+ ++progress;
+ condition_variable.notify_one();
+
+ }
+
+ }
+
+}
+
+void Forest::compute_oobag_vi() {
+
+ // catch interrupts from threads
+ aborted = false;
+ aborted_threads = 0;
+
+ // show progress from threads
+ progress = 0;
+
+ if(n_thread == 1){
+ compute_oobag_vi_single_thread(&vi_numer);
+ return;
+ }
+
+ std::vector threads;
+ std::vector vi_numer_threads(n_thread);
+ // no denominator b/c it is equal to n_tree for all oob vi methods
+
+ threads.reserve(n_thread);
+
+ for (uint i = 0; i < n_thread; ++i) {
+
+ vi_numer_threads[i].zeros(data->n_cols);
+
+ threads.emplace_back(&Forest::compute_oobag_vi_multi_thread,
+ this, i, &(vi_numer_threads[i]));
+ }
+
+ if(verbosity == 1){
+ show_progress("Computing importance", n_tree);
+ }
+
+ for (auto &thread : threads) {
+ thread.join();
+ }
+
+ if (aborted_threads > 0) {
+ throw std::runtime_error("User interrupt.");
+ }
+
+ for(uint i = 0; i < n_thread; ++i){
+ vi_numer += vi_numer_threads[i];
+ }
+
+}
+
+void Forest::compute_oobag_vi_single_thread(vec* vi_numer_ptr) {
+
+ using std::chrono::steady_clock;
+ using std::chrono::duration_cast;
+ using std::chrono::seconds;
+
+ steady_clock::time_point start_time = steady_clock::now();
+ steady_clock::time_point last_time = steady_clock::now();
+ size_t max_progress = n_tree;
+
+ for(uint i = 0; i < n_tree; ++i){
+
+ trees[i]->compute_oobag_vi(vi_numer_ptr, vi_type);
+
+ ++progress;
+
+ if(verbosity == 1){
+
+ seconds elapsed_time = duration_cast(steady_clock::now() - last_time);
+
+ if ((progress > 0 && elapsed_time.count() > STATUS_INTERVAL) ||
+ (progress == max_progress)) {
+
+ double relative_progress = (double) progress / (double) max_progress;
+ seconds time_from_start = duration_cast(steady_clock::now() - start_time);
+ uint remaining_time = (1 / relative_progress - 1) * time_from_start.count();
+
+ Rcout << "Computing importance: ";
+ Rcout << round(100 * relative_progress) << "%. ";
+
+ if(progress < max_progress){
+ Rcout << "~ time remaining: ";
+ Rcout << beautifyTime(remaining_time) << ".";
+ }
+
+ Rcout << std::endl;
+
+ last_time = steady_clock::now();
+
+ }
+
+ }
+
+ Rcpp::checkUserInterrupt();
+
+ }
+
+}
+
+void Forest::compute_oobag_vi_multi_thread(uint thread_idx, vec* vi_numer_ptr) {
+
+ if (thread_ranges.size() > thread_idx + 1) {
+
+ for(uint i=thread_ranges[thread_idx]; icompute_oobag_vi(vi_numer_ptr, vi_type);
+
+ // Check for user interrupt
+ if (aborted) {
+ std::unique_lock lock(mutex);
+ ++aborted_threads;
+ condition_variable.notify_one();
+ return;
+ }
+
+ // Increase progress by 1 tree
+ std::unique_lock lock(mutex);
+ ++progress;
+ condition_variable.notify_one();
+
+ }
+
+ }
+
+}
+
+void Forest::compute_prediction_accuracy(Data* prediction_data,
+ arma::mat& prediction_values,
+ arma::uword row_fill){
+
+ // avoid dividing by zero
+ uvec valid_observations = find(oobag_denom > 0);
+
+ // subset each data input
+ mat y_valid = prediction_data->y_rows(valid_observations);
+ vec w_valid = prediction_data->w_subvec(valid_observations);
+ mat p_valid = prediction_values.rows(valid_observations);
+
+ // scale predictions based on how many trees contributed
+ // (important to note it's different for each oobag obs)
+ vec valid_denom = oobag_denom(valid_observations);
+ p_valid.each_col() /= valid_denom;
+
+ // pass along to forest-specific version
+ compute_prediction_accuracy(y_valid, w_valid, p_valid, row_fill);
+
+}
+
+mat Forest::predict(bool oobag) {
+
+ mat result;
+
+ // No. of cols in pred mat depend on the type of forest
+ resize_pred_mat(result);
+
+ // Slots to hold oobag prediction accuracy
+ // (needs to be resized even if !oobag)
+ resize_oobag_eval();
+
+ progress = 0;
+ aborted = false;
+ aborted_threads = 0;
+
+ if(n_thread == 1){
+ // ensure safe usage of R functions
+ predict_single_thread(data.get(), oobag, result);
+
+ } else {
+
+ std::vector threads;
+ std::vector result_threads(n_thread);
+ std::vector oobag_denom_threads(n_thread);
+
+ threads.reserve(n_thread);
+
+ for (uint i = 0; i < n_thread; ++i) {
+
+ resize_pred_mat(result_threads[i]);
+ if(oobag) oobag_denom_threads[i].zeros(data->n_rows);
+
+ threads.emplace_back(&Forest::predict_multi_thread,
+ this, i, data.get(), oobag,
+ &(result_threads[i]),
+ &(oobag_denom_threads[i]));
+ }
+
+ if(verbosity == 1){
+ show_progress("Computing predictions", n_tree);
+ }
+
+ // wait for all threads to finish before proceeding
+ for (auto &thread : threads) {
+ thread.join();
+ }
+
+ for(uint i = 0; i < n_thread; ++i){
+
+ result += result_threads[i];
+
+ if(oobag){
+
+ oobag_denom += oobag_denom_threads[i];
+
+ // evaluate oobag error after joining each thread
+ // (only safe to do this when the condition below holds)
+ if(grow_mode &&
+ n_tree/oobag_eval_every == n_thread &&
+ i < n_thread - 1){
+
+ // i should be uint to access threads,
+ // eval_row should be uword to access oobag_eval
+ uword eval_row = i;
+
+ compute_prediction_accuracy(data.get(), result, eval_row);
+
+ }
+ }
+
+ }
+
+ }
+
+ if(pred_type == PRED_TERMINAL_NODES || !pred_aggregate){
+ return(result);
+ }
+
+ if(oobag){
+
+ if(grow_mode){
+ compute_prediction_accuracy(data.get(), result, oobag_eval.n_rows-1);
+ }
+
+ // it's okay if we divide by 0 here. It makes the result NaN but
+ // that will be fixed when the results are post-processed in R/orsf.R
+ result.each_col() /= oobag_denom;
+
+ } else {
+
+ result /= n_tree;
+
+ }
+
+ return(result);
+
+}
+
+std::vector> Forest::compute_dependence(bool oobag){
+
+ std::vector> result;
+
+ result.reserve(pd_x_vals.size());
+
+ // looping through each item in the pd list
+ for(uword k = 0; k < pd_x_vals.size(); ++k){
+
+ uword n = pd_x_vals[k].n_rows;
+
+ std::vector result_k;
+
+ result_k.reserve(n);
+
+ // saving x values
+ for(const auto& x_col : pd_x_cols[k]){
+ data->save_col(x_col);
+ }
+
+ // loop through each row in the current pd matrix
+ for(uword i = 0; i < n; ++i){
+
+ uword j = 0;
+ // fill x with current pd values
+ for(const auto& x_col : pd_x_cols[k]){
+ data->fill_col(pd_x_vals[k].at(i, j), x_col);
+ ++j;
+ }
+
+ if(oobag) oobag_denom.fill(0);
+
+ mat preds = predict(oobag);
+
+ if(pd_type == PD_SUMMARY){
+
+ mat preds_summary = mean(preds, 0);
+
+ mat preds_quant = quantile(preds, pd_probs, 0);
+ result_k.push_back(join_vert(preds_summary, preds_quant));
+
+ } else if(pd_type == PD_ICE) {
+
+ result_k.push_back(preds);
+
+ }
+
+ }
+
+ // bring back original values before moving to next pd item
+ for(const auto& x_col : pd_x_cols[k]){
+ data->restore_col(x_col);
+ }
+
+ result.push_back(result_k);
+
+ }
+
+ return(result);
+
+}
+
+void Forest::predict_single_thread(Data* prediction_data,
+ bool oobag,
+ mat& result) {
+
+ using std::chrono::steady_clock;
+ using std::chrono::duration_cast;
+ using std::chrono::seconds;
+ steady_clock::time_point start_time = steady_clock::now();
+ steady_clock::time_point last_time = steady_clock::now();
+ size_t max_progress = n_tree;
+
+ for (uint i = 0; i < n_tree; ++i) {
+
+ if(verbosity > 1){
+ if(oobag){
+ Rcout << "--- Computing oobag predictions: tree " << i << " ---";
+ } else {
+ Rcout << "------ Computing predictions: tree " << i << " -----";
+ }
+ Rcout << std::endl;
+ Rcout << std::endl;
+ }
+
+ trees[i]->predict_leaf(prediction_data, oobag);
+
+ if(pred_type == PRED_TERMINAL_NODES){
+
+ result.col(i) = conv_to::from(trees[i]->get_pred_leaf());
+
+ } else if (!pred_aggregate){
+
+ vec col_i = result.unsafe_col(i);
+ trees[i]->predict_value(&col_i, &oobag_denom, pred_type, oobag);
+
+ } else {
+
+ trees[i]->predict_value(&result, &oobag_denom, pred_type, oobag);
+
+ }
+
+ progress++;
+
+ if(verbosity == 1){
+
+ seconds elapsed_time = duration_cast(steady_clock::now() - last_time);
+
+ if ((progress > 0 && elapsed_time.count() > STATUS_INTERVAL) ||
+ (progress == max_progress)) {
+
+ double relative_progress = (double) progress / (double) max_progress;
+ seconds time_from_start = duration_cast(steady_clock::now() - start_time);
+ uint remaining_time = (1 / relative_progress - 1) * time_from_start.count();
+
+ Rcout << "Computing predictions: ";
+ Rcout << round(100 * relative_progress) << "%. ";
+
+ if(progress < max_progress){
+ Rcout << "~ time remaining: ";
+ Rcout << beautifyTime(remaining_time) << ".";
+ }
+
+ Rcout << std::endl;
+
+ last_time = steady_clock::now();
+
+ }
+
+ }
+
+ // if tracking oobag error over time:
+ if(oobag && grow_mode && (progress%oobag_eval_every==0) && pred_aggregate){
+
+ uword eval_row = (progress / oobag_eval_every) - 1;
+ // mat preds = result.each_col() / oobag_denom;
+ compute_prediction_accuracy(prediction_data, result, eval_row);
+
+ }
+
+ }
+
+}
+
+void Forest::predict_multi_thread(uint thread_idx,
+ Data* prediction_data,
+ bool oobag,
+ mat* result_ptr,
+ vec* denom_ptr) {
+
+ if (thread_ranges.size() > thread_idx + 1) {
+
+ for (uint i = thread_ranges[thread_idx]; i < thread_ranges[thread_idx + 1]; ++i) {
+
+ trees[i]->predict_leaf(prediction_data, oobag);
+
+ if(pred_type == PRED_TERMINAL_NODES){
+
+ (*result_ptr).col(i) = conv_to::from(trees[i]->get_pred_leaf());
+
+ } else if (!pred_aggregate){
+
+ vec col_i = (*result_ptr).unsafe_col(i);
+ trees[i]->predict_value(&col_i, denom_ptr, pred_type, oobag);
+
+ } else {
+
+ trees[i]->predict_value(result_ptr, denom_ptr, pred_type, oobag);
+
+ }
+
+ // Check for user interrupt
+ if (aborted) {
+ std::unique_lock lock(mutex);
+ ++aborted_threads;
+ condition_variable.notify_one();
+ return;
+ }
+
+ // Increase progress by 1 tree
+ std::unique_lock lock(mutex);
+ ++progress;
+
+ condition_variable.notify_one();
+
+ }
+
+ }
+
+}
+
+arma::uword Forest::find_max_eval_steps(){
+
+ if(!oobag_pred) return(0);
+
+ uword n_evals = std::ceil(n_tree / oobag_eval_every);
+
+ if(n_evals > n_tree) n_evals = n_tree;
+
+ if(n_evals < 1) n_evals = 1;
+
+ return(n_evals);
+
+}
+
+void Forest::resize_oobag_eval(){
+
+ uword n_evals = find_max_eval_steps();
+
+ oobag_eval.resize(n_evals, 1);
+
+}
+
+void Forest::show_progress(std::string operation, size_t max_progress) {
+
+ using std::chrono::steady_clock;
+ using std::chrono::duration_cast;
+ using std::chrono::seconds;
+
+ steady_clock::time_point start_time = steady_clock::now();
+ steady_clock::time_point last_time = steady_clock::now();
+ std::unique_lock lock(mutex);
+
+ // Wait for message from threads and show output if enough time elapsed
+ while (progress < max_progress) {
+
+ condition_variable.wait(lock);
+
+ seconds elapsed_time = duration_cast(steady_clock::now() - last_time);
+
+ // Check for user interrupt
+ if (!aborted && checkInterrupt()) {
+ aborted = true;
+ }
+ if (aborted && aborted_threads >= n_thread) {
+ return;
+ }
+
+ if ((progress > 0 && elapsed_time.count() > STATUS_INTERVAL) ||
+ (progress == max_progress)) {
+
+ double relative_progress = (double) progress / (double) max_progress;
+ seconds time_from_start = duration_cast(steady_clock::now() - start_time);
+ uint remaining_time = (1 / relative_progress - 1) * time_from_start.count();
+
+ Rcout << operation << ": ";
+ Rcout << round(100 * relative_progress) << "%. ";
+
+ if(progress < max_progress){
+ Rcout << "~ time remaining: ";
+ Rcout << beautifyTime(remaining_time) << ".";
+ }
+
+ Rcout << std::endl;
+
+ last_time = steady_clock::now();
+
+ }
+ }
+}
+
+void Forest::resize_pred_mat(arma::mat& p){
+
+ if(pred_type == PRED_TERMINAL_NODES || !pred_aggregate){
+
+ p.zeros(data->n_rows, n_tree);
+
+ } else {
+
+ resize_pred_mat_internal(p);
+
+ }
+
+}
+
+}
+
+
diff --git a/src/Forest.h b/src/Forest.h
new file mode 100644
index 00000000..39a5ee58
--- /dev/null
+++ b/src/Forest.h
@@ -0,0 +1,333 @@
+
+// Forest.h
+
+#ifndef FOREST_H
+#define FOREST_H
+
+#include "Data.h"
+#include "globals.h"
+#include "utility.h"
+#include "Tree.h"
+#include "TreeSurvival.h"
+
+#include
+#include
+#include
+
+namespace aorsf {
+
+class Forest {
+
+public:
+
+ // Constructor
+
+ Forest();
+
+ // deleting the copy constructor
+ Forest(const Forest&) = delete;
+ // deleting the copy assignment operator
+ Forest& operator=(const Forest&) = delete;
+
+ virtual ~Forest() = default;
+
+ // Methods
+
+ void init(std::unique_ptr input_data,
+ Rcpp::IntegerVector& tree_seeds,
+ arma::uword n_tree,
+ arma::uword mtry,
+ bool sample_with_replacement,
+ double sample_fraction,
+ bool grow_mode,
+ VariableImportance vi_type,
+ double vi_max_pvalue,
+ // leaves
+ double leaf_min_obs,
+ // node splitting
+ SplitRule split_rule,
+ double split_min_obs,
+ double split_min_stat,
+ arma::uword split_max_cuts,
+ arma::uword split_max_retry,
+ // linear combinations
+ LinearCombo lincomb_type,
+ double lincomb_eps,
+ arma::uword lincomb_iter_max,
+ bool lincomb_scale,
+ double lincomb_alpha,
+ arma::uword lincomb_df_target,
+ arma::uword lincomb_ties_method,
+ RObject lincomb_R_function,
+ // predictions
+ PredType pred_type,
+ bool pred_mode,
+ bool pred_aggregate,
+ bool oobag_pred,
+ EvalType oobag_eval_type,
+ arma::uword oobag_eval_every,
+ Rcpp::RObject oobag_R_function,
+ uint n_thread,
+ int verbosity);
+
+ // Grow or predict
+ // void run(bool verbose, bool oobag);
+
+ virtual void compute_prediction_accuracy(
+ Data* prediction_data,
+ arma::mat& prediction_values,
+ arma::uword row_fill
+ );
+
+ virtual void compute_prediction_accuracy(
+ arma::mat& y,
+ arma::vec& w,
+ arma::mat& predictions,
+ arma::uword row_fill
+ ) = 0;
+
+ std::vector> get_cutpoint() {
+
+ std::vector> result;
+
+ result.reserve(n_tree);
+
+ for (auto& tree : trees) {
+ result.push_back(tree->get_cutpoint());
+ }
+
+ return result;
+
+ }
+ std::vector get_rows_oobag() {
+
+ std::vector result;
+
+ result.reserve(n_tree);
+
+ for (auto& tree : trees) {
+ result.push_back(tree->get_rows_oobag());
+ }
+
+ return result;
+
+ }
+
+ std::vector> get_child_left() {
+
+ std::vector> result;
+
+ result.reserve(n_tree);
+
+ for (auto& tree : trees) {
+ result.push_back(tree->get_child_left());
+ }
+
+ return result;
+
+ }
+
+ std::vector> get_coef_values() {
+
+ std::vector> result;
+
+ result.reserve(n_tree);
+
+ for (auto& tree : trees) {
+ result.push_back(tree->get_coef_values());
+ }
+
+ return result;
+
+ }
+ std::vector> get_coef_indices() {
+
+ std::vector> result;
+
+ result.reserve(n_tree);
+
+ for (auto& tree : trees) {
+ result.push_back(tree->get_coef_indices());
+ }
+
+ return result;
+
+ }
+
+ std::vector> get_leaf_summary() {
+
+ std::vector> result;
+
+ result.reserve(n_tree);
+
+ for (auto& tree : trees) {
+ result.push_back(tree->get_leaf_summary());
+ }
+
+ return result;
+
+ }
+
+ void set_unique_event_times(arma::vec& x){
+ this->unique_event_times = x;
+ }
+
+ arma::vec& get_unique_event_times(){
+ return(unique_event_times);
+ }
+
+ arma::vec& get_vi_numer(){
+ return(vi_numer);
+ }
+
+ arma::uvec& get_vi_denom(){
+ return(vi_denom);
+ }
+
+ arma::mat& get_oobag_eval(){
+ return(oobag_eval);
+ }
+
+ arma::mat& get_predictions(){
+ return(pred_values);
+ }
+
+ std::vector>& get_pd_values(){
+ return(pd_values);
+ }
+
+ void run(bool oobag);
+
+ virtual void plant() = 0;
+
+ void grow();
+
+ arma::mat predict(bool oobag);
+
+ std::vector> compute_dependence(bool oobag);
+
+protected:
+
+ void init_trees();
+
+ void grow_single_thread(vec* vi_numer_ptr,
+ uvec* vi_denom_ptr);
+
+ void grow_multi_thread(uint thread_idx,
+ vec* vi_numer_ptr,
+ uvec* vi_denom_ptr);
+
+ void predict_single_thread(Data* prediction_data,
+ bool oobag,
+ mat& result);
+
+ void predict_multi_thread(uint thread_idx,
+ Data* prediction_data,
+ bool oobag,
+ mat* result_ptr,
+ vec* denom_ptr);
+
+ void compute_oobag_vi();
+
+ void compute_oobag_vi_single_thread(vec* vi_numer_ptr);
+
+ void compute_oobag_vi_multi_thread(uint thread_idx, vec* vi_numer_ptr);
+
+ void show_progress(std::string operation, size_t max_progress);
+
+ virtual void resize_pred_mat(arma::mat& p);
+
+ virtual void resize_pred_mat_internal(arma::mat& p) = 0;
+
+ arma::uword find_max_eval_steps();
+
+ virtual void resize_oobag_eval();
+
+ // Member variables
+
+ arma::uword n_tree;
+ arma::uword mtry;
+ bool sample_with_replacement;
+ double sample_fraction;
+ Rcpp::IntegerVector tree_seeds;
+
+ std::vector> trees;
+
+ std::unique_ptr data;
+
+ arma::vec unique_event_times;
+
+ // variable importance
+ VariableImportance vi_type;
+ double vi_max_pvalue;
+
+ arma::vec vi_numer;
+ arma::uvec vi_denom;
+
+ // leaves
+ double leaf_min_events;
+ double leaf_min_obs;
+
+ // node splitting
+ SplitRule split_rule;
+ double split_min_events;
+ double split_min_obs;
+ double split_min_stat;
+ arma::uword split_max_cuts;
+ arma::uword split_max_retry;
+
+ // linear combinations
+ LinearCombo lincomb_type;
+ double lincomb_eps;
+ bool lincomb_scale;
+ double lincomb_alpha;
+ arma::uword lincomb_iter_max;
+ arma::uword lincomb_df_target;
+ arma::uword lincomb_ties_method;
+ RObject lincomb_R_function;
+
+ bool grow_mode;
+
+ // predictions
+ PredType pred_type;
+ bool pred_mode;
+ bool pred_aggregate;
+ arma::mat pred_values;
+
+ // partial dependence
+ PartialDepType pd_type;
+ std::vector> pd_values;
+ std::vector pd_x_vals;
+ std::vector pd_x_cols;
+ arma::vec pd_probs;
+
+ // out-of-bag
+ bool oobag_pred;
+ arma::vec oobag_denom;
+ arma::mat oobag_eval;
+ EvalType oobag_eval_type;
+ arma::uword oobag_eval_every;
+ RObject oobag_R_function;
+
+
+ // multi-threading
+ uint n_thread;
+ std::vector thread_ranges;
+ std::mutex mutex;
+ std::condition_variable condition_variable;
+
+ size_t progress;
+ size_t aborted_threads;
+ bool aborted;
+
+ // printing to console
+ int verbosity;
+
+
+};
+
+}
+
+
+
+#endif /* Forest_H */
diff --git a/src/ForestSurvival.cpp b/src/ForestSurvival.cpp
new file mode 100644
index 00000000..b6ef3576
--- /dev/null
+++ b/src/ForestSurvival.cpp
@@ -0,0 +1,191 @@
+// Forest.cpp
+
+#include
+#include "ForestSurvival.h"
+#include "TreeSurvival.h"
+
+using namespace arma;
+using namespace Rcpp;
+
+namespace aorsf {
+
+ForestSurvival::ForestSurvival() { }
+
+ForestSurvival::ForestSurvival(double leaf_min_events,
+ double split_min_events,
+ arma::vec& pred_horizon){
+
+ this->leaf_min_events = leaf_min_events;
+ this->split_min_events = split_min_events;
+ this->pred_horizon = pred_horizon;
+
+}
+
+
+
+void ForestSurvival::load(
+ arma::uword n_tree,
+ std::vector& forest_rows_oobag,
+ std::vector>& forest_cutpoint,
+ std::vector>& forest_child_left,
+ std::vector>& forest_coef_values,
+ std::vector>& forest_coef_indices,
+ std::vector>& forest_leaf_pred_indx,
+ std::vector>& forest_leaf_pred_prob,
+ std::vector>& forest_leaf_pred_chaz,
+ std::vector>& forest_leaf_summary,
+ PartialDepType pd_type,
+ std::vector& pd_x_vals,
+ std::vector& pd_x_cols,
+ arma::vec& pd_probs
+) {
+
+ this->n_tree = n_tree;
+ this->pd_type = pd_type;
+ this->pd_x_vals = pd_x_vals;
+ this->pd_x_cols = pd_x_cols;
+ this->pd_probs = pd_probs;
+
+ if(VERBOSITY > 0){
+ Rcout << "---- loading forest from input list ----";
+ Rcout << std::endl << std::endl;
+ }
+
+ // Create trees
+ trees.reserve(n_tree);
+
+ for (uword i = 0; i < n_tree; ++i) {
+ trees.push_back(
+ std::make_unique(forest_rows_oobag[i],
+ forest_cutpoint[i],
+ forest_child_left[i],
+ forest_coef_values[i],
+ forest_coef_indices[i],
+ forest_leaf_pred_indx[i],
+ forest_leaf_pred_prob[i],
+ forest_leaf_pred_chaz[i],
+ forest_leaf_summary[i],
+ &pred_horizon)
+ );
+ }
+
+ // Create thread ranges
+ equalSplit(thread_ranges, 0, n_tree - 1, n_thread);
+
+}
+
+// growInternal() in ranger
+void ForestSurvival::plant() {
+
+ this->unique_event_times = find_unique_event_times(data->get_y());
+
+ trees.reserve(n_tree);
+
+ for (arma::uword i = 0; i < n_tree; ++i) {
+ trees.push_back(std::make_unique(leaf_min_events,
+ split_min_events,
+ &unique_event_times,
+ &pred_horizon));
+ }
+
+}
+
+void ForestSurvival::resize_pred_mat_internal(arma::mat& p){
+
+ p.zeros(data->n_rows, pred_horizon.size());
+
+}
+
+std::vector> ForestSurvival::get_leaf_pred_indx() {
+
+ std::vector> result;
+
+ result.reserve(n_tree);
+
+ for (auto& tree : trees) {
+ auto& temp = dynamic_cast(*tree);
+ result.push_back(temp.get_leaf_pred_indx());
+ }
+
+ return result;
+
+}
+
+std::vector> ForestSurvival::get_leaf_pred_prob() {
+
+ std::vector> result;
+
+ result.reserve(n_tree);
+
+ for (auto& tree : trees) {
+ auto& temp = dynamic_cast(*tree);
+ result.push_back(temp.get_leaf_pred_prob());
+ }
+
+ return result;
+
+}
+
+std::vector> ForestSurvival::get_leaf_pred_chaz() {
+
+ std::vector> result;
+
+ result.reserve(n_tree);
+
+ for (auto& tree : trees) {
+ auto& temp = dynamic_cast(*tree);
+ result.push_back(temp.get_leaf_pred_chaz());
+ }
+
+ return result;
+
+}
+
+void ForestSurvival::resize_oobag_eval(){
+
+ uword n_evals = find_max_eval_steps();
+
+ oobag_eval.resize(n_evals, pred_horizon.size());
+
+}
+
+void ForestSurvival::compute_prediction_accuracy(arma::mat& y,
+ arma::vec& w,
+ arma::mat& predictions,
+ arma::uword row_fill){
+
+ bool pred_is_risklike = true;
+
+ if(pred_type == PRED_SURVIVAL) pred_is_risklike = false;
+
+
+ if(oobag_eval_type == EVAL_R_FUNCTION){
+
+ // initialize function from tree object
+ // (Functions can't be stored in C++ classes, but Robjects can)
+ Function f_oobag_eval = as(oobag_R_function);
+ NumericMatrix y_ = wrap(y);
+ NumericVector w_ = wrap(w);
+
+ for(arma::uword i = 0; i < oobag_eval.n_cols; ++i){
+ vec p = predictions.col(i);
+ NumericVector p_ = wrap(p);
+ NumericVector R_result = f_oobag_eval(y_, w_, p_);
+ oobag_eval(row_fill, i) = R_result[0];
+ }
+ return;
+ }
+
+
+ for(arma::uword i = 0; i < oobag_eval.n_cols; ++i){
+ vec p = predictions.unsafe_col(i);
+ oobag_eval(row_fill, i) = compute_cstat(y, w, p, pred_is_risklike);
+ }
+
+
+}
+
+
+}
+
+
diff --git a/src/ForestSurvival.h b/src/ForestSurvival.h
new file mode 100644
index 00000000..65e10592
--- /dev/null
+++ b/src/ForestSurvival.h
@@ -0,0 +1,71 @@
+
+// Forest.h
+
+#ifndef FORESTSURVIVAL_H
+#define FORESTSURVIVAL_H
+
+#include "Data.h"
+#include "globals.h"
+#include "Forest.h"
+
+namespace aorsf {
+
+class ForestSurvival: public Forest {
+
+public:
+
+ ForestSurvival();
+
+ ForestSurvival(double leaf_min_events,
+ double split_min_events,
+ arma::vec& pred_horizon);
+
+
+ ForestSurvival(const ForestSurvival&) = delete;
+ ForestSurvival& operator=(const ForestSurvival&) = delete;
+
+ void load(arma::uword n_tree,
+ std::vector& rows_oobag,
+ std::vector>& forest_cutpoint,
+ std::vector>& forest_child_left,
+ std::vector>& forest_coef_values,
+ std::vector>& forest_coef_indices,
+ std::vector>& forest_leaf_pred_indx,
+ std::vector>& forest_leaf_pred_prob,
+ std::vector>& forest_leaf_pred_chaz,
+ std::vector>& forest_leaf_summary,
+ PartialDepType pd_type,
+ std::vector& pd_x_vals,
+ std::vector& pd_x_cols,
+ arma::vec& pd_probs);
+
+ std::vector> get_leaf_pred_indx();
+ std::vector> get_leaf_pred_prob();
+ std::vector> get_leaf_pred_chaz();
+
+ // growInternal() in ranger
+ void plant() override;
+
+ void compute_prediction_accuracy(
+ arma::mat& y,
+ arma::vec& w,
+ arma::mat& predictions,
+ arma::uword row_fill
+ ) override;
+
+protected:
+
+ void resize_pred_mat_internal(arma::mat& p) override;
+
+ void resize_oobag_eval() override;
+
+ arma::vec pred_horizon;
+
+
+};
+
+}
+
+
+
+#endif /* Forest_H */
diff --git a/src/Makevars b/src/Makevars
index d715e498..2321181e 100644
--- a/src/Makevars
+++ b/src/Makevars
@@ -1,3 +1,3 @@
-CXX_STD = CXX11
+CXX_STD = CXX17
PKG_CXXFLAGS = $(SHLIB_OPENMP_CXXFLAGS)
PKG_LIBS = $(SHLIB_OPENMP_CXXFLAGS) $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)
diff --git a/src/Makevars.win b/src/Makevars.win
index d715e498..2321181e 100644
--- a/src/Makevars.win
+++ b/src/Makevars.win
@@ -1,3 +1,3 @@
-CXX_STD = CXX11
+CXX_STD = CXX17
PKG_CXXFLAGS = $(SHLIB_OPENMP_CXXFLAGS)
PKG_LIBS = $(SHLIB_OPENMP_CXXFLAGS) $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)
diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp
index 2996eb25..17a78909 100644
--- a/src/RcppExports.cpp
+++ b/src/RcppExports.cpp
@@ -11,288 +11,137 @@ Rcpp::Rostream& Rcpp::Rcout = Rcpp::Rcpp_cout_get();
Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif
-// std_setdiff
-arma::uvec std_setdiff(arma::uvec& x, arma::uvec& y);
-RcppExport SEXP _aorsf_std_setdiff(SEXP xSEXP, SEXP ySEXP) {
+// coxph_fit_exported
+List coxph_fit_exported(arma::mat& x_node, arma::mat& y_node, arma::vec& w_node, int method, double cph_eps, arma::uword cph_iter_max);
+RcppExport SEXP _aorsf_coxph_fit_exported(SEXP x_nodeSEXP, SEXP y_nodeSEXP, SEXP w_nodeSEXP, SEXP methodSEXP, SEXP cph_epsSEXP, SEXP cph_iter_maxSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< arma::uvec& >::type x(xSEXP);
- Rcpp::traits::input_parameter< arma::uvec& >::type y(ySEXP);
- rcpp_result_gen = Rcpp::wrap(std_setdiff(x, y));
- return rcpp_result_gen;
-END_RCPP
-}
-// x_node_scale_exported
-List x_node_scale_exported(NumericMatrix& x_, NumericVector& w_);
-RcppExport SEXP _aorsf_x_node_scale_exported(SEXP x_SEXP, SEXP w_SEXP) {
-BEGIN_RCPP
- Rcpp::RObject rcpp_result_gen;
- Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< NumericMatrix& >::type x_(x_SEXP);
- Rcpp::traits::input_parameter< NumericVector& >::type w_(w_SEXP);
- rcpp_result_gen = Rcpp::wrap(x_node_scale_exported(x_, w_));
- return rcpp_result_gen;
-END_RCPP
-}
-// leaf_kaplan_testthat
-arma::mat leaf_kaplan_testthat(const arma::mat& y, const arma::vec& w);
-RcppExport SEXP _aorsf_leaf_kaplan_testthat(SEXP ySEXP, SEXP wSEXP) {
-BEGIN_RCPP
- Rcpp::RObject rcpp_result_gen;
- Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< const arma::mat& >::type y(ySEXP);
- Rcpp::traits::input_parameter< const arma::vec& >::type w(wSEXP);
- rcpp_result_gen = Rcpp::wrap(leaf_kaplan_testthat(y, w));
- return rcpp_result_gen;
-END_RCPP
-}
-// newtraph_cph_testthat
-arma::vec newtraph_cph_testthat(NumericMatrix& x_in, NumericMatrix& y_in, NumericVector& w_in, int method, double cph_eps_, int iter_max);
-RcppExport SEXP _aorsf_newtraph_cph_testthat(SEXP x_inSEXP, SEXP y_inSEXP, SEXP w_inSEXP, SEXP methodSEXP, SEXP cph_eps_SEXP, SEXP iter_maxSEXP) {
-BEGIN_RCPP
- Rcpp::RObject rcpp_result_gen;
- Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< NumericMatrix& >::type x_in(x_inSEXP);
- Rcpp::traits::input_parameter< NumericMatrix& >::type y_in(y_inSEXP);
- Rcpp::traits::input_parameter< NumericVector& >::type w_in(w_inSEXP);
+ Rcpp::traits::input_parameter< arma::mat& >::type x_node(x_nodeSEXP);
+ Rcpp::traits::input_parameter< arma::mat& >::type y_node(y_nodeSEXP);
+ Rcpp::traits::input_parameter< arma::vec& >::type w_node(w_nodeSEXP);
Rcpp::traits::input_parameter< int >::type method(methodSEXP);
- Rcpp::traits::input_parameter< double >::type cph_eps_(cph_eps_SEXP);
- Rcpp::traits::input_parameter< int >::type iter_max(iter_maxSEXP);
- rcpp_result_gen = Rcpp::wrap(newtraph_cph_testthat(x_in, y_in, w_in, method, cph_eps_, iter_max));
- return rcpp_result_gen;
-END_RCPP
-}
-// lrt_multi_testthat
-List lrt_multi_testthat(NumericMatrix& y_node_, NumericVector& w_node_, NumericVector& XB_, int n_split_, int leaf_min_events_, int leaf_min_obs_);
-RcppExport SEXP _aorsf_lrt_multi_testthat(SEXP y_node_SEXP, SEXP w_node_SEXP, SEXP XB_SEXP, SEXP n_split_SEXP, SEXP leaf_min_events_SEXP, SEXP leaf_min_obs_SEXP) {
-BEGIN_RCPP
- Rcpp::RObject rcpp_result_gen;
- Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< NumericMatrix& >::type y_node_(y_node_SEXP);
- Rcpp::traits::input_parameter< NumericVector& >::type w_node_(w_node_SEXP);
- Rcpp::traits::input_parameter< NumericVector& >::type XB_(XB_SEXP);
- Rcpp::traits::input_parameter< int >::type n_split_(n_split_SEXP);
- Rcpp::traits::input_parameter< int >::type leaf_min_events_(leaf_min_events_SEXP);
- Rcpp::traits::input_parameter< int >::type leaf_min_obs_(leaf_min_obs_SEXP);
- rcpp_result_gen = Rcpp::wrap(lrt_multi_testthat(y_node_, w_node_, XB_, n_split_, leaf_min_events_, leaf_min_obs_));
- return rcpp_result_gen;
-END_RCPP
-}
-// oobag_c_harrell_testthat
-double oobag_c_harrell_testthat(NumericMatrix y_mat, NumericVector s_vec);
-RcppExport SEXP _aorsf_oobag_c_harrell_testthat(SEXP y_matSEXP, SEXP s_vecSEXP) {
-BEGIN_RCPP
- Rcpp::RObject rcpp_result_gen;
- Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< NumericMatrix >::type y_mat(y_matSEXP);
- Rcpp::traits::input_parameter< NumericVector >::type s_vec(s_vecSEXP);
- rcpp_result_gen = Rcpp::wrap(oobag_c_harrell_testthat(y_mat, s_vec));
- return rcpp_result_gen;
-END_RCPP
-}
-// ostree_pred_leaf_testthat
-arma::uvec ostree_pred_leaf_testthat(List& tree, NumericMatrix& x_pred_);
-RcppExport SEXP _aorsf_ostree_pred_leaf_testthat(SEXP treeSEXP, SEXP x_pred_SEXP) {
-BEGIN_RCPP
- Rcpp::RObject rcpp_result_gen;
- Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< List& >::type tree(treeSEXP);
- Rcpp::traits::input_parameter< NumericMatrix& >::type x_pred_(x_pred_SEXP);
- rcpp_result_gen = Rcpp::wrap(ostree_pred_leaf_testthat(tree, x_pred_));
- return rcpp_result_gen;
-END_RCPP
-}
-// orsf_fit
-List orsf_fit(NumericMatrix& x, NumericMatrix& y, NumericVector& weights, const int& n_tree, const int& n_split_, const int& mtry_, const double& leaf_min_events_, const double& leaf_min_obs_, const double& split_min_events_, const double& split_min_obs_, const double& split_min_stat_, const int& cph_method_, const double& cph_eps_, const int& cph_iter_max_, const bool& cph_do_scale_, const double& net_alpha_, const int& net_df_target_, const bool& oobag_pred_, const char& oobag_pred_type_, const double& oobag_pred_horizon_, const int& oobag_eval_every_, const bool& oobag_importance_, const char& oobag_importance_type_, IntegerVector& tree_seeds, const int& max_retry_, Function f_beta, const char& type_beta_, Function f_oobag_eval, const char& type_oobag_eval_, const bool verbose_progress);
-RcppExport SEXP _aorsf_orsf_fit(SEXP xSEXP, SEXP ySEXP, SEXP weightsSEXP, SEXP n_treeSEXP, SEXP n_split_SEXP, SEXP mtry_SEXP, SEXP leaf_min_events_SEXP, SEXP leaf_min_obs_SEXP, SEXP split_min_events_SEXP, SEXP split_min_obs_SEXP, SEXP split_min_stat_SEXP, SEXP cph_method_SEXP, SEXP cph_eps_SEXP, SEXP cph_iter_max_SEXP, SEXP cph_do_scale_SEXP, SEXP net_alpha_SEXP, SEXP net_df_target_SEXP, SEXP oobag_pred_SEXP, SEXP oobag_pred_type_SEXP, SEXP oobag_pred_horizon_SEXP, SEXP oobag_eval_every_SEXP, SEXP oobag_importance_SEXP, SEXP oobag_importance_type_SEXP, SEXP tree_seedsSEXP, SEXP max_retry_SEXP, SEXP f_betaSEXP, SEXP type_beta_SEXP, SEXP f_oobag_evalSEXP, SEXP type_oobag_eval_SEXP, SEXP verbose_progressSEXP) {
-BEGIN_RCPP
- Rcpp::RObject rcpp_result_gen;
- Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< NumericMatrix& >::type x(xSEXP);
- Rcpp::traits::input_parameter< NumericMatrix& >::type y(ySEXP);
- Rcpp::traits::input_parameter< NumericVector& >::type weights(weightsSEXP);
- Rcpp::traits::input_parameter< const int& >::type n_tree(n_treeSEXP);
- Rcpp::traits::input_parameter< const int& >::type n_split_(n_split_SEXP);
- Rcpp::traits::input_parameter< const int& >::type mtry_(mtry_SEXP);
- Rcpp::traits::input_parameter< const double& >::type leaf_min_events_(leaf_min_events_SEXP);
- Rcpp::traits::input_parameter< const double& >::type leaf_min_obs_(leaf_min_obs_SEXP);
- Rcpp::traits::input_parameter< const double& >::type split_min_events_(split_min_events_SEXP);
- Rcpp::traits::input_parameter< const double& >::type split_min_obs_(split_min_obs_SEXP);
- Rcpp::traits::input_parameter< const double& >::type split_min_stat_(split_min_stat_SEXP);
- Rcpp::traits::input_parameter< const int& >::type cph_method_(cph_method_SEXP);
- Rcpp::traits::input_parameter< const double& >::type cph_eps_(cph_eps_SEXP);
- Rcpp::traits::input_parameter< const int& >::type cph_iter_max_(cph_iter_max_SEXP);
- Rcpp::traits::input_parameter< const bool& >::type cph_do_scale_(cph_do_scale_SEXP);
- Rcpp::traits::input_parameter< const double& >::type net_alpha_(net_alpha_SEXP);
- Rcpp::traits::input_parameter< const int& >::type net_df_target_(net_df_target_SEXP);
- Rcpp::traits::input_parameter< const bool& >::type oobag_pred_(oobag_pred_SEXP);
- Rcpp::traits::input_parameter< const char& >::type oobag_pred_type_(oobag_pred_type_SEXP);
- Rcpp::traits::input_parameter< const double& >::type oobag_pred_horizon_(oobag_pred_horizon_SEXP);
- Rcpp::traits::input_parameter< const int& >::type oobag_eval_every_(oobag_eval_every_SEXP);
- Rcpp::traits::input_parameter< const bool& >::type oobag_importance_(oobag_importance_SEXP);
- Rcpp::traits::input_parameter< const char& >::type oobag_importance_type_(oobag_importance_type_SEXP);
- Rcpp::traits::input_parameter< IntegerVector& >::type tree_seeds(tree_seedsSEXP);
- Rcpp::traits::input_parameter< const int& >::type max_retry_(max_retry_SEXP);
- Rcpp::traits::input_parameter< Function >::type f_beta(f_betaSEXP);
- Rcpp::traits::input_parameter< const char& >::type type_beta_(type_beta_SEXP);
- Rcpp::traits::input_parameter< Function >::type f_oobag_eval(f_oobag_evalSEXP);
- Rcpp::traits::input_parameter< const char& >::type type_oobag_eval_(type_oobag_eval_SEXP);
- Rcpp::traits::input_parameter< const bool >::type verbose_progress(verbose_progressSEXP);
- rcpp_result_gen = Rcpp::wrap(orsf_fit(x, y, weights, n_tree, n_split_, mtry_, leaf_min_events_, leaf_min_obs_, split_min_events_, split_min_obs_, split_min_stat_, cph_method_, cph_eps_, cph_iter_max_, cph_do_scale_, net_alpha_, net_df_target_, oobag_pred_, oobag_pred_type_, oobag_pred_horizon_, oobag_eval_every_, oobag_importance_, oobag_importance_type_, tree_seeds, max_retry_, f_beta, type_beta_, f_oobag_eval, type_oobag_eval_, verbose_progress));
- return rcpp_result_gen;
-END_RCPP
-}
-// orsf_oob_negate_vi
-arma::vec orsf_oob_negate_vi(NumericMatrix& x, NumericMatrix& y, List& forest, const double& last_eval_stat, const double& time_pred_, Function f_oobag_eval, const char& pred_type_, const char& type_oobag_eval_);
-RcppExport SEXP _aorsf_orsf_oob_negate_vi(SEXP xSEXP, SEXP ySEXP, SEXP forestSEXP, SEXP last_eval_statSEXP, SEXP time_pred_SEXP, SEXP f_oobag_evalSEXP, SEXP pred_type_SEXP, SEXP type_oobag_eval_SEXP) {
-BEGIN_RCPP
- Rcpp::RObject rcpp_result_gen;
- Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< NumericMatrix& >::type x(xSEXP);
- Rcpp::traits::input_parameter< NumericMatrix& >::type y(ySEXP);
- Rcpp::traits::input_parameter< List& >::type forest(forestSEXP);
- Rcpp::traits::input_parameter< const double& >::type last_eval_stat(last_eval_statSEXP);
- Rcpp::traits::input_parameter< const double& >::type time_pred_(time_pred_SEXP);
- Rcpp::traits::input_parameter< Function >::type f_oobag_eval(f_oobag_evalSEXP);
- Rcpp::traits::input_parameter< const char& >::type pred_type_(pred_type_SEXP);
- Rcpp::traits::input_parameter< const char& >::type type_oobag_eval_(type_oobag_eval_SEXP);
- rcpp_result_gen = Rcpp::wrap(orsf_oob_negate_vi(x, y, forest, last_eval_stat, time_pred_, f_oobag_eval, pred_type_, type_oobag_eval_));
- return rcpp_result_gen;
-END_RCPP
-}
-// orsf_oob_permute_vi
-arma::vec orsf_oob_permute_vi(NumericMatrix& x, NumericMatrix& y, List& forest, const double& last_eval_stat, const double& time_pred_, Function f_oobag_eval, const char& pred_type_, const char& type_oobag_eval_);
-RcppExport SEXP _aorsf_orsf_oob_permute_vi(SEXP xSEXP, SEXP ySEXP, SEXP forestSEXP, SEXP last_eval_statSEXP, SEXP time_pred_SEXP, SEXP f_oobag_evalSEXP, SEXP pred_type_SEXP, SEXP type_oobag_eval_SEXP) {
-BEGIN_RCPP
- Rcpp::RObject rcpp_result_gen;
- Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< NumericMatrix& >::type x(xSEXP);
- Rcpp::traits::input_parameter< NumericMatrix& >::type y(ySEXP);
- Rcpp::traits::input_parameter< List& >::type forest(forestSEXP);
- Rcpp::traits::input_parameter< const double& >::type last_eval_stat(last_eval_statSEXP);
- Rcpp::traits::input_parameter< const double& >::type time_pred_(time_pred_SEXP);
- Rcpp::traits::input_parameter< Function >::type f_oobag_eval(f_oobag_evalSEXP);
- Rcpp::traits::input_parameter< const char& >::type pred_type_(pred_type_SEXP);
- Rcpp::traits::input_parameter< const char& >::type type_oobag_eval_(type_oobag_eval_SEXP);
- rcpp_result_gen = Rcpp::wrap(orsf_oob_permute_vi(x, y, forest, last_eval_stat, time_pred_, f_oobag_eval, pred_type_, type_oobag_eval_));
- return rcpp_result_gen;
-END_RCPP
-}
-// orsf_pred_uni
-arma::mat orsf_pred_uni(List& forest, NumericMatrix& x_new, double time_dbl, char pred_type);
-RcppExport SEXP _aorsf_orsf_pred_uni(SEXP forestSEXP, SEXP x_newSEXP, SEXP time_dblSEXP, SEXP pred_typeSEXP) {
-BEGIN_RCPP
- Rcpp::RObject rcpp_result_gen;
- Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< List& >::type forest(forestSEXP);
- Rcpp::traits::input_parameter< NumericMatrix& >::type x_new(x_newSEXP);
- Rcpp::traits::input_parameter< double >::type time_dbl(time_dblSEXP);
- Rcpp::traits::input_parameter< char >::type pred_type(pred_typeSEXP);
- rcpp_result_gen = Rcpp::wrap(orsf_pred_uni(forest, x_new, time_dbl, pred_type));
+ Rcpp::traits::input_parameter< double >::type cph_eps(cph_epsSEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type cph_iter_max(cph_iter_maxSEXP);
+ rcpp_result_gen = Rcpp::wrap(coxph_fit_exported(x_node, y_node, w_node, method, cph_eps, cph_iter_max));
return rcpp_result_gen;
END_RCPP
}
-// orsf_pred_multi
-arma::mat orsf_pred_multi(List& forest, NumericMatrix& x_new, NumericVector& time_vec, char pred_type);
-RcppExport SEXP _aorsf_orsf_pred_multi(SEXP forestSEXP, SEXP x_newSEXP, SEXP time_vecSEXP, SEXP pred_typeSEXP) {
+// compute_cstat_exported_vec
+double compute_cstat_exported_vec(arma::mat& y, arma::vec& w, arma::vec& p, bool pred_is_risklike);
+RcppExport SEXP _aorsf_compute_cstat_exported_vec(SEXP ySEXP, SEXP wSEXP, SEXP pSEXP, SEXP pred_is_risklikeSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< List& >::type forest(forestSEXP);
- Rcpp::traits::input_parameter< NumericMatrix& >::type x_new(x_newSEXP);
- Rcpp::traits::input_parameter< NumericVector& >::type time_vec(time_vecSEXP);
- Rcpp::traits::input_parameter< char >::type pred_type(pred_typeSEXP);
- rcpp_result_gen = Rcpp::wrap(orsf_pred_multi(forest, x_new, time_vec, pred_type));
+ Rcpp::traits::input_parameter< arma::mat& >::type y(ySEXP);
+ Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP);
+ Rcpp::traits::input_parameter< arma::vec& >::type p(pSEXP);
+ Rcpp::traits::input_parameter< bool >::type pred_is_risklike(pred_is_risklikeSEXP);
+ rcpp_result_gen = Rcpp::wrap(compute_cstat_exported_vec(y, w, p, pred_is_risklike));
return rcpp_result_gen;
END_RCPP
}
-// pd_new_smry
-arma::mat pd_new_smry(List& forest, NumericMatrix& x_new_, IntegerVector& x_cols_, NumericMatrix& x_vals_, NumericVector& probs_, const double time_dbl, char pred_type);
-RcppExport SEXP _aorsf_pd_new_smry(SEXP forestSEXP, SEXP x_new_SEXP, SEXP x_cols_SEXP, SEXP x_vals_SEXP, SEXP probs_SEXP, SEXP time_dblSEXP, SEXP pred_typeSEXP) {
+// compute_cstat_exported_uvec
+double compute_cstat_exported_uvec(arma::mat& y, arma::vec& w, arma::uvec& g, bool pred_is_risklike);
+RcppExport SEXP _aorsf_compute_cstat_exported_uvec(SEXP ySEXP, SEXP wSEXP, SEXP gSEXP, SEXP pred_is_risklikeSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< List& >::type forest(forestSEXP);
- Rcpp::traits::input_parameter< NumericMatrix& >::type x_new_(x_new_SEXP);
- Rcpp::traits::input_parameter< IntegerVector& >::type x_cols_(x_cols_SEXP);
- Rcpp::traits::input_parameter< NumericMatrix& >::type x_vals_(x_vals_SEXP);
- Rcpp::traits::input_parameter< NumericVector& >::type probs_(probs_SEXP);
- Rcpp::traits::input_parameter< const double >::type time_dbl(time_dblSEXP);
- Rcpp::traits::input_parameter< char >::type pred_type(pred_typeSEXP);
- rcpp_result_gen = Rcpp::wrap(pd_new_smry(forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type));
+ Rcpp::traits::input_parameter< arma::mat& >::type y(ySEXP);
+ Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP);
+ Rcpp::traits::input_parameter< arma::uvec& >::type g(gSEXP);
+ Rcpp::traits::input_parameter< bool >::type pred_is_risklike(pred_is_risklikeSEXP);
+ rcpp_result_gen = Rcpp::wrap(compute_cstat_exported_uvec(y, w, g, pred_is_risklike));
return rcpp_result_gen;
END_RCPP
}
-// pd_oob_smry
-arma::mat pd_oob_smry(List& forest, NumericMatrix& x_new_, IntegerVector& x_cols_, NumericMatrix& x_vals_, NumericVector& probs_, const double time_dbl, char pred_type);
-RcppExport SEXP _aorsf_pd_oob_smry(SEXP forestSEXP, SEXP x_new_SEXP, SEXP x_cols_SEXP, SEXP x_vals_SEXP, SEXP probs_SEXP, SEXP time_dblSEXP, SEXP pred_typeSEXP) {
+// compute_logrank_exported
+double compute_logrank_exported(arma::mat& y, arma::vec& w, arma::uvec& g);
+RcppExport SEXP _aorsf_compute_logrank_exported(SEXP ySEXP, SEXP wSEXP, SEXP gSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< List& >::type forest(forestSEXP);
- Rcpp::traits::input_parameter< NumericMatrix& >::type x_new_(x_new_SEXP);
- Rcpp::traits::input_parameter< IntegerVector& >::type x_cols_(x_cols_SEXP);
- Rcpp::traits::input_parameter< NumericMatrix& >::type x_vals_(x_vals_SEXP);
- Rcpp::traits::input_parameter< NumericVector& >::type probs_(probs_SEXP);
- Rcpp::traits::input_parameter< const double >::type time_dbl(time_dblSEXP);
- Rcpp::traits::input_parameter< char >::type pred_type(pred_typeSEXP);
- rcpp_result_gen = Rcpp::wrap(pd_oob_smry(forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type));
+ Rcpp::traits::input_parameter< arma::mat& >::type y(ySEXP);
+ Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP);
+ Rcpp::traits::input_parameter< arma::uvec& >::type g(gSEXP);
+ rcpp_result_gen = Rcpp::wrap(compute_logrank_exported(y, w, g));
return rcpp_result_gen;
END_RCPP
}
-// pd_new_ice
-arma::mat pd_new_ice(List& forest, NumericMatrix& x_new_, IntegerVector& x_cols_, NumericMatrix& x_vals_, NumericVector& probs_, const double time_dbl, char pred_type);
-RcppExport SEXP _aorsf_pd_new_ice(SEXP forestSEXP, SEXP x_new_SEXP, SEXP x_cols_SEXP, SEXP x_vals_SEXP, SEXP probs_SEXP, SEXP time_dblSEXP, SEXP pred_typeSEXP) {
+// cph_scale
+List cph_scale(arma::mat& x, arma::vec& w);
+RcppExport SEXP _aorsf_cph_scale(SEXP xSEXP, SEXP wSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< List& >::type forest(forestSEXP);
- Rcpp::traits::input_parameter< NumericMatrix& >::type x_new_(x_new_SEXP);
- Rcpp::traits::input_parameter< IntegerVector& >::type x_cols_(x_cols_SEXP);
- Rcpp::traits::input_parameter< NumericMatrix& >::type x_vals_(x_vals_SEXP);
- Rcpp::traits::input_parameter< NumericVector& >::type probs_(probs_SEXP);
- Rcpp::traits::input_parameter< const double >::type time_dbl(time_dblSEXP);
- Rcpp::traits::input_parameter< char >::type pred_type(pred_typeSEXP);
- rcpp_result_gen = Rcpp::wrap(pd_new_ice(forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type));
+ Rcpp::traits::input_parameter< arma::mat& >::type x(xSEXP);
+ Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP);
+ rcpp_result_gen = Rcpp::wrap(cph_scale(x, w));
return rcpp_result_gen;
END_RCPP
}
-// pd_oob_ice
-arma::mat pd_oob_ice(List& forest, NumericMatrix& x_new_, IntegerVector& x_cols_, NumericMatrix& x_vals_, NumericVector& probs_, const double time_dbl, char pred_type);
-RcppExport SEXP _aorsf_pd_oob_ice(SEXP forestSEXP, SEXP x_new_SEXP, SEXP x_cols_SEXP, SEXP x_vals_SEXP, SEXP probs_SEXP, SEXP time_dblSEXP, SEXP pred_typeSEXP) {
+// orsf_cpp
+List orsf_cpp(arma::mat& x, arma::mat& y, arma::vec& w, arma::uword tree_type_R, Rcpp::IntegerVector& tree_seeds, Rcpp::List& loaded_forest, Rcpp::RObject lincomb_R_function, Rcpp::RObject oobag_R_function, arma::uword n_tree, arma::uword mtry, bool sample_with_replacement, double sample_fraction, arma::uword vi_type_R, double vi_max_pvalue, double leaf_min_events, double leaf_min_obs, arma::uword split_rule_R, double split_min_events, double split_min_obs, double split_min_stat, arma::uword split_max_cuts, arma::uword split_max_retry, arma::uword lincomb_type_R, double lincomb_eps, arma::uword lincomb_iter_max, bool lincomb_scale, double lincomb_alpha, arma::uword lincomb_df_target, arma::uword lincomb_ties_method, bool pred_mode, arma::uword pred_type_R, arma::vec pred_horizon, bool pred_aggregate, bool oobag, arma::uword oobag_eval_type_R, arma::uword oobag_eval_every, int pd_type_R, std::vector& pd_x_vals, std::vector& pd_x_cols, arma::vec& pd_probs, unsigned int n_thread, bool write_forest, bool run_forest, int verbosity);
+RcppExport SEXP _aorsf_orsf_cpp(SEXP xSEXP, SEXP ySEXP, SEXP wSEXP, SEXP tree_type_RSEXP, SEXP tree_seedsSEXP, SEXP loaded_forestSEXP, SEXP lincomb_R_functionSEXP, SEXP oobag_R_functionSEXP, SEXP n_treeSEXP, SEXP mtrySEXP, SEXP sample_with_replacementSEXP, SEXP sample_fractionSEXP, SEXP vi_type_RSEXP, SEXP vi_max_pvalueSEXP, SEXP leaf_min_eventsSEXP, SEXP leaf_min_obsSEXP, SEXP split_rule_RSEXP, SEXP split_min_eventsSEXP, SEXP split_min_obsSEXP, SEXP split_min_statSEXP, SEXP split_max_cutsSEXP, SEXP split_max_retrySEXP, SEXP lincomb_type_RSEXP, SEXP lincomb_epsSEXP, SEXP lincomb_iter_maxSEXP, SEXP lincomb_scaleSEXP, SEXP lincomb_alphaSEXP, SEXP lincomb_df_targetSEXP, SEXP lincomb_ties_methodSEXP, SEXP pred_modeSEXP, SEXP pred_type_RSEXP, SEXP pred_horizonSEXP, SEXP pred_aggregateSEXP, SEXP oobagSEXP, SEXP oobag_eval_type_RSEXP, SEXP oobag_eval_everySEXP, SEXP pd_type_RSEXP, SEXP pd_x_valsSEXP, SEXP pd_x_colsSEXP, SEXP pd_probsSEXP, SEXP n_threadSEXP, SEXP write_forestSEXP, SEXP run_forestSEXP, SEXP verbositySEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
- Rcpp::traits::input_parameter< List& >::type forest(forestSEXP);
- Rcpp::traits::input_parameter< NumericMatrix& >::type x_new_(x_new_SEXP);
- Rcpp::traits::input_parameter< IntegerVector& >::type x_cols_(x_cols_SEXP);
- Rcpp::traits::input_parameter< NumericMatrix& >::type x_vals_(x_vals_SEXP);
- Rcpp::traits::input_parameter< NumericVector& >::type probs_(probs_SEXP);
- Rcpp::traits::input_parameter< const double >::type time_dbl(time_dblSEXP);
- Rcpp::traits::input_parameter< char >::type pred_type(pred_typeSEXP);
- rcpp_result_gen = Rcpp::wrap(pd_oob_ice(forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type));
+ Rcpp::traits::input_parameter< arma::mat& >::type x(xSEXP);
+ Rcpp::traits::input_parameter< arma::mat& >::type y(ySEXP);
+ Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type tree_type_R(tree_type_RSEXP);
+ Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type tree_seeds(tree_seedsSEXP);
+ Rcpp::traits::input_parameter< Rcpp::List& >::type loaded_forest(loaded_forestSEXP);
+ Rcpp::traits::input_parameter< Rcpp::RObject >::type lincomb_R_function(lincomb_R_functionSEXP);
+ Rcpp::traits::input_parameter< Rcpp::RObject >::type oobag_R_function(oobag_R_functionSEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type n_tree(n_treeSEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type mtry(mtrySEXP);
+ Rcpp::traits::input_parameter< bool >::type sample_with_replacement(sample_with_replacementSEXP);
+ Rcpp::traits::input_parameter< double >::type sample_fraction(sample_fractionSEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type vi_type_R(vi_type_RSEXP);
+ Rcpp::traits::input_parameter< double >::type vi_max_pvalue(vi_max_pvalueSEXP);
+ Rcpp::traits::input_parameter< double >::type leaf_min_events(leaf_min_eventsSEXP);
+ Rcpp::traits::input_parameter< double >::type leaf_min_obs(leaf_min_obsSEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type split_rule_R(split_rule_RSEXP);
+ Rcpp::traits::input_parameter< double >::type split_min_events(split_min_eventsSEXP);
+ Rcpp::traits::input_parameter< double >::type split_min_obs(split_min_obsSEXP);
+ Rcpp::traits::input_parameter< double >::type split_min_stat(split_min_statSEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type split_max_cuts(split_max_cutsSEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type split_max_retry(split_max_retrySEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type lincomb_type_R(lincomb_type_RSEXP);
+ Rcpp::traits::input_parameter< double >::type lincomb_eps(lincomb_epsSEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type lincomb_iter_max(lincomb_iter_maxSEXP);
+ Rcpp::traits::input_parameter< bool >::type lincomb_scale(lincomb_scaleSEXP);
+ Rcpp::traits::input_parameter< double >::type lincomb_alpha(lincomb_alphaSEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type lincomb_df_target(lincomb_df_targetSEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type lincomb_ties_method(lincomb_ties_methodSEXP);
+ Rcpp::traits::input_parameter< bool >::type pred_mode(pred_modeSEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type pred_type_R(pred_type_RSEXP);
+ Rcpp::traits::input_parameter< arma::vec >::type pred_horizon(pred_horizonSEXP);
+ Rcpp::traits::input_parameter< bool >::type pred_aggregate(pred_aggregateSEXP);
+ Rcpp::traits::input_parameter< bool >::type oobag(oobagSEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type oobag_eval_type_R(oobag_eval_type_RSEXP);
+ Rcpp::traits::input_parameter< arma::uword >::type oobag_eval_every(oobag_eval_everySEXP);
+ Rcpp::traits::input_parameter< int >::type pd_type_R(pd_type_RSEXP);
+ Rcpp::traits::input_parameter< std::vector& >::type pd_x_vals(pd_x_valsSEXP);
+ Rcpp::traits::input_parameter< std::vector& >::type pd_x_cols(pd_x_colsSEXP);
+ Rcpp::traits::input_parameter< arma::vec& >::type pd_probs(pd_probsSEXP);
+ Rcpp::traits::input_parameter< unsigned int >::type n_thread(n_threadSEXP);
+ Rcpp::traits::input_parameter< bool >::type write_forest(write_forestSEXP);
+ Rcpp::traits::input_parameter< bool >::type run_forest(run_forestSEXP);
+ Rcpp::traits::input_parameter< int >::type verbosity(verbositySEXP);
+ rcpp_result_gen = Rcpp::wrap(orsf_cpp(x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, sample_with_replacement, sample_fraction, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_x_vals, pd_x_cols, pd_probs, n_thread, write_forest, run_forest, verbosity));
return rcpp_result_gen;
END_RCPP
}
static const R_CallMethodDef CallEntries[] = {
- {"_aorsf_std_setdiff", (DL_FUNC) &_aorsf_std_setdiff, 2},
- {"_aorsf_x_node_scale_exported", (DL_FUNC) &_aorsf_x_node_scale_exported, 2},
- {"_aorsf_leaf_kaplan_testthat", (DL_FUNC) &_aorsf_leaf_kaplan_testthat, 2},
- {"_aorsf_newtraph_cph_testthat", (DL_FUNC) &_aorsf_newtraph_cph_testthat, 6},
- {"_aorsf_lrt_multi_testthat", (DL_FUNC) &_aorsf_lrt_multi_testthat, 6},
- {"_aorsf_oobag_c_harrell_testthat", (DL_FUNC) &_aorsf_oobag_c_harrell_testthat, 2},
- {"_aorsf_ostree_pred_leaf_testthat", (DL_FUNC) &_aorsf_ostree_pred_leaf_testthat, 2},
- {"_aorsf_orsf_fit", (DL_FUNC) &_aorsf_orsf_fit, 30},
- {"_aorsf_orsf_oob_negate_vi", (DL_FUNC) &_aorsf_orsf_oob_negate_vi, 8},
- {"_aorsf_orsf_oob_permute_vi", (DL_FUNC) &_aorsf_orsf_oob_permute_vi, 8},
- {"_aorsf_orsf_pred_uni", (DL_FUNC) &_aorsf_orsf_pred_uni, 4},
- {"_aorsf_orsf_pred_multi", (DL_FUNC) &_aorsf_orsf_pred_multi, 4},
- {"_aorsf_pd_new_smry", (DL_FUNC) &_aorsf_pd_new_smry, 7},
- {"_aorsf_pd_oob_smry", (DL_FUNC) &_aorsf_pd_oob_smry, 7},
- {"_aorsf_pd_new_ice", (DL_FUNC) &_aorsf_pd_new_ice, 7},
- {"_aorsf_pd_oob_ice", (DL_FUNC) &_aorsf_pd_oob_ice, 7},
+ {"_aorsf_coxph_fit_exported", (DL_FUNC) &_aorsf_coxph_fit_exported, 6},
+ {"_aorsf_compute_cstat_exported_vec", (DL_FUNC) &_aorsf_compute_cstat_exported_vec, 4},
+ {"_aorsf_compute_cstat_exported_uvec", (DL_FUNC) &_aorsf_compute_cstat_exported_uvec, 4},
+ {"_aorsf_compute_logrank_exported", (DL_FUNC) &_aorsf_compute_logrank_exported, 3},
+ {"_aorsf_cph_scale", (DL_FUNC) &_aorsf_cph_scale, 2},
+ {"_aorsf_orsf_cpp", (DL_FUNC) &_aorsf_orsf_cpp, 44},
{NULL, NULL, 0}
};
diff --git a/src/Tree.cpp b/src/Tree.cpp
new file mode 100644
index 00000000..4113fb28
--- /dev/null
+++ b/src/Tree.cpp
@@ -0,0 +1,1157 @@
+/*-----------------------------------------------------------------------------
+ This file is part of aorsf.
+ Author: Byron C Jaeger
+ aorsf may be modified and distributed under the terms of the MIT license.
+#----------------------------------------------------------------------------*/
+
+#include
+#include "Tree.h"
+#include "Coxph.h"
+
+#include
+#include
+
+ using namespace arma;
+ using namespace Rcpp;
+
+ namespace aorsf {
+
+ Tree::Tree() :
+ data(0),
+ n_cols_total(0),
+ n_rows_total(0),
+ seed(0),
+ mtry(0),
+ pred_type(DEFAULT_PRED_TYPE),
+ vi_type(VI_NONE),
+ vi_max_pvalue(DEFAULT_ANOVA_VI_PVALUE),
+ // leaf_min_events(DEFAULT_LEAF_MIN_EVENTS),
+ leaf_min_obs(DEFAULT_LEAF_MIN_OBS),
+ split_rule(DEFAULT_SPLITRULE),
+ // split_min_events(DEFAULT_SPLIT_MIN_EVENTS),
+ split_min_obs(DEFAULT_SPLIT_MIN_OBS),
+ split_min_stat(DEFAULT_SPLIT_MIN_STAT),
+ split_max_cuts(DEFAULT_SPLIT_MAX_CUTS),
+ split_max_retry(DEFAULT_SPLIT_MAX_RETRY),
+ lincomb_type(DEFAULT_LINCOMB),
+ lincomb_eps(DEFAULT_LINCOMB_EPS),
+ lincomb_iter_max(DEFAULT_LINCOMB_ITER_MAX),
+ lincomb_scale(DEFAULT_LINCOMB_SCALE),
+ lincomb_alpha(DEFAULT_LINCOMB_ALPHA),
+ lincomb_df_target(0),
+ lincomb_ties_method(DEFAULT_LINCOMB_TIES_METHOD),
+ lincomb_R_function(0),
+ verbosity(0){
+
+ }
+
+ Tree::Tree(arma::uvec& rows_oobag,
+ std::vector& cutpoint,
+ std::vector& child_left,
+ std::vector& coef_values,
+ std::vector& coef_indices,
+ std::vector& leaf_summary) :
+ data(0),
+ n_cols_total(0),
+ n_rows_total(0),
+ seed(0),
+ mtry(0),
+ pred_type(DEFAULT_PRED_TYPE),
+ vi_type(VI_NONE),
+ vi_max_pvalue(DEFAULT_ANOVA_VI_PVALUE),
+ // leaf_min_events(DEFAULT_LEAF_MIN_EVENTS),
+ leaf_min_obs(DEFAULT_LEAF_MIN_OBS),
+ split_rule(DEFAULT_SPLITRULE),
+ // split_min_events(DEFAULT_SPLIT_MIN_EVENTS),
+ split_min_obs(DEFAULT_SPLIT_MIN_OBS),
+ split_min_stat(DEFAULT_SPLIT_MIN_STAT),
+ split_max_cuts(DEFAULT_SPLIT_MAX_CUTS),
+ split_max_retry(DEFAULT_SPLIT_MAX_RETRY),
+ lincomb_type(DEFAULT_LINCOMB),
+ lincomb_eps(DEFAULT_LINCOMB_EPS),
+ lincomb_iter_max(DEFAULT_LINCOMB_ITER_MAX),
+ lincomb_scale(DEFAULT_LINCOMB_SCALE),
+ lincomb_alpha(DEFAULT_LINCOMB_ALPHA),
+ lincomb_df_target(0),
+ lincomb_ties_method(DEFAULT_LINCOMB_TIES_METHOD),
+ lincomb_R_function(0),
+ verbosity(0),
+ rows_oobag(rows_oobag),
+ cutpoint(cutpoint),
+ child_left(child_left),
+ coef_values(coef_values),
+ coef_indices(coef_indices),
+ leaf_summary(leaf_summary){
+
+ this->max_nodes = cutpoint.size()+1;
+ this->max_leaves = cutpoint.size()+1;
+
+ }
+
+
+ void Tree::init(Data* data,
+ int seed,
+ arma::uword mtry,
+ bool sample_with_replacement,
+ double sample_fraction,
+ PredType pred_type,
+ // double leaf_min_events,
+ double leaf_min_obs,
+ VariableImportance vi_type,
+ double vi_max_pvalue,
+ SplitRule split_rule,
+ // double split_min_events,
+ double split_min_obs,
+ double split_min_stat,
+ arma::uword split_max_cuts,
+ arma::uword split_max_retry,
+ LinearCombo lincomb_type,
+ double lincomb_eps,
+ arma::uword lincomb_iter_max,
+ bool lincomb_scale,
+ double lincomb_alpha,
+ arma::uword lincomb_df_target,
+ arma::uword lincomb_ties_method,
+ RObject lincomb_R_function,
+ RObject oobag_R_function,
+ EvalType oobag_eval_type,
+ int verbosity){
+
+ // Initialize random number generator and set seed
+ random_number_generator.seed(seed);
+
+ this->data = data;
+ this->n_cols_total = data->n_cols;
+ this->n_rows_total = data->n_rows;
+ this->seed = seed;
+ this->mtry = mtry;
+ this->sample_with_replacement = sample_with_replacement;
+ this->sample_fraction = sample_fraction;
+ this->pred_type = pred_type;
+ // this->leaf_min_events = leaf_min_events;
+ this->leaf_min_obs = leaf_min_obs;
+ this->vi_type = vi_type;
+ this->vi_max_pvalue = vi_max_pvalue;
+ this->split_rule = split_rule;
+ // this->split_min_events = split_min_events;
+ this->split_min_obs = split_min_obs;
+ this->split_min_stat = split_min_stat;
+ this->split_max_cuts = split_max_cuts;
+ this->split_max_retry = split_max_retry;
+ this->lincomb_type = lincomb_type;
+ this->lincomb_eps = lincomb_eps;
+ this->lincomb_iter_max = lincomb_iter_max;
+ this->lincomb_scale = lincomb_scale;
+ this->lincomb_alpha = lincomb_alpha;
+ this->lincomb_df_target = lincomb_df_target;
+ this->lincomb_ties_method = lincomb_ties_method;
+ this->lincomb_R_function = lincomb_R_function;
+ this->oobag_R_function = oobag_R_function;
+ this->oobag_eval_type = oobag_eval_type;
+ this->verbosity = verbosity;
+
+ }
+
+ void Tree::allocate_oobag_memory(){
+
+ if(rows_oobag.size() == 0){
+ stop("attempting to allocate oob memory with empty rows_oobag");
+ }
+
+ x_oobag = data->x_rows(rows_oobag);
+ y_oobag = data->y_rows(rows_oobag);
+ w_oobag = data->w_subvec(rows_oobag);
+
+ }
+
+ void Tree::resize_leaves(arma::uword new_size){
+
+ leaf_summary.resize(new_size);
+
+ }
+
+ void Tree::sample_rows(){
+
+ uword i, draw, n = data->n_rows;
+
+ // Start with all samples OOB
+ vec w_inbag(n, fill::zeros);
+
+ std::uniform_int_distribution udist_rows(0, n - 1);
+
+ if(sample_with_replacement){
+
+ for (i = 0; i < n; ++i) {
+ draw = udist_rows(random_number_generator);
+ ++w_inbag[draw];
+ }
+
+ } else {
+
+ if(sample_fraction == 1){
+ w_inbag.fill(1);
+ } else {
+
+ uword n_sample = (uword) std::round(n * sample_fraction);
+ for (i = 0; i < n_sample; ++i) {
+ draw = udist_rows(random_number_generator);
+ while(w_inbag[draw] == 1){
+ draw = udist_rows(random_number_generator);
+ }
+ ++w_inbag[draw];
+
+ }
+ }
+
+ }
+
+ // multiply w_inbag by user specified weights.
+ if(data->has_weights){
+ w_inbag = w_inbag % data->w;
+ }
+
+ this->rows_inbag = find(w_inbag > 0);
+ this->rows_oobag = find(w_inbag == 0);
+ // shrink the size of w_inbag from n to n wts > 0
+ this->w_inbag = w_inbag(rows_inbag);
+
+ }
+
+ void Tree::sample_cols(){
+
+ // Start empty
+ this->cols_node.set_size(mtry);
+ uint cols_accepted = 0;
+
+ // Set all to not selected
+ std::vector temp;
+ temp.resize(n_cols_total, false);
+
+ std::uniform_int_distribution udist_cols(0, n_cols_total - 1);
+
+ uword i, draw;
+
+ for (i = 0; i < n_cols_total; ++i) {
+
+ do {draw = udist_cols(random_number_generator); } while (temp[draw]);
+
+ temp[draw] = true;
+
+ if(is_col_splittable(draw)){
+ cols_node[cols_accepted] = draw;
+ cols_accepted++;
+ }
+
+ if(cols_accepted == mtry) break;
+
+ }
+
+ if(cols_accepted < mtry) cols_node.resize(cols_accepted);
+
+ }
+
+ bool Tree::is_col_splittable(uword j){
+
+ uvec::iterator i;
+
+ // initialize as 0 but do not make comparisons until x_first_value
+ // is formally defined at the first instance of status == 1
+ double x_first_value=0;
+
+ bool x_first_undef = true;
+
+ for (i = rows_node.begin(); i != rows_node.end(); ++i) {
+
+ if(x_first_undef){
+
+ x_first_value = x_inbag.at(*i, j);
+ x_first_undef = false;
+
+ } else {
+
+ if(x_inbag.at(*i, j) != x_first_value){
+ return(true);
+ }
+
+ }
+
+ }
+
+ if(VERBOSITY > 1){
+
+ mat x_print = x_inbag.rows(rows_node);
+ Rcout << "Column " << j << " was sampled but ";
+ Rcout << "unique values of column " << j << " are ";
+ Rcout << unique(x_print.col(j)) << std::endl;
+
+ }
+
+ return(false);
+
+ }
+
+ bool Tree::is_node_splittable(uword node_id){
+
+ if(node_id == 0){
+
+ // all inbag observations are in the first node
+ rows_node = regspace(0, n_rows_inbag-1);
+ y_node = y_inbag;
+ w_node = w_inbag;
+ return(true);
+
+ }
+
+ rows_node = find(node_assignments == node_id);
+
+ y_node = y_inbag.rows(rows_node);
+ w_node = w_inbag(rows_node);
+
+ bool result = is_node_splittable_internal();
+
+ return(result);
+
+ }
+
+ bool Tree::is_node_splittable_internal(){
+
+ double n_obs = sum(w_node);
+
+ return(n_obs >= 2*leaf_min_obs &&
+ n_obs >= split_min_obs);
+
+ }
+
+
+ uvec Tree::find_cutpoints(){
+
+ // placeholder with values indicating invalid cps
+ uvec output;
+
+ uword i, j, k;
+
+ uvec::iterator it, it_min, it_max;
+
+ double n_obs = 0;
+
+ if(VERBOSITY > 1){
+ Rcout << "----- finding lower bound for cut-points -----" << std::endl;
+ }
+
+ // stop at end-1 b/c we access it+1 in lincomb_sort
+ for(it = lincomb_sort.begin(); it < lincomb_sort.end()-1; ++it){
+
+ n_obs += w_node[*it];
+
+ // If we want to make the current value of lincomb a cut-point, we need
+ // to make sure the next value of lincomb isn't equal to this current value.
+ // Otherwise, we will have the same value of lincomb in both groups!
+
+ if(lincomb[*it] != lincomb[*(it+1)]){
+
+ if(n_obs >= leaf_min_obs) {
+
+ if(VERBOSITY > 0){
+ Rcout << std::endl;
+ Rcout << "lower cutpoint: " << lincomb(*it) << std::endl;
+ Rcout << " - n_obs, left node: " << n_obs << std::endl;
+ Rcout << std::endl;
+ }
+
+ break;
+
+ }
+
+ }
+
+ }
+
+ it_min = it;
+
+ if(it == lincomb_sort.end()-1) {
+
+ if(VERBOSITY > 1){
+ Rcout << "Could not find a valid cut-point" << std::endl;
+ }
+
+ return(output);
+
+ }
+
+ // j = number of steps we have taken forward in lincomb
+ j = it - lincomb_sort.begin();
+
+ // reset before finding the upper limit
+ n_obs=0;
+
+ if(VERBOSITY > 1){
+ Rcout << "----- finding upper bound for cut-points -----" << std::endl;
+ }
+
+ // stop at beginning+1 b/c we access it-1 in lincomb_sort
+ for(it = lincomb_sort.end()-1; it >= lincomb_sort.begin()+1; --it){
+
+ n_obs += w_node[*it];
+
+ if(lincomb[*it] != lincomb[*(it-1)]){
+
+ if(n_obs >= leaf_min_obs) {
+
+ // the upper cutpoint needs to be one step below the current
+ // it value, because we use x <= cp to determine whether a
+ // value x goes to the left node versus the right node. So,
+ // if it currently points to 3, and the next value down is 2,
+ // then we want to say the cut-point is 2 because then all
+ // values <= 2 will go left, and 3 will go right. This matters
+ // when 3 is the highest value in the vector.
+
+ --it;
+
+ if(VERBOSITY > 0){
+ Rcout << std::endl;
+ Rcout << "upper cutpoint: " << lincomb(*it) << std::endl;
+ Rcout << " - n_obs, right node: " << n_obs << std::endl;
+ Rcout << std::endl;
+ }
+
+ break;
+
+ }
+
+ }
+
+ }
+
+ it_max = it;
+
+ // k = n steps from beginning of sorted lincomb to current it
+ k = it - lincomb_sort.begin();
+
+ if(j > k){
+
+ if(VERBOSITY > 0) {
+ Rcout << "Could not find valid cut-points" << std::endl;
+ }
+
+ return(output);
+
+ }
+
+ // only one valid cutpoint
+ if(j == k){
+
+ output = {j};
+ return(output);
+
+ }
+
+ i = 0;
+ uvec output_middle(k-j);
+
+ for(it = it_min+1;
+ it < it_max; ++it){
+ if(lincomb[*it] != lincomb[*(it+1)]){
+ output_middle[i] = it - lincomb_sort.begin();
+ i++;
+ }
+ }
+
+ output_middle.resize(i);
+
+ uvec output_left = {j};
+ uvec output_right = {k};
+
+ output = join_vert(output_left, output_middle, output_right);
+
+ return(output);
+
+ }
+
+ double Tree::compute_split_score(){
+
+ // default method is to pick one completely at random
+ // (this won't stay the default - it's a placeholder)
+
+ std::normal_distribution ndist_score(0, 1);
+
+ double result = ndist_score(random_number_generator);
+
+ return(result);
+
+ }
+
+ double Tree::split_node(arma::uvec& cuts_all){
+
+ // sample a subset of cutpoints.
+ uvec cuts_sampled;
+
+ if(split_max_cuts >= cuts_all.size()){
+
+ // no need for random sample if there are fewer valid cut-points
+ // than the number of cut-points we planned to sample.
+ cuts_sampled = cuts_all;
+
+ } else { // split_max_cuts < cuts_all.size()
+
+ cuts_sampled.resize(split_max_cuts);
+
+ std::uniform_int_distribution udist_cuts(0, cuts_all.size() - 1);
+
+ // Set all to not selected
+ std::vector temp;
+ temp.resize(cuts_all.size(), false);
+
+ uword draw;
+
+ for (uword i = 0; i < split_max_cuts; ++i) {
+
+ do {draw = udist_cuts(random_number_generator); } while (temp[draw]);
+
+ temp[draw] = true;
+
+ cuts_sampled[i] = draw;
+
+ }
+
+ // important that cut-points are ordered from low to high
+ cuts_sampled = sort(cuts_sampled);
+
+ }
+
+ // initialize grouping for the current node
+ // value of 1 indicates go to right node
+ g_node.ones(lincomb.size());
+
+ uvec::iterator it;
+
+ uword it_start = 0, it_best = 0;
+
+ double stat, stat_best = 0;
+
+ if(verbosity > 3){
+ Rcout << " -- cutpoint (score)" << std::endl;
+ }
+
+ for(it = cuts_sampled.begin(); it != cuts_sampled.end(); ++it){
+
+ // flip node assignments from left to right, up to the next cutpoint
+ g_node.elem(lincomb_sort.subvec(it_start, *it)).fill(0);
+ // compute split statistics with this cut-point
+ stat = compute_split_score();
+ // stat = score_logrank();
+ // update leaderboard
+ if(stat > stat_best) { stat_best = stat; it_best = *it; }
+ // set up next loop run
+ it_start = *it;
+
+ if(verbosity > 3){
+ Rcout << " --- ";
+ Rcout << lincomb.at(lincomb_sort(*it));
+ Rcout << " (" << stat << "), ";
+ Rcout << "N = " << sum(g_node % w_node) << " moving right";
+ Rcout << std::endl;
+ }
+
+ }
+
+ if(verbosity > 3){
+ Rcout << std::endl;
+ Rcout << " -- best stat: " << stat_best;
+ Rcout << ", min to split: " << split_min_stat;
+ Rcout << std::endl;
+ Rcout << std::endl;
+ }
+
+
+ // do not split if best stat < minimum stat
+ if(stat_best < split_min_stat){
+
+ return(R_PosInf);
+
+ }
+
+ // backtrack g_node to be what it was when best it was found
+ if(it_best < it_start){
+ g_node.elem(lincomb_sort.subvec(it_best+1, it_start)).fill(1);
+ }
+
+
+ // return the cut-point from best split
+ return(lincomb[lincomb_sort[it_best]]);
+
+ }
+
+ void Tree::sprout_leaf(uword node_id){
+
+ if(verbosity > 2){
+ Rcout << "-- sprouting node " << node_id << " into a leaf";
+ Rcout << " (N = " << sum(w_node) << ")";
+ Rcout << std::endl;
+ Rcout << std::endl;
+ }
+
+ leaf_summary[node_id] = mean(y_node.col(0));
+
+ }
+
+ double Tree::compute_max_leaves(){
+
+ // find maximum number of leaves for this tree
+ // there are two ways to have maximal tree size:
+ vec max_leaves_2ways = {
+ // 1. every leaf node has exactly leaf_min_obs,
+ n_obs_inbag / leaf_min_obs,
+ // 2. every leaf node has exactly split_min_obs - 1,
+ n_obs_inbag / (split_min_obs - 1)
+ };
+
+ double max_leaves = std::ceil(max(max_leaves_2ways));
+
+ return(max_leaves);
+
+ }
+
+ void Tree::grow(arma::vec* vi_numer,
+ arma::uvec* vi_denom){
+
+ this->vi_numer = vi_numer;
+ this->vi_denom = vi_denom;
+
+ sample_rows();
+
+ // create inbag views of x, y, and w,
+ this->x_inbag = data->x_rows(rows_inbag);
+ this->y_inbag = data->y_rows(rows_inbag);
+
+ this->n_obs_inbag = sum(w_inbag);
+ this->n_rows_inbag = x_inbag.n_rows;
+
+ node_assignments.zeros(n_rows_inbag);
+
+ this->max_leaves = compute_max_leaves();
+ this->max_nodes = (2 * max_leaves) - 1;
+
+ if(verbosity > 2){
+
+ Rcout << "- N obs inbag: " << n_obs_inbag;
+ Rcout << std::endl;
+ Rcout << "- N row inbag: " << n_rows_inbag;
+ Rcout << std::endl;
+ Rcout << "- max nodes: " << max_nodes;
+ Rcout << std::endl;
+ Rcout << "- max leaves: " << max_leaves;
+ Rcout << std::endl;
+ Rcout << std::endl;
+
+
+ }
+
+ // reserve memory for outputs (likely more than we need)
+ cutpoint.resize(max_nodes);
+ child_left.resize(max_nodes);
+ coef_values.resize(max_nodes);
+ coef_indices.resize(max_nodes);
+ // memory for leaves based on corresponding tree type
+ resize_leaves(max_nodes);
+
+ // coordinate the order that nodes are grown.
+ std::vector nodes_open;
+
+ // start node 0
+ nodes_open.push_back(0);
+
+ // nodes to grow in the next run through the do-loop
+ std::vector nodes_queued;
+
+ // reserve space (most we could ever need is max_leaves)
+ nodes_open.reserve(max_leaves);
+ nodes_queued.reserve(max_leaves);
+
+ // number of nodes in the tree
+ uword n_nodes = 0;
+
+ // iterate through nodes to be grown
+ std::vector::iterator node;
+
+ // ID of the left node (node_right = node_left + 1)
+ uword node_left;
+
+ // all possible cut-points for a linear combination
+ uvec cuts_all;
+
+ do{
+
+ for(node = nodes_open.begin(); node != nodes_open.end(); ++node){
+
+ // determine rows in the current node and if it can be split
+ if(!is_node_splittable(*node)){
+ sprout_leaf(*node);
+ continue;
+ }
+
+ uword n_retry = 0;
+
+ // determines if a node is split or sprouted
+ // (split means two new nodes are created)
+ // (sprouted means the node becomes a leaf)
+ for(; ;){
+
+ // repeat until all the retries are spent.
+ n_retry++;
+
+ if(verbosity > 3){
+
+ Rcout << "-- attempting to split node " << *node;
+ Rcout << " (N = " << sum(w_node) << ",";
+ Rcout << " try number " << n_retry << ")";
+ Rcout << std::endl;
+ Rcout << std::endl;
+ }
+
+ sample_cols();
+
+ if(!cols_node.is_empty()){
+
+ x_node = x_inbag(rows_node, cols_node);
+
+ if(verbosity > 3) {
+ print_uvec(cols_node, "columns sampled (showing up to 5)", 5);
+ }
+
+ // beta holds estimates (first item) and variance (second)
+ // for the regression coefficients that created lincomb.
+ // the variances are optional (only used for VI_ANOVA)
+ mat beta;
+
+ lincomb.zeros(x_node.n_rows);
+
+ switch (lincomb_type) {
+
+ case LC_NEWTON_RAPHSON: {
+
+ beta = coxph_fit(x_node, y_node, w_node,
+ lincomb_scale, lincomb_ties_method,
+ lincomb_eps, lincomb_iter_max);
+
+ break;
+
+ }
+
+ case LC_RANDOM_COEFS: {
+
+ beta.set_size(x_node.n_cols, 1);
+
+ std::uniform_real_distribution unif_coef(0.0, 1.0);
+
+ for(uword i = 0; i < x_node.n_cols; ++i){
+ beta.at(i, 0) = unif_coef(random_number_generator);
+ }
+
+ break;
+
+ }
+
+ case LC_GLMNET: {
+
+ NumericMatrix xx = wrap(x_node);
+ NumericMatrix yy = wrap(y_node);
+ NumericVector ww = wrap(w_node);
+
+ // initialize function from tree object
+ // (Functions can't be stored in C++ classes, but RObjects can)
+ Function f_beta = as(lincomb_R_function);
+
+ NumericMatrix beta_R = f_beta(xx, yy, ww,
+ lincomb_alpha,
+ lincomb_df_target);
+
+ beta = mat(beta_R.begin(), beta_R.nrow(), beta_R.ncol(), false);
+
+ break;
+
+ }
+
+ case LC_R_FUNCTION: {
+
+ NumericMatrix xx = wrap(x_node);
+ NumericMatrix yy = wrap(y_node);
+ NumericVector ww = wrap(w_node);
+
+ // initialize function from tree object
+ // (Functions can't be stored in C++ classes, but RObjects can)
+ Function f_beta = as(lincomb_R_function);
+
+ NumericMatrix beta_R = f_beta(xx, yy, ww);
+
+ beta = mat(beta_R.begin(), beta_R.nrow(), beta_R.ncol(), false);
+
+ break;
+
+ }
+
+ } // end switch lincomb_type
+
+ vec beta_est = beta.unsafe_col(0);
+
+ if(verbosity > 3) {
+ print_vec(beta_est, "linear combo weights (showing up to 5)", 5);
+ }
+
+
+ lincomb = x_node * beta_est;
+
+ // sorted in ascending order
+ lincomb_sort = sort_index(lincomb);
+
+ // find all valid cutpoints for lincomb
+ cuts_all = find_cutpoints();
+
+ if(verbosity > 3 && cuts_all.is_empty()){
+
+ Rcout << " -- no cutpoints identified";
+ Rcout << std::endl;
+
+ }
+
+ // empty cuts_all => no valid cutpoints => make leaf or retry
+ if(!cuts_all.is_empty()){
+
+ double cut_point = split_node(cuts_all);
+
+ if(cut_point < R_PosInf){
+
+ if(vi_type == VI_ANOVA && lincomb_type == LC_NEWTON_RAPHSON){
+
+ // only do ANOVA variable importance when
+ // 1. a split of the node is guaranteed
+ // 2. the method used for lincombs allows it
+
+ if(verbosity > 3){
+ Rcout << " -- p-values:" << std::endl;
+ }
+
+ vec beta_var = beta.unsafe_col(1);
+
+ double pvalue;
+
+ for(uword i = 0; i < beta_est.size(); ++i){
+
+ (*vi_denom)[cols_node[i]]++;
+
+ if(beta_est[i] != 0){
+
+ pvalue = R::pchisq(pow(beta_est[i],2)/beta_var[i], 1, false, false);
+
+ if(verbosity > 3){
+
+ Rcout << " --- column " << cols_node[i] << ": ";
+ Rcout << pvalue;
+ if(pvalue < 0.05) Rcout << "*";
+ if(pvalue < 0.01) Rcout << "*";
+ if(pvalue < 0.001) Rcout << "*";
+ if(pvalue < vi_max_pvalue) Rcout << " [+1 to VI numerator]";
+ Rcout << std::endl;
+
+ }
+
+ if(pvalue < vi_max_pvalue){ (*vi_numer)[cols_node[i]]++; }
+
+ }
+
+ }
+
+ if(verbosity > 3){ Rcout << std::endl; }
+
+ }
+
+ // make new nodes if a valid cutpoint was found
+ node_left = n_nodes + 1;
+ n_nodes += 2;
+ // update tree parameters
+ cutpoint[*node] = cut_point;
+ coef_values[*node] = beta_est;
+ coef_indices[*node] = cols_node;
+
+ child_left[*node] = node_left;
+ // re-assign observations in the current node
+ // (note that g_node is 0 if left, 1 if right)
+ node_assignments.elem(rows_node) = node_left + g_node;
+
+ if(verbosity > 2){
+ Rcout << "-- node " << *node << " was split into ";
+ Rcout << "node " << node_left << " (left) and ";
+ Rcout << node_left+1 << " (right)";
+ Rcout << std::endl;
+ Rcout << std::endl;
+ }
+
+ nodes_queued.push_back(node_left);
+ nodes_queued.push_back(node_left + 1);
+ break;
+
+ }
+
+ }
+
+ }
+
+ if(n_retry >= split_max_retry){
+ sprout_leaf(*node);
+ break;
+ }
+
+ }
+
+
+ }
+
+ nodes_open = nodes_queued;
+ nodes_queued.clear();
+
+ } while (nodes_open.size() > 0);
+
+ // don't forget to count the root node
+ n_nodes++;
+
+ cutpoint.resize(n_nodes);
+ child_left.resize(n_nodes);
+ coef_values.resize(n_nodes);
+ coef_indices.resize(n_nodes);
+
+ resize_leaves(n_nodes);
+
+ } // Tree::grow
+
+ void Tree::predict_leaf(Data* prediction_data, bool oobag) {
+
+ pred_leaf.zeros(prediction_data->n_rows);
+
+ // if tree is root node, 0 is the correct leaf prediction
+ if(coef_values.size() == 0) return;
+
+ if(VERBOSITY > 0){
+ Rcout << "---- computing leaf predictions ----" << std::endl;
+ }
+
+ uvec obs_in_node;
+
+ // it iterates over the observations in a node
+ uvec::iterator it;
+
+ // i iterates over nodes, j over observations
+ uword i, j;
+
+ for(i = 0; i < coef_values.size(); i++){
+
+ // if child_left == 0, it's a leaf (no need to find next child)
+ if(child_left[i] != 0){
+
+ if(i == 0 && oobag){
+ obs_in_node = rows_oobag;
+ } else if (i == 0 && !oobag) {
+ obs_in_node = regspace(0, 1, pred_leaf.size()-1);
+ } else {
+ obs_in_node = find(pred_leaf == i);
+ }
+
+ if(obs_in_node.size() > 0){
+
+ lincomb = prediction_data->x_submat(obs_in_node, coef_indices[i]) * coef_values[i];
+
+ it = obs_in_node.begin();
+
+ for(j = 0; j < lincomb.size(); ++j, ++it){
+
+ if(lincomb[j] <= cutpoint[i]) {
+
+ pred_leaf[*it] = child_left[i];
+
+ } else {
+
+ pred_leaf[*it] = child_left[i]+1;
+
+ }
+
+ }
+
+ if(verbosity > 4){
+ uvec in_left = find(pred_leaf == child_left[i]);
+ uvec in_right = find(pred_leaf == child_left[i]+1);
+ Rcout << "No. to node " << child_left[i] << ": ";
+ Rcout << in_left.size() << "; " << std::endl;
+ Rcout << "No. to node " << child_left[i]+1 << ": ";
+ Rcout << in_right.size() << std::endl << std::endl;
+ }
+
+ }
+
+ }
+
+ }
+
+ if(oobag){
+ // If the forest is loaded, only rows_oobag is saved.
+ if(rows_inbag.size() == 0){
+ pred_leaf.elem(find(pred_leaf == 0)).fill(max_nodes);
+ } else {
+ pred_leaf.elem(rows_inbag).fill(max_nodes);
+ }
+
+ }
+
+ }
+
+ double Tree::compute_prediction_accuracy(arma::vec& preds){
+
+ if (oobag_eval_type == EVAL_R_FUNCTION){
+
+ NumericMatrix y_wrap = wrap(y_oobag);
+ NumericVector w_wrap = wrap(w_oobag);
+ NumericVector p_wrap = wrap(preds);
+
+ // initialize function from tree object
+ // (Functions can't be stored in C++ classes, but RObjects can)
+ Function f_oobag = as(oobag_R_function);
+
+ NumericVector result_R = f_oobag(y_wrap, w_wrap, p_wrap);
+
+ return(result_R[0]);
+
+ }
+
+ return(compute_prediction_accuracy_internal(preds));
+
+ }
+
+ void Tree::negate_coef(arma::uword pred_col){
+
+ for(uint j = 0; j < coef_indices.size(); ++j){
+
+ for(uword k = 0; k < coef_indices[j].size(); ++k){
+ if(coef_indices[j][k] == pred_col){
+ coef_values[j][k] *= (-1);
+ }
+ }
+
+ }
+
+ }
+
+ void Tree::compute_oobag_vi(arma::vec* vi_numer, VariableImportance vi_type) {
+
+ allocate_oobag_memory();
+ std::unique_ptr data_oobag { };
+ data_oobag = std::make_unique(x_oobag, y_oobag, w_oobag);
+
+ // using oobag = false for predict b/c data_oobag is already subsetted
+ predict_leaf(data_oobag.get(), false);
+
+ vec pred_values(data_oobag->n_rows);
+
+ for(uword i = 0; i < pred_values.size(); ++i){
+ pred_values[i] = leaf_summary[pred_leaf[i]];
+ }
+
+ // Compute normal prediction accuracy for each tree. Predictions already computed..
+ double accuracy_normal = compute_prediction_accuracy(pred_values);
+
+ if(VERBOSITY > 1){
+ Rcout << "prediction accuracy before noising: ";
+ Rcout << accuracy_normal << std::endl;
+ Rcout << " - mean leaf pred: ";
+ Rcout << mean(conv_to::from(pred_leaf));
+ Rcout << std::endl << std::endl;
+ }
+
+
+ // Randomly permute for all independent variables
+ for (uword pred_col = 0; pred_col < data->get_n_cols(); ++pred_col) {
+
+ // Check whether the i-th variable is used in the tree:
+ bool pred_is_used = false;
+
+ for(uint j = 0; j < coef_indices.size(); ++j){
+ for(uword k = 0; k < coef_indices[j].size(); ++k){
+ if(coef_indices[j][k] == pred_col){
+ pred_is_used = true;
+ break;
+ }
+ }
+ }
+
+ // proceed if the variable is used in the tree, otherwise vi = 0
+ if (pred_is_used) {
+
+ if(vi_type == VI_PERMUTE){
+ // everyone gets the same permutation
+ random_number_generator.seed(seed);
+ data_oobag->permute_col(pred_col, random_number_generator);
+ } else if (vi_type == VI_NEGATE){
+ negate_coef(pred_col);
+ }
+
+ predict_leaf(data_oobag.get(), false);
+
+ for(uword i = 0; i < pred_values.size(); ++i){
+ pred_values[i] = leaf_summary[pred_leaf[i]];
+ }
+
+ double accuracy_permuted = compute_prediction_accuracy(pred_values);
+
+ if(VERBOSITY>1){
+ Rcout << "prediction accuracy after noising " << pred_col << ": ";
+ Rcout << accuracy_permuted << std::endl;
+ Rcout << " - mean leaf pred: ";
+ Rcout << mean(conv_to::from(pred_leaf));
+ Rcout << std::endl << std::endl;
+ }
+
+ double accuracy_difference = accuracy_normal - accuracy_permuted;
+
+ (*vi_numer)[pred_col] += accuracy_difference;
+
+ if(vi_type == VI_PERMUTE){
+ data_oobag->restore_col(pred_col);
+ } else if (vi_type == VI_NEGATE){
+ negate_coef(pred_col);
+ }
+
+ }
+ }
+ }
+
+ void Tree::restore_rows_inbag(arma::uword n_obs) {
+
+ rows_inbag.set_size(n_obs);
+ uword rows_inbag_counter = 0;
+
+ if(rows_oobag[0] != 0){
+ rows_inbag[0] = 0;
+ rows_inbag_counter = 1;
+ }
+
+ for(arma::uword i = 1; i < rows_oobag.size(); i++){
+ if(rows_oobag[i-1]+1 != rows_oobag[i]){
+ for(arma::uword j = rows_oobag[i-1]+1; j < rows_oobag[i]; ++j){
+ rows_inbag[rows_inbag_counter] = j;
+ rows_inbag_counter++;
+ }
+ }
+ }
+
+ if(rows_oobag.back() < n_obs){
+ for(arma::uword j = rows_oobag.back()+1; j < n_obs; ++j){
+ rows_inbag[rows_inbag_counter] = j;
+ rows_inbag_counter++;
+ }
+ }
+
+ rows_inbag.resize(rows_inbag_counter);
+
+ }
+
+
+
+
+ } // namespace aorsf
+
diff --git a/src/Tree.h b/src/Tree.h
new file mode 100644
index 00000000..a8465bf8
--- /dev/null
+++ b/src/Tree.h
@@ -0,0 +1,268 @@
+/*-----------------------------------------------------------------------------
+ This file is part of aorsf.
+ Author: Byron C Jaeger
+ aorsf may be modified and distributed under the terms of the MIT license.
+#----------------------------------------------------------------------------*/
+
+#ifndef TREE_H_
+#define TREE_H_
+
+#include "Data.h"
+#include "globals.h"
+#include "utility.h"
+
+ namespace aorsf {
+
+ class Tree {
+
+ public:
+
+ Tree();
+
+ // Create from loaded forest
+ Tree(arma::uvec& rows_oobag,
+ std::vector& cutpoint,
+ std::vector& child_left,
+ std::vector& coef_values,
+ std::vector& coef_indices,
+ std::vector& leaf_summary);
+
+ virtual ~Tree() = default;
+
+ // deleting the copy constructor
+ Tree(const Tree&) = delete;
+ // deleting the copy assignment operator
+ Tree& operator=(const Tree&) = delete;
+
+ void init(Data* data,
+ int seed,
+ arma::uword mtry,
+ bool sample_with_replacement,
+ double sample_fraction,
+ PredType pred_type,
+ // double leaf_min_events,
+ double leaf_min_obs,
+ VariableImportance vi_type,
+ double vi_max_pvalue,
+ SplitRule split_rule,
+ // double split_min_events,
+ double split_min_obs,
+ double split_min_stat,
+ arma::uword split_max_cuts,
+ arma::uword split_max_retry,
+ LinearCombo lincomb_type,
+ double lincomb_eps,
+ arma::uword lincomb_iter_max,
+ bool lincomb_scale,
+ double lincomb_alpha,
+ arma::uword lincomb_df_target,
+ arma::uword lincomb_ties_method,
+ Rcpp::RObject lincomb_R_function,
+ Rcpp::RObject oobag_R_function,
+ EvalType oobag_eval_type,
+ int verbosity);
+
+
+ virtual void resize_leaves(arma::uword new_size);
+
+ void sample_rows();
+
+ void sample_cols();
+
+ virtual bool is_col_splittable(arma::uword j);
+
+ bool is_node_splittable(arma::uword node_id);
+
+ virtual bool is_node_splittable_internal();
+
+ virtual arma::uvec find_cutpoints();
+
+ virtual double compute_split_score();
+
+ double split_node(arma::uvec& cuts_all);
+
+ virtual void sprout_leaf(arma::uword node_id);
+
+ virtual double compute_max_leaves();
+
+ void grow(arma::vec* vi_numer,
+ arma::uvec* vi_denom);
+
+ void predict_leaf(Data* prediction_data,
+ bool oobag);
+
+ virtual void predict_value(arma::mat* pred_output,
+ arma::vec* pred_denom,
+ PredType pred_type,
+ bool oobag) = 0;
+
+ void negate_coef(arma::uword pred_col);
+
+ void compute_oobag_vi(arma::vec* vi_numer, VariableImportance vi_type);
+
+ // void grow(arma::vec& vi_numer, arma::uvec& vi_denom);
+
+ std::vector& get_coef_indices() {
+ return(coef_indices);
+ }
+
+ arma::uvec& get_rows_oobag() {
+ return(rows_oobag);
+ }
+
+ std::vector& get_coef_values() {
+ return(coef_values);
+ }
+
+ std::vector& get_leaf_summary(){
+ return(leaf_summary);
+ }
+
+ std::vector& get_cutpoint(){
+ return(cutpoint);
+ }
+
+ std::vector& get_child_left(){
+ return(child_left);
+ }
+
+ arma::uvec& get_pred_leaf(){
+ return(pred_leaf);
+ }
+
+
+ protected:
+
+ void allocate_oobag_memory();
+
+ void restore_rows_inbag(arma::uword n_obs);
+
+ // pointers to variable importance in forest
+ arma::vec* vi_numer;
+ arma::uvec* vi_denom;
+
+ // Pointer to original data
+ Data* data;
+
+ arma::uword n_cols_total;
+ arma::uword n_rows_total;
+
+ arma::uword n_rows_inbag;
+
+ double n_obs_inbag;
+ double n_events_inbag;
+
+ double max_nodes;
+ double max_leaves;
+
+
+ // views of data
+ arma::mat x_inbag;
+ arma::mat x_oobag;
+ arma::mat x_node;
+
+ arma::vec x_oobag_restore;
+
+ arma::mat y_inbag;
+ arma::mat y_oobag;
+ arma::mat y_node;
+
+ // the 'w' is short for 'weights'
+ arma::vec w_inbag;
+ arma::vec w_oobag;
+ arma::vec w_node;
+
+ // g_node indicates where observations will go when this node splits
+ // 0 means go down to left node, 1 means go down to right node
+ // the 'g' is short for 'groups'
+ arma::uvec g_node;
+
+ int seed;
+
+ arma::uword mtry;
+
+ bool sample_with_replacement;
+ double sample_fraction;
+
+ // what type of predictions you compute
+ PredType pred_type;
+
+ // variable importance
+ VariableImportance vi_type;
+ double vi_max_pvalue;
+
+ // Random number generator
+ std::mt19937_64 random_number_generator;
+
+ // tree growing members
+ // double leaf_min_events;
+ double leaf_min_obs;
+
+ // node split members
+ SplitRule split_rule;
+ // double split_min_events;
+ double split_min_obs;
+ double split_min_stat;
+ arma::uword split_max_cuts;
+ arma::uword split_max_retry;
+
+ // linear combination members
+ LinearCombo lincomb_type;
+ arma::vec lincomb;
+ arma::uvec lincomb_sort;
+ double lincomb_eps;
+ arma::uword lincomb_iter_max;
+ bool lincomb_scale;
+ double lincomb_alpha;
+ arma::uword lincomb_df_target;
+ arma::uword lincomb_ties_method;
+ Rcpp::RObject lincomb_R_function;
+
+ // allow customization of oobag prediction accuracy
+ Rcpp::RObject oobag_R_function;
+ EvalType oobag_eval_type;
+
+ int verbosity;
+
+ // which rows of data are held out while growing the tree
+ arma::uvec rows_inbag;
+ arma::uvec rows_oobag;
+ arma::uvec rows_node;
+ arma::uvec cols_node;
+
+
+ // predicted leaf node
+ arma::uvec pred_leaf;
+
+ // which node each inbag observation is currently in.
+ arma::uvec node_assignments;
+
+ // cutpoints used to split the nodes
+ std::vector cutpoint;
+
+ // left child nodes (right child is left + 1)
+ std::vector child_left;
+
+ // coefficients for linear combinations;
+ // one row per variable (mtry rows), one column per node
+ // leaf nodes have all coefficients=0
+ std::vector coef_values;
+ // std::vector coef_values;
+
+ // indices of the predictors used by a node
+ std::vector coef_indices;
+ // std::vector coef_indices;
+
+ // leaf values (only in leaf nodes)
+ std::vector leaf_summary;
+
+
+ virtual double compute_prediction_accuracy(arma::vec& preds);
+
+ virtual double compute_prediction_accuracy_internal(arma::vec& preds) = 0;
+
+ };
+
+ } // namespace aorsf
+
+#endif /* TREE_H_ */
diff --git a/src/TreeSurvival.cpp b/src/TreeSurvival.cpp
new file mode 100644
index 00000000..b38ebd02
--- /dev/null
+++ b/src/TreeSurvival.cpp
@@ -0,0 +1,729 @@
+/*-----------------------------------------------------------------------------
+ This file is part of aorsf.
+ Author: Byron C Jaeger
+ aorsf may be modified and distributed under the terms of the MIT license.
+#----------------------------------------------------------------------------*/
+
+#include
+#include "TreeSurvival.h"
+#include "Coxph.h"
+#include "utility.h"
+// #include "NodeSplitStats.h"
+
+ using namespace arma;
+ using namespace Rcpp;
+
+ namespace aorsf {
+
+ TreeSurvival::TreeSurvival() { }
+
+ TreeSurvival::TreeSurvival(double leaf_min_events,
+ double split_min_events,
+ arma::vec* unique_event_times,
+ arma::vec* pred_horizon){
+
+ this->leaf_min_events = leaf_min_events;
+ this->split_min_events = split_min_events;
+ this->unique_event_times = unique_event_times;
+ this->pred_horizon = pred_horizon;
+
+ }
+
+ TreeSurvival::TreeSurvival(arma::uvec& rows_oobag,
+ std::vector& cutpoint,
+ std::vector& child_left,
+ std::vector& coef_values,
+ std::vector& coef_indices,
+ std::vector& leaf_pred_indx,
+ std::vector& leaf_pred_prob,
+ std::vector& leaf_pred_chaz,
+ std::vector& leaf_summary,
+ arma::vec* pred_horizon) :
+ Tree(rows_oobag, cutpoint, child_left, coef_values, coef_indices, leaf_summary),
+ leaf_pred_indx(leaf_pred_indx),
+ leaf_pred_prob(leaf_pred_prob),
+ leaf_pred_chaz(leaf_pred_chaz),
+ pred_horizon(pred_horizon){ }
+
+ void TreeSurvival::resize_leaves(arma::uword new_size) {
+
+ leaf_pred_indx.resize(new_size);
+ leaf_pred_prob.resize(new_size);
+ leaf_pred_chaz.resize(new_size);
+ leaf_summary.resize(new_size);
+
+ }
+
+ double TreeSurvival::compute_max_leaves(){
+
+ n_events_inbag = sum(w_inbag % y_inbag.col(1));
+
+ // find maximum number of leaves for this tree
+ // there are four ways to have maximal tree size:
+ vec max_leaves_4ways = {
+ // 1. every leaf node has exactly leaf_min_obs,
+ n_obs_inbag / leaf_min_obs,
+ // 2. every leaf node has exactly leaf_min_events,
+ n_events_inbag / leaf_min_events,
+ // 3. every leaf node has exactly split_min_obs - 1,
+ n_obs_inbag / (split_min_obs - 1),
+ // 4. every leaf node has exactly split_min_events-1
+ n_events_inbag / (split_min_events - 1)
+ };
+
+ // number of nodes total in binary tree is 2*L - 1,
+ // where L is the number of leaf nodes in the tree.
+ // (can prove by induction)
+ double max_leaves = std::ceil(max(max_leaves_4ways));
+
+ return(max_leaves);
+
+ }
+
+ bool TreeSurvival::is_col_splittable(uword j){
+
+ uvec::iterator i;
+
+ // initialize as 0 but do not make comparisons until x_first_value
+ // is formally defined at the first instance of status == 1
+ double x_first_value=0;
+
+ bool x_first_undef = true;
+
+ for (i = rows_node.begin(); i != rows_node.end(); ++i) {
+
+ // if event occurred for this observation
+ // column is only splittable if X is non-constant among
+ // observations where an event occurred.
+ if(y_inbag.at(*i, 1) == 1){
+
+ if(x_first_undef){
+
+ x_first_value = x_inbag.at(*i, j);
+ x_first_undef = false;
+
+ } else {
+
+ if(x_inbag.at(*i, j) != x_first_value){
+ return(true);
+ }
+
+ }
+
+ }
+
+ }
+
+ if(verbosity > 3){
+
+ mat x_print = x_inbag.rows(rows_node);
+ mat y_print = y_inbag.rows(rows_node);
+
+ uvec rows_event = find(y_print.col(1) == 1);
+ x_print = x_print.rows(rows_event);
+
+ Rcout << " --- Column " << j << " was sampled but ";
+ Rcout << " unique values of column " << j << " are ";
+ Rcout << unique(x_print.col(j)) << std::endl;
+
+ }
+
+ return(false);
+
+ }
+
+ bool TreeSurvival::is_node_splittable_internal(){
+
+ double n_risk = sum(w_node);
+ double n_events = sum(y_node.col(1) % w_node);
+
+ return(n_events >= 2*leaf_min_events &&
+ n_risk >= 2*leaf_min_obs &&
+ n_events >= split_min_events &&
+ n_risk >= split_min_obs);
+
+ }
+
+ uvec TreeSurvival::find_cutpoints(){
+
+ vec y_status = y_node.unsafe_col(1);
+
+ // placeholder with values indicating invalid cps
+ uvec output;
+
+ uword i, j, k;
+
+ uvec::iterator it, it_min, it_max;
+
+ double n_events = 0, n_risk = 0;
+
+ if(VERBOSITY > 1){
+ Rcout << "----- finding lower bound for cut-points -----" << std::endl;
+ }
+
+ // stop at end-1 b/c we access it+1 in lincomb_sort
+ for(it = lincomb_sort.begin(); it < lincomb_sort.end()-1; ++it){
+
+ n_events += y_status[*it] * w_node[*it];
+ n_risk += w_node[*it];
+
+
+ if(VERBOSITY > 2){
+ Rcout << "current value: "<< lincomb(*it) << " -- ";
+ Rcout << "next: "<< lincomb(*(it+1)) << " -- ";
+ Rcout << "N events: " << n_events << " -- ";
+ Rcout << "N risk: " << n_risk << std::endl;
+ }
+
+ // If we want to make the current value of lincomb a cut-point, we need
+ // to make sure the next value of lincomb isn't equal to this current value.
+ // Otherwise, we will have the same value of lincomb in both groups!
+
+ if(lincomb[*it] != lincomb[*(it+1)]){
+
+ if( n_events >= leaf_min_events &&
+ n_risk >= leaf_min_obs ) {
+
+ if(VERBOSITY > 0){
+ Rcout << std::endl;
+ Rcout << "lower cutpoint: " << lincomb(*it) << std::endl;
+ Rcout << " - n_events, left node: " << n_events << std::endl;
+ Rcout << " - n_risk, left node: " << n_risk << std::endl;
+ Rcout << std::endl;
+ }
+
+ break;
+
+ }
+
+ }
+
+ }
+
+ it_min = it;
+
+ if(it == lincomb_sort.end()-1) {
+
+ if(VERBOSITY > 1){
+ Rcout << "Could not find a valid cut-point" << std::endl;
+ }
+
+ return(output);
+
+ }
+
+ // j = number of steps we have taken forward in lincomb
+ j = it - lincomb_sort.begin();
+
+ // reset before finding the upper limit
+ n_events=0, n_risk=0;
+
+ if(VERBOSITY > 1){
+ Rcout << "----- finding upper bound for cut-points -----" << std::endl;
+ }
+
+ // stop at beginning+1 b/c we access it-1 in lincomb_sort
+ for(it = lincomb_sort.end()-1; it >= lincomb_sort.begin()+1; --it){
+
+ n_events += y_status[*it] * w_node[*it];
+ n_risk += w_node[*it];
+
+ if(VERBOSITY > 2){
+ Rcout << "current value: "<< lincomb(*it) << " ---- ";
+ Rcout << "next value: "<< lincomb(*(it-1)) << " ---- ";
+ Rcout << "N events: " << n_events << " ---- ";
+ Rcout << "N risk: " << n_risk << std::endl;
+ }
+
+ if(lincomb[*it] != lincomb[*(it-1)]){
+
+ if( n_events >= leaf_min_events &&
+ n_risk >= leaf_min_obs ) {
+
+ // the upper cutpoint needs to be one step below the current
+ // it value, because we use x <= cp to determine whether a
+ // value x goes to the left node versus the right node. So,
+ // if it currently points to 3, and the next value down is 2,
+ // then we want to say the cut-point is 2 because then all
+ // values <= 2 will go left, and 3 will go right. This matters
+ // when 3 is the highest value in the vector.
+
+ --it;
+
+ if(VERBOSITY > 0){
+ Rcout << std::endl;
+ Rcout << "upper cutpoint: " << lincomb(*it) << std::endl;
+ Rcout << " - n_events, right node: " << n_events << std::endl;
+ Rcout << " - n_risk, right node: " << n_risk << std::endl;
+ Rcout << std::endl;
+ }
+
+ break;
+
+ }
+
+ }
+
+ }
+
+ it_max = it;
+
+ // k = n steps from beginning of sorted lincomb to current it
+ k = it - lincomb_sort.begin();
+
+ if(j > k){
+
+ if(VERBOSITY > 0) {
+ Rcout << "Could not find valid cut-points" << std::endl;
+ }
+
+ return(output);
+
+ }
+
+ // only one valid cutpoint
+ if(j == k){
+
+ output = {j};
+ return(output);
+
+ }
+
+ i = 0;
+ uvec output_middle(k-j);
+
+ for(it = it_min+1;
+ it < it_max; ++it){
+ if(lincomb[*it] != lincomb[*(it+1)]){
+ output_middle[i] = it - lincomb_sort.begin();
+ i++;
+ }
+ }
+
+ output_middle.resize(i);
+
+ uvec output_left = {j};
+ uvec output_right = {k};
+
+ output = join_vert(output_left, output_middle, output_right);
+
+ return(output);
+
+ }
+
+ double TreeSurvival::compute_split_score(){
+
+ double result=0;
+
+ switch (split_rule) {
+
+ case SPLIT_LOGRANK: {
+ result = compute_logrank(y_node, w_node, g_node);
+ break;
+ }
+
+ case SPLIT_CONCORD: {
+ result = compute_cstat(y_node, w_node, g_node, true);
+ break;
+ }
+
+ }
+
+ return(result);
+
+ }
+
+ double TreeSurvival::score_logrank(){
+
+ double
+ n_risk=0,
+ g_risk=0,
+ observed=0,
+ expected=0,
+ V=0,
+ temp1,
+ temp2,
+ n_events;
+
+ vec y_time = y_node.unsafe_col(0);
+ vec y_status = y_node.unsafe_col(1);
+
+ bool break_loop = false;
+
+ uword i = y_node.n_rows-1;
+
+ // breaking condition of outer loop governed by inner loop
+ for (; ;){
+
+ temp1 = y_time[i];
+
+ n_events = 0;
+
+ for ( ; y_time[i] == temp1; i--) {
+
+ n_risk += w_node[i];
+ n_events += y_status[i] * w_node[i];
+ g_risk += g_node[i] * w_node[i];
+ observed += y_status[i] * g_node[i] * w_node[i];
+
+ if(i == 0){
+ break_loop = true;
+ break;
+ }
+
+ }
+
+ // should only do these calculations if n_events > 0,
+ // but in practice its often faster to multiply by 0
+ // versus check if n_events is > 0.
+
+ temp2 = g_risk / n_risk;
+ expected += n_events * temp2;
+
+ // update variance if n_risk > 1 (if n_risk == 1, variance is 0)
+ // definitely check if n_risk is > 1 b/c otherwise divide by 0
+ if (n_risk > 1){
+ temp1 = n_events * temp2 * (n_risk-n_events) / (n_risk-1);
+ V += temp1 * (1 - temp2);
+ }
+
+ if(break_loop) break;
+
+ }
+
+ return(pow(expected-observed, 2) / V);
+
+ }
+
+ void TreeSurvival::sprout_leaf(uword node_id){
+
+ if(verbosity > 2){
+ Rcout << "-- sprouting node " << node_id << " into a leaf";
+ Rcout << " (N = " << sum(w_node) << ")";
+ Rcout << std::endl;
+ Rcout << std::endl;
+ }
+
+ // reserve as much size as could be needed (probably more)
+ mat leaf_data(y_node.n_rows, 3);
+
+ uword person = 0;
+
+ // find the first unique event time
+ while(y_node.at(person, 1) == 0 && person < y_node.n_rows){
+ person++;
+ }
+
+ // person corresponds to first event or last censor time
+ leaf_data.at(0, 0) = y_node.at(person, 0);
+
+ // if no events in this node:
+ // (TODO: should this case even occur? consider removing)
+ if(person == y_node.n_rows){
+
+ vec temp_surv(1, arma::fill::ones);
+ vec temp_chf(1, arma::fill::zeros);
+
+ leaf_pred_indx[node_id] = leaf_data.col(0);
+ leaf_pred_prob[node_id] = temp_surv;
+ leaf_pred_chaz[node_id] = temp_chf;
+ leaf_summary[node_id] = 0.0;
+
+ return;
+
+ }
+
+ double temp_time = y_node.at(person, 0);
+
+ uword i = 1;
+
+ // find the rest of the unique event times
+ for( ; person < y_node.n_rows; person++){
+
+ if(temp_time != y_node.at(person, 0) && y_node.at(person, 1) == 1){
+
+ leaf_data.at(i, 0) = y_node.at(person,0);
+ temp_time = y_node.at(person, 0);
+ i++;
+
+ }
+
+ }
+
+ leaf_data.resize(i, 3);
+
+ // reset for kaplan meier loop
+ person = 0; i = 0;
+ double n_risk = sum(w_node);
+ double temp_surv = 1.0;
+ double temp_haz = 0.0;
+
+ do {
+
+ double n_events = 0;
+ double n_risk_sub = 0;
+ temp_time = y_node.at(person, 0);
+
+ while(y_node.at(person, 0) == temp_time){
+
+ n_risk_sub += w_node.at(person);
+ n_events += y_node.at(person, 1) * w_node.at(person);
+
+ if(person == y_node.n_rows-1) break;
+
+ person++;
+
+ }
+
+ // only do km if a death was observed
+
+ if(n_events > 0){
+
+ temp_surv = temp_surv * (n_risk - n_events) / n_risk;
+
+ temp_haz = temp_haz + n_events / n_risk;
+
+ leaf_data.at(i, 1) = temp_surv;
+ leaf_data.at(i, 2) = temp_haz;
+ i++;
+
+ }
+
+ n_risk -= n_risk_sub;
+
+ } while (i < leaf_data.n_rows);
+
+
+ if(verbosity > 3){
+ mat tmp_mat = join_horiz(y_node, w_node);
+ print_mat(tmp_mat, "time & status & weights in this node", 10, 10);
+ print_mat(leaf_data, "leaf_data (showing up to 5 rows)", 5, 5);
+ }
+
+ leaf_pred_indx[node_id] = leaf_data.col(0);
+ leaf_pred_prob[node_id] = leaf_data.col(1);
+ leaf_pred_chaz[node_id] = leaf_data.col(2);
+ leaf_summary[node_id] = compute_mortality(leaf_data);
+
+ }
+
+ double TreeSurvival::compute_mortality(arma::mat& leaf_data){
+
+ double result = 0;
+ uword i=0, j=0;
+
+ for( ; i < (*unique_event_times).size(); i++){
+
+ if((*unique_event_times)[i] >= leaf_data.at(j, 0) &&
+ j < (leaf_data.n_rows-1)) {j++;}
+
+ result += leaf_data.at(j, 2);
+
+ }
+
+ return(result);
+
+ }
+
+ void TreeSurvival::predict_value(arma::mat* pred_output,
+ arma::vec* pred_denom,
+ PredType pred_type,
+ bool oobag){
+
+ uvec pred_leaf_sort = sort_index(pred_leaf, "ascend");
+
+ uvec::iterator it = pred_leaf_sort.begin();
+
+ if(verbosity > 2){
+ uvec tmp_uvec = find(pred_leaf < max_nodes);
+ Rcout << " -- N preds expected: " << tmp_uvec.size() << std::endl;
+ }
+
+ uword leaf_id = pred_leaf[*it];
+
+ // default for risk or survival at time 0
+ double pred_t0 = 1;
+
+ // default otherwise
+ if (pred_type == PRED_CHAZ ||
+ pred_type == PRED_MORTALITY) {
+ pred_t0 = 0;
+ }
+
+ uword i, j;
+
+ uword n_preds_made = 0;
+
+ vec leaf_times, leaf_values;
+
+ vec temp_vec((*pred_horizon).size());
+
+ double temp_dbl = pred_t0;
+ bool break_loop = false;
+
+ for(; ;) {
+
+
+ // copies of leaf data using same aux memory
+ leaf_times = vec(leaf_pred_indx[leaf_id].begin(),
+ leaf_pred_indx[leaf_id].size(),
+ false);
+
+ switch (pred_type) {
+
+ case PRED_RISK: case PRED_SURVIVAL: {
+
+ leaf_values = vec(leaf_pred_prob[leaf_id].begin(),
+ leaf_pred_prob[leaf_id].size(),
+ false);
+
+ break;
+
+ }
+
+ case PRED_CHAZ: {
+
+ leaf_values = vec(leaf_pred_chaz[leaf_id].begin(),
+ leaf_pred_chaz[leaf_id].size(),
+ false);
+
+ break;
+
+ }
+
+ case PRED_MORTALITY: {
+
+ temp_vec.fill(leaf_summary[leaf_id]);
+
+ break;
+
+ }
+
+ default:
+ Rcout << "Invalid pred type; R will crash";
+ break;
+
+ }
+
+ // don't reset i in the loop b/c leaf_times ascend
+ i = 0;
+
+ if(pred_type != PRED_MORTALITY){
+
+ for(j = 0; j < (*pred_horizon).size(); j++){
+
+ // t is the current prediction time
+ double t = (*pred_horizon)[j];
+
+ // if t < t', where t' is the max time in this leaf,
+ // then we may find a time t* such that t* < t < t'.
+ // If so, prediction should be anchored to t*.
+ // But, there may be multiple t* < t, and we want to
+ // find the largest t* that is < t, so we find the
+ // first t** > t and assign t* to be whatever came
+ // right before t**.
+ if(t < leaf_times.back()){
+
+ for(; i < leaf_times.size(); i++){
+
+ // we found t**
+ if (leaf_times[i] > t){
+
+ if(i == 0)
+ // first leaf event occurred after prediction time
+ temp_dbl = pred_t0;
+ else
+ // t* is the time value just before t**, so use i-1
+ temp_dbl = leaf_values[i-1];
+
+ break;
+
+ } else if (leaf_times[i] == t){
+ // pred_horizon just happens to equal a leaf time
+ temp_dbl = leaf_values[i];
+
+ break;
+
+ }
+
+ }
+
+ } else {
+ // if t > t' use the last recorded prediction
+ temp_dbl = leaf_values.back();
+
+ }
+
+ temp_vec[j] = temp_dbl;
+
+ }
+
+ }
+
+ if(pred_type == PRED_RISK) temp_vec = 1 - temp_vec;
+
+ (*pred_output).row(*it) += temp_vec.t();
+ n_preds_made++;
+ if(oobag) (*pred_denom)[*it]++;
+
+ // Rcout << "npreds: " << n_preds_made << ", ";
+ // Rcout << "*it: " << (*it) << std::endl;
+
+ // in case the last obs has a unique leaf assignment
+ if(it == pred_leaf_sort.end()-1) break;
+
+ for(; ;){
+
+ ++it;
+ if (it == pred_leaf_sort.end()-1){
+ // we've reached the final value of pred_leaf
+ // check to see if it's the same leaf as the obs before:
+ if (leaf_id == pred_leaf[*it]){
+ // if it is, add the value to the pred_output, and be done
+ (*pred_output).row(*it) += temp_vec.t();
+ n_preds_made++;
+ if(oobag) (*pred_denom)[*it]++;
+ break_loop = true;
+ break;
+ }
+
+ }
+
+ if(leaf_id != pred_leaf[*it]) break;
+
+ (*pred_output).row(*it) += temp_vec.t();
+ n_preds_made++;
+ if(oobag) (*pred_denom)[*it]++;
+
+ // Rcout << "npreds: " << n_preds_made << ", ";
+ // Rcout << "*it (inner loop): " << (*it) << std::endl;
+
+ }
+
+ if(break_loop) break;
+
+ leaf_id = pred_leaf(*it);
+
+ // case 3: we've finished out-of-bag predictions
+ if(leaf_id == max_nodes) break;
+
+ }
+
+ if(verbosity > 2){
+ Rcout << " -- N preds made: " << n_preds_made;
+ Rcout << std::endl;
+ Rcout << std::endl;
+ }
+
+
+ }
+
+ double TreeSurvival::compute_prediction_accuracy_internal(arma::vec& preds){
+
+ return compute_cstat(y_oobag, w_oobag, preds, true);
+
+ }
+
+
+ } // namespace aorsf
+
diff --git a/src/TreeSurvival.h b/src/TreeSurvival.h
new file mode 100644
index 00000000..0306616c
--- /dev/null
+++ b/src/TreeSurvival.h
@@ -0,0 +1,96 @@
+/*-----------------------------------------------------------------------------
+ This file is part of aorsf.
+ Author: Byron C Jaeger
+ aorsf may be modified and distributed under the terms of the MIT license.
+#----------------------------------------------------------------------------*/
+
+#ifndef TREESURVIVAL_H_
+#define TREESURVIVAL_H_
+
+
+#include "Data.h"
+#include "globals.h"
+#include "Tree.h"
+
+ namespace aorsf {
+
+ class TreeSurvival: public Tree {
+
+ public:
+
+ TreeSurvival();
+
+ TreeSurvival(const TreeSurvival&) = delete;
+ TreeSurvival& operator=(const TreeSurvival&) = delete;
+
+ TreeSurvival(double leaf_min_events,
+ double split_min_events,
+ arma::vec* unique_event_times,
+ arma::vec* pred_horizon);
+
+ TreeSurvival(arma::uvec& rows_oobag,
+ std::vector& cutpoint,
+ std::vector& child_left,
+ std::vector& coef_values,
+ std::vector& coef_indices,
+ std::vector& leaf_pred_indx,
+ std::vector& leaf_pred_prob,
+ std::vector& leaf_pred_chaz,
+ std::vector& leaf_summary,
+ arma::vec* pred_horizon);
+
+ double compute_max_leaves() override;
+
+ void resize_leaves(arma::uword new_size) override;
+
+ bool is_col_splittable(arma::uword j) override;
+
+ bool is_node_splittable_internal() override;
+
+ arma::uvec find_cutpoints() override;
+
+ double compute_split_score() override;
+
+ double score_logrank();
+
+ double compute_mortality(arma::mat& leaf_data);
+
+ void sprout_leaf(uword node_id) override;
+
+ void predict_value(arma::mat* pred_output,
+ arma::vec* pred_denom,
+ PredType pred_type,
+ bool oobag) override;
+
+ std::vector& get_leaf_pred_indx(){
+ return(leaf_pred_indx);
+ }
+
+ std::vector& get_leaf_pred_prob(){
+ return(leaf_pred_prob);
+ }
+
+ std::vector& get_leaf_pred_chaz(){
+ return(leaf_pred_chaz);
+ }
+
+ double compute_prediction_accuracy_internal(arma::vec& preds) override;
+
+ std::vector leaf_pred_indx;
+ std::vector leaf_pred_prob;
+ std::vector leaf_pred_chaz;
+
+ // pointer to event times in forest
+ arma::vec* unique_event_times;
+
+ // prediction times
+ arma::vec* pred_horizon;
+
+ double leaf_min_events;
+ double split_min_events;
+
+ };
+
+ } // namespace aorsf
+
+#endif /* TREESURVIVAL_H_ */
diff --git a/src/globals.h b/src/globals.h
new file mode 100644
index 00000000..d865270c
--- /dev/null
+++ b/src/globals.h
@@ -0,0 +1,108 @@
+/*-----------------------------------------------------------------------------
+ This file is part of aorsf.
+ Author: Byron C Jaeger
+ aorsf may be modified and distributed under the terms of the MIT license.
+#----------------------------------------------------------------------------*/
+
+#ifndef GLOBALS_H_
+#define GLOBALS_H_
+
+ namespace aorsf {
+
+ typedef unsigned int uint;
+
+ // Tree types
+ enum TreeType {
+ TREE_CLASSIFICATION = 1,
+ TREE_REGRESSION = 2,
+ TREE_SURVIVAL = 3,
+ TREE_PROBABILITY = 4
+ };
+
+ // Variable importance
+ enum VariableImportance {
+ VI_NONE = 0,
+ VI_NEGATE = 1,
+ VI_PERMUTE = 2,
+ VI_ANOVA = 3
+ };
+
+ // Split mode
+ enum SplitRule {
+ SPLIT_LOGRANK = 1,
+ SPLIT_CONCORD = 2
+ };
+
+ enum EvalType {
+ EVAL_NONE = 0,
+ EVAL_CONCORD = 1,
+ EVAL_R_FUNCTION = 2
+ };
+
+ enum PartialDepType {
+ PD_NONE = 0,
+ PD_SUMMARY = 1,
+ PD_ICE = 2
+ };
+
+ // Linear combination method
+ enum LinearCombo {
+ LC_NEWTON_RAPHSON = 1,
+ LC_RANDOM_COEFS = 2,
+ LC_GLMNET = 3,
+ LC_R_FUNCTION = 4
+ };
+
+ // Prediction type
+ enum PredType {
+ PRED_NONE = 0,
+ PRED_RISK = 1,
+ PRED_SURVIVAL = 2,
+ PRED_CHAZ = 3,
+ PRED_MORTALITY = 4,
+ PRED_MEAN = 5,
+ PRED_PROBABILITY = 6,
+ PRED_CLASS = 7,
+ PRED_TERMINAL_NODES = 8
+ };
+
+ // Default values
+ const int DEFAULT_N_TREE = 500;
+ const int DEFAULT_N_THREADS = 1;
+
+ const VariableImportance DEFAULT_IMPORTANCE = VI_NONE;
+
+ const double DEFAULT_SPLIT_MAX_RETRY = 1;
+
+
+ const double DEFAULT_LEAF_MIN_EVENTS = 1;
+ const double DEFAULT_LEAF_MIN_OBS = 5;
+
+ const SplitRule DEFAULT_SPLITRULE = SPLIT_LOGRANK;
+ const double DEFAULT_SPLIT_MIN_EVENTS = 5;
+ const double DEFAULT_SPLIT_MIN_OBS = 10;
+ const double DEFAULT_SPLIT_MIN_STAT = 3.84;
+
+ const arma::uword DEFAULT_SPLIT_MAX_CUTS = 5;
+ const arma::uword DEFAULT_MAX_RETRY = 3;
+
+ const LinearCombo DEFAULT_LINCOMB = LC_NEWTON_RAPHSON;
+ const double DEFAULT_LINCOMB_EPS = 1e-9;
+ const arma::uword DEFAULT_LINCOMB_ITER_MAX = 20;
+ const bool DEFAULT_LINCOMB_SCALE = true;
+ const double DEFAULT_LINCOMB_ALPHA = 0.5;
+ const arma::uword DEFAULT_LINCOMB_TIES_METHOD = 1;
+
+ const double DEFAULT_ANOVA_VI_PVALUE = 0.01;
+
+ const PredType DEFAULT_PRED_TYPE = PRED_RISK;
+
+ const int VERBOSITY = 0;
+
+ // Interval to print progress in seconds
+ const double STATUS_INTERVAL = 1.0;
+
+
+ } // namespace aorsf
+
+#endif /* GLOBALS_H_ */
diff --git a/src/orsf.cpp b/src/orsf.cpp
index a21e3568..94bf2fe1 100644
--- a/src/orsf.cpp
+++ b/src/orsf.cpp
@@ -1,4113 +1,4111 @@
-#include
-#include
-
-// [[Rcpp::depends(RcppArmadillo)]]
-
-
-using namespace Rcpp;
-using namespace arma;
-
-// ----------------------------------------------------------------------------
-// ---------------------------- global parameters -----------------------------
-// ----------------------------------------------------------------------------
-
-// special note: dont change these doubles to uword,
-// even though some of them could be uwords;
-// operations involving uwords and doubles are not
-// straightforward and may break the routine.
-// also: double + uword is slower than double + double.
-
-double
- weight_avg,
- weight_events,
- w_node_sum,
- denom_events,
- denom,
- cph_eps,
- // the n_ variables could be integers but it
- // is safer and faster when they are doubles
- n_events,
- n_events_total,
- n_events_right,
- n_events_left,
- n_risk,
- n_risk_right,
- n_risk_left,
- n_risk_sub,
- g_risk,
- temp1,
- temp2,
- temp3,
- halving,
- stat_current,
- stat_best,
- w_node_person,
- xb,
- risk,
- loglik,
- cutpoint,
- observed,
- expected,
- V,
- pred_t0,
- leaf_min_obs,
- leaf_min_events,
- split_min_events,
- split_min_obs,
- split_min_stat,
- time_pred,
- ll_second,
- ll_init,
- net_alpha;
-
-int
- // verbose=0,
- max_retry,
- n_retry,
- tree,
- mtry_int,
- net_df_target,
- oobag_eval_every;
-
-char
- type_beta,
- type_oobag_eval,
- oobag_pred_type,
- oobag_importance_type,
- pred_type_dflt = 'S';
-
-// armadillo unsigned integers
-uword
- i,
- j,
- k,
- iter,
- mtry,
- mtry_temp,
- person,
- person_leaf,
- person_ref_index,
- n_vars,
- n_rows,
- cph_method,
- cph_iter_max,
- n_split,
- nodes_max_guess,
- nodes_max_true,
- n_cols_to_sample,
- nn_left,
- leaf_node_counter,
- leaf_node_index_counter,
- leaf_node_col,
- oobag_eval_counter;
-
-bool
- break_loop, // a delayed break statement
- oobag_pred,
- oobag_importance,
- use_tree_seed,
- cph_do_scale;
-
-// armadillo vectors (doubles)
-vec
- vec_temp,
- times_pred,
- eval_oobag,
- node_assignments,
- nodes_grown,
- surv_pvec,
- surv_pvec_output,
- denom_pred,
- beta_current,
- beta_new,
- beta_fit,
- vi_pval_numer,
- vi_pval_denom,
- cutpoints,
- w_input,
- w_inbag,
- w_user,
- w_node,
- group,
- u,
- a,
- a2,
- XB,
- Risk;
-
-// armadillo unsigned integer vectors
-uvec
- iit_vals,
- jit_vals,
- rows_inbag,
- rows_oobag,
- rows_node,
- rows_leaf,
- rows_node_combined,
- cols_to_sample_01,
- cols_to_sample,
- cols_node,
- leaf_node_index,
- nodes_to_grow,
- nodes_to_grow_next,
- obs_in_node,
- children_left,
- leaf_pred;
-
-// armadillo iterators for unsigned integer vectors
-uvec::iterator
- iit,
- iit_best,
- jit,
- node;
-
-// armadillo matrices (doubles)
-mat
- x_input,
- x_transforms,
- y_input,
- x_inbag,
- y_inbag,
- x_node,
- y_node,
- x_pred,
- // x_mean,
- vmat,
- cmat,
- cmat2,
- betas,
- leaf_node,
- leaf_nodes,
- surv_pmat;
-
-umat
- col_indices,
- leaf_indices;
-
-cube
- surv_pcube;
-
-List ostree;
-
-NumericMatrix
- beta_placeholder,
- xx,
- yy;
-
-CharacterVector yy_names = CharacterVector::create("time","status");
-
-NumericVector ww;
-
-Environment base_env("package:base");
-
-Function set_seed_r = base_env["set.seed"];
-
-// Set difference for arma vectors
-//
-// @description the same as setdiff() in R
-//
-// @param x first vector
-// @param y second vector
-//
-// [[Rcpp::export]]
-arma::uvec std_setdiff(arma::uvec& x, arma::uvec& y) {
-
- std::vector a = conv_to< std::vector >::from(sort(x));
- std::vector b = conv_to< std::vector >::from(sort(y));
- std::vector out;
-
- std::set_difference(a.begin(), a.end(),
- b.begin(), b.end(),
- std::inserter(out, out.end()));
-
- return conv_to::from(out);
-
-}
-
-// ----------------------------------------------------------------------------
-// ---------------------------- scaling functions -----------------------------
-// ----------------------------------------------------------------------------
-
-// scale observations in predictor matrix
-//
-// @description this scales inputs in the same way as
-// the survival::coxph() function. The main reasons we do this
-// are to avoid exponential overflow and to prevent the scale
-// of inputs from impacting the estimated beta coefficients.
-// E.g., you can try multiplying numeric inputs by 100 prior
-// to calling orsf() with orsf_control_fast(do_scale = FALSE)
-// and you will see that you get back a different forest.
-//
-// @param x_node matrix of predictors
-// @param w_node replication weights
-// @param x_transforms matrix used to store the means and scales
-//
-// @return modified x_node and x_transform filled with values
-//
-void x_node_scale(){
-
- // set aside memory for outputs
- // first column holds the mean values
- // second column holds the scale values
-
- x_transforms.zeros(n_vars, 2);
- vec means = x_transforms.unsafe_col(0); // Reference to column 1
- vec scales = x_transforms.unsafe_col(1); // Reference to column 2
-
- w_node_sum = sum(w_node);
-
- for(i = 0; i < n_vars; i++) {
-
- means.at(i) = sum( w_node % x_node.col(i) ) / w_node_sum;
-
- x_node.col(i) -= means.at(i);
-
- scales.at(i) = sum(w_node % abs(x_node.col(i)));
-
- if(scales(i) > 0)
- scales.at(i) = w_node_sum / scales.at(i);
- else
- scales.at(i) = 1.0; // rare case of constant covariate;
-
- x_node.col(i) *= scales.at(i);
-
- }
-
-}
-
-// same as above function, but just the means
-// (currently not used)
-void x_node_means(){
-
- x_transforms.zeros(n_vars, 1);
- w_node_sum = sum(w_node);
-
- for(i = 0; i < n_vars; i++) {
-
- x_transforms.at(i, 0) = sum( w_node % x_node.col(i) ) / w_node_sum;
-
- }
-
-}
-
-// Same as x_node_scale, but this can be called from R
-// [[Rcpp::export]]
-List x_node_scale_exported(NumericMatrix& x_,
- NumericVector& w_){
-
- x_node = mat(x_.begin(), x_.nrow(), x_.ncol(), false);
- w_node = vec(w_.begin(), w_.length(), false);
- n_vars = x_node.n_cols;
-
- x_node_scale();
-
- return(
- List::create(
- _["x_scaled"] = x_node,
- _["x_transforms"] = x_transforms
- )
- );
-
-}
-
-// ----------------------------------------------------------------------------
-// -------------------------- leaf_surv functions -----------------------------
-// ----------------------------------------------------------------------------
-
-// Create kaplan-meier survival curve in leaf node
-//
-// @description Modifies leaf_nodes by adding data from the current node,
-// where the current node is one that is too small to be split and will
-// be converted to a leaf.
-//
-// @param y the outcome matrix in the current leaf
-// @param w the weights vector in the current leaf
-// @param leaf_indices a matrix that indicates where leaf nodes are
-// inside of leaf_nodes. leaf_indices has three columns:
-// - first column: the id for the leaf
-// - second column: starting row for the leaf
-// - third column: ending row for the leaf
-// @param leaf_node_index_counter keeps track of where we are in leaf_node
-// @param leaf_node_counter keeps track of which leaf node we are in
-// @param leaf_nodes a matrix with three columns:
-// - first column: time
-// - second column: survival probability
-// - third column: cumulative hazard
-
-void leaf_kaplan(const arma::mat& y,
- const arma::vec& w){
-
- leaf_indices(leaf_node_index_counter, 1) = leaf_node_counter;
- i = leaf_node_counter;
-
- // find the first unique event time
- person = 0;
-
- while(y.at(person, 1) == 0){
- person++;
- }
-
- // now person corresponds to the first event time
- leaf_nodes.at(i, 0) = y.at(person, 0); // see above
- temp2 = y.at(person, 0);
-
- i++;
-
- // find the rest of the unique event times
- for( ; person < y.n_rows; person++){
-
- if(temp2 != y.at(person, 0) && y.at(person, 1) == 1){
-
- leaf_nodes.at(i, 0) = y.at(person,0);
- temp2 = y.at(person, 0);
- i++;
-
- }
-
- }
-
- // reset for kaplan meier loop
- n_risk = sum(w);
- person = 0;
- temp1 = 1.0;
- temp3 = 0.0;
-
- do {
-
- n_events = 0;
- n_risk_sub = 0;
- temp2 = y.at(person, 0);
-
- while(y.at(person, 0) == temp2){
-
- n_risk_sub += w.at(person);
- n_events += y.at(person, 1) * w.at(person);
-
- if(person == y.n_rows-1) break;
-
- person++;
-
- }
-
- // only do km if a death was observed
-
- if(n_events > 0){
-
- temp1 = temp1 * (n_risk - n_events) / n_risk;
-
- temp3 = temp3 + n_events / n_risk;
-
- leaf_nodes.at(leaf_node_counter, 1) = temp1;
- leaf_nodes.at(leaf_node_counter, 2) = temp3;
- leaf_node_counter++;
-
- }
-
- n_risk -= n_risk_sub;
-
- } while (leaf_node_counter < i);
-
-
- leaf_indices(leaf_node_index_counter, 2) = leaf_node_counter-1;
- leaf_node_index_counter++;
-
- if(leaf_node_index_counter >= leaf_indices.n_rows){
- leaf_indices.insert_rows(leaf_indices.n_rows, 10);
- }
-
-}
-
-// Same as above, but this function can be called from R and is
-// used to run tests with testthat (hence the name). Note: this
-// needs to be updated to include CHF, which was added to the
-// function above recently.
-// [[Rcpp::export]]
-arma::mat leaf_kaplan_testthat(const arma::mat& y,
- const arma::vec& w){
-
-
- leaf_nodes.set_size(y.n_rows, 3);
- leaf_node_counter = 0;
-
- // find the first unique event time
- person = 0;
-
- while(y.at(person, 1) == 0){
- person++;
- }
-
- // now person corresponds to the first event time
- leaf_nodes.at(leaf_node_counter, 0) = y.at(person, 0); // see above
- temp2 = y.at(person, 0);
-
- leaf_node_counter++;
-
- // find the rest of the unique event times
- for( ; person < y.n_rows; person++){
-
- if(temp2 != y.at(person, 0) && y.at(person, 1) == 1){
-
- leaf_nodes.at(leaf_node_counter, 0) = y.at(person,0);
- temp2 = y.at(person, 0);
- leaf_node_counter++;
-
- }
-
- }
-
-
- // reset for kaplan meier loop
- i = leaf_node_counter;
- n_risk = sum(w);
- person = 0;
- temp1 = 1.0;
- leaf_node_counter = 0;
-
-
- do {
-
- n_events = 0;
- n_risk_sub = 0;
- temp2 = y.at(person, 0);
-
- while(y.at(person, 0) == temp2){
-
- n_risk_sub += w.at(person);
- n_events += y.at(person, 1) * w.at(person);
-
- if(person == y.n_rows-1) break;
-
- person++;
-
- }
-
- // only do km if a death was observed
-
- if(n_events > 0){
-
- temp1 = temp1 * (n_risk - n_events) / n_risk;
- leaf_nodes.at(leaf_node_counter, 1) = temp1;
- leaf_node_counter++;
-
- }
-
- n_risk -= n_risk_sub;
-
- } while (leaf_node_counter < i);
-
- leaf_nodes.resize(leaf_node_counter, 3);
-
- return(leaf_nodes);
-
-}
-
-
-
-
-// ----------------------------------------------------------------------------
-// ---------------------------- cholesky functions ----------------------------
-// ----------------------------------------------------------------------------
-
-// cholesky decomposition
-//
-// @description this function is copied from the survival package and
-// translated into arma.
-//
-// @param vmat matrix with covariance estimates
-// @param n_vars the number of predictors used in the current node
-//
-// prepares vmat for cholesky_solve()
-
-
-void cholesky(){
-
- double eps_chol = 0;
- double toler = 1e-8;
- double pivot;
-
- for(i = 0; i < n_vars; i++){
-
- if(vmat.at(i,i) > eps_chol) eps_chol = vmat.at(i,i);
-
- // copy upper right values to bottom left
- for(j = (i+1); j eps_chol) {
-
- for(j = (i+1); j < n_vars; j++){
-
- temp1 = vmat.at(j,i) / pivot;
- vmat.at(j,i) = temp1;
- vmat.at(j,j) -= temp1*temp1*pivot;
-
- for(k = (j+1); k < n_vars; k++){
-
- vmat.at(k, j) -= temp1 * vmat.at(k, i);
-
- }
-
- }
-
- } else {
-
- vmat.at(i, i) = 0;
-
- }
-
- }
-
-}
-
-// solve cholesky decomposition
-//
-// @description this function is copied from the survival package and
-// translated into arma. Prepares u, the vector used to update beta.
-//
-// @param vmat matrix with covariance estimates
-// @param n_vars the number of predictors used in the current node
-//
-//
-void cholesky_solve(){
-
- for (i = 0; i < n_vars; i++) {
-
- temp1 = u[i];
-
- for (j = 0; j < i; j++){
-
- temp1 -= u[j] * vmat.at(i, j);
- u[i] = temp1;
-
- }
-
- }
-
-
- for (i = n_vars; i >= 1; i--){
-
- if (vmat.at(i-1, i-1) == 0){
-
- u[i-1] = 0;
-
- } else {
-
- temp1 = u[i-1] / vmat.at(i-1, i-1);
-
- for (j = i; j < n_vars; j++){
- temp1 -= u[j] * vmat.at(j, i-1);
- }
-
- u[i-1] = temp1;
-
- }
-
- }
-
-}
-
-// invert the cholesky in the lower triangle
-//
-// @description this function is copied from the survival package and
-// translated into arma. Inverts vmat
-//
-// @param vmat matrix with covariance estimates
-// @param n_vars the number of predictors used in the current node
-//
-
-void cholesky_invert(){
-
- for (i=0; i0) {
-
- // take full advantage of the cholesky's diagonal of 1's
- vmat.at(i,i) = 1.0 / vmat.at(i,i);
-
- for (j=(i+1); j 0) {
-
- if (cph_method == 0 || n_events == 1) { // Breslow
-
- denom += denom_events;
- loglik -= weight_events * log(denom);
-
- for (i=0; i 0) {
-
- if (cph_method == 0 || n_events == 1) { // Breslow
-
- denom += denom_events;
- loglik -= denom_events * log(denom);
-
- for (i=0; i 1 && stat_best < R_PosInf){
-
- for(iter = 1; iter < cph_iter_max; iter++){
-
- // if(verbose > 0){
- //
- // Rcout << "--------- Newt-Raph algo; iter " << iter;
- // Rcout << " ---------" << std::endl;
- // Rcout << "beta: " << beta_new.t();
- // Rcout << "loglik: " << stat_best;
- // Rcout << std::endl;
- // Rcout << "------------------------------------------";
- // Rcout << std::endl << std::endl << std::endl;
- //
- // }
-
- // do the next iteration
- stat_current = newtraph_cph_iter(beta_new);
-
- cholesky();
-
- // don't go trying to fix this, just use the last
- // set of valid coefficients
- if(std::isinf(stat_current)) break;
-
- // check for convergence
- // break the loop if the new ll is ~ same as old best ll
- if(fabs(1 - stat_best / stat_current) < cph_eps){
- break;
- }
-
- if(stat_current < stat_best){ // it's not converging!
-
- halving++; // get more aggressive when it doesn't work
-
- // reduce the magnitude by which beta_new modifies beta_current
- for (i = 0; i < n_vars; i++){
- beta_new[i] = (beta_new[i]+halving*beta_current[i]) / (halving+1.0);
- }
-
- // yeah its not technically the best but I need to do this for
- // more reasonable output when verbose = true; I should remove
- // this line when verbosity is taken out.
- stat_best = stat_current;
-
- } else { // it's converging!
-
- halving = 0;
- stat_best = stat_current;
-
- cholesky_solve();
-
- for (i = 0; i < n_vars; i++) {
-
- beta_current[i] = beta_new[i];
- beta_new[i] = beta_new[i] + u[i];
-
- }
-
- }
-
- }
-
- }
-
- // invert vmat
- cholesky_invert();
-
- for (i=0; i < n_vars; i++) {
-
- beta_current[i] = beta_new[i];
-
- if(std::isinf(beta_current[i]) || std::isnan(beta_current[i])){
- beta_current[i] = 0;
- }
-
- if(std::isinf(vmat.at(i, i)) || std::isnan(vmat.at(i, i))){
- vmat.at(i, i) = 1.0;
- }
-
- // if(verbose > 0) Rcout << "scaled beta: " << beta_current[i] << "; ";
-
- if(cph_do_scale){
- beta_current.at(i) *= x_transforms.at(i, 1);
- vmat.at(i, i) *= x_transforms.at(i, 1) * x_transforms.at(i, 1);
- }
-
- // if(verbose > 0) Rcout << "un-scaled beta: " << beta_current[i] << std::endl;
-
- if(oobag_importance_type == 'A'){
-
- if(beta_current.at(i) != 0){
-
- temp1 = R::pchisq(pow(beta_current[i], 2) / vmat.at(i, i),
- 1, false, false);
-
- if(temp1 < 0.01) vi_pval_numer[cols_node[i]]++;
-
- }
-
- vi_pval_denom[cols_node[i]]++;
-
- }
-
- }
-
- // if(verbose > 1) Rcout << std::endl;
-
- return(beta_current);
-
-}
-
-// same function as above, but exported to R for testing
-// [[Rcpp::export]]
-arma::vec newtraph_cph_testthat(NumericMatrix& x_in,
- NumericMatrix& y_in,
- NumericVector& w_in,
- int method,
- double cph_eps_,
- int iter_max){
-
-
- x_node = mat(x_in.begin(), x_in.nrow(), x_in.ncol(), false);
- y_node = mat(y_in.begin(), y_in.nrow(), y_in.ncol(), false);
- w_node = vec(w_in.begin(), w_in.length(), false);
-
- cph_do_scale = true;
-
- cph_method = method;
- cph_eps = cph_eps_;
- cph_iter_max = iter_max;
- n_vars = x_node.n_cols;
-
- vi_pval_numer.zeros(x_node.n_cols);
- vi_pval_denom.zeros(x_node.n_cols);
- cols_node = regspace(0, x_node.n_cols - 1);
-
- x_node_scale();
-
- vec out = newtraph_cph();
-
- return(out);
-
-}
-
-// ----------------------------------------------------------------------------
-// ---------------------------- node functions --------------------------------
-// ----------------------------------------------------------------------------
-
-// Log rank test w/multiple cutpoints
-//
-// this function returns a cutpoint obtaining a local maximum
-// of the log-rank test (lrt) statistic. The default value (+Inf)
-// is really for diagnostic purposes. Put another way, if the
-// return value is +Inf (an impossible value for a cutpoint),
-// that means that we didn't find any valid cut-points and
-// the node cannot be grown with the current XB.
-//
-// if there is a valid cut-point, then the main side effect
-// of this function is to modify the group vector, which
-// will be used to assign observations to the two new nodes.
-//
-// @param group the vector that determines which node to send each
-// observation to (left node = 0, right node = 1)
-// @param y_node matrix of outcomes
-// @param w_node vector of weights
-// @param XB linear combination of predictors
-//
-// the group vector is modified by this function and the value returned
-// is the maximal log-rank statistic across all the possible cutpoints.
-double lrt_multi(){
-
- break_loop = false;
-
- // group should be initialized as all 0s
- group.zeros(y_node.n_rows);
-
- // initialize at the lowest possible LRT stat value
- stat_best = 0;
-
- // sort XB- we need to iterate over the sorted indices
- iit_vals = sort_index(XB, "ascend");
-
- // unsafe columns point to cols in y_node.
- vec y_status = y_node.unsafe_col(1);
- vec y_time = y_node.unsafe_col(0);
-
- // first determine the lowest value of XB that will
- // be a valid cut-point to split a node. A valid cut-point
- // is one that, if used, will result in at least leaf_min_obs
- // and leaf_min_events in both the left and right node.
-
- n_events = 0;
- n_risk = 0;
-
- // if(verbose > 1){
- // Rcout << "----- finding cut-point boundaries -----" << std::endl;
- // }
-
- // Iterate through the sorted values of XB, in ascending order.
-
- for(iit = iit_vals.begin(); iit < iit_vals.end()-1; ++iit){
-
- n_events += y_status[*iit] * w_node[*iit];
- n_risk += w_node[*iit];
-
- // If we want to make the current value of XB a cut-point, we need
- // to make sure the next value of XB isn't equal to this current value.
- // Otherwise, we will have the same value of XB in both groups!
-
- // if(verbose > 1){
- // Rcout << XB[*iit] << " ---- ";
- // Rcout << XB[*(iit+1)] << " ---- ";
- // Rcout << n_events << " ---- ";
- // Rcout << n_risk << std::endl;
- // }
-
- if(XB[*iit] != XB[*(iit+1)]){
-
- // if(verbose > 1){
- // Rcout << "********* New cut-point here ********" << std::endl;
- // }
-
-
- if( n_events >= leaf_min_events &&
- n_risk >= leaf_min_obs) {
-
- // if(verbose > 1){
- // Rcout << std::endl;
- // Rcout << "lower cutpoint: " << XB[*iit] << std::endl;
- // Rcout << " - n_events, left node: " << n_events << std::endl;
- // Rcout << " - n_risk, left node: " << n_risk << std::endl;
- // Rcout << std::endl;
- // }
-
- break;
-
- }
-
- }
-
- }
-
- // if(verbose > 1){
- // if(iit >= iit_vals.end()-1) {
- // Rcout << "Could not find a valid lower cut-point" << std::endl;
- // }
- // }
-
-
- j = iit - iit_vals.begin();
-
- // got to reset these before finding the upper limit
- n_events=0;
- n_risk=0;
-
- // do the first step in the loop manually since we need to
- // refer to iit+1 in all proceeding steps.
-
- for(iit = iit_vals.end()-1; iit >= iit_vals.begin()+1; --iit){
-
- n_events += y_status[*iit] * w_node[*iit];
- n_risk += w_node[*iit];
- group[*iit] = 1;
-
- // if(verbose > 1){
- // Rcout << XB[*iit] << " ---- ";
- // Rcout << XB(*(iit-1)) << " ---- ";
- // Rcout << n_events << " ---- ";
- // Rcout << n_risk << std::endl;
- // }
-
- if ( XB[*iit] != XB[*(iit-1)] ) {
-
- // if(verbose > 1){
- // Rcout << "********* New cut-point here ********" << std::endl;
- // }
-
- if( n_events >= leaf_min_events &&
- n_risk >= leaf_min_obs ) {
-
- // the upper cutpoint needs to be one step below the current
- // iit value, because we use x <= cp to determine whether a
- // value x goes to the left node versus the right node. So,
- // if iit currently points to 3, and the next value down is 2,
- // then we want to say the cut-point is 2 because then all
- // values <= 2 will go left, and 3 will go right. This matters
- // when 3 is the highest value in the vector.
-
- --iit;
-
- // if(verbose > 1){
- // Rcout << std::endl;
- // Rcout << "upper cutpoint: " << XB[*iit] << std::endl;
- // Rcout << " - n_events, right node: " << n_events << std::endl;
- // Rcout << " - n_risk, right node: " << n_risk << std::endl;
- // }
-
- break;
-
- }
-
- }
-
- }
-
- // number of steps taken
- k = iit + 1 - iit_vals.begin();
-
- // if(verbose > 1){
- // Rcout << "----------------------------------------" << std::endl;
- // Rcout << std::endl << std::endl;
- // Rcout << "sorted XB: " << std::endl << XB(iit_vals).t() << std::endl;
- // }
-
- // initialize cut-point as the value of XB iit currently points to.
- iit_best = iit;
-
- // what happens if we don't have enough events or obs to split?
- // the first valid lower cut-point (at iit_vals(k)) is > the first
- // valid upper cutpoint (current value of n_risk). Put another way,
- // k (the number of steps taken from beginning of the XB vec)
- // will be > n_rows - p, where the difference on the RHS is
- // telling us where we are after taking p steps from the end
- // of the XB vec. Returning the infinite cp is a red flag.
-
- // if(verbose > 1){
- // Rcout << "j: " << j << std::endl;
- // Rcout << "k: " << k << std::endl;
- // }
-
- if (j > k){
-
- // if(verbose > 1) {
- // Rcout << "Could not find a cut-point for this XB" << std::endl;
- // }
-
- return(R_PosInf);
- }
-
- // if(verbose > 1){
- //
- // Rcout << "----- initializing log-rank test cutpoints -----" << std::endl;
- // Rcout << "n potential cutpoints: " << k-j << std::endl;
- //
- // }
-
-
- // adjust k to indicate the number of valid cut-points
- k -= j;
-
- if(k > n_split){
-
- jit_vals = linspace(0, k, n_split);
-
- } else {
-
- // what happens if there are only 5 potential cut-points
- // but the value of n_split is > 5? We will just check out
- // the 5 valid cutpoints.
- jit_vals = linspace(0, k, k);
-
- }
-
- vec_temp.resize( jit_vals.size() );
-
- // protection from going out of bounds with jit_vals(k) below
- if(j == 0) jit_vals.at(jit_vals.size()-1)--;
-
- // put the indices of potential cut-points into vec_temp
- for(k = 0; k < vec_temp.size(); k++){
- vec_temp[k] = XB.at(*(iit_best - jit_vals[k]));
- }
-
- // back to how it was!
- if(j == 0) jit_vals.at(jit_vals.size()-1)++;
-
- // if(verbose > 1){
- //
- // Rcout << "cut-points chosen: ";
- //
- // Rcout << vec_temp.t();
- //
- // Rcout << "----------------------------------------" << std::endl <<
- // std::endl << std::endl;
- //
- // }
-
- bool do_lrt = true;
-
- k = 0;
- j = 1;
-
- // begin outer loop - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
- for(jit = jit_vals.begin(); jit != jit_vals.end(); ++jit){
-
-
- // if(verbose > 1){
- // Rcout << "jit points to " << *jit << std::endl;
- // }
-
- // switch group values from 0 to 1 until you get to the next cut-point
- for( ; j < *jit; j++){
- group[*iit] = 1;
- --iit;
- }
-
- if(jit == jit_vals.begin() ||
- jit == jit_vals.end()-1){
-
- do_lrt = true;
-
- } else {
-
- if( vec_temp[k] == vec_temp[k+1] ||
- vec_temp[k] == vec_temp[0] ||
- *jit <= 1){
-
- do_lrt = false;
-
- } else {
-
- while( XB[*iit] == XB[*(iit - 1)] ){
-
- group[*iit] = 1;
- --iit;
- ++j;
-
- // if(verbose > 1){
- // Rcout << "cutpoint dropped down one spot: ";
- // Rcout << XB[*iit] << std::endl;
- // }
-
- }
-
- do_lrt = true;
-
- }
-
- }
-
- ++k;
-
- if(do_lrt){
-
- n_risk=0;
- g_risk=0;
-
- observed=0;
- expected=0;
-
- V=0;
-
- break_loop = false;
-
- i = y_node.n_rows-1;
-
- // if(verbose > 1){
- // Rcout << "sum(group==1): " << sum(group) << "; ";
- // Rcout << "sum(group==1 * w_node): " << sum(group % w_node);
- // Rcout << std::endl;
- // if(verbose > 1){
- // Rcout << "group:" << std::endl;
- // Rcout << group(iit_vals).t() << std::endl;
- // }
- // }
-
-
- // begin inner loop - - - - - - - - - - - - - - - - - - - - - - - - - -
- for (; ;){
-
- temp1 = y_time[i];
-
- n_events = 0;
-
- for ( ; y_time[i] == temp1; i--) {
-
- n_risk += w_node[i];
- n_events += y_status[i] * w_node[i];
- g_risk += group[i] * w_node[i];
- observed += y_status[i] * group[i] * w_node[i];
-
- if(i == 0){
- break_loop = true;
- break;
- }
-
- }
-
- // should only do these calculations if n_events > 0,
- // but turns out its faster to multiply by 0 than
- // it is to check whether n_events is > 0
-
- temp2 = g_risk / n_risk;
- expected += n_events * temp2;
-
- // update variance if n_risk > 1 (if n_risk == 1, variance is 0)
- // definitely check if n_risk is > 1 b/c otherwise divide by 0
- if (n_risk > 1){
- temp1 = n_events * temp2 * (n_risk-n_events) / (n_risk-1);
- V += temp1 * (1 - temp2);
- }
-
- if(break_loop) break;
-
- }
- // end inner loop - - - - - - - - - - - - - - - - - - - - - - - - - - - -
-
- stat_current = pow(expected-observed, 2) / V;
-
- // if(verbose > 1){
- //
- // Rcout << "-------- log-rank test results --------" << std::endl;
- // Rcout << "cutpoint: " << XB[*iit] << std::endl;
- // Rcout << "lrt stat: " << stat_current << std::endl;
- // Rcout << "---------------------------------------" << std::endl <<
- // std::endl << std::endl;
- //
- // }
-
- if(stat_current > stat_best){
- iit_best = iit;
- stat_best = stat_current;
- n_events_right = observed;
- n_risk_right = g_risk;
- n_risk_left = n_risk - g_risk;
- }
-
- }
- // end outer loop - - - - - - - - - - - - - - - - - - - - - - - - - - - -
-
- }
-
- // if the log-rank test does not detect a difference at 0.05 alpha,
- // maybe it's not a good idea to split this node.
-
- if(stat_best < split_min_stat) return(R_PosInf);
-
- // if(verbose > 1){
- // Rcout << "Best LRT stat: " << stat_best << std::endl;
- // }
-
- // rewind iit until it is back where it was when we got the
- // best lrt stat. While rewinding iit, also reset the group
- // values so that group is as it was when we got the best
- // lrt stat.
-
-
- while(iit <= iit_best){
- group[*iit] = 0;
- ++iit;
- }
-
- // XB at *iit_best is the cut-point that maximized the log-rank test
- return(XB[*iit_best]);
-
-}
-
-// this function is the same as above, but is exported to R for testing
-// [[Rcpp::export]]
-List lrt_multi_testthat(NumericMatrix& y_node_,
- NumericVector& w_node_,
- NumericVector& XB_,
- int n_split_,
- int leaf_min_events_,
- int leaf_min_obs_
-){
-
- y_node = mat(y_node_.begin(), y_node_.nrow(), y_node_.ncol(), false);
- w_node = vec(w_node_.begin(), w_node_.length(), false);
- XB = vec(XB_.begin(), XB_.length(), false);
-
- n_split = n_split_;
- leaf_min_events = leaf_min_events_;
- leaf_min_obs = leaf_min_obs_;
-
- // about this function - - - - - - - - - - - - - - - - - - - - - - - - - - -
- //
- // this function returns a cutpoint obtaining a local maximum
- // of the log-rank test (lrt) statistic. The default value (+Inf)
- // is really for diagnostic purposes. Put another way, if the
- // return value is +Inf (an impossible value for a cutpoint),
- // that means that we didn't find any valid cut-points and
- // the node cannot be grown with the current XB.
- //
- // if there is a valid cut-point, then the main side effect
- // of this function is to modify the group vector, which
- // will be used to assign observations to the two new nodes.
- //
- // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
-
- break_loop = false;
-
- vec cutpoints_used(n_split);
- vec lrt_statistics(n_split);
- uword list_counter = 0;
-
- // group should be initialized as all 0s
- group.zeros(y_node.n_rows);
-
- // initialize at the lowest possible LRT stat value
- stat_best = 0;
-
- // sort XB- we need to iterate over the sorted indices
- iit_vals = sort_index(XB, "ascend");
-
- // unsafe columns point to cols in y_node.
- vec y_status = y_node.unsafe_col(1);
- vec y_time = y_node.unsafe_col(0);
-
- // first determine the lowest value of XB that will
- // be a valid cut-point to split a node. A valid cut-point
- // is one that, if used, will result in at least leaf_min_obs
- // and leaf_min_events in both the left and right node.
-
- n_events = 0;
- n_risk = 0;
-
- // if(verbose > 1){
- // Rcout << "----- finding cut-point boundaries -----" << std::endl;
- // }
-
- // Iterate through the sorted values of XB, in ascending order.
-
- for(iit = iit_vals.begin(); iit < iit_vals.end()-1; ++iit){
-
- n_events += y_status(*iit) * w_node(*iit);
- n_risk += w_node(*iit);
-
- // If we want to make the current value of XB a cut-point, we need
- // to make sure the next value of XB isn't equal to this current value.
- // Otherwise, we will have the same value of XB in both groups!
-
- // if(verbose > 1){
- // Rcout << XB(*iit) << " ---- ";
- // Rcout << XB(*(iit+1)) << " ---- ";
- // Rcout << n_events << " ---- ";
- // Rcout << n_risk << std::endl;
- // }
-
- if(XB(*iit) != XB(*(iit+1))){
-
- // if(verbose > 1){
- // Rcout << "********* New cut-point here ********" << std::endl;
- // }
-
-
- if( n_events >= leaf_min_events &&
- n_risk >= leaf_min_obs) {
-
- // if(verbose > 1){
- // Rcout << std::endl;
- // Rcout << "lower cutpoint: " << XB(*iit) << std::endl;
- // Rcout << " - n_events, left node: " << n_events << std::endl;
- // Rcout << " - n_risk, left node: " << n_risk << std::endl;
- // Rcout << std::endl;
- // }
-
- break;
-
- }
-
- }
-
- }
-
- // if(verbose > 1){
- // if(iit >= iit_vals.end()-1) {
- // Rcout << "Could not find a valid lower cut-point" << std::endl;
- // }
- // }
-
-
- j = iit - iit_vals.begin();
-
- // got to reset these before finding the upper limit
- n_events=0;
- n_risk=0;
-
- // do the first step in the loop manually since we need to
- // refer to iit+1 in all proceeding steps.
-
- for(iit = iit_vals.end()-1; iit >= iit_vals.begin()+1; --iit){
-
- n_events += y_status(*iit) * w_node(*iit);
- n_risk += w_node(*iit);
- group(*iit) = 1;
-
- // if(verbose > 1){
- // Rcout << XB(*iit) << " ---- ";
- // Rcout << XB(*(iit-1)) << " ---- ";
- // Rcout << n_events << " ---- ";
- // Rcout << n_risk << std::endl;
- // }
-
- if(XB(*iit) != XB(*(iit-1))){
-
- // if(verbose > 1){
- // Rcout << "********* New cut-point here ********" << std::endl;
- // }
-
- if( n_events >= leaf_min_events &&
- n_risk >= leaf_min_obs ) {
-
- // the upper cutpoint needs to be one step below the current
- // iit value, because we use x <= cp to determine whether a
- // value x goes to the left node versus the right node. So,
- // if iit currently points to 3, and the next value down is 2,
- // then we want to say the cut-point is 2 because then all
- // values <= 2 will go left, and 3 will go right. This matters
- // when 3 is the highest value in the vector.
-
- --iit;
-
- // if(verbose > 1){
- // Rcout << std::endl;
- // Rcout << "upper cutpoint: " << XB(*iit) << std::endl;
- // Rcout << " - n_events, right node: " << n_events << std::endl;
- // Rcout << " - n_risk, right node: " << n_risk << std::endl;
- // }
-
- break;
-
- }
-
- }
-
- }
-
- // number of steps taken
- k = iit + 1 - iit_vals.begin();
-
- // if(verbose > 1){
- // Rcout << "----------------------------------------" << std::endl;
- // Rcout << std::endl << std::endl;
- // Rcout << "sorted XB: " << std::endl << XB(iit_vals).t() << std::endl;
- // }
-
- // initialize cut-point as the value of XB iit currently points to.
- iit_best = iit;
-
- // what happens if we don't have enough events or obs to split?
- // the first valid lower cut-point (at iit_vals(k)) is > the first
- // valid upper cutpoint (current value of n_risk). Put another way,
- // k (the number of steps taken from beginning of the XB vec)
- // will be > n_rows - p, where the difference on the RHS is
- // telling us where we are after taking p steps from the end
- // of the XB vec. Returning the infinite cp is a red flag.
-
- // if(verbose > 1){
- // Rcout << "j: " << j << std::endl;
- // Rcout << "k: " << k << std::endl;
- // }
-
- if (j > k){
-
- // if(verbose > 1) {
- // Rcout << "Could not find a cut-point for this XB" << std::endl;
- // }
-
- return(R_PosInf);
- }
-
- // if(verbose > 1){
- //
- // Rcout << "----- initializing log-rank test cutpoints -----" << std::endl;
- // Rcout << "n potential cutpoints: " << k-j << std::endl;
- //
- // }
-
- // what happens if there are only 5 potential cut-points
- // but the value of n_split is > 5? We will just check out
- // the 5 valid cutpoints.
-
- // adjust k to indicate steps taken in the outer loop.
- k -= j;
-
- if(k > n_split){
-
- jit_vals = linspace(0, k, n_split);
-
- } else {
-
- jit_vals = linspace(0, k, k);
-
- }
-
- vec_temp.resize( jit_vals.size() );
-
- if(j == 0) jit_vals(jit_vals.size()-1)--;
-
- for(k = 0; k < vec_temp.size(); k++){
- vec_temp(k) = XB(*(iit_best - jit_vals(k)));
- }
-
- if(j == 0) jit_vals(jit_vals.size()-1)++;
-
-
- // if(verbose > 1){
- //
- // Rcout << "cut-points chosen: ";
- //
- // Rcout << vec_temp.t();
- //
- // Rcout << "----------------------------------------" << std::endl <<
- // std::endl << std::endl;
- //
- // }
-
- bool do_lrt = true;
-
- k = 0;
- j = 1;
-
- // begin outer loop - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
- for(jit = jit_vals.begin(); jit != jit_vals.end(); ++jit){
-
-
- // if(verbose > 1){
- // Rcout << "jit points to " << *jit << std::endl;
- // }
-
- for( ; j < *jit; j++){
- group(*iit) = 1;
- --iit;
- }
-
- if(jit == jit_vals.begin() ||
- jit == jit_vals.end()-1){
-
- do_lrt = true;
-
- } else {
-
- if( vec_temp(k) == vec_temp(k+1) ||
- vec_temp(k) == vec_temp(0) ||
- *jit <= 1){
-
- do_lrt = false;
-
- } else {
-
- while(XB(*iit) == XB(*(iit - 1))){
-
- group(*iit) = 1;
- --iit;
- ++j;
-
- // if(verbose > 1){
- // Rcout << "cutpoint dropped down one spot: ";
- // Rcout << XB(*iit) << std::endl;
- // }
-
- }
-
- do_lrt = true;
-
- }
-
- }
-
- ++k;
-
- if(do_lrt){
-
- cutpoints_used(list_counter) = XB(*iit);
-
- n_risk=0;
- g_risk=0;
-
- observed=0;
- expected=0;
-
- V=0;
-
- break_loop = false;
-
- i = y_node.n_rows-1;
-
- // if(verbose > 1){
- // Rcout << "sum(group==1): " << sum(group) << "; ";
- // Rcout << "sum(group==1 * w_node): " << sum(group % w_node);
- // Rcout << std::endl;
- // if(verbose > 1){
- // Rcout << "group:" << std::endl;
- // Rcout << group(iit_vals).t() << std::endl;
- // }
- // }
-
-
- // begin inner loop - - - - - - - - - - - - - - - - - - - - - - - - - -
- for (; ;){
-
- temp1 = y_time[i];
-
- n_events = 0;
-
- for ( ; y_time[i] == temp1; i--) {
-
- n_risk += w_node[i];
- n_events += y_status[i] * w_node[i];
- g_risk += group[i] * w_node[i];
- observed += y_status[i] * group[i] * w_node[i];
-
- if(i == 0){
- break_loop = true;
- break;
- }
-
- }
-
- // should only do these calculations if n_events > 0,
- // but turns out its faster to multiply by 0 than
- // it is to check whether n_events is > 0
-
- temp2 = g_risk / n_risk;
- expected += n_events * temp2;
-
- // update variance if n_risk > 1 (if n_risk == 1, variance is 0)
- // definitely check if n_risk is > 1 b/c otherwise divide by 0
- if (n_risk > 1){
- temp1 = n_events * temp2 * (n_risk-n_events) / (n_risk-1);
- V += temp1 * (1 - temp2);
- }
-
- if(break_loop) break;
-
- }
- // end inner loop - - - - - - - - - - - - - - - - - - - - - - - - - - - -
-
- stat_current = pow(expected-observed, 2) / V;
-
- lrt_statistics(list_counter) = stat_current;
-
- list_counter++;
-
- // if(verbose > 1){
- //
- // Rcout << "-------- log-rank test results --------" << std::endl;
- // Rcout << "cutpoint: " << XB(*iit) << std::endl;
- // Rcout << "lrt stat: " << stat_current << std::endl;
- // Rcout << "---------------------------------------" << std::endl <<
- // std::endl << std::endl;
- //
- // }
-
- if(stat_current > stat_best){
- iit_best = iit;
- stat_best = stat_current;
- n_events_right = observed;
- n_risk_right = g_risk;
- n_risk_left = n_risk - g_risk;
- }
-
- }
- // end outer loop - - - - - - - - - - - - - - - - - - - - - - - - - - - -
-
- }
-
- // if the log-rank test does not detect a difference at 0.05 alpha,
- // maybe it's not a good idea to split this node.
-
- if(stat_best < 3.841459) return(R_PosInf);
-
- // if(verbose > 1){
- // Rcout << "Best LRT stat: " << stat_best << std::endl;
- // }
-
- // rewind iit until it is back where it was when we got the
- // best lrt stat. While rewinding iit, also reset the group
- // values so that group is as it was when we got the best
- // lrt stat.
-
-
- while(iit <= iit_best){
- group(*iit) = 0;
- ++iit;
- }
-
- return(List::create(_["cutpoints"] = cutpoints_used,
- _["statistic"] = lrt_statistics));
-
-}
-
-
-// out-of-bag prediction for single prediction horizon
-//
-// @param pred_type indicates what type of prediction to compute
-// @param leaf_pred a vector indicating which leaf each observation
-// landed in.
-// @param leaf_indices a matrix that contains indices for each leaf node
-// inside of leaf_nodes
-// @param leaf_nodes a matrix with ids, survival, and cumulative hazard
-// functions for each leaf node.
-//
-// @return matrix with predictions, dimension n by 1
-
-void oobag_pred_surv_uni(char pred_type){
-
- iit_vals = sort_index(leaf_pred, "ascend");
- iit = iit_vals.begin();
-
- switch(pred_type){
-
- case 'S': case 'R':
-
- leaf_node_col = 1;
- pred_t0 = 1;
- break;
-
- case 'H':
-
- leaf_node_col = 2;
- pred_t0 = 0;
- break;
-
- }
-
- do {
-
- person_leaf = leaf_pred[*iit];
-
- // find the current leaf
- for(i = 0; i < leaf_indices.n_rows; i++){
- if(leaf_indices.at(i, 0) == person_leaf){
- break;
- }
- }
-
- // get submat view for this leaf
- leaf_node = leaf_nodes.rows(leaf_indices(i, 1),
- leaf_indices(i, 2));
-
- // if(verbose > 1){
- // Rcout << "leaf_node:" << std::endl << leaf_node << std::endl;
- // }
-
- i = 0;
-
- if(time_pred < leaf_node.at(leaf_node.n_rows - 1, 0)){
-
- for(; i < leaf_node.n_rows; i++){
- if (leaf_node.at(i, 0) > time_pred){
- if(i == 0)
- temp1 = pred_t0;
- else
- temp1 = leaf_node.at(i-1, leaf_node_col);
- break;
- } else if (leaf_node.at(i, 0) == time_pred){
- temp1 = leaf_node.at(i, leaf_node_col);
- break;
- }
- }
-
- } else {
-
- // go here if prediction horizon > max time in current leaf.
- temp1 = leaf_node.at(leaf_node.n_rows - 1, leaf_node_col);
-
- }
-
- // running mean: mean_k = mean_{k-1} + (new val - old val) / k
- // compute new val - old val
- // be careful, every oob row has a different denom!
- temp2 = temp1 - surv_pvec[rows_oobag[*iit]];
- surv_pvec[rows_oobag[*iit]] += temp2 / denom_pred[rows_oobag[*iit]];
- ++iit;
-
- if(iit < iit_vals.end()){
-
- while(person_leaf == leaf_pred(*iit)){
-
- temp2 = temp1 - surv_pvec[rows_oobag[*iit]];
- surv_pvec[rows_oobag[*iit]] += temp2 / denom_pred[rows_oobag[*iit]];
-
- ++iit;
-
- if (iit == iit_vals.end()) break;
-
- }
-
- }
-
- } while (iit < iit_vals.end());
-
- // if(verbose > 0){
- // Rcout << "surv_pvec:" << std::endl << surv_pvec.t() << std::endl;
- // }
-
-}
-
-// out-of-bag prediction evaluation, Harrell's C-statistic
-//
-// @param pred_type indicates what type of prediction to compute
-// @param y_input matrix of outcomes from input
-//
-// @return the C-statistic
-
-double oobag_c_harrell(char pred_type){
-
- vec time = y_input.unsafe_col(0);
- vec status = y_input.unsafe_col(1);
- iit_vals = find(status == 1);
-
- k = y_input.n_rows;
-
- double total=0, concordant=0;
-
- switch(pred_type){
-
- case 'S': case 'R':
- for (iit = iit_vals.begin(); iit < iit_vals.end(); ++iit) {
-
- for(j = *iit + 1; j < k; ++j){
-
- if (time[j] > time[*iit]) { // ties not counted
-
- total++;
-
- // for survival, current value > next vals is good
- // risk is the same as survival until just before we output
- // the oobag predictions, when we say pvec = 1-pvec,
- if (surv_pvec[j] > surv_pvec[*iit]){
-
- concordant++;
-
- } else if (surv_pvec[j] == surv_pvec[*iit]){
-
- concordant+= 0.5;
-
- }
-
- }
-
- }
-
- }
- break;
-
- case 'H':
- for (iit = iit_vals.begin(); iit < iit_vals.end(); ++iit) {
-
- for(j = *iit + 1; j < k; ++j){
-
- if (time[j] > time[*iit]) { // ties not counted
-
- total++;
-
- // for risk & chf current value < next vals is good.
- if (surv_pvec[j] < surv_pvec[*iit]){
-
- concordant++;
-
- } else if (surv_pvec[j] == surv_pvec[*iit]){
-
- concordant+= 0.5;
-
- }
-
- }
-
- }
-
- }
- break;
- }
-
- return(concordant / total);
-
-}
-
-// same function as above but exported to R for testing
-// [[Rcpp::export]]
-double oobag_c_harrell_testthat(NumericMatrix y_mat,
- NumericVector s_vec) {
-
- y_input = mat(y_mat.begin(), y_mat.nrow(), y_mat.ncol(), false);
- surv_pvec = vec(s_vec.begin(), s_vec.length(), false);
-
- return(oobag_c_harrell(pred_type_dflt));
-
-}
-
-// this function is the same as oobag_pred_surv_uni,
-// but it operates on new data rather than out-of-bag data
-// and it allows for multiple prediction horizons instead of one
-void new_pred_surv_multi(char pred_type){
-
- // allocate memory for output
- // surv_pvec.zeros(x_pred.n_rows);
-
- surv_pvec.set_size(times_pred.size());
- iit_vals = sort_index(leaf_pred, "ascend");
- iit = iit_vals.begin();
-
- switch(pred_type){
-
- case 'S': case 'R':
-
- leaf_node_col = 1;
- pred_t0 = 1;
- break;
-
- case 'H':
-
- leaf_node_col = 2;
- pred_t0 = 0;
- break;
-
- }
-
- do {
-
- person_leaf = leaf_pred(*iit);
-
- for(i = 0; i < leaf_indices.n_rows; i++){
- if(leaf_indices.at(i, 0) == person_leaf){
- break;
- }
- }
-
- leaf_node = leaf_nodes.rows(leaf_indices(i, 1),
- leaf_indices(i, 2));
-
- // if(verbose > 1){
- // Rcout << "leaf_node:" << std::endl << leaf_node << std::endl;
- // }
-
- i = 0;
-
- for(j = 0; j < times_pred.size(); j++){
-
- time_pred = times_pred.at(j);
-
- if(time_pred < leaf_node.at(leaf_node.n_rows - 1, 0)){
-
- for(; i < leaf_node.n_rows; i++){
-
- if (leaf_node.at(i, 0) > time_pred){
-
- if(i == 0)
- temp1 = pred_t0;
- else
- temp1 = leaf_node.at(i-1, leaf_node_col);
-
- break;
-
- } else if (leaf_node.at(i, 0) == time_pred){
-
- temp1 = leaf_node.at(i, leaf_node_col);
- break;
-
- }
-
- }
-
- } else {
-
- // go here if prediction horizon > max time in current leaf.
- temp1 = leaf_node.at(leaf_node.n_rows - 1, leaf_node_col);
-
- }
-
- surv_pvec.at(j) = temp1;
-
- }
-
- surv_pmat.row(*iit) += surv_pvec.t();
- ++iit;
-
- if(iit < iit_vals.end()){
-
- while(person_leaf == leaf_pred.at(*iit)){
-
- surv_pmat.row(*iit) += surv_pvec.t();
- ++iit;
-
- if (iit == iit_vals.end()) break;
-
- }
-
- }
-
- } while (iit < iit_vals.end());
-
-}
-
-// this function is the same as new_pred_surv_multi,
-// but only uses one prediction horizon
-void new_pred_surv_uni(char pred_type){
-
- iit_vals = sort_index(leaf_pred, "ascend");
- iit = iit_vals.begin();
-
- switch(pred_type){
-
- case 'S': case 'R':
-
- leaf_node_col = 1;
- pred_t0 = 1;
- break;
-
- case 'H':
-
- leaf_node_col = 2;
- pred_t0 = 0;
- break;
-
- }
-
- do {
-
- person_leaf = leaf_pred(*iit);
-
- for(i = 0; i < leaf_indices.n_rows; i++){
- if(leaf_indices.at(i, 0) == person_leaf){
- break;
- }
- }
-
- leaf_node = leaf_nodes.rows(leaf_indices.at(i, 1),
- leaf_indices.at(i, 2));
-
- // if(verbose > 1){
- // Rcout << "leaf_node:" << std::endl << leaf_node << std::endl;
- // }
-
- i = 0;
-
- if(time_pred < leaf_node.at(leaf_node.n_rows - 1, 0)){
-
- for(; i < leaf_node.n_rows; i++){
- if (leaf_node.at(i, 0) > time_pred){
-
- if(i == 0){
-
- temp1 = pred_t0;
-
- } else {
-
- temp1 = leaf_node.at(i - 1, leaf_node_col);
-
- // experimental - does not seem to help!
- // weighted average of surv est from before and after time of pred
- // temp2 = leaf_node(i, 0) - leaf_node(i-1, 0);
- //
- // temp1 = leaf_node(i, 1) * (time_pred - leaf_node(i-1,0)) / temp2 +
- // leaf_node(i-1, 1) * (leaf_node(i,0) - time_pred) / temp2;
-
- }
-
- break;
-
- } else if (leaf_node.at(i, 0) == time_pred){
- temp1 = leaf_node.at(i, leaf_node_col);
- break;
- }
- }
-
- } else if (time_pred == leaf_node.at(leaf_node.n_rows - 1, 0)){
-
- temp1 = leaf_node.at(leaf_node.n_rows - 1, leaf_node_col);
-
- } else {
-
- // go here if prediction horizon > max time in current leaf.
- temp1 = leaf_node.at(leaf_node.n_rows - 1, leaf_node_col);
-
- // --- EXPERIMENTAL ADD-ON --- //
- // if you are predicting beyond the max time in a node,
- // then determine how much further out you are and assume
- // the survival probability decays at the same rate.
-
- // temp2 = (1.0 - temp1) *
- // (time_pred - leaf_node(leaf_node.n_rows - 1, 0)) / time_pred;
- //
- // temp1 = temp1 * (1.0-temp2);
-
- }
-
- surv_pvec.at(*iit) += temp1;
- ++iit;
-
- if(iit < iit_vals.end()){
-
- while(person_leaf == leaf_pred.at(*iit)){
-
- surv_pvec.at(*iit) += temp1;
- ++iit;
-
- if (iit == iit_vals.end()) break;
-
- }
-
- }
-
- } while (iit < iit_vals.end());
-
- // if(verbose > 1){
- // Rcout << "pred_surv:" << std::endl << surv_pvec.t() << std::endl;
- // }
-
-}
-
-
-// ----------------------------------------------------------------------------
-// --------------------------- ostree functions -------------------------------
-// ----------------------------------------------------------------------------
-
-// increase the memory allocated to a tree
-//
-// this function is used if the initial memory allocation isn't enough
-// to grow the tree. It modifies all elements of the tree, including
-// betas, col_indices, children_left, and cutpoints
-//
-void ostree_size_buffer(){
-
- // if(verbose > 1){
- // Rcout << "---------- buffering outputs ----------" << std::endl;
- // Rcout << "betas before: " << std::endl << betas.t() << std::endl;
- // }
-
- betas.insert_cols(betas.n_cols, 10);
- // x_mean.insert_cols(x_mean.n_cols, 10);
- col_indices.insert_cols(col_indices.n_cols, 10);
- children_left.insert_rows(children_left.size(), 10);
- cutpoints.insert_rows(cutpoints.size(), 10);
-
- // if(verbose > 1){
- // Rcout << "betas after: " << std::endl << betas.t() << std::endl;
- // Rcout << "---------------------------------------";
- // Rcout << std::endl << std::endl;
- // }
-
-
-}
-
-// transfer memory from R into arma types
-//
-// when trees are passed from R, they need to be converted back into
-// arma objects. The intent of this function is to convert everything
-// back into an arma object without copying any data.
-//
-// nothing is modified apart from types
-
-void ostree_mem_xfer(){
-
- // no data copied according to tracemem.
- // not including boot rows or x_mean (don't always need them)
-
- NumericMatrix leaf_nodes_ = ostree["leaf_nodes"];
- NumericMatrix betas_ = ostree["betas"];
- NumericVector cutpoints_ = ostree["cut_points"];
- IntegerMatrix col_indices_ = ostree["col_indices"];
- IntegerMatrix leaf_indices_ = ostree["leaf_node_index"];
- IntegerVector children_left_ = ostree["children_left"];
-
- leaf_nodes = mat(leaf_nodes_.begin(),
- leaf_nodes_.nrow(),
- leaf_nodes_.ncol(),
- false);
-
- betas = mat(betas_.begin(),
- betas_.nrow(),
- betas_.ncol(),
- false);
-
- cutpoints = vec(cutpoints_.begin(), cutpoints_.length(), false);
-
- col_indices = conv_to::from(
- imat(col_indices_.begin(),
- col_indices_.nrow(),
- col_indices_.ncol(),
- false)
- );
-
- leaf_indices = conv_to::from(
- imat(leaf_indices_.begin(),
- leaf_indices_.nrow(),
- leaf_indices_.ncol(),
- false)
- );
-
- children_left = conv_to::from(
- ivec(children_left_.begin(),
- children_left_.length(),
- false)
- );
-
-}
-
-// drop observations down the tree
-//
-// @description Determine the leaves that are assigned to new data.
-//
-// @param children_left vector of child node ids (right node = left node + 1)
-// @param x_pred matrix of predictors from new data
-//
-// @return a vector indicating which leaf each observation was mapped to
-void ostree_pred_leaf(){
-
- // reset values
- // this is needed for pred_leaf since every obs gets a new leaf in
- // the next tree, but it isn't needed for pred_surv because survival
- // probs get aggregated over all the trees.
- leaf_pred.fill(0);
-
- for(i = 0; i < betas.n_cols; i++){
-
- if(children_left[i] != 0){
-
- if(i == 0){
- obs_in_node = regspace(0, 1, leaf_pred.size()-1);
- } else {
- obs_in_node = find(leaf_pred == i);
- }
-
-
- if(obs_in_node.size() > 0){
-
- // Fastest sub-matrix multiplication i can think of.
- // Matrix product = linear combination of columns
- // (this is faster b/c armadillo is great at making
- // pointers to the columns of an arma mat)
- // I had to stop using this b/c it fails on
- // XB.zeros(obs_in_node.size());
- //
- // uvec col_indices_i = col_indices.unsafe_col(i);
- //
- // j = 0;
- //
- // jit = col_indices_i.begin();
- //
- // for(; jit < col_indices_i.end(); ++jit, ++j){
- //
- // vec x_j = x_pred.unsafe_col(*jit);
- //
- // XB += x_j(obs_in_node) * betas.at(j, i);
- //
- // }
-
- // this is slower but more clear matrix multiplication
- XB = x_pred(obs_in_node, col_indices.col(i)) * betas.col(i);
-
- jit = obs_in_node.begin();
-
- for(j = 0; j < XB.size(); ++j, ++jit){
-
- if(XB[j] <= cutpoints[i]) {
-
- leaf_pred[*jit] = children_left[i];
-
- } else {
-
- leaf_pred[*jit] = children_left[i]+1;
-
- }
-
- }
-
- // if(verbose > 0){
- //
- // uvec in_left = find(leaf_pred == children_left(i));
- // uvec in_right = find(leaf_pred == children_left(i)+1);
- //
- // Rcout << "N to node_" << children_left(i) << ": ";
- // Rcout << in_left.size() << "; ";
- // Rcout << "N to node_" << children_left(i)+1 << ": ";
- // Rcout << in_right.size() << std::endl;
- //
- // }
-
- }
-
- }
-
- }
-
-
-
-}
-
-// same as above but exported to R for testins
-// [[Rcpp::export]]
-arma::uvec ostree_pred_leaf_testthat(List& tree,
- NumericMatrix& x_pred_){
-
-
- x_pred = mat(x_pred_.begin(),
- x_pred_.nrow(),
- x_pred_.ncol(),
- false);
-
- leaf_pred.set_size(x_pred.n_rows);
-
- ostree = tree;
- ostree_mem_xfer();
- ostree_pred_leaf();
-
- return(leaf_pred);
-
-}
-
-// Fit an oblique survival tree
-//
-// @description used in orsf_fit, which has parameters defined below.
-//
-// @param f_beta the function used to find linear combinations of predictors
-//
-// @return a fitted oblique survival tree
-//
-List ostree_fit(Function f_beta){
-
- betas.fill(0);
- // x_mean.fill(0);
- col_indices.fill(0);
- cutpoints.fill(0);
- children_left.fill(0);
- node_assignments.fill(0);
- leaf_nodes.fill(0);
-
- node_assignments.zeros(x_inbag.n_rows);
- nodes_to_grow.zeros(1);
- nodes_max_true = 0;
- leaf_node_counter = 0;
- leaf_node_index_counter = 0;
-
- // ----------------------
- // ---- main do loop ----
- // ----------------------
-
- do {
-
- nodes_to_grow_next.set_size(0);
-
- // if(verbose > 0){
- //
- // Rcout << "----------- nodes to grow -----------" << std::endl;
- // Rcout << "nodes: "<< nodes_to_grow.t() << std::endl;
- // Rcout << "-------------------------------------" << std::endl <<
- // std::endl << std::endl;
- //
- //
- // }
-
- for(node = nodes_to_grow.begin(); node != nodes_to_grow.end(); ++node){
-
- if(nodes_to_grow[0] == 0){
-
- // when growing the first node, there is no need to find
- // which rows are in the node.
- rows_node = linspace(0,
- x_inbag.n_rows-1,
- x_inbag.n_rows);
-
- } else {
-
- // identify which rows are in the current node.
- rows_node = find(node_assignments == *node);
-
- }
-
- y_node = y_inbag.rows(rows_node);
- w_node = w_inbag(rows_node);
-
- // if(verbose > 0){
- //
- // n_risk = sum(w_node);
- // n_events = sum(y_node.col(1) % w_node);
- // Rcout << "-------- Growing node " << *node << " --------" << std::endl;
- // Rcout << "No. of observations in node: " << n_risk << std::endl;
- // Rcout << "No. of events in node: " << n_events << std::endl;
- // Rcout << "No. of rows in node: " << w_node.size() << std::endl;
- // Rcout << "--------------------------------" << std::endl;
- // Rcout << std::endl << std::endl;
- //
- // }
-
- // initialize an impossible cut-point value
- // if cutpoint is still infinite later, node should not be split
- cutpoint = R_PosInf;
-
- // ------------------------------------------------------------------
- // ---- sample a random subset of columns with non-zero variance ----
- // ------------------------------------------------------------------
-
- mtry_int = mtry;
- cols_to_sample_01.fill(0);
-
- // constant columns are constant in the rows where events occurred
-
- for(j = 0; j < cols_to_sample_01.size(); j++){
-
- temp1 = R_PosInf;
-
- for(iit = rows_node.begin()+1; iit != rows_node.end(); ++iit){
-
- if(y_inbag.at(*iit, 1) == 1){
-
- if (temp1 < R_PosInf){
-
- if(x_inbag.at(*iit, j) != temp1){
-
- cols_to_sample_01[j] = 1;
- break;
-
- }
-
- } else {
-
- temp1 = x_inbag.at(*iit, j);
-
- }
-
- }
-
- }
-
- }
-
- n_cols_to_sample = sum(cols_to_sample_01);
-
- if(n_cols_to_sample >= 1){
-
- n_events_total = sum(y_node.col(1) % w_node);
-
- if(n_cols_to_sample < mtry){
-
- mtry_int = n_cols_to_sample;
-
- // if(verbose > 0){
- // Rcout << " ---- >=1 constant column in node rows ----" << std::endl;
- // Rcout << "mtry reduced to " << mtry_temp << " from " << mtry;
- // Rcout << std::endl;
- // Rcout << "-------------------------------------------" << std::endl;
- // Rcout << std::endl << std::endl;
- // }
-
- }
-
- if (type_beta == 'C'){
-
- // make sure there are at least 3 event per predictor variable.
- // (if using CPH)
- while(n_events_total / mtry_int < 3 && mtry_int > 1){
- --mtry_int;
- }
-
- }
-
-
- n_cols_to_sample = mtry_int;
-
- // if(verbose > 0){
- // Rcout << "n_events: " << n_events_total << std::endl;
- // Rcout << "mtry: " << mtry_int << std::endl;
- // Rcout << "n_events per column: " << n_events_total/mtry_int << std::endl;
- // }
-
- if(mtry_int >= 1){
-
- cols_to_sample = find(cols_to_sample_01);
-
- // re-try hinge point
- n_retry = 0;
- cutpoint = R_PosInf;
-
- while(n_retry <= max_retry){
-
- // if(n_retry > 0) Rcout << "trying again!" << std::endl;
-
- cols_node = Rcpp::RcppArmadillo::sample(cols_to_sample,
- mtry_int,
- false);
-
- x_node = x_inbag(rows_node, cols_node);
-
- // here is where n_vars gets updated to match the current node
- // originally it matched the number of variables in the input x.
-
- n_vars = x_node.n_cols;
-
- if(cph_do_scale){
- x_node_scale();
- }
-
- // if(verbose > 0){
- //
- // uword temp_uword_1 = min(uvec {x_node.n_rows, 5});
- // Rcout << "x node scaled: " << std::endl;
- // Rcout << x_node.submat(0, 0, temp_uword_1-1, x_node.n_cols-1);
- // Rcout << std::endl;
- //
- // }
-
- switch(type_beta) {
-
- case 'C' :
-
- beta_fit = newtraph_cph();
-
- if(cph_do_scale){
- for(i = 0; i < x_transforms.n_rows; i++){
- x_node.col(i) /= x_transforms(i,1);
- x_node.col(i) += x_transforms(i,0);
- }
-
- }
-
- break;
-
- case 'N' :
-
- xx = wrap(x_node);
- yy = wrap(y_node);
- ww = wrap(w_node);
- colnames(yy) = yy_names;
-
- beta_placeholder = f_beta(xx, yy, ww,
- net_alpha,
- net_df_target);
-
- beta_fit = mat(beta_placeholder.begin(),
- beta_placeholder.nrow(),
- beta_placeholder.ncol(),
- false);
-
- break;
-
- case 'U' :
-
- xx = wrap(x_node);
- yy = wrap(y_node);
- ww = wrap(w_node);
- colnames(yy) = yy_names;
-
- beta_placeholder = f_beta(xx, yy, ww);
-
- beta_fit = mat(beta_placeholder.begin(),
- beta_placeholder.nrow(),
- beta_placeholder.ncol(),
- false);
-
- break;
-
- }
-
-
- if(any(beta_fit)){
-
- // if(verbose > 0){
- //
- // uword temp_uword_1 = min(uvec {x_node.n_rows, 5});
- // Rcout << "x node unscaled: " << std::endl;
- // Rcout << x_node.submat(0, 0, temp_uword_1-1, x_node.n_cols-1);
- // Rcout << std::endl;
- //
- // }
-
- XB = x_node * beta_fit;
- cutpoint = lrt_multi();
-
- }
-
- if(!std::isinf(cutpoint)) break;
- n_retry++;
-
- }
-
- }
-
- }
-
- if(!std::isinf(cutpoint)){
-
- // make new nodes if a valid cutpoint was found
- nn_left = nodes_max_true + 1;
- nodes_max_true = nodes_max_true + 2;
-
-
- // if(verbose > 0){
- //
- // Rcout << "-------- New nodes created --------" << std::endl;
- // Rcout << "Left node: " << nn_left << std::endl;
- // Rcout << "Right node: " << nodes_max_true << std::endl;
- // Rcout << "-----------------------------------" << std::endl <<
- // std::endl << std::endl;
- //
- // }
-
- n_events_left = n_events_total - n_events_right;
-
- // if(verbose > 0){
- // Rcout << "n_events_left: " << n_events_left << std::endl;
- // Rcout << "n_risk_left: " << n_risk_left << std::endl;
- // Rcout << "n_events_right: " << n_events_right << std::endl;
- // Rcout << "n_risk_right: " << n_risk_right << std::endl;
- // }
-
- i=0;
-
- for(iit = rows_node.begin(); iit != rows_node.end(); ++iit, ++i){
-
- node_assignments[*iit] = nn_left + group[i];
-
- }
-
- if(n_events_left >= 2*leaf_min_events &&
- n_risk_left >= 2*leaf_min_obs &&
- n_events_left >= split_min_events &&
- n_risk_left >= split_min_obs){
-
- nodes_to_grow_next = join_cols(nodes_to_grow_next,
- uvec{nn_left});
-
- } else {
-
- rows_leaf = find(group==0);
- leaf_indices(leaf_node_index_counter, 0) = nn_left;
- leaf_kaplan(y_node.rows(rows_leaf), w_node(rows_leaf));
-
- // if(verbose > 0){
- // Rcout << "-------- creating a new leaf --------" << std::endl;
- // Rcout << "name: node_" << nn_left << std::endl;
- // Rcout << "n_obs: " << sum(w_node(rows_leaf));
- // Rcout << std::endl;
- // Rcout << "n_events: ";
- // vec_temp = y_node.col(1);
- // Rcout << sum(w_node(rows_leaf) % vec_temp(rows_leaf));
- // Rcout << std::endl;
- // Rcout << "------------------------------------";
- // Rcout << std::endl << std::endl << std::endl;
- // }
-
- }
-
- if(n_events_right >= 2*leaf_min_events &&
- n_risk_right >= 2*leaf_min_obs &&
- n_events_right >= split_min_events &&
- n_risk_right >= split_min_obs){
-
- nodes_to_grow_next = join_cols(nodes_to_grow_next,
- uvec{nodes_max_true});
-
- } else {
-
- rows_leaf = find(group==1);
- leaf_indices(leaf_node_index_counter, 0) = nodes_max_true;
- leaf_kaplan(y_node.rows(rows_leaf), w_node(rows_leaf));
-
- // if(verbose > 0){
- // Rcout << "-------- creating a new leaf --------" << std::endl;
- // Rcout << "name: node_" << nodes_max_true << std::endl;
- // Rcout << "n_obs: " << sum(w_node(rows_leaf));
- // Rcout << std::endl;
- // Rcout << "n_events: ";
- // vec_temp = y_node.col(1);
- // Rcout << sum(w_node(rows_leaf) % vec_temp(rows_leaf));
- // Rcout << std::endl;
- // Rcout << "------------------------------------";
- // Rcout << std::endl << std::endl << std::endl;
- // }
-
- }
-
- if(nodes_max_true >= betas.n_cols) ostree_size_buffer();
-
- for(i = 0; i < n_cols_to_sample; i++){
- betas.at(i, *node) = beta_fit[i];
- // x_mean.at(i, *node) = x_transforms(i, 0);
- col_indices.at(i, *node) = cols_node[i];
- }
-
- children_left[*node] = nn_left;
- cutpoints[*node] = cutpoint;
-
- } else {
-
- // make a leaf node if a valid cutpoint could not be found
- leaf_indices(leaf_node_index_counter, 0) = *node;
- leaf_kaplan(y_node, w_node);
-
- // if(verbose > 0){
- // Rcout << "-------- creating a new leaf --------" << std::endl;
- // Rcout << "name: node_" << *node << std::endl;
- // Rcout << "n_obs: " << sum(w_node) << std::endl;
- // Rcout << "n_events: " << sum(w_node % y_node.col(1));
- // Rcout << std::endl;
- // Rcout << "Couldn't find a cutpoint??" << std::endl;
- // Rcout << "------------------------------------" << std::endl;
- // Rcout << std::endl << std::endl;
- // }
-
- }
-
- }
-
- nodes_to_grow = nodes_to_grow_next;
-
- } while (nodes_to_grow.size() > 0);
-
- return(
- List::create(
-
- _["leaf_nodes"] = leaf_nodes.rows(span(0, leaf_node_counter-1)),
-
- _["leaf_node_index"] = conv_to::from(
- leaf_indices.rows(span(0, leaf_node_index_counter-1))
- ),
-
- _["betas"] = betas.cols(span(0, nodes_max_true)),
-
- // _["x_mean"] = x_mean.cols(span(0, nodes_max_true)),
-
- _["col_indices"] = conv_to::from(
- col_indices.cols(span(0, nodes_max_true))
- ),
-
- _["cut_points"] = cutpoints(span(0, nodes_max_true)),
-
- _["children_left"] = conv_to::from(
- children_left(span(0, nodes_max_true))
- ),
-
- _["rows_oobag"] = conv_to::from(rows_oobag)
-
- )
- );
-
-
-}
-
-// ----------------------------------------------------------------------------
-// ---------------------------- orsf functions --------------------------------
-// ----------------------------------------------------------------------------
-
-// fit an oblique random survival forest.
-//
-// @param x matrix of predictors
-// @param y matrix of outcomes
-// @param weights vector of weights
-// @param n_tree number of trees to fit
-// @param n_split_ number of splits to try with lrt
-// @param mtry_ number of predictors to try
-// @param leaf_min_events_ min number of events in a leaf
-// @param leaf_min_obs_ min number of observations in a leaf
-// @param split_min_events_ min number of events to split a node
-// @param split_min_obs_ min number of observations to split a node
-// @param split_min_stat_ min lrt to split a node
-// @param cph_method_ method for ties
-// @param cph_eps_ criteria for convergence of newton raphson algorithm
-// @param cph_iter_max_ max number of newton raphson iterations
-// @param cph_do_scale_ to scale or not to scale
-// @param net_alpha_ alpha parameter for glmnet
-// @param net_df_target_ degrees of freedom for glmnet
-// @param oobag_pred_ whether to predict out-of-bag preds or not
-// @param oobag_pred_type_ what type of out-of-bag preds to compute
-// @param oobag_pred_horizon_ out-of-bag prediction horizon
-// @param oobag_eval_every_ trees between each evaluation of oob error
-// @param oobag_importance_ to compute importance or not
-// @param oobag_importance_type_ type of importance to compute
-// @param tree_seeds vector of seeds to set before each tree is fit
-// @param max_retry_ max number of retries for linear combinations
-// @param f_beta function to find linear combinations of predictors
-// @param type_beta_ what type of linear combination to find
-// @param f_oobag_eval function to evaluate out-of-bag error
-// @param type_oobag_eval_ whether to use default or custom out-of-bag error
-//
-// @return an orsf_fit object sent back to R
-
-// [[Rcpp::export]]
-List orsf_fit(NumericMatrix& x,
- NumericMatrix& y,
- NumericVector& weights,
- const int& n_tree,
- const int& n_split_,
- const int& mtry_,
- const double& leaf_min_events_,
- const double& leaf_min_obs_,
- const double& split_min_events_,
- const double& split_min_obs_,
- const double& split_min_stat_,
- const int& cph_method_,
- const double& cph_eps_,
- const int& cph_iter_max_,
- const bool& cph_do_scale_,
- const double& net_alpha_,
- const int& net_df_target_,
- const bool& oobag_pred_,
- const char& oobag_pred_type_,
- const double& oobag_pred_horizon_,
- const int& oobag_eval_every_,
- const bool& oobag_importance_,
- const char& oobag_importance_type_,
- IntegerVector& tree_seeds,
- const int& max_retry_,
- Function f_beta,
- const char& type_beta_,
- Function f_oobag_eval,
- const char& type_oobag_eval_,
- const bool verbose_progress){
-
-
- // convert inputs into arma objects
- x_input = mat(x.begin(), x.nrow(), x.ncol(), false);
-
- y_input = mat(y.begin(), y.nrow(), y.ncol(), false);
-
- w_user = vec(weights.begin(), weights.length(), false);
-
- // these change later in ostree_fit()
- n_rows = x_input.n_rows;
- n_vars = x_input.n_cols;
-
- // initialize the variable importance (vi) vectors
- vi_pval_numer.zeros(n_vars);
- vi_pval_denom.zeros(n_vars);
-
- // if(verbose > 0){
- // Rcout << "------------ dimensions ------------" << std::endl;
- // Rcout << "N obs total: " << n_rows << std::endl;
- // Rcout << "N columns total: " << n_vars << std::endl;
- // Rcout << "------------------------------------";
- // Rcout << std::endl << std::endl << std::endl;
- // }
-
- n_split = n_split_;
- mtry = mtry_;
- leaf_min_events = leaf_min_events_;
- leaf_min_obs = leaf_min_obs_;
- split_min_events = split_min_events_;
- split_min_obs = split_min_obs_;
- split_min_stat = split_min_stat_;
- cph_method = cph_method_;
- cph_eps = cph_eps_;
- cph_iter_max = cph_iter_max_;
- cph_do_scale = cph_do_scale_;
- net_alpha = net_alpha_;
- net_df_target = net_df_target_;
- oobag_pred = oobag_pred_;
- oobag_pred_type = oobag_pred_type_;
- oobag_eval_every = oobag_eval_every_;
- oobag_eval_counter = 0;
- oobag_importance = oobag_importance_;
- oobag_importance_type = oobag_importance_type_;
- use_tree_seed = tree_seeds.length() > 0;
- max_retry = max_retry_;
- type_beta = type_beta_;
- type_oobag_eval = type_oobag_eval_;
- temp1 = 1.0 / n_rows;
-
- if(cph_iter_max > 1) cph_do_scale = true;
-
- if((type_beta == 'N') || (type_beta == 'U')) cph_do_scale = false;
-
- if(cph_iter_max == 1) cph_do_scale = false;
-
-
- if(oobag_pred){
-
- time_pred = oobag_pred_horizon_;
-
- if(time_pred == 0) time_pred = median(y_input.col(0));
-
- eval_oobag.set_size(std::floor(n_tree / oobag_eval_every));
-
- } else {
-
- eval_oobag.set_size(0);
-
- }
-
- // if(verbose > 0){
- // Rcout << "------------ input variables ------------" << std::endl;
- // Rcout << "n_split: " << n_split << std::endl;
- // Rcout << "mtry: " << mtry << std::endl;
- // Rcout << "leaf_min_events: " << leaf_min_events << std::endl;
- // Rcout << "leaf_min_obs: " << leaf_min_obs << std::endl;
- // Rcout << "cph_method: " << cph_method << std::endl;
- // Rcout << "cph_eps: " << cph_eps << std::endl;
- // Rcout << "cph_iter_max: " << cph_iter_max << std::endl;
- // Rcout << "-----------------------------------------" << std::endl;
- // Rcout << std::endl << std::endl;
- // }
-
- // ----------------------------------------------------
- // ---- sample weights to mimic a bootstrap sample ----
- // ----------------------------------------------------
-
- // s is the number of times you might get selected into
- // a bootstrap sample. Realistically this won't be >10,
- // but it could technically be as big as n_row.
- IntegerVector s = seq(0, 10);
-
- // compute probability of being selected into the bootstrap
- // 0 times, 1, times, ..., 9 times, or 10 times.
- NumericVector probs = dbinom(s, n_rows, temp1, false);
-
- // ---------------------------------------------
- // ---- preallocate memory for tree outputs ----
- // ---------------------------------------------
-
- cols_to_sample_01.zeros(n_vars);
- leaf_nodes.zeros(n_rows, 3);
-
- if(oobag_pred){
-
- surv_pvec.zeros(n_rows);
- denom_pred.zeros(n_rows);
-
- } else {
-
- surv_pvec.set_size(0);
- denom_pred.set_size(0);
-
- }
-
- // guessing the number of nodes needed to grow a tree
- nodes_max_guess = std::ceil(0.5 * n_rows / leaf_min_events);
-
- betas.zeros(mtry, nodes_max_guess);
- // x_mean.zeros(mtry, nodes_max_guess);
- col_indices.zeros(mtry, nodes_max_guess);
- cutpoints.zeros(nodes_max_guess);
- children_left.zeros(nodes_max_guess);
- leaf_indices.zeros(nodes_max_guess, 3);
-
- // some great variable names here
- List forest(n_tree);
-
- for(tree = 0; tree < n_tree; ){
-
- // Abort the routine if user has pressed Ctrl + C or Escape in R.
- Rcpp::checkUserInterrupt();
-
- // --------------------------------------------
- // ---- initialize parameters to grow tree ----
- // --------------------------------------------
-
- // rows_inbag = find(w_inbag != 0);
-
- if(use_tree_seed) set_seed_r(tree_seeds[tree]);
-
- w_input = as(sample(s, n_rows, true, probs));
-
- // if the user gives a weight vector, then each bootstrap weight
- // should be multiplied by the corresponding user weight.
- if(w_user.size() > 0) w_input = w_input % w_user;
-
- rows_oobag = find(w_input == 0);
- rows_inbag = regspace(0, n_rows-1);
- rows_inbag = std_setdiff(rows_inbag, rows_oobag);
- w_inbag = w_input(rows_inbag);
-
- // if(verbose > 0){
- //
- // Rcout << "------------ boot weights ------------" << std::endl;
- // Rcout << "pr(inbag): " << 1-pow(1-temp1,n_rows) << std::endl;
- // Rcout << "total: " << sum(w_inbag) << std::endl;
- // Rcout << "N > 0: " << rows_inbag.size() << std::endl;
- // Rcout << "--------------------------------------" <<
- // std::endl << std::endl << std::endl;
- //
- // }
-
- x_inbag = x_input.rows(rows_inbag);
- y_inbag = y_input.rows(rows_inbag);
-
- if(oobag_pred){
- x_pred = x_input.rows(rows_oobag);
- leaf_pred.set_size(rows_oobag.size());
- }
-
- // if(verbose > 0){
- //
- // uword temp_uword_1, temp_uword_2;
- //
- // if(x_inbag.n_rows < 5)
- // temp_uword_1 = x_inbag.n_rows-1;
- // else
- // temp_uword_1 = 5;
- //
- // if(x_inbag.n_cols < 5)
- // temp_uword_2 = x_inbag.n_cols-1;
- // else
- // temp_uword_2 = 4;
- //
- // Rcout << "x inbag: " << std::endl <<
- // x_inbag.submat(0, 0,
- // temp_uword_1,
- // temp_uword_2) << std::endl;
- //
- // }
-
- if(verbose_progress){
- Rcout << "\r growing tree no. " << tree << " of " << n_tree;
- }
-
-
- forest[tree] = ostree_fit(f_beta);
-
- // add 1 to tree here instead of end of loop
- // (more convenient to compute tree % oobag_eval_every)
- tree++;
-
-
- if(oobag_pred){
-
- denom_pred(rows_oobag) += 1;
- ostree_pred_leaf();
- oobag_pred_surv_uni(oobag_pred_type);
-
- if(tree % oobag_eval_every == 0){
-
- switch(type_oobag_eval) {
-
- // H stands for Harrell's C-statistic
- case 'H' :
-
- eval_oobag[oobag_eval_counter] = oobag_c_harrell(oobag_pred_type);
- oobag_eval_counter++;
-
- break;
-
- // U stands for a user-supplied function
- case 'U' :
-
- ww = wrap(surv_pvec);
-
- eval_oobag[oobag_eval_counter] = as(
- f_oobag_eval(y, ww)
- );
-
- oobag_eval_counter++;
-
- break;
-
- }
-
-
- }
-
- }
-
- }
-
- if(verbose_progress){
- Rcout << std::endl;
- }
-
- vec vimp(x_input.n_cols);
-
- // ANOVA importance
- if(oobag_importance_type == 'A') vimp = vi_pval_numer / vi_pval_denom;
-
- // if we are computing variable importance, surv_pvec is about
- // to get modified, and we don't want to return the modified
- // version of surv_pvec.
- // So make a deep copy if oobag_importance is true.
- // Make a shallow copy if oobag_importance is false
- surv_pvec_output = vec(surv_pvec.begin(),
- surv_pvec.size(),
- oobag_importance);
-
- if(oobag_importance && n_tree > 0){
-
- uvec betas_to_flip;
- // vec betas_temp;
- oobag_eval_counter--;
-
- for(uword variable = 0; variable < x_input.n_cols; ++variable){
-
- surv_pvec.fill(0);
- denom_pred.fill(0);
-
- for(tree = 0; tree < n_tree; ++tree){
-
- ostree = forest[tree];
-
- IntegerMatrix rows_oobag_ = ostree["rows_oobag"];
-
- rows_oobag = conv_to::from(
- ivec(rows_oobag_.begin(),
- rows_oobag_.length(),
- false)
- );
-
- x_pred = x_input.rows(rows_oobag);
-
- if(oobag_importance_type == 'P'){
- x_pred.col(variable) = shuffle(x_pred.col(variable));
- }
-
- ostree_mem_xfer();
-
-
- if(oobag_importance_type == 'N'){
- betas_to_flip = find(col_indices == variable);
- //betas_temp = betas.elem( betas_to_flip );
- betas.elem( betas_to_flip ) *= (-1);
- //betas.elem( betas_to_flip ) *= 0;
- }
-
- denom_pred(rows_oobag) += 1;
-
- leaf_pred.set_size(rows_oobag.size());
-
- ostree_pred_leaf();
-
- oobag_pred_surv_uni(oobag_pred_type);
-
- if(oobag_importance_type == 'N'){
- betas.elem( betas_to_flip ) *= (-1);
- // betas.elem( betas_to_flip ) = betas_temp;
- }
-
- }
-
- switch(type_oobag_eval) {
-
- // H stands for Harrell's C-statistic
- case 'H' :
-
- vimp(variable) = eval_oobag[oobag_eval_counter] -
- oobag_c_harrell(oobag_pred_type);
-
- break;
-
- // U stands for a user-supplied function
- case 'U' :
-
- ww = wrap(surv_pvec);
-
- vimp(variable) =
- eval_oobag[oobag_eval_counter] - as(f_oobag_eval(y, ww));
-
-
- break;
-
- }
-
- }
-
- }
-
- if(oobag_pred_type == 'R') surv_pvec_output = 1 - surv_pvec_output;
-
- return(
- List::create(
- _["forest"] = forest,
- _["pred_oobag"] = surv_pvec_output,
- _["pred_horizon"] = time_pred,
- _["eval_oobag"] = List::create(_["stat_values"] = eval_oobag,
- _["stat_type"] = type_oobag_eval),
- _["importance"] = vimp
- )
- );
-
-
-}
-
-// @description compute negation importance
-//
-// @param x matrix of predictors
-// @param y outcome matrix
-// @param forest forest object from an orsf_fit
-// @param last_eval_stat the last estimate of out-of-bag error
-// @param time_pred_ the prediction horizon
-// @param f_oobag_eval function used to evaluate out-of-bag error
-// @param pred_type_ the type of prediction to compute
-// @param type_oobag_eval_ custom or default out-of-bag predictions
-//
-// @return a vector of importance values
-//
-// [[Rcpp::export]]
-arma::vec orsf_oob_negate_vi(NumericMatrix& x,
- NumericMatrix& y,
- List& forest,
- const double& last_eval_stat,
- const double& time_pred_,
- Function f_oobag_eval,
- const char& pred_type_,
- const char& type_oobag_eval_){
-
- x_input = mat(x.begin(), x.nrow(), x.ncol(), false);
- y_input = mat(y.begin(), y.nrow(), y.ncol(), false);
-
- time_pred = time_pred_;
- type_oobag_eval = type_oobag_eval_;
- oobag_pred_type = pred_type_;
-
- vec vimp(x_input.n_cols);
-
- uvec betas_to_flip;
- // vec betas_temp;
- uword variable;
-
- denom_pred.set_size(x_input.n_rows);
- surv_pvec.set_size(x_input.n_rows);
-
- for(variable = 0; variable < x_input.n_cols; ++variable){
-
- // Abort the routine if user has pressed Ctrl + C or Escape in R.
- Rcpp::checkUserInterrupt();
-
- surv_pvec.fill(0);
- denom_pred.fill(0);
-
- for(tree = 0; tree < forest.length(); ++tree){
-
- ostree = forest[tree];
-
- IntegerMatrix rows_oobag_ = ostree["rows_oobag"];
-
- rows_oobag = conv_to::from(
- ivec(rows_oobag_.begin(),
- rows_oobag_.length(),
- false)
- );
-
- x_pred = x_input.rows(rows_oobag);
-
- ostree_mem_xfer();
-
- betas_to_flip = find(col_indices == variable);
-
- // betas_temp = betas.elem( betas_to_flip );
- // betas.elem( betas_to_flip ) *= 0;
-
- betas.elem( betas_to_flip ) *= (-1);
-
- denom_pred(rows_oobag) += 1;
-
- leaf_pred.set_size(rows_oobag.size());
-
- ostree_pred_leaf();
-
- oobag_pred_surv_uni(oobag_pred_type);
-
- betas.elem( betas_to_flip ) *= (-1);
- // betas.elem( betas_to_flip ) = betas_temp;
-
- }
-
- switch(type_oobag_eval) {
-
- // H stands for Harrell's C-statistic
- case 'H' :
-
- vimp(variable) = last_eval_stat - oobag_c_harrell(oobag_pred_type);
-
- break;
-
- // U stands for a user-supplied function
- case 'U' :
-
- ww = wrap(surv_pvec);
-
- vimp(variable) = last_eval_stat - as(f_oobag_eval(y, ww));
-
- break;
-
- }
-
- }
-
- return(vimp);
-
-}
-
-// same as above but computes permutation importance instead of negation
-// [[Rcpp::export]]
-arma::vec orsf_oob_permute_vi(NumericMatrix& x,
- NumericMatrix& y,
- List& forest,
- const double& last_eval_stat,
- const double& time_pred_,
- Function f_oobag_eval,
- const char& pred_type_,
- const char& type_oobag_eval_){
-
- x_input = mat(x.begin(), x.nrow(), x.ncol(), false);
- y_input = mat(y.begin(), y.nrow(), y.ncol(), false);
-
- time_pred = time_pred_;
- type_oobag_eval = type_oobag_eval_;
- oobag_pred_type = pred_type_;
-
- vec vimp(x_input.n_cols);
-
- uword variable;
-
- denom_pred.set_size(x_input.n_rows);
- surv_pvec.set_size(x_input.n_rows);
-
- for(variable = 0; variable < x_input.n_cols; ++variable){
-
- // Abort the routine if user has pressed Ctrl + C or Escape in R.
- Rcpp::checkUserInterrupt();
-
- surv_pvec.fill(0);
- denom_pred.fill(0);
-
- for(tree = 0; tree < forest.length(); ++tree){
-
- ostree = forest[tree];
-
- IntegerMatrix rows_oobag_ = ostree["rows_oobag"];
-
- rows_oobag = conv_to::from(
- ivec(rows_oobag_.begin(),
- rows_oobag_.length(),
- false)
- );
-
- x_pred = x_input.rows(rows_oobag);
-
- x_pred.col(variable) = shuffle(x_pred.col(variable));
-
- ostree_mem_xfer();
-
- denom_pred(rows_oobag) += 1;
-
- leaf_pred.set_size(rows_oobag.size());
-
- ostree_pred_leaf();
-
- oobag_pred_surv_uni(oobag_pred_type);
-
- // x_variable = x_variable_original;
- // x_input.col(variable) = x_variable;
-
- }
-
- switch(type_oobag_eval) {
-
- // H stands for Harrell's C-statistic
- case 'H' :
-
- vimp(variable) = last_eval_stat - oobag_c_harrell(oobag_pred_type);
-
- break;
-
- // U stands for a user-supplied function
- case 'U' :
-
- ww = wrap(surv_pvec);
-
- vimp(variable) = last_eval_stat - as(f_oobag_eval(y, ww));
-
- break;
-
- }
-
- }
-
- return(vimp);
-
-}
-
-// predictions from an oblique random survival forest
-//
-// @description makes predictions based on a single horizon
-//
-// @param forest forest object from orsf_fit object
-// @param x_new matrix of predictors
-// @param time_dbl prediction horizon
-// @param pred_type type of prediction to compute
-//
-// [[Rcpp::export]]
-arma::mat orsf_pred_uni(List& forest,
- NumericMatrix& x_new,
- double time_dbl,
- char pred_type){
-
- x_pred = mat(x_new.begin(), x_new.nrow(), x_new.ncol(), false);
- time_pred = time_dbl;
-
- // memory for outputs
- leaf_pred.set_size(x_pred.n_rows);
- surv_pvec.zeros(x_pred.n_rows);
-
- for(tree = 0; tree < forest.length(); ++tree){
- ostree = forest[tree];
- ostree_mem_xfer();
- ostree_pred_leaf();
- new_pred_surv_uni(pred_type);
- }
-
- surv_pvec /= tree;
-
- if(pred_type == 'R'){
- return(1 - surv_pvec);
- } else {
- return(surv_pvec);
- }
-
-}
-
-// same as above but makes predictions for multiple horizons
-// [[Rcpp::export]]
-arma::mat orsf_pred_multi(List& forest,
- NumericMatrix& x_new,
- NumericVector& time_vec,
- char pred_type){
-
- x_pred = mat(x_new.begin(), x_new.nrow(), x_new.ncol(), false);
- times_pred = vec(time_vec.begin(), time_vec.length(), false);
-
- // memory for outputs
- // initial values don't matter for leaf_pred,
- // but do matter for surv_pmat
- leaf_pred.set_size(x_pred.n_rows);
- surv_pmat.zeros(x_pred.n_rows, times_pred.size());
-
- for(tree = 0; tree < forest.length(); ++tree){
- ostree = forest[tree];
- ostree_mem_xfer();
- ostree_pred_leaf();
- new_pred_surv_multi(pred_type);
- }
-
- surv_pmat /= tree;
-
- if(pred_type == 'R'){
- return(1 - surv_pmat);
- } else {
- return(surv_pmat);
- }
-
-}
-
-// partial dependence for new data
-//
-// @description calls predict on the data with a predictor fixed
-// and then summarizes the predictions.
-//
-// @param forest a forest object from an orsf_fit object
-// @param x_new_ matrix of predictors
-// @param x_cols_ columns of variables of interest
-// @param x_vals_ values to set these columsn to
-// @param probs_ for quantiles
-// @param time_dbl prediction horizon
-// @param pred_type prediction type
-//
-// @return matrix with partial dependence
-// [[Rcpp::export]]
-arma::mat pd_new_smry(List& forest,
- NumericMatrix& x_new_,
- IntegerVector& x_cols_,
- NumericMatrix& x_vals_,
- NumericVector& probs_,
- const double time_dbl,
- char pred_type){
-
-
- uword pd_i;
-
- time_pred = time_dbl;
-
- x_pred = mat(x_new_.begin(), x_new_.nrow(), x_new_.ncol(), false);
-
- mat x_vals = mat(x_vals_.begin(), x_vals_.nrow(), x_vals_.ncol(), false);
-
- uvec x_cols = conv_to::from(
- ivec(x_cols_.begin(), x_cols_.length(), false)
- );
-
- vec probs = vec(probs_.begin(), probs_.length(), false);
-
- mat output_quantiles(probs.size(), x_vals.n_rows);
- mat output_means(1, x_vals.n_rows);
-
- leaf_pred.set_size(x_pred.n_rows);
- surv_pvec.set_size(x_pred.n_rows);
-
- for(pd_i = 0; pd_i < x_vals.n_rows; pd_i++){
-
- // Abort the routine if user has pressed Ctrl + C or Escape in R.
- Rcpp::checkUserInterrupt();
-
- j = 0;
-
- surv_pvec.fill(0);
-
- for(jit = x_cols.begin(); jit < x_cols.end(); ++jit, ++j){
-
- x_pred.col(*jit).fill(x_vals(pd_i, j));
-
- }
-
- for(tree = 0; tree < forest.length(); ++tree){
- ostree = forest[tree];
- ostree_mem_xfer();
- ostree_pred_leaf();
- new_pred_surv_uni(pred_type);
- }
-
- surv_pvec /= tree;
-
- if(pred_type == 'R'){ surv_pvec = 1 - surv_pvec; }
-
- output_means.col(pd_i) = mean(surv_pvec);
- output_quantiles.col(pd_i) = quantile(surv_pvec, probs);
-
-
- }
-
- return(join_vert(output_means, output_quantiles));
-
-}
-
-
-// same as above but for out-of-bag data
-// [[Rcpp::export]]
-arma::mat pd_oob_smry(List& forest,
- NumericMatrix& x_new_,
- IntegerVector& x_cols_,
- NumericMatrix& x_vals_,
- NumericVector& probs_,
- const double time_dbl,
- char pred_type){
-
-
- uword pd_i;
-
- time_pred = time_dbl;
-
- mat x_vals = mat(x_vals_.begin(), x_vals_.nrow(), x_vals_.ncol(), false);
-
- uvec x_cols = conv_to::from(
- ivec(x_cols_.begin(), x_cols_.length(), false)
- );
-
- vec probs = vec(probs_.begin(), probs_.length(), false);
-
- mat output_quantiles(probs.size(), x_vals.n_rows);
- mat output_means(1, x_vals.n_rows);
-
- x_input = mat(x_new_.begin(), x_new_.nrow(), x_new_.ncol(), false);
- denom_pred.set_size(x_input.n_rows);
- surv_pvec.set_size(x_input.n_rows);
-
- for(pd_i = 0; pd_i < x_vals.n_rows; pd_i++){
-
- // Abort the routine if user has pressed Ctrl + C or Escape in R.
- Rcpp::checkUserInterrupt();
-
- j = 0;
- denom_pred.fill(0);
- surv_pvec.fill(0);
-
- for(jit = x_cols.begin(); jit < x_cols.end(); ++jit, ++j){
-
- x_input.col(*jit).fill(x_vals(pd_i, j));
-
- }
-
- for(tree = 0; tree < forest.length(); ++tree){
-
- ostree = forest[tree];
-
- IntegerMatrix rows_oobag_ = ostree["rows_oobag"];
-
- rows_oobag = conv_to::from(
- ivec(rows_oobag_.begin(),
- rows_oobag_.length(),
- false)
- );
-
- x_pred = x_input.rows(rows_oobag);
- leaf_pred.set_size(x_pred.n_rows);
- denom_pred(rows_oobag) += 1;
-
- ostree_mem_xfer();
- ostree_pred_leaf();
- oobag_pred_surv_uni(pred_type);
-
-
- }
-
- if(pred_type == 'R'){ surv_pvec = 1 - surv_pvec; }
-
- output_means.col(pd_i) = mean(surv_pvec);
- output_quantiles.col(pd_i) = quantile(surv_pvec, probs);
-
-
- }
-
-
- return(join_vert(output_means, output_quantiles));
-
-}
-
-// same as above but doesn't summarize the predictions
-// [[Rcpp::export]]
-arma::mat pd_new_ice(List& forest,
- NumericMatrix& x_new_,
- IntegerVector& x_cols_,
- NumericMatrix& x_vals_,
- NumericVector& probs_,
- const double time_dbl,
- char pred_type){
-
-
- uword pd_i;
-
- time_pred = time_dbl;
-
- x_pred = mat(x_new_.begin(), x_new_.nrow(), x_new_.ncol(), false);
-
- mat x_vals = mat(x_vals_.begin(), x_vals_.nrow(), x_vals_.ncol(), false);
-
- uvec x_cols = conv_to::from(
- ivec(x_cols_.begin(), x_cols_.length(), false)
- );
-
- vec probs = vec(probs_.begin(), probs_.length(), false);
-
- mat output_ice(x_vals.n_rows * x_pred.n_rows, 2);
- vec output_ids = output_ice.unsafe_col(0);
- vec output_pds = output_ice.unsafe_col(1);
-
- uvec pd_rows = regspace(0, 1, x_pred.n_rows - 1);
-
- leaf_pred.set_size(x_pred.n_rows);
- surv_pvec.set_size(x_pred.n_rows);
-
- for(pd_i = 0; pd_i < x_vals.n_rows; pd_i++){
-
- // Abort the routine if user has pressed Ctrl + C or Escape in R.
- Rcpp::checkUserInterrupt();
-
- j = 0;
-
- surv_pvec.fill(0);
-
- for(jit = x_cols.begin(); jit < x_cols.end(); ++jit, ++j){
-
- x_pred.col(*jit).fill(x_vals(pd_i, j));
-
- }
-
- for(tree = 0; tree < forest.length(); ++tree){
- ostree = forest[tree];
- ostree_mem_xfer();
- ostree_pred_leaf();
- new_pred_surv_uni(pred_type);
- }
-
- surv_pvec /= tree;
-
- if(pred_type == 'R'){ surv_pvec = 1 - surv_pvec; }
-
- output_ids(pd_rows).fill(pd_i+1);
- output_pds(pd_rows) = surv_pvec;
- pd_rows += x_pred.n_rows;
-
-
- }
-
- return(output_ice);
-
-}
-
-// same as above but out-of-bag and doesn't summarize the predictions
-// [[Rcpp::export]]
-arma::mat pd_oob_ice(List& forest,
- NumericMatrix& x_new_,
- IntegerVector& x_cols_,
- NumericMatrix& x_vals_,
- NumericVector& probs_,
- const double time_dbl,
- char pred_type){
-
-
- uword pd_i;
-
- time_pred = time_dbl;
-
- mat x_vals = mat(x_vals_.begin(), x_vals_.nrow(), x_vals_.ncol(), false);
-
- uvec x_cols = conv_to::from(
- ivec(x_cols_.begin(), x_cols_.length(), false)
- );
-
- x_input = mat(x_new_.begin(), x_new_.nrow(), x_new_.ncol(), false);
-
- mat output_ice(x_vals.n_rows * x_input.n_rows, 2);
- vec output_ids = output_ice.unsafe_col(0);
- vec output_pds = output_ice.unsafe_col(1);
-
- uvec pd_rows = regspace(0, 1, x_input.n_rows - 1);
-
- denom_pred.set_size(x_input.n_rows);
- surv_pvec.set_size(x_input.n_rows);
-
- for(pd_i = 0; pd_i < x_vals.n_rows; pd_i++){
-
- // Abort the routine if user has pressed Ctrl + C or Escape in R.
- Rcpp::checkUserInterrupt();
-
- j = 0;
- denom_pred.fill(0);
- surv_pvec.fill(0);
-
- for(jit = x_cols.begin(); jit < x_cols.end(); ++jit, ++j){
-
- x_input.col(*jit).fill(x_vals(pd_i, j));
-
- }
-
- for(tree = 0; tree < forest.length(); ++tree){
-
- ostree = forest[tree];
-
- IntegerMatrix rows_oobag_ = ostree["rows_oobag"];
-
- rows_oobag = conv_to::from(
- ivec(rows_oobag_.begin(),
- rows_oobag_.length(),
- false)
- );
-
- x_pred = x_input.rows(rows_oobag);
- leaf_pred.set_size(x_pred.n_rows);
- denom_pred(rows_oobag) += 1;
-
- ostree_mem_xfer();
- ostree_pred_leaf();
- oobag_pred_surv_uni(pred_type);
-
-
- }
-
- if(pred_type == 'R'){ surv_pvec = 1 - surv_pvec; }
-
- output_ids(pd_rows).fill(pd_i+1);
- output_pds(pd_rows) = surv_pvec;
- pd_rows += x_input.n_rows;
-
-
- }
-
- return(output_ice);
-
-}
-
-
-
+//
+// #include
+// #include
+//
+// // [[Rcpp::depends(RcppArmadillo)]]
+//
+//
+// using namespace Rcpp;
+// using namespace arma;
+//
+// // ----------------------------------------------------------------------------
+// // ---------------------------- global parameters -----------------------------
+// // ----------------------------------------------------------------------------
+//
+// // special note: dont change these doubles to uword,
+// // even though some of them could be uwords;
+// // operations involving uwords and doubles are not
+// // straightforward and may break the routine.
+// // also: double + uword is slower than double + double.
+//
+// double
+// weight_avg,
+// weight_events,
+// w_node_sum,
+// denom_events,
+// denom,
+// cph_eps,
+// // the n_ variables could be integers but it
+// // is safer and faster when they are doubles
+// n_events,
+// n_events_total,
+// n_events_right,
+// n_events_left,
+// n_risk,
+// n_risk_right,
+// n_risk_left,
+// n_risk_sub,
+// g_risk,
+// temp1,
+// temp2,
+// temp3,
+// halving,
+// stat_current,
+// stat_best,
+// w_node_person,
+// xb,
+// risk,
+// loglik,
+// cutpoint,
+// observed,
+// expected,
+// V,
+// pred_t0,
+// leaf_min_obs,
+// leaf_min_events,
+// split_min_events,
+// split_min_obs,
+// split_min_stat,
+// time_pred,
+// ll_second,
+// ll_init,
+// net_alpha;
+//
+// int
+// // verbose=0,
+// max_retry,
+// n_retry,
+// tree,
+// mtry_int,
+// net_df_target,
+// oobag_eval_every;
+//
+// char
+// type_beta,
+// type_oobag_eval,
+// oobag_pred_type,
+// oobag_importance_type,
+// pred_type_dflt = 'S';
+//
+// // armadillo unsigned integers
+// uword
+// i,
+// j,
+// k,
+// iter,
+// mtry,
+// mtry_temp,
+// person,
+// person_leaf,
+// person_ref_index,
+// n_vars,
+// n_rows,
+// cph_method,
+// cph_iter_max,
+// n_split,
+// nodes_max_guess,
+// nodes_max_true,
+// n_cols_to_sample,
+// nn_left,
+// leaf_node_counter,
+// leaf_node_index_counter,
+// leaf_node_col,
+// oobag_eval_counter;
+//
+// bool
+// break_loop, // a delayed break statement
+// oobag_pred,
+// oobag_importance,
+// use_tree_seed,
+// cph_do_scale;
+//
+// // armadillo vectors (doubles)
+// vec
+// vec_temp,
+// times_pred,
+// eval_oobag,
+// node_assignments,
+// nodes_grown,
+// surv_pvec,
+// surv_pvec_output,
+// denom_pred,
+// beta_current,
+// beta_new,
+// beta_fit,
+// vi_pval_numer,
+// vi_pval_denom,
+// cutpoints,
+// w_input,
+// w_inbag,
+// w_user,
+// w_node,
+// group,
+// u,
+// a,
+// a2,
+// XB,
+// Risk;
+//
+// // armadillo unsigned integer vectors
+// uvec
+// iit_vals,
+// jit_vals,
+// rows_inbag,
+// rows_oobag,
+// rows_node,
+// rows_leaf,
+// rows_node_combined,
+// cols_to_sample_01,
+// cols_to_sample,
+// cols_node,
+// leaf_node_index,
+// nodes_to_grow,
+// nodes_to_grow_next,
+// obs_in_node,
+// children_left,
+// leaf_pred;
+//
+// // armadillo iterators for unsigned integer vectors
+// uvec::iterator
+// iit,
+// iit_best,
+// jit,
+// node;
+//
+// // armadillo matrices (doubles)
+// mat
+// x_input,
+// x_transforms,
+// y_input,
+// x_inbag,
+// y_inbag,
+// x_node,
+// y_node,
+// x_pred,
+// // x_mean,
+// vmat,
+// cmat,
+// cmat2,
+// betas,
+// leaf_node,
+// leaf_nodes,
+// surv_pmat;
+//
+// umat
+// col_indices,
+// leaf_indices;
+//
+// cube
+// surv_pcube;
+//
+// List ostree;
+//
+// NumericMatrix
+// beta_placeholder,
+// xx,
+// yy;
+//
+// CharacterVector yy_names = CharacterVector::create("time","status");
+//
+// NumericVector ww;
+//
+// Environment base_env("package:base");
+//
+// Function set_seed_r = base_env["set.seed"];
+//
+// // Set difference for arma vectors
+// //
+// // @description the same as setdiff() in R
+// //
+// // @param x first vector
+// // @param y second vector
+// //
+// // [[Rcpp::export]]
+// arma::uvec std_setdiff(arma::uvec& x, arma::uvec& y) {
+//
+// std::vector a = conv_to< std::vector >::from(sort(x));
+// std::vector b = conv_to< std::vector >::from(sort(y));
+// std::vector out;
+//
+// std::set_difference(a.begin(), a.end(),
+// b.begin(), b.end(),
+// std::inserter(out, out.end()));
+//
+// return conv_to::from(out);
+//
+// }
+//
+// // ----------------------------------------------------------------------------
+// // ---------------------------- scaling functions -----------------------------
+// // ----------------------------------------------------------------------------
+//
+// // scale observations in predictor matrix
+// //
+// // @description this scales inputs in the same way as
+// // the survival::coxph() function. The main reasons we do this
+// // are to avoid exponential overflow and to prevent the scale
+// // of inputs from impacting the estimated beta coefficients.
+// // E.g., you can try multiplying numeric inputs by 100 prior
+// // to calling orsf() with orsf_control_fast(do_scale = FALSE)
+// // and you will see that you get back a different forest.
+// //
+// // @param x_node matrix of predictors
+// // @param w_node replication weights
+// // @param x_transforms matrix used to store the means and scales
+// //
+// // @return modified x_node and x_transform filled with values
+// //
+// void x_node_scale(){
+//
+// // set aside memory for outputs
+// // first column holds the mean values
+// // second column holds the scale values
+//
+// x_transforms.zeros(n_vars, 2);
+// vec means = x_transforms.unsafe_col(0); // Reference to column 1
+// vec scales = x_transforms.unsafe_col(1); // Reference to column 2
+//
+// w_node_sum = sum(w_node);
+//
+// for(i = 0; i < n_vars; i++) {
+//
+// means.at(i) = sum( w_node % x_node.col(i) ) / w_node_sum;
+//
+// x_node.col(i) -= means.at(i);
+//
+// scales.at(i) = sum(w_node % abs(x_node.col(i)));
+//
+// if(scales(i) > 0)
+// scales.at(i) = w_node_sum / scales.at(i);
+// else
+// scales.at(i) = 1.0; // rare case of constant covariate;
+//
+// x_node.col(i) *= scales.at(i);
+//
+// }
+//
+// }
+//
+// // same as above function, but just the means
+// // (currently not used)
+// void x_node_means(){
+//
+// x_transforms.zeros(n_vars, 1);
+// w_node_sum = sum(w_node);
+//
+// for(i = 0; i < n_vars; i++) {
+//
+// x_transforms.at(i, 0) = sum( w_node % x_node.col(i) ) / w_node_sum;
+//
+// }
+//
+// }
+//
+// // Same as x_node_scale, but this can be called from R
+// // [[Rcpp::export]]
+// List x_node_scale_exported(NumericMatrix& x_,
+// NumericVector& w_){
+//
+// x_node = mat(x_.begin(), x_.nrow(), x_.ncol(), false);
+// w_node = vec(w_.begin(), w_.length(), false);
+// n_vars = x_node.n_cols;
+//
+// x_node_scale();
+//
+// return(
+// List::create(
+// _["x_scaled"] = x_node,
+// _["x_transforms"] = x_transforms
+// )
+// );
+//
+// }
+//
+// // ----------------------------------------------------------------------------
+// // -------------------------- leaf_surv functions -----------------------------
+// // ----------------------------------------------------------------------------
+//
+// // Create kaplan-meier survival curve in leaf node
+// //
+// // @description Modifies leaf_nodes by adding data from the current node,
+// // where the current node is one that is too small to be split and will
+// // be converted to a leaf.
+// //
+// // @param y the outcome matrix in the current leaf
+// // @param w the weights vector in the current leaf
+// // @param leaf_indices a matrix that indicates where leaf nodes are
+// // inside of leaf_nodes. leaf_indices has three columns:
+// // - first column: the id for the leaf
+// // - second column: starting row for the leaf
+// // - third column: ending row for the leaf
+// // @param leaf_node_index_counter keeps track of where we are in leaf_node
+// // @param leaf_node_counter keeps track of which leaf node we are in
+// // @param leaf_nodes a matrix with three columns:
+// // - first column: time
+// // - second column: survival probability
+// // - third column: cumulative hazard
+//
+// void leaf_kaplan(const arma::mat& y,
+// const arma::vec& w){
+//
+// leaf_indices(leaf_node_index_counter, 1) = leaf_node_counter;
+// i = leaf_node_counter;
+//
+// // find the first unique event time
+// person = 0;
+//
+// while(y.at(person, 1) == 0){
+// person++;
+// }
+//
+// // now person corresponds to the first event time
+// leaf_nodes.at(i, 0) = y.at(person, 0); // see above
+// temp2 = y.at(person, 0);
+//
+// i++;
+//
+// // find the rest of the unique event times
+// for( ; person < y.n_rows; person++){
+//
+// if(temp2 != y.at(person, 0) && y.at(person, 1) == 1){
+//
+// leaf_nodes.at(i, 0) = y.at(person,0);
+// temp2 = y.at(person, 0);
+// i++;
+//
+// }
+//
+// }
+//
+// // reset for kaplan meier loop
+// n_risk = sum(w);
+// person = 0;
+// temp1 = 1.0;
+// temp3 = 0.0;
+//
+// do {
+//
+// n_events = 0;
+// n_risk_sub = 0;
+// temp2 = y.at(person, 0);
+//
+// while(y.at(person, 0) == temp2){
+//
+// n_risk_sub += w.at(person);
+// n_events += y.at(person, 1) * w.at(person);
+//
+// if(person == y.n_rows-1) break;
+//
+// person++;
+//
+// }
+//
+// // only do km if a death was observed
+//
+// if(n_events > 0){
+//
+// temp1 = temp1 * (n_risk - n_events) / n_risk;
+//
+// temp3 = temp3 + n_events / n_risk;
+//
+// leaf_nodes.at(leaf_node_counter, 1) = temp1;
+// leaf_nodes.at(leaf_node_counter, 2) = temp3;
+// leaf_node_counter++;
+//
+// }
+//
+// n_risk -= n_risk_sub;
+//
+// } while (leaf_node_counter < i);
+//
+//
+// leaf_indices(leaf_node_index_counter, 2) = leaf_node_counter-1;
+// leaf_node_index_counter++;
+//
+// if(leaf_node_index_counter >= leaf_indices.n_rows){
+// leaf_indices.insert_rows(leaf_indices.n_rows, 10);
+// }
+//
+// }
+//
+// // Same as above, but this function can be called from R and is
+// // used to run tests with testthat (hence the name). Note: this
+// // needs to be updated to include CHF, which was added to the
+// // function above recently.
+// // [[Rcpp::export]]
+// arma::mat leaf_kaplan_testthat(const arma::mat& y,
+// const arma::vec& w){
+//
+//
+// leaf_nodes.set_size(y.n_rows, 3);
+// leaf_node_counter = 0;
+//
+// // find the first unique event time
+// person = 0;
+//
+// while(y.at(person, 1) == 0){
+// person++;
+// }
+//
+// // now person corresponds to the first event time
+// leaf_nodes.at(leaf_node_counter, 0) = y.at(person, 0); // see above
+// temp2 = y.at(person, 0);
+//
+// leaf_node_counter++;
+//
+// // find the rest of the unique event times
+// for( ; person < y.n_rows; person++){
+//
+// if(temp2 != y.at(person, 0) && y.at(person, 1) == 1){
+//
+// leaf_nodes.at(leaf_node_counter, 0) = y.at(person,0);
+// temp2 = y.at(person, 0);
+// leaf_node_counter++;
+//
+// }
+//
+// }
+//
+//
+// // reset for kaplan meier loop
+// i = leaf_node_counter;
+// n_risk = sum(w);
+// person = 0;
+// temp1 = 1.0;
+// leaf_node_counter = 0;
+//
+//
+// do {
+//
+// n_events = 0;
+// n_risk_sub = 0;
+// temp2 = y.at(person, 0);
+//
+// while(y.at(person, 0) == temp2){
+//
+// n_risk_sub += w.at(person);
+// n_events += y.at(person, 1) * w.at(person);
+//
+// if(person == y.n_rows-1) break;
+//
+// person++;
+//
+// }
+//
+// // only do km if a death was observed
+//
+// if(n_events > 0){
+//
+// temp1 = temp1 * (n_risk - n_events) / n_risk;
+// leaf_nodes.at(leaf_node_counter, 1) = temp1;
+// leaf_node_counter++;
+//
+// }
+//
+// n_risk -= n_risk_sub;
+//
+// } while (leaf_node_counter < i);
+//
+// leaf_nodes.resize(leaf_node_counter, 3);
+//
+// return(leaf_nodes);
+//
+// }
+//
+//
+//
+//
+// // ----------------------------------------------------------------------------
+// // ---------------------------- cholesky functions ----------------------------
+// // ----------------------------------------------------------------------------
+//
+// // cholesky decomposition
+// //
+// // @description this function is copied from the survival package and
+// // translated into arma.
+// //
+// // @param vmat matrix with covariance estimates
+// // @param n_vars the number of predictors used in the current node
+// //
+// // prepares vmat for cholesky_solve()
+//
+//
+// void cholesky(){
+//
+// double eps_chol = 0;
+// double toler = 1e-8;
+// double pivot;
+//
+// for(i = 0; i < n_vars; i++){
+//
+// if(vmat.at(i,i) > eps_chol) eps_chol = vmat.at(i,i);
+//
+// // copy upper right values to bottom left
+// for(j = (i+1); j eps_chol) {
+//
+// for(j = (i+1); j < n_vars; j++){
+//
+// temp1 = vmat.at(j,i) / pivot;
+// vmat.at(j,i) = temp1;
+// vmat.at(j,j) -= temp1*temp1*pivot;
+//
+// for(k = (j+1); k < n_vars; k++){
+//
+// vmat.at(k, j) -= temp1 * vmat.at(k, i);
+//
+// }
+//
+// }
+//
+// } else {
+//
+// vmat.at(i, i) = 0;
+//
+// }
+//
+// }
+//
+// }
+//
+// // solve cholesky decomposition
+// //
+// // @description this function is copied from the survival package and
+// // translated into arma. Prepares u, the vector used to update beta.
+// //
+// // @param vmat matrix with covariance estimates
+// // @param n_vars the number of predictors used in the current node
+// //
+// //
+// void cholesky_solve(){
+//
+// for (i = 0; i < n_vars; i++) {
+//
+// temp1 = u[i];
+//
+// for (j = 0; j < i; j++){
+//
+// temp1 -= u[j] * vmat.at(i, j);
+// u[i] = temp1;
+//
+// }
+//
+// }
+//
+//
+// for (i = n_vars; i >= 1; i--){
+//
+// if (vmat.at(i-1, i-1) == 0){
+//
+// u[i-1] = 0;
+//
+// } else {
+//
+// temp1 = u[i-1] / vmat.at(i-1, i-1);
+//
+// for (j = i; j < n_vars; j++){
+// temp1 -= u[j] * vmat.at(j, i-1);
+// }
+//
+// u[i-1] = temp1;
+//
+// }
+//
+// }
+//
+// }
+//
+// // invert the cholesky in the lower triangle
+// //
+// // @description this function is copied from the survival package and
+// // translated into arma. Inverts vmat
+// //
+// // @param vmat matrix with covariance estimates
+// // @param n_vars the number of predictors used in the current node
+// //
+//
+// void cholesky_invert(){
+//
+// for (i=0; i0) {
+//
+// // take full advantage of the cholesky's diagonal of 1's
+// vmat.at(i,i) = 1.0 / vmat.at(i,i);
+//
+// for (j=(i+1); j 0) {
+//
+// if (cph_method == 0 || n_events == 1) { // Breslow
+//
+// denom += denom_events;
+// loglik -= weight_events * log(denom);
+//
+// for (i=0; i 0) {
+//
+// if (cph_method == 0 || n_events == 1) { // Breslow
+//
+// denom += denom_events;
+// loglik -= denom_events * log(denom);
+//
+// for (i=0; i 1 && stat_best < R_PosInf){
+//
+// for(iter = 1; iter < cph_iter_max; iter++){
+//
+// // if(verbose > 0){
+// //
+// // Rcout << "--------- Newt-Raph algo; iter " << iter;
+// // Rcout << " ---------" << std::endl;
+// // Rcout << "beta: " << beta_new.t();
+// // Rcout << "loglik: " << stat_best;
+// // Rcout << std::endl;
+// // Rcout << "------------------------------------------";
+// // Rcout << std::endl << std::endl << std::endl;
+// //
+// // }
+//
+// // do the next iteration
+// stat_current = newtraph_cph_iter(beta_new);
+//
+// cholesky();
+//
+// // don't go trying to fix this, just use the last
+// // set of valid coefficients
+// if(std::isinf(stat_current)) break;
+//
+// // check for convergence
+// // break the loop if the new ll is ~ same as old best ll
+// if(fabs(1 - stat_best / stat_current) < cph_eps){
+// break;
+// }
+//
+// if(stat_current < stat_best){ // it's not converging!
+//
+// halving++; // get more aggressive when it doesn't work
+//
+// // reduce the magnitude by which beta_new modifies beta_current
+// for (i = 0; i < n_vars; i++){
+// beta_new[i] = (beta_new[i]+halving*beta_current[i]) / (halving+1.0);
+// }
+//
+// // yeah its not technically the best but I need to do this for
+// // more reasonable output when verbose = true; I should remove
+// // this line when verbosity is taken out.
+// stat_best = stat_current;
+//
+// } else { // it's converging!
+//
+// halving = 0;
+// stat_best = stat_current;
+//
+// cholesky_solve();
+//
+// for (i = 0; i < n_vars; i++) {
+//
+// beta_current[i] = beta_new[i];
+// beta_new[i] = beta_new[i] + u[i];
+//
+// }
+//
+// }
+//
+// }
+//
+// }
+//
+// // invert vmat
+// cholesky_invert();
+//
+// for (i=0; i < n_vars; i++) {
+//
+// beta_current[i] = beta_new[i];
+//
+// if(std::isinf(beta_current[i]) || std::isnan(beta_current[i])){
+// beta_current[i] = 0;
+// }
+//
+// if(std::isinf(vmat.at(i, i)) || std::isnan(vmat.at(i, i))){
+// vmat.at(i, i) = 1.0;
+// }
+//
+// // if(verbose > 0) Rcout << "scaled beta: " << beta_current[i] << "; ";
+//
+// if(cph_do_scale){
+// beta_current.at(i) *= x_transforms.at(i, 1);
+// vmat.at(i, i) *= x_transforms.at(i, 1) * x_transforms.at(i, 1);
+// }
+//
+// // if(verbose > 0) Rcout << "un-scaled beta: " << beta_current[i] << std::endl;
+//
+// if(oobag_importance_type == 'A'){
+//
+// if(beta_current.at(i) != 0){
+//
+// temp1 = R::pchisq(pow(beta_current[i], 2) / vmat.at(i, i),
+// 1, false, false);
+//
+// if(temp1 < 0.01) vi_pval_numer[cols_node[i]]++;
+//
+// }
+//
+// vi_pval_denom[cols_node[i]]++;
+//
+// }
+//
+// }
+//
+// // if(verbose > 1) Rcout << std::endl;
+//
+// return(beta_current);
+//
+// }
+//
+// // same function as above, but exported to R for testing
+// // [[Rcpp::export]]
+// arma::vec newtraph_cph_testthat(NumericMatrix& x_in,
+// NumericMatrix& y_in,
+// NumericVector& w_in,
+// int method,
+// double cph_eps_,
+// int iter_max){
+//
+//
+// x_node = mat(x_in.begin(), x_in.nrow(), x_in.ncol(), false);
+// y_node = mat(y_in.begin(), y_in.nrow(), y_in.ncol(), false);
+// w_node = vec(w_in.begin(), w_in.length(), false);
+//
+// cph_do_scale = true;
+//
+// cph_method = method;
+// cph_eps = cph_eps_;
+// cph_iter_max = iter_max;
+// n_vars = x_node.n_cols;
+//
+// vi_pval_numer.zeros(x_node.n_cols);
+// vi_pval_denom.zeros(x_node.n_cols);
+// cols_node = regspace(0, x_node.n_cols - 1);
+//
+// x_node_scale();
+//
+// vec out = newtraph_cph();
+//
+// return(out);
+//
+// }
+//
+// // ----------------------------------------------------------------------------
+// // ---------------------------- node functions --------------------------------
+// // ----------------------------------------------------------------------------
+//
+// // Log rank test w/multiple cutpoints
+// //
+// // this function returns a cutpoint obtaining a local maximum
+// // of the log-rank test (lrt) statistic. The default value (+Inf)
+// // is really for diagnostic purposes. Put another way, if the
+// // return value is +Inf (an impossible value for a cutpoint),
+// // that means that we didn't find any valid cut-points and
+// // the node cannot be grown with the current XB.
+// //
+// // if there is a valid cut-point, then the main side effect
+// // of this function is to modify the group vector, which
+// // will be used to assign observations to the two new nodes.
+// //
+// // @param group the vector that determines which node to send each
+// // observation to (left node = 0, right node = 1)
+// // @param y_node matrix of outcomes
+// // @param w_node vector of weights
+// // @param XB linear combination of predictors
+// //
+// // the group vector is modified by this function and the value returned
+// // is the maximal log-rank statistic across all the possible cutpoints.
+// double lrt_multi(){
+//
+// break_loop = false;
+//
+// // group should be initialized as all 0s
+// group.zeros(y_node.n_rows);
+//
+// // initialize at the lowest possible LRT stat value
+// stat_best = 0;
+//
+// // sort XB- we need to iterate over the sorted indices
+// iit_vals = sort_index(XB, "ascend");
+//
+// // unsafe columns point to cols in y_node.
+// vec y_status = y_node.unsafe_col(1);
+// vec y_time = y_node.unsafe_col(0);
+//
+// // first determine the lowest value of XB that will
+// // be a valid cut-point to split a node. A valid cut-point
+// // is one that, if used, will result in at least leaf_min_obs
+// // and leaf_min_events in both the left and right node.
+//
+// n_events = 0;
+// n_risk = 0;
+//
+// // if(verbose > 1){
+// // Rcout << "----- finding cut-point boundaries -----" << std::endl;
+// // }
+//
+// // Iterate through the sorted values of XB, in ascending order.
+//
+// for(iit = iit_vals.begin(); iit < iit_vals.end()-1; ++iit){
+//
+// n_events += y_status[*iit] * w_node[*iit];
+// n_risk += w_node[*iit];
+//
+// // If we want to make the current value of XB a cut-point, we need
+// // to make sure the next value of XB isn't equal to this current value.
+// // Otherwise, we will have the same value of XB in both groups!
+//
+// // if(verbose > 1){
+// // Rcout << XB[*iit] << " ---- ";
+// // Rcout << XB[*(iit+1)] << " ---- ";
+// // Rcout << n_events << " ---- ";
+// // Rcout << n_risk << std::endl;
+// // }
+//
+// if(XB[*iit] != XB[*(iit+1)]){
+//
+// // if(verbose > 1){
+// // Rcout << "********* New cut-point here ********" << std::endl;
+// // }
+//
+//
+// if( n_events >= leaf_min_events &&
+// n_risk >= leaf_min_obs) {
+//
+// // if(verbose > 1){
+// // Rcout << std::endl;
+// // Rcout << "lower cutpoint: " << XB[*iit] << std::endl;
+// // Rcout << " - n_events, left node: " << n_events << std::endl;
+// // Rcout << " - n_risk, left node: " << n_risk << std::endl;
+// // Rcout << std::endl;
+// // }
+//
+// break;
+//
+// }
+//
+// }
+//
+// }
+//
+// // if(verbose > 1){
+// // if(iit >= iit_vals.end()-1) {
+// // Rcout << "Could not find a valid lower cut-point" << std::endl;
+// // }
+// // }
+//
+//
+// j = iit - iit_vals.begin();
+//
+// // got to reset these before finding the upper limit
+// n_events=0;
+// n_risk=0;
+//
+// // do the first step in the loop manually since we need to
+// // refer to iit+1 in all proceeding steps.
+//
+// for(iit = iit_vals.end()-1; iit >= iit_vals.begin()+1; --iit){
+//
+// n_events += y_status[*iit] * w_node[*iit];
+// n_risk += w_node[*iit];
+// group[*iit] = 1;
+//
+// // if(verbose > 1){
+// // Rcout << XB[*iit] << " ---- ";
+// // Rcout << XB(*(iit-1)) << " ---- ";
+// // Rcout << n_events << " ---- ";
+// // Rcout << n_risk << std::endl;
+// // }
+//
+// if ( XB[*iit] != XB[*(iit-1)] ) {
+//
+// // if(verbose > 1){
+// // Rcout << "********* New cut-point here ********" << std::endl;
+// // }
+//
+// if( n_events >= leaf_min_events &&
+// n_risk >= leaf_min_obs ) {
+//
+// // the upper cutpoint needs to be one step below the current
+// // iit value, because we use x <= cp to determine whether a
+// // value x goes to the left node versus the right node. So,
+// // if iit currently points to 3, and the next value down is 2,
+// // then we want to say the cut-point is 2 because then all
+// // values <= 2 will go left, and 3 will go right. This matters
+// // when 3 is the highest value in the vector.
+//
+// --iit;
+//
+// // if(verbose > 1){
+// // Rcout << std::endl;
+// // Rcout << "upper cutpoint: " << XB[*iit] << std::endl;
+// // Rcout << " - n_events, right node: " << n_events << std::endl;
+// // Rcout << " - n_risk, right node: " << n_risk << std::endl;
+// // }
+//
+// break;
+//
+// }
+//
+// }
+//
+// }
+//
+// // number of steps taken
+// k = iit + 1 - iit_vals.begin();
+//
+// // if(verbose > 1){
+// // Rcout << "----------------------------------------" << std::endl;
+// // Rcout << std::endl << std::endl;
+// // Rcout << "sorted XB: " << std::endl << XB(iit_vals).t() << std::endl;
+// // }
+//
+// // initialize cut-point as the value of XB iit currently points to.
+// iit_best = iit;
+//
+// // what happens if we don't have enough events or obs to split?
+// // the first valid lower cut-point (at iit_vals(k)) is > the first
+// // valid upper cutpoint (current value of n_risk). Put another way,
+// // k (the number of steps taken from beginning of the XB vec)
+// // will be > n_rows - p, where the difference on the RHS is
+// // telling us where we are after taking p steps from the end
+// // of the XB vec. Returning the infinite cp is a red flag.
+//
+// // if(verbose > 1){
+// // Rcout << "j: " << j << std::endl;
+// // Rcout << "k: " << k << std::endl;
+// // }
+//
+// if (j > k){
+//
+// // if(verbose > 1) {
+// // Rcout << "Could not find a cut-point for this XB" << std::endl;
+// // }
+//
+// return(R_PosInf);
+// }
+//
+// // if(verbose > 1){
+// //
+// // Rcout << "----- initializing log-rank test cutpoints -----" << std::endl;
+// // Rcout << "n potential cutpoints: " << k-j << std::endl;
+// //
+// // }
+//
+//
+// // adjust k to indicate the number of valid cut-points
+// k -= j;
+//
+// if(k > n_split){
+//
+// jit_vals = linspace(0, k, n_split);
+//
+// } else {
+//
+// // what happens if there are only 5 potential cut-points
+// // but the value of n_split is > 5? We will just check out
+// // the 5 valid cutpoints.
+// jit_vals = linspace(0, k, k);
+//
+// }
+//
+// vec_temp.resize( jit_vals.size() );
+//
+// // protection from going out of bounds with jit_vals(k) below
+// if(j == 0) jit_vals.at(jit_vals.size()-1)--;
+//
+// // put the indices of potential cut-points into vec_temp
+// for(k = 0; k < vec_temp.size(); k++){
+// vec_temp[k] = XB.at(*(iit_best - jit_vals[k]));
+// }
+//
+// // back to how it was!
+// if(j == 0) jit_vals.at(jit_vals.size()-1)++;
+//
+// // if(verbose > 1){
+// //
+// // Rcout << "cut-points chosen: ";
+// //
+// // Rcout << vec_temp.t();
+// //
+// // Rcout << "----------------------------------------" << std::endl <<
+// // std::endl << std::endl;
+// //
+// // }
+//
+// bool do_lrt = true;
+//
+// k = 0;
+// j = 1;
+//
+// // begin outer loop - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+// for(jit = jit_vals.begin(); jit != jit_vals.end(); ++jit){
+//
+//
+// // if(verbose > 1){
+// // Rcout << "jit points to " << *jit << std::endl;
+// // }
+//
+// // switch group values from 0 to 1 until you get to the next cut-point
+// for( ; j < *jit; j++){
+// group[*iit] = 1;
+// --iit;
+// }
+//
+// if(jit == jit_vals.begin() ||
+// jit == jit_vals.end()-1){
+//
+// do_lrt = true;
+//
+// } else {
+//
+// if( vec_temp[k] == vec_temp[k+1] ||
+// vec_temp[k] == vec_temp[0] ||
+// *jit <= 1){
+//
+// do_lrt = false;
+//
+// } else {
+//
+// while( XB[*iit] == XB[*(iit - 1)] ){
+//
+// group[*iit] = 1;
+// --iit;
+// ++j;
+//
+// // if(verbose > 1){
+// // Rcout << "cutpoint dropped down one spot: ";
+// // Rcout << XB[*iit] << std::endl;
+// // }
+//
+// }
+//
+// do_lrt = true;
+//
+// }
+//
+// }
+//
+// ++k;
+//
+// if(do_lrt){
+//
+// n_risk=0;
+// g_risk=0;
+//
+// observed=0;
+// expected=0;
+//
+// V=0;
+//
+// break_loop = false;
+//
+// i = y_node.n_rows-1;
+//
+// // if(verbose > 1){
+// // Rcout << "sum(group==1): " << sum(group) << "; ";
+// // Rcout << "sum(group==1 * w_node): " << sum(group % w_node);
+// // Rcout << std::endl;
+// // if(verbose > 1){
+// // Rcout << "group:" << std::endl;
+// // Rcout << group(iit_vals).t() << std::endl;
+// // }
+// // }
+//
+//
+// // begin inner loop - - - - - - - - - - - - - - - - - - - - - - - - - -
+// for (; ;){
+//
+// temp1 = y_time[i];
+//
+// n_events = 0;
+//
+// for ( ; y_time[i] == temp1; i--) {
+//
+// n_risk += w_node[i];
+// n_events += y_status[i] * w_node[i];
+// g_risk += group[i] * w_node[i];
+// observed += y_status[i] * group[i] * w_node[i];
+//
+// if(i == 0){
+// break_loop = true;
+// break;
+// }
+//
+// }
+//
+// // should only do these calculations if n_events > 0,
+// // but turns out its faster to multiply by 0 than
+// // it is to check whether n_events is > 0
+//
+// temp2 = g_risk / n_risk;
+// expected += n_events * temp2;
+//
+// // update variance if n_risk > 1 (if n_risk == 1, variance is 0)
+// // definitely check if n_risk is > 1 b/c otherwise divide by 0
+// if (n_risk > 1){
+// temp1 = n_events * temp2 * (n_risk-n_events) / (n_risk-1);
+// V += temp1 * (1 - temp2);
+// }
+//
+// if(break_loop) break;
+//
+// }
+// // end inner loop - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+//
+// stat_current = pow(expected-observed, 2) / V;
+//
+// // if(verbose > 1){
+// //
+// // Rcout << "-------- log-rank test results --------" << std::endl;
+// // Rcout << "cutpoint: " << XB[*iit] << std::endl;
+// // Rcout << "lrt stat: " << stat_current << std::endl;
+// // Rcout << "---------------------------------------" << std::endl <<
+// // std::endl << std::endl;
+// //
+// // }
+//
+// if(stat_current > stat_best){
+// iit_best = iit;
+// stat_best = stat_current;
+// n_events_right = observed;
+// n_risk_right = g_risk;
+// n_risk_left = n_risk - g_risk;
+// }
+//
+// }
+// // end outer loop - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+//
+// }
+//
+// // if the log-rank test does not detect a difference at 0.05 alpha,
+// // maybe it's not a good idea to split this node.
+//
+// if(stat_best < split_min_stat) return(R_PosInf);
+//
+// // if(verbose > 1){
+// // Rcout << "Best LRT stat: " << stat_best << std::endl;
+// // }
+//
+// // rewind iit until it is back where it was when we got the
+// // best lrt stat. While rewinding iit, also reset the group
+// // values so that group is as it was when we got the best
+// // lrt stat.
+//
+//
+// while(iit <= iit_best){
+// group[*iit] = 0;
+// ++iit;
+// }
+//
+// // XB at *iit_best is the cut-point that maximized the log-rank test
+// return(XB[*iit_best]);
+//
+// }
+//
+// // this function is the same as above, but is exported to R for testing
+// // [[Rcpp::export]]
+// List lrt_multi_testthat(NumericMatrix& y_node_,
+// NumericVector& w_node_,
+// NumericVector& XB_,
+// int n_split_,
+// int leaf_min_events_,
+// int leaf_min_obs_
+// ){
+//
+// y_node = mat(y_node_.begin(), y_node_.nrow(), y_node_.ncol(), false);
+// w_node = vec(w_node_.begin(), w_node_.length(), false);
+// XB = vec(XB_.begin(), XB_.length(), false);
+//
+// n_split = n_split_;
+// leaf_min_events = leaf_min_events_;
+// leaf_min_obs = leaf_min_obs_;
+//
+// // about this function - - - - - - - - - - - - - - - - - - - - - - - - - - -
+// //
+// // this function returns a cutpoint obtaining a local maximum
+// // of the log-rank test (lrt) statistic. The default value (+Inf)
+// // is really for diagnostic purposes. Put another way, if the
+// // return value is +Inf (an impossible value for a cutpoint),
+// // that means that we didn't find any valid cut-points and
+// // the node cannot be grown with the current XB.
+// //
+// // if there is a valid cut-point, then the main side effect
+// // of this function is to modify the group vector, which
+// // will be used to assign observations to the two new nodes.
+// //
+// // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+//
+// break_loop = false;
+//
+// vec cutpoints_used(n_split);
+// vec lrt_statistics(n_split);
+// uword list_counter = 0;
+//
+// // group should be initialized as all 0s
+// group.zeros(y_node.n_rows);
+//
+// // initialize at the lowest possible LRT stat value
+// stat_best = 0;
+//
+// // sort XB- we need to iterate over the sorted indices
+// iit_vals = sort_index(XB, "ascend");
+//
+// // unsafe columns point to cols in y_node.
+// vec y_status = y_node.unsafe_col(1);
+// vec y_time = y_node.unsafe_col(0);
+//
+// // first determine the lowest value of XB that will
+// // be a valid cut-point to split a node. A valid cut-point
+// // is one that, if used, will result in at least leaf_min_obs
+// // and leaf_min_events in both the left and right node.
+//
+// n_events = 0;
+// n_risk = 0;
+//
+// // if(verbose > 1){
+// // Rcout << "----- finding cut-point boundaries -----" << std::endl;
+// // }
+//
+// // Iterate through the sorted values of XB, in ascending order.
+//
+// for(iit = iit_vals.begin(); iit < iit_vals.end()-1; ++iit){
+//
+// n_events += y_status(*iit) * w_node(*iit);
+// n_risk += w_node(*iit);
+//
+// // If we want to make the current value of XB a cut-point, we need
+// // to make sure the next value of XB isn't equal to this current value.
+// // Otherwise, we will have the same value of XB in both groups!
+//
+// // if(verbose > 1){
+// // Rcout << XB(*iit) << " ---- ";
+// // Rcout << XB(*(iit+1)) << " ---- ";
+// // Rcout << n_events << " ---- ";
+// // Rcout << n_risk << std::endl;
+// // }
+//
+// if(XB(*iit) != XB(*(iit+1))){
+//
+// // if(verbose > 1){
+// // Rcout << "********* New cut-point here ********" << std::endl;
+// // }
+//
+//
+// if( n_events >= leaf_min_events &&
+// n_risk >= leaf_min_obs) {
+//
+// // if(verbose > 1){
+// // Rcout << std::endl;
+// // Rcout << "lower cutpoint: " << XB(*iit) << std::endl;
+// // Rcout << " - n_events, left node: " << n_events << std::endl;
+// // Rcout << " - n_risk, left node: " << n_risk << std::endl;
+// // Rcout << std::endl;
+// // }
+//
+// break;
+//
+// }
+//
+// }
+//
+// }
+//
+// // if(verbose > 1){
+// // if(iit >= iit_vals.end()-1) {
+// // Rcout << "Could not find a valid lower cut-point" << std::endl;
+// // }
+// // }
+//
+//
+// j = iit - iit_vals.begin();
+//
+// // got to reset these before finding the upper limit
+// n_events=0;
+// n_risk=0;
+//
+// // do the first step in the loop manually since we need to
+// // refer to iit+1 in all proceeding steps.
+//
+// for(iit = iit_vals.end()-1; iit >= iit_vals.begin()+1; --iit){
+//
+// n_events += y_status(*iit) * w_node(*iit);
+// n_risk += w_node(*iit);
+// group(*iit) = 1;
+//
+// // if(verbose > 1){
+// // Rcout << XB(*iit) << " ---- ";
+// // Rcout << XB(*(iit-1)) << " ---- ";
+// // Rcout << n_events << " ---- ";
+// // Rcout << n_risk << std::endl;
+// // }
+//
+// if(XB(*iit) != XB(*(iit-1))){
+//
+// // if(verbose > 1){
+// // Rcout << "********* New cut-point here ********" << std::endl;
+// // }
+//
+// if( n_events >= leaf_min_events &&
+// n_risk >= leaf_min_obs ) {
+//
+// // the upper cutpoint needs to be one step below the current
+// // iit value, because we use x <= cp to determine whether a
+// // value x goes to the left node versus the right node. So,
+// // if iit currently points to 3, and the next value down is 2,
+// // then we want to say the cut-point is 2 because then all
+// // values <= 2 will go left, and 3 will go right. This matters
+// // when 3 is the highest value in the vector.
+//
+// --iit;
+//
+// // if(verbose > 1){
+// // Rcout << std::endl;
+// // Rcout << "upper cutpoint: " << XB(*iit) << std::endl;
+// // Rcout << " - n_events, right node: " << n_events << std::endl;
+// // Rcout << " - n_risk, right node: " << n_risk << std::endl;
+// // }
+//
+// break;
+//
+// }
+//
+// }
+//
+// }
+//
+// // number of steps taken
+// k = iit + 1 - iit_vals.begin();
+//
+// // if(verbose > 1){
+// // Rcout << "----------------------------------------" << std::endl;
+// // Rcout << std::endl << std::endl;
+// // Rcout << "sorted XB: " << std::endl << XB(iit_vals).t() << std::endl;
+// // }
+//
+// // initialize cut-point as the value of XB iit currently points to.
+// iit_best = iit;
+//
+// // what happens if we don't have enough events or obs to split?
+// // the first valid lower cut-point (at iit_vals(k)) is > the first
+// // valid upper cutpoint (current value of n_risk). Put another way,
+// // k (the number of steps taken from beginning of the XB vec)
+// // will be > n_rows - p, where the difference on the RHS is
+// // telling us where we are after taking p steps from the end
+// // of the XB vec. Returning the infinite cp is a red flag.
+//
+// // if(verbose > 1){
+// // Rcout << "j: " << j << std::endl;
+// // Rcout << "k: " << k << std::endl;
+// // }
+//
+// if (j > k){
+//
+// // if(verbose > 1) {
+// // Rcout << "Could not find a cut-point for this XB" << std::endl;
+// // }
+//
+// return(R_PosInf);
+// }
+//
+// // if(verbose > 1){
+// //
+// // Rcout << "----- initializing log-rank test cutpoints -----" << std::endl;
+// // Rcout << "n potential cutpoints: " << k-j << std::endl;
+// //
+// // }
+//
+// // what happens if there are only 5 potential cut-points
+// // but the value of n_split is > 5? We will just check out
+// // the 5 valid cutpoints.
+//
+// // adjust k to indicate steps taken in the outer loop.
+// k -= j;
+//
+// if(k > n_split){
+//
+// jit_vals = linspace(0, k, n_split);
+//
+// } else {
+//
+// jit_vals = linspace(0, k, k);
+//
+// }
+//
+// vec_temp.resize( jit_vals.size() );
+//
+// if(j == 0) jit_vals(jit_vals.size()-1)--;
+//
+// for(k = 0; k < vec_temp.size(); k++){
+// vec_temp(k) = XB(*(iit_best - jit_vals(k)));
+// }
+//
+// if(j == 0) jit_vals(jit_vals.size()-1)++;
+//
+//
+// // if(verbose > 1){
+// //
+// // Rcout << "cut-points chosen: ";
+// //
+// // Rcout << vec_temp.t();
+// //
+// // Rcout << "----------------------------------------" << std::endl <<
+// // std::endl << std::endl;
+// //
+// // }
+//
+// bool do_lrt = true;
+//
+// k = 0;
+// j = 1;
+//
+// // begin outer loop - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+// for(jit = jit_vals.begin(); jit != jit_vals.end(); ++jit){
+//
+//
+// // if(verbose > 1){
+// // Rcout << "jit points to " << *jit << std::endl;
+// // }
+//
+// for( ; j < *jit; j++){
+// group(*iit) = 1;
+// --iit;
+// }
+//
+// if(jit == jit_vals.begin() ||
+// jit == jit_vals.end()-1){
+//
+// do_lrt = true;
+//
+// } else {
+//
+// if( vec_temp(k) == vec_temp(k+1) ||
+// vec_temp(k) == vec_temp(0) ||
+// *jit <= 1){
+//
+// do_lrt = false;
+//
+// } else {
+//
+// while(XB(*iit) == XB(*(iit - 1))){
+//
+// group(*iit) = 1;
+// --iit;
+// ++j;
+//
+// // if(verbose > 1){
+// // Rcout << "cutpoint dropped down one spot: ";
+// // Rcout << XB(*iit) << std::endl;
+// // }
+//
+// }
+//
+// do_lrt = true;
+//
+// }
+//
+// }
+//
+// ++k;
+//
+// if(do_lrt){
+//
+// cutpoints_used(list_counter) = XB(*iit);
+//
+// n_risk=0;
+// g_risk=0;
+//
+// observed=0;
+// expected=0;
+//
+// V=0;
+//
+// break_loop = false;
+//
+// i = y_node.n_rows-1;
+//
+// // if(verbose > 1){
+// // Rcout << "sum(group==1): " << sum(group) << "; ";
+// // Rcout << "sum(group==1 * w_node): " << sum(group % w_node);
+// // Rcout << std::endl;
+// // if(verbose > 1){
+// // Rcout << "group:" << std::endl;
+// // Rcout << group(iit_vals).t() << std::endl;
+// // }
+// // }
+//
+//
+// // begin inner loop - - - - - - - - - - - - - - - - - - - - - - - - - -
+// for (; ;){
+//
+// temp1 = y_time[i];
+//
+// n_events = 0;
+//
+// for ( ; y_time[i] == temp1; i--) {
+//
+// n_risk += w_node[i];
+// n_events += y_status[i] * w_node[i];
+// g_risk += group[i] * w_node[i];
+// observed += y_status[i] * group[i] * w_node[i];
+//
+// if(i == 0){
+// break_loop = true;
+// break;
+// }
+//
+// }
+//
+// // should only do these calculations if n_events > 0,
+// // but turns out its faster to multiply by 0 than
+// // it is to check whether n_events is > 0
+//
+// temp2 = g_risk / n_risk;
+// expected += n_events * temp2;
+//
+// // update variance if n_risk > 1 (if n_risk == 1, variance is 0)
+// // definitely check if n_risk is > 1 b/c otherwise divide by 0
+// if (n_risk > 1){
+// temp1 = n_events * temp2 * (n_risk-n_events) / (n_risk-1);
+// V += temp1 * (1 - temp2);
+// }
+//
+// if(break_loop) break;
+//
+// }
+// // end inner loop - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+//
+// stat_current = pow(expected-observed, 2) / V;
+//
+// lrt_statistics(list_counter) = stat_current;
+//
+// list_counter++;
+//
+// // if(verbose > 1){
+// //
+// // Rcout << "-------- log-rank test results --------" << std::endl;
+// // Rcout << "cutpoint: " << XB(*iit) << std::endl;
+// // Rcout << "lrt stat: " << stat_current << std::endl;
+// // Rcout << "---------------------------------------" << std::endl <<
+// // std::endl << std::endl;
+// //
+// // }
+//
+// if(stat_current > stat_best){
+// iit_best = iit;
+// stat_best = stat_current;
+// n_events_right = observed;
+// n_risk_right = g_risk;
+// n_risk_left = n_risk - g_risk;
+// }
+//
+// }
+// // end outer loop - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+//
+// }
+//
+// // if the log-rank test does not detect a difference at 0.05 alpha,
+// // maybe it's not a good idea to split this node.
+//
+// if(stat_best < 3.841459) return(R_PosInf);
+//
+// // if(verbose > 1){
+// // Rcout << "Best LRT stat: " << stat_best << std::endl;
+// // }
+//
+// // rewind iit until it is back where it was when we got the
+// // best lrt stat. While rewinding iit, also reset the group
+// // values so that group is as it was when we got the best
+// // lrt stat.
+//
+//
+// while(iit <= iit_best){
+// group(*iit) = 0;
+// ++iit;
+// }
+//
+// return(List::create(_["cutpoints"] = cutpoints_used,
+// _["statistic"] = lrt_statistics));
+//
+// }
+//
+//
+// // out-of-bag prediction for single prediction horizon
+// //
+// // @param pred_type indicates what type of prediction to compute
+// // @param leaf_pred a vector indicating which leaf each observation
+// // landed in.
+// // @param leaf_indices a matrix that contains indices for each leaf node
+// // inside of leaf_nodes
+// // @param leaf_nodes a matrix with ids, survival, and cumulative hazard
+// // functions for each leaf node.
+// //
+// // @return matrix with predictions, dimension n by 1
+//
+// void oobag_pred_surv_uni(char pred_type){
+//
+// iit_vals = sort_index(leaf_pred, "ascend");
+// iit = iit_vals.begin();
+//
+// switch(pred_type){
+//
+// case 'S': case 'R':
+//
+// leaf_node_col = 1;
+// pred_t0 = 1;
+// break;
+//
+// case 'H':
+//
+// leaf_node_col = 2;
+// pred_t0 = 0;
+// break;
+//
+// }
+//
+// do {
+//
+// person_leaf = leaf_pred[*iit];
+//
+// // find the current leaf
+// for(i = 0; i < leaf_indices.n_rows; i++){
+// if(leaf_indices.at(i, 0) == person_leaf){
+// break;
+// }
+// }
+//
+// // get submat view for this leaf
+// leaf_node = leaf_nodes.rows(leaf_indices(i, 1),
+// leaf_indices(i, 2));
+//
+// // if(verbose > 1){
+// // Rcout << "leaf_node:" << std::endl << leaf_node << std::endl;
+// // }
+//
+// i = 0;
+//
+// if(time_pred < leaf_node.at(leaf_node.n_rows - 1, 0)){
+//
+// for(; i < leaf_node.n_rows; i++){
+// if (leaf_node.at(i, 0) > time_pred){
+// if(i == 0)
+// temp1 = pred_t0;
+// else
+// temp1 = leaf_node.at(i-1, leaf_node_col);
+// break;
+// } else if (leaf_node.at(i, 0) == time_pred){
+// temp1 = leaf_node.at(i, leaf_node_col);
+// break;
+// }
+// }
+//
+// } else {
+//
+// // go here if prediction horizon > max time in current leaf.
+// temp1 = leaf_node.at(leaf_node.n_rows - 1, leaf_node_col);
+//
+// }
+//
+// // running mean: mean_k = mean_{k-1} + (new val - old val) / k
+// // compute new val - old val
+// // be careful, every oob row has a different denom!
+// temp2 = temp1 - surv_pvec[rows_oobag[*iit]];
+// surv_pvec[rows_oobag[*iit]] += temp2 / denom_pred[rows_oobag[*iit]];
+// ++iit;
+//
+// if(iit < iit_vals.end()){
+//
+// while(person_leaf == leaf_pred(*iit)){
+//
+// temp2 = temp1 - surv_pvec[rows_oobag[*iit]];
+// surv_pvec[rows_oobag[*iit]] += temp2 / denom_pred[rows_oobag[*iit]];
+//
+// ++iit;
+//
+// if (iit == iit_vals.end()) break;
+//
+// }
+//
+// }
+//
+// } while (iit < iit_vals.end());
+//
+// // if(verbose > 0){
+// // Rcout << "surv_pvec:" << std::endl << surv_pvec.t() << std::endl;
+// // }
+//
+// }
+//
+// // out-of-bag prediction evaluation, Harrell's C-statistic
+// //
+// // @param pred_type indicates what type of prediction to compute
+// // @param y_input matrix of outcomes from input
+// //
+// // @return the C-statistic
+//
+// double oobag_c_harrell(char pred_type){
+//
+// vec time = y_input.unsafe_col(0);
+// vec status = y_input.unsafe_col(1);
+// iit_vals = find(status == 1);
+//
+// k = y_input.n_rows;
+//
+// double total=0, concordant=0;
+//
+// switch(pred_type){
+//
+// case 'S': case 'R':
+// for (iit = iit_vals.begin(); iit < iit_vals.end(); ++iit) {
+//
+// for(j = *iit + 1; j < k; ++j){
+//
+// if (time[j] > time[*iit]) { // ties not counted
+//
+// total++;
+//
+// // for survival, current value > next vals is good
+// // risk is the same as survival until just before we output
+// // the oobag predictions, when we say pvec = 1-pvec,
+// if (surv_pvec[j] > surv_pvec[*iit]){
+//
+// concordant++;
+//
+// } else if (surv_pvec[j] == surv_pvec[*iit]){
+//
+// concordant+= 0.5;
+//
+// }
+//
+// }
+//
+// }
+//
+// }
+// break;
+//
+// case 'H':
+// for (iit = iit_vals.begin(); iit < iit_vals.end(); ++iit) {
+//
+// for(j = *iit + 1; j < k; ++j){
+//
+// if (time[j] > time[*iit]) { // ties not counted
+//
+// total++;
+//
+// // for risk & chf current value < next vals is good.
+// if (surv_pvec[j] < surv_pvec[*iit]){
+//
+// concordant++;
+//
+// } else if (surv_pvec[j] == surv_pvec[*iit]){
+//
+// concordant+= 0.5;
+//
+// }
+//
+// }
+//
+// }
+//
+// }
+// break;
+// }
+//
+// return(concordant / total);
+//
+// }
+//
+// // same function as above but exported to R for testing
+// // [[Rcpp::export]]
+// double oobag_c_harrell_testthat(NumericMatrix y_mat,
+// NumericVector s_vec) {
+//
+// y_input = mat(y_mat.begin(), y_mat.nrow(), y_mat.ncol(), false);
+// surv_pvec = vec(s_vec.begin(), s_vec.length(), false);
+//
+// return(oobag_c_harrell(pred_type_dflt));
+//
+// }
+//
+// // this function is the same as oobag_pred_surv_uni,
+// // but it operates on new data rather than out-of-bag data
+// // and it allows for multiple prediction horizons instead of one
+// void new_pred_surv_multi(char pred_type){
+//
+// // allocate memory for output
+// // surv_pvec.zeros(x_pred.n_rows);
+//
+// surv_pvec.set_size(times_pred.size());
+// iit_vals = sort_index(leaf_pred, "ascend");
+// iit = iit_vals.begin();
+//
+// switch(pred_type){
+//
+// case 'S': case 'R':
+//
+// leaf_node_col = 1;
+// pred_t0 = 1;
+// break;
+//
+// case 'H':
+//
+// leaf_node_col = 2;
+// pred_t0 = 0;
+// break;
+//
+// }
+//
+// do {
+//
+// person_leaf = leaf_pred(*iit);
+//
+// for(i = 0; i < leaf_indices.n_rows; i++){
+// if(leaf_indices.at(i, 0) == person_leaf){
+// break;
+// }
+// }
+//
+// leaf_node = leaf_nodes.rows(leaf_indices(i, 1),
+// leaf_indices(i, 2));
+//
+// // if(verbose > 1){
+// // Rcout << "leaf_node:" << std::endl << leaf_node << std::endl;
+// // }
+//
+// i = 0;
+//
+// for(j = 0; j < times_pred.size(); j++){
+//
+// time_pred = times_pred.at(j);
+//
+// if(time_pred < leaf_node.at(leaf_node.n_rows - 1, 0)){
+//
+// for(; i < leaf_node.n_rows; i++){
+//
+// if (leaf_node.at(i, 0) > time_pred){
+//
+// if(i == 0)
+// temp1 = pred_t0;
+// else
+// temp1 = leaf_node.at(i-1, leaf_node_col);
+//
+// break;
+//
+// } else if (leaf_node.at(i, 0) == time_pred){
+//
+// temp1 = leaf_node.at(i, leaf_node_col);
+// break;
+//
+// }
+//
+// }
+//
+// } else {
+//
+// // go here if prediction horizon > max time in current leaf.
+// temp1 = leaf_node.at(leaf_node.n_rows - 1, leaf_node_col);
+//
+// }
+//
+// surv_pvec.at(j) = temp1;
+//
+// }
+//
+// surv_pmat.row(*iit) += surv_pvec.t();
+// ++iit;
+//
+// if(iit < iit_vals.end()){
+//
+// while(person_leaf == leaf_pred.at(*iit)){
+//
+// surv_pmat.row(*iit) += surv_pvec.t();
+// ++iit;
+//
+// if (iit == iit_vals.end()) break;
+//
+// }
+//
+// }
+//
+// } while (iit < iit_vals.end());
+//
+// }
+//
+// // this function is the same as new_pred_surv_multi,
+// // but only uses one prediction horizon
+// void new_pred_surv_uni(char pred_type){
+//
+// iit_vals = sort_index(leaf_pred, "ascend");
+// iit = iit_vals.begin();
+//
+// switch(pred_type){
+//
+// case 'S': case 'R':
+//
+// leaf_node_col = 1;
+// pred_t0 = 1;
+// break;
+//
+// case 'H':
+//
+// leaf_node_col = 2;
+// pred_t0 = 0;
+// break;
+//
+// }
+//
+// do {
+//
+// person_leaf = leaf_pred(*iit);
+//
+// for(i = 0; i < leaf_indices.n_rows; i++){
+// if(leaf_indices.at(i, 0) == person_leaf){
+// break;
+// }
+// }
+//
+// leaf_node = leaf_nodes.rows(leaf_indices.at(i, 1),
+// leaf_indices.at(i, 2));
+//
+// // if(verbose > 1){
+// // Rcout << "leaf_node:" << std::endl << leaf_node << std::endl;
+// // }
+//
+// i = 0;
+//
+// if(time_pred < leaf_node.at(leaf_node.n_rows - 1, 0)){
+//
+// for(; i < leaf_node.n_rows; i++){
+// if (leaf_node.at(i, 0) > time_pred){
+//
+// if(i == 0){
+//
+// temp1 = pred_t0;
+//
+// } else {
+//
+// temp1 = leaf_node.at(i - 1, leaf_node_col);
+//
+// // experimental - does not seem to help!
+// // weighted average of surv est from before and after time of pred
+// // temp2 = leaf_node(i, 0) - leaf_node(i-1, 0);
+// //
+// // temp1 = leaf_node(i, 1) * (time_pred - leaf_node(i-1,0)) / temp2 +
+// // leaf_node(i-1, 1) * (leaf_node(i,0) - time_pred) / temp2;
+//
+// }
+//
+// break;
+//
+// } else if (leaf_node.at(i, 0) == time_pred){
+// temp1 = leaf_node.at(i, leaf_node_col);
+// break;
+// }
+// }
+//
+// } else if (time_pred == leaf_node.at(leaf_node.n_rows - 1, 0)){
+//
+// temp1 = leaf_node.at(leaf_node.n_rows - 1, leaf_node_col);
+//
+// } else {
+//
+// // go here if prediction horizon > max time in current leaf.
+// temp1 = leaf_node.at(leaf_node.n_rows - 1, leaf_node_col);
+//
+// // --- EXPERIMENTAL ADD-ON --- //
+// // if you are predicting beyond the max time in a node,
+// // then determine how much further out you are and assume
+// // the survival probability decays at the same rate.
+//
+// // temp2 = (1.0 - temp1) *
+// // (time_pred - leaf_node(leaf_node.n_rows - 1, 0)) / time_pred;
+// //
+// // temp1 = temp1 * (1.0-temp2);
+//
+// }
+//
+// surv_pvec.at(*iit) += temp1;
+// ++iit;
+//
+// if(iit < iit_vals.end()){
+//
+// while(person_leaf == leaf_pred.at(*iit)){
+//
+// surv_pvec.at(*iit) += temp1;
+// ++iit;
+//
+// if (iit == iit_vals.end()) break;
+//
+// }
+//
+// }
+//
+// } while (iit < iit_vals.end());
+//
+// // if(verbose > 1){
+// // Rcout << "pred_surv:" << std::endl << surv_pvec.t() << std::endl;
+// // }
+//
+// }
+//
+//
+// // ----------------------------------------------------------------------------
+// // --------------------------- ostree functions -------------------------------
+// // ----------------------------------------------------------------------------
+//
+// // increase the memory allocated to a tree
+// //
+// // this function is used if the initial memory allocation isn't enough
+// // to grow the tree. It modifies all elements of the tree, including
+// // betas, col_indices, children_left, and cutpoints
+// //
+// void ostree_size_buffer(){
+//
+// // if(verbose > 1){
+// // Rcout << "---------- buffering outputs ----------" << std::endl;
+// // Rcout << "betas before: " << std::endl << betas.t() << std::endl;
+// // }
+//
+// betas.insert_cols(betas.n_cols, 10);
+// // x_mean.insert_cols(x_mean.n_cols, 10);
+// col_indices.insert_cols(col_indices.n_cols, 10);
+// children_left.insert_rows(children_left.size(), 10);
+// cutpoints.insert_rows(cutpoints.size(), 10);
+//
+// // if(verbose > 1){
+// // Rcout << "betas after: " << std::endl << betas.t() << std::endl;
+// // Rcout << "---------------------------------------";
+// // Rcout << std::endl << std::endl;
+// // }
+//
+//
+// }
+//
+// // transfer memory from R into arma types
+// //
+// // when trees are passed from R, they need to be converted back into
+// // arma objects. The intent of this function is to convert everything
+// // back into an arma object without copying any data.
+// //
+// // nothing is modified apart from types
+//
+// void ostree_mem_xfer(){
+//
+// // no data copied according to tracemem.
+// // not including boot rows or x_mean (don't always need them)
+//
+// NumericMatrix leaf_nodes_ = ostree["leaf_nodes"];
+// NumericMatrix betas_ = ostree["betas"];
+// NumericVector cutpoints_ = ostree["cut_points"];
+// IntegerMatrix col_indices_ = ostree["col_indices"];
+// IntegerMatrix leaf_indices_ = ostree["leaf_node_index"];
+// IntegerVector children_left_ = ostree["children_left"];
+//
+// leaf_nodes = mat(leaf_nodes_.begin(),
+// leaf_nodes_.nrow(),
+// leaf_nodes_.ncol(),
+// false);
+//
+// betas = mat(betas_.begin(),
+// betas_.nrow(),
+// betas_.ncol(),
+// false);
+//
+// cutpoints = vec(cutpoints_.begin(), cutpoints_.length(), false);
+//
+// col_indices = conv_to::from(
+// imat(col_indices_.begin(),
+// col_indices_.nrow(),
+// col_indices_.ncol(),
+// false)
+// );
+//
+// leaf_indices = conv_to::from(
+// imat(leaf_indices_.begin(),
+// leaf_indices_.nrow(),
+// leaf_indices_.ncol(),
+// false)
+// );
+//
+// children_left = conv_to::from(
+// ivec(children_left_.begin(),
+// children_left_.length(),
+// false)
+// );
+//
+// }
+//
+// // drop observations down the tree
+// //
+// // @description Determine the leaves that are assigned to new data.
+// //
+// // @param children_left vector of child node ids (right node = left node + 1)
+// // @param x_pred matrix of predictors from new data
+// //
+// // @return a vector indicating which leaf each observation was mapped to
+// void ostree_pred_leaf(){
+//
+// // reset values
+// // this is needed for pred_leaf since every obs gets a new leaf in
+// // the next tree, but it isn't needed for pred_surv because survival
+// // probs get aggregated over all the trees.
+// leaf_pred.fill(0);
+//
+// for(i = 0; i < betas.n_cols; i++){
+//
+// if(children_left[i] != 0){
+//
+// if(i == 0){
+// obs_in_node = regspace(0, 1, leaf_pred.size()-1);
+// } else {
+// obs_in_node = find(leaf_pred == i);
+// }
+//
+//
+// if(obs_in_node.size() > 0){
+//
+// // Fastest sub-matrix multiplication i can think of.
+// // Matrix product = linear combination of columns
+// // (this is faster b/c armadillo is great at making
+// // pointers to the columns of an arma mat)
+// // I had to stop using this b/c it fails on
+// // XB.zeros(obs_in_node.size());
+// //
+// // uvec col_indices_i = col_indices.unsafe_col(i);
+// //
+// // j = 0;
+// //
+// // jit = col_indices_i.begin();
+// //
+// // for(; jit < col_indices_i.end(); ++jit, ++j){
+// //
+// // vec x_j = x_pred.unsafe_col(*jit);
+// //
+// // XB += x_j(obs_in_node) * betas.at(j, i);
+// //
+// // }
+//
+// // this is slower but more clear matrix multiplication
+// XB = x_pred(obs_in_node, col_indices.col(i)) * betas.col(i);
+//
+// jit = obs_in_node.begin();
+//
+// for(j = 0; j < XB.size(); ++j, ++jit){
+//
+// if(XB[j] <= cutpoints[i]) {
+//
+// leaf_pred[*jit] = children_left[i];
+//
+// } else {
+//
+// leaf_pred[*jit] = children_left[i]+1;
+//
+// }
+//
+// }
+//
+// // if(verbose > 0){
+// //
+// // uvec in_left = find(leaf_pred == children_left(i));
+// // uvec in_right = find(leaf_pred == children_left(i)+1);
+// //
+// // Rcout << "N to node_" << children_left(i) << ": ";
+// // Rcout << in_left.size() << "; ";
+// // Rcout << "N to node_" << children_left(i)+1 << ": ";
+// // Rcout << in_right.size() << std::endl;
+// //
+// // }
+//
+// }
+//
+// }
+//
+// }
+//
+//
+//
+// }
+//
+// // same as above but exported to R for testins
+// // [[Rcpp::export]]
+// arma::uvec ostree_pred_leaf_testthat(List& tree,
+// NumericMatrix& x_pred_){
+//
+//
+// x_pred = mat(x_pred_.begin(),
+// x_pred_.nrow(),
+// x_pred_.ncol(),
+// false);
+//
+// leaf_pred.set_size(x_pred.n_rows);
+//
+// ostree = tree;
+// ostree_mem_xfer();
+// ostree_pred_leaf();
+//
+// return(leaf_pred);
+//
+// }
+//
+// // Fit an oblique survival tree
+// //
+// // @description used in orsf_fit, which has parameters defined below.
+// //
+// // @param f_beta the function used to find linear combinations of predictors
+// //
+// // @return a fitted oblique survival tree
+// //
+// List ostree_fit(Function f_beta){
+//
+// betas.fill(0);
+// // x_mean.fill(0);
+// col_indices.fill(0);
+// cutpoints.fill(0);
+// children_left.fill(0);
+// node_assignments.fill(0);
+// leaf_nodes.fill(0);
+//
+// node_assignments.zeros(x_inbag.n_rows);
+// nodes_to_grow.zeros(1);
+// nodes_max_true = 0;
+// leaf_node_counter = 0;
+// leaf_node_index_counter = 0;
+//
+// // ----------------------
+// // ---- main do loop ----
+// // ----------------------
+//
+// do {
+//
+// nodes_to_grow_next.set_size(0);
+//
+// // if(verbose > 0){
+// //
+// // Rcout << "----------- nodes to grow -----------" << std::endl;
+// // Rcout << "nodes: "<< nodes_to_grow.t() << std::endl;
+// // Rcout << "-------------------------------------" << std::endl <<
+// // std::endl << std::endl;
+// //
+// //
+// // }
+//
+// for(node = nodes_to_grow.begin(); node != nodes_to_grow.end(); ++node){
+//
+// if(nodes_to_grow[0] == 0){
+//
+// // when growing the first node, there is no need to find
+// // which rows are in the node.
+// rows_node = linspace(0,
+// x_inbag.n_rows-1,
+// x_inbag.n_rows);
+//
+// } else {
+//
+// // identify which rows are in the current node.
+// rows_node = find(node_assignments == *node);
+//
+// }
+//
+// y_node = y_inbag.rows(rows_node);
+// w_node = w_inbag(rows_node);
+//
+// // if(verbose > 0){
+// //
+// // n_risk = sum(w_node);
+// // n_events = sum(y_node.col(1) % w_node);
+// // Rcout << "-------- Growing node " << *node << " --------" << std::endl;
+// // Rcout << "No. of observations in node: " << n_risk << std::endl;
+// // Rcout << "No. of events in node: " << n_events << std::endl;
+// // Rcout << "No. of rows in node: " << w_node.size() << std::endl;
+// // Rcout << "--------------------------------" << std::endl;
+// // Rcout << std::endl << std::endl;
+// //
+// // }
+//
+// // initialize an impossible cut-point value
+// // if cutpoint is still infinite later, node should not be split
+// cutpoint = R_PosInf;
+//
+// // ------------------------------------------------------------------
+// // ---- sample a random subset of columns with non-zero variance ----
+// // ------------------------------------------------------------------
+//
+// mtry_int = mtry;
+// cols_to_sample_01.fill(0);
+//
+// // constant columns are constant in the rows where events occurred
+//
+// for(j = 0; j < cols_to_sample_01.size(); j++){
+//
+// temp1 = R_PosInf;
+//
+// for(iit = rows_node.begin()+1; iit != rows_node.end(); ++iit){
+//
+// if(y_inbag.at(*iit, 1) == 1){
+//
+// if (temp1 < R_PosInf){
+//
+// if(x_inbag.at(*iit, j) != temp1){
+//
+// cols_to_sample_01[j] = 1;
+// break;
+//
+// }
+//
+// } else {
+//
+// temp1 = x_inbag.at(*iit, j);
+//
+// }
+//
+// }
+//
+// }
+//
+// }
+//
+// n_cols_to_sample = sum(cols_to_sample_01);
+//
+// if(n_cols_to_sample >= 1){
+//
+// n_events_total = sum(y_node.col(1) % w_node);
+//
+// if(n_cols_to_sample < mtry){
+//
+// mtry_int = n_cols_to_sample;
+//
+// // if(verbose > 0){
+// // Rcout << " ---- >=1 constant column in node rows ----" << std::endl;
+// // Rcout << "mtry reduced to " << mtry_temp << " from " << mtry;
+// // Rcout << std::endl;
+// // Rcout << "-------------------------------------------" << std::endl;
+// // Rcout << std::endl << std::endl;
+// // }
+//
+// }
+//
+// if (type_beta == 'C'){
+//
+// // make sure there are at least 3 event per predictor variable.
+// // (if using CPH)
+// while(n_events_total / mtry_int < 3 && mtry_int > 1){
+// --mtry_int;
+// }
+//
+// }
+//
+//
+// n_cols_to_sample = mtry_int;
+//
+// // if(verbose > 0){
+// // Rcout << "n_events: " << n_events_total << std::endl;
+// // Rcout << "mtry: " << mtry_int << std::endl;
+// // Rcout << "n_events per column: " << n_events_total/mtry_int << std::endl;
+// // }
+//
+// if(mtry_int >= 1){
+//
+// cols_to_sample = find(cols_to_sample_01);
+//
+// // re-try hinge point
+// n_retry = 0;
+// cutpoint = R_PosInf;
+//
+// while(n_retry <= max_retry){
+//
+// // if(n_retry > 0) Rcout << "trying again!" << std::endl;
+//
+// cols_node = Rcpp::RcppArmadillo::sample(cols_to_sample,
+// mtry_int,
+// false);
+//
+// x_node = x_inbag(rows_node, cols_node);
+//
+// // here is where n_vars gets updated to match the current node
+// // originally it matched the number of variables in the input x.
+//
+// n_vars = x_node.n_cols;
+//
+// if(cph_do_scale){
+// x_node_scale();
+// }
+//
+// // if(verbose > 0){
+// //
+// // uword temp_uword_1 = min(uvec {x_node.n_rows, 5});
+// // Rcout << "x node scaled: " << std::endl;
+// // Rcout << x_node.submat(0, 0, temp_uword_1-1, x_node.n_cols-1);
+// // Rcout << std::endl;
+// //
+// // }
+//
+// switch(type_beta) {
+//
+// case 'C' :
+//
+// beta_fit = newtraph_cph();
+//
+// if(cph_do_scale){
+// for(i = 0; i < x_transforms.n_rows; i++){
+// x_node.col(i) /= x_transforms(i,1);
+// x_node.col(i) += x_transforms(i,0);
+// }
+//
+// }
+//
+// break;
+//
+// case 'N' :
+//
+// xx = wrap(x_node);
+// yy = wrap(y_node);
+// ww = wrap(w_node);
+// colnames(yy) = yy_names;
+//
+// beta_placeholder = f_beta(xx, yy, ww,
+// net_alpha,
+// net_df_target);
+//
+// beta_fit = mat(beta_placeholder.begin(),
+// beta_placeholder.nrow(),
+// beta_placeholder.ncol(),
+// false);
+//
+// break;
+//
+// case 'U' :
+//
+// xx = wrap(x_node);
+// yy = wrap(y_node);
+// ww = wrap(w_node);
+// colnames(yy) = yy_names;
+//
+// beta_placeholder = f_beta(xx, yy, ww);
+//
+// beta_fit = mat(beta_placeholder.begin(),
+// beta_placeholder.nrow(),
+// beta_placeholder.ncol(),
+// false);
+//
+// break;
+//
+// }
+//
+//
+// if(any(beta_fit)){
+//
+// // if(verbose > 0){
+// //
+// // uword temp_uword_1 = min(uvec {x_node.n_rows, 5});
+// // Rcout << "x node unscaled: " << std::endl;
+// // Rcout << x_node.submat(0, 0, temp_uword_1-1, x_node.n_cols-1);
+// // Rcout << std::endl;
+// //
+// // }
+//
+// XB = x_node * beta_fit;
+// cutpoint = lrt_multi();
+//
+// }
+//
+// if(!std::isinf(cutpoint)) break;
+// n_retry++;
+//
+// }
+//
+// }
+//
+// }
+//
+// if(!std::isinf(cutpoint)){
+//
+// // make new nodes if a valid cutpoint was found
+// nn_left = nodes_max_true + 1;
+// nodes_max_true = nodes_max_true + 2;
+//
+//
+// // if(verbose > 0){
+// //
+// // Rcout << "-------- New nodes created --------" << std::endl;
+// // Rcout << "Left node: " << nn_left << std::endl;
+// // Rcout << "Right node: " << nodes_max_true << std::endl;
+// // Rcout << "-----------------------------------" << std::endl <<
+// // std::endl << std::endl;
+// //
+// // }
+//
+// n_events_left = n_events_total - n_events_right;
+//
+// // if(verbose > 0){
+// // Rcout << "n_events_left: " << n_events_left << std::endl;
+// // Rcout << "n_risk_left: " << n_risk_left << std::endl;
+// // Rcout << "n_events_right: " << n_events_right << std::endl;
+// // Rcout << "n_risk_right: " << n_risk_right << std::endl;
+// // }
+//
+// i=0;
+//
+// for(iit = rows_node.begin(); iit != rows_node.end(); ++iit, ++i){
+//
+// node_assignments[*iit] = nn_left + group[i];
+//
+// }
+//
+// if(n_events_left >= 2*leaf_min_events &&
+// n_risk_left >= 2*leaf_min_obs &&
+// n_events_left >= split_min_events &&
+// n_risk_left >= split_min_obs){
+//
+// nodes_to_grow_next = join_cols(nodes_to_grow_next,
+// uvec{nn_left});
+//
+// } else {
+//
+// rows_leaf = find(group==0);
+// leaf_indices(leaf_node_index_counter, 0) = nn_left;
+// leaf_kaplan(y_node.rows(rows_leaf), w_node(rows_leaf));
+//
+// // if(verbose > 0){
+// // Rcout << "-------- creating a new leaf --------" << std::endl;
+// // Rcout << "name: node_" << nn_left << std::endl;
+// // Rcout << "n_obs: " << sum(w_node(rows_leaf));
+// // Rcout << std::endl;
+// // Rcout << "n_events: ";
+// // vec_temp = y_node.col(1);
+// // Rcout << sum(w_node(rows_leaf) % vec_temp(rows_leaf));
+// // Rcout << std::endl;
+// // Rcout << "------------------------------------";
+// // Rcout << std::endl << std::endl << std::endl;
+// // }
+//
+// }
+//
+// if(n_events_right >= 2*leaf_min_events &&
+// n_risk_right >= 2*leaf_min_obs &&
+// n_events_right >= split_min_events &&
+// n_risk_right >= split_min_obs){
+//
+// nodes_to_grow_next = join_cols(nodes_to_grow_next,
+// uvec{nodes_max_true});
+//
+// } else {
+//
+// rows_leaf = find(group==1);
+// leaf_indices(leaf_node_index_counter, 0) = nodes_max_true;
+// leaf_kaplan(y_node.rows(rows_leaf), w_node(rows_leaf));
+//
+// // if(verbose > 0){
+// // Rcout << "-------- creating a new leaf --------" << std::endl;
+// // Rcout << "name: node_" << nodes_max_true << std::endl;
+// // Rcout << "n_obs: " << sum(w_node(rows_leaf));
+// // Rcout << std::endl;
+// // Rcout << "n_events: ";
+// // vec_temp = y_node.col(1);
+// // Rcout << sum(w_node(rows_leaf) % vec_temp(rows_leaf));
+// // Rcout << std::endl;
+// // Rcout << "------------------------------------";
+// // Rcout << std::endl << std::endl << std::endl;
+// // }
+//
+// }
+//
+// if(nodes_max_true >= betas.n_cols) ostree_size_buffer();
+//
+// for(i = 0; i < n_cols_to_sample; i++){
+// betas.at(i, *node) = beta_fit[i];
+// // x_mean.at(i, *node) = x_transforms(i, 0);
+// col_indices.at(i, *node) = cols_node[i];
+// }
+//
+// children_left[*node] = nn_left;
+// cutpoints[*node] = cutpoint;
+//
+// } else {
+//
+// // make a leaf node if a valid cutpoint could not be found
+// leaf_indices(leaf_node_index_counter, 0) = *node;
+// leaf_kaplan(y_node, w_node);
+//
+// // if(verbose > 0){
+// // Rcout << "-------- creating a new leaf --------" << std::endl;
+// // Rcout << "name: node_" << *node << std::endl;
+// // Rcout << "n_obs: " << sum(w_node) << std::endl;
+// // Rcout << "n_events: " << sum(w_node % y_node.col(1));
+// // Rcout << std::endl;
+// // Rcout << "Couldn't find a cutpoint??" << std::endl;
+// // Rcout << "------------------------------------" << std::endl;
+// // Rcout << std::endl << std::endl;
+// // }
+//
+// }
+//
+// }
+//
+// nodes_to_grow = nodes_to_grow_next;
+//
+// } while (nodes_to_grow.size() > 0);
+//
+// return(
+// List::create(
+//
+// _["leaf_nodes"] = leaf_nodes.rows(span(0, leaf_node_counter-1)),
+//
+// _["leaf_node_index"] = conv_to::from(
+// leaf_indices.rows(span(0, leaf_node_index_counter-1))
+// ),
+//
+// _["betas"] = betas.cols(span(0, nodes_max_true)),
+//
+// // _["x_mean"] = x_mean.cols(span(0, nodes_max_true)),
+//
+// _["col_indices"] = conv_to::from(
+// col_indices.cols(span(0, nodes_max_true))
+// ),
+//
+// _["cut_points"] = cutpoints(span(0, nodes_max_true)),
+//
+// _["children_left"] = conv_to::from(
+// children_left(span(0, nodes_max_true))
+// ),
+//
+// _["rows_oobag"] = conv_to::from(rows_oobag)
+//
+// )
+// );
+//
+//
+// }
+//
+// // ----------------------------------------------------------------------------
+// // ---------------------------- orsf functions --------------------------------
+// // ----------------------------------------------------------------------------
+//
+// // fit an oblique random survival forest.
+// //
+// // @param x matrix of predictors
+// // @param y matrix of outcomes
+// // @param weights vector of weights
+// // @param n_tree number of trees to fit
+// // @param n_split_ number of splits to try with lrt
+// // @param mtry_ number of predictors to try
+// // @param leaf_min_events_ min number of events in a leaf
+// // @param leaf_min_obs_ min number of observations in a leaf
+// // @param split_min_events_ min number of events to split a node
+// // @param split_min_obs_ min number of observations to split a node
+// // @param split_min_stat_ min lrt to split a node
+// // @param cph_method_ method for ties
+// // @param cph_eps_ criteria for convergence of newton raphson algorithm
+// // @param cph_iter_max_ max number of newton raphson iterations
+// // @param cph_do_scale_ to scale or not to scale
+// // @param net_alpha_ alpha parameter for glmnet
+// // @param net_df_target_ degrees of freedom for glmnet
+// // @param oobag_pred_ whether to predict out-of-bag preds or not
+// // @param oobag_pred_type_ what type of out-of-bag preds to compute
+// // @param oobag_pred_horizon_ out-of-bag prediction horizon
+// // @param oobag_eval_every_ trees between each evaluation of oob error
+// // @param oobag_importance_ to compute importance or not
+// // @param oobag_importance_type_ type of importance to compute
+// // @param tree_seeds vector of seeds to set before each tree is fit
+// // @param max_retry_ max number of retries for linear combinations
+// // @param f_beta function to find linear combinations of predictors
+// // @param type_beta_ what type of linear combination to find
+// // @param f_oobag_eval function to evaluate out-of-bag error
+// // @param type_oobag_eval_ whether to use default or custom out-of-bag error
+// //
+// // @return an orsf_fit object sent back to R
+//
+// // [[Rcpp::export]]
+// List orsf_fit(NumericMatrix& x,
+// NumericMatrix& y,
+// NumericVector& weights,
+// const int& n_tree,
+// const int& n_split_,
+// const int& mtry_,
+// const double& leaf_min_events_,
+// const double& leaf_min_obs_,
+// const double& split_min_events_,
+// const double& split_min_obs_,
+// const double& split_min_stat_,
+// const int& cph_method_,
+// const double& cph_eps_,
+// const int& cph_iter_max_,
+// const bool& cph_do_scale_,
+// const double& net_alpha_,
+// const int& net_df_target_,
+// const bool& oobag_pred_,
+// const char& oobag_pred_type_,
+// const double& oobag_pred_horizon_,
+// const int& oobag_eval_every_,
+// const bool& oobag_importance_,
+// const char& oobag_importance_type_,
+// IntegerVector& tree_seeds,
+// const int& max_retry_,
+// Function f_beta,
+// const char& type_beta_,
+// Function f_oobag_eval,
+// const char& type_oobag_eval_,
+// const bool verbose_progress){
+//
+//
+// // convert inputs into arma objects
+// x_input = mat(x.begin(), x.nrow(), x.ncol(), false);
+//
+// y_input = mat(y.begin(), y.nrow(), y.ncol(), false);
+//
+// w_user = vec(weights.begin(), weights.length(), false);
+//
+// // these change later in ostree_fit()
+// n_rows = x_input.n_rows;
+// n_vars = x_input.n_cols;
+//
+// // initialize the variable importance (vi) vectors
+// vi_pval_numer.zeros(n_vars);
+// vi_pval_denom.zeros(n_vars);
+//
+// // if(verbose > 0){
+// // Rcout << "------------ dimensions ------------" << std::endl;
+// // Rcout << "N obs total: " << n_rows << std::endl;
+// // Rcout << "N columns total: " << n_vars << std::endl;
+// // Rcout << "------------------------------------";
+// // Rcout << std::endl << std::endl << std::endl;
+// // }
+//
+// n_split = n_split_;
+// mtry = mtry_;
+// leaf_min_events = leaf_min_events_;
+// leaf_min_obs = leaf_min_obs_;
+// split_min_events = split_min_events_;
+// split_min_obs = split_min_obs_;
+// split_min_stat = split_min_stat_;
+// cph_method = cph_method_;
+// cph_eps = cph_eps_;
+// cph_iter_max = cph_iter_max_;
+// cph_do_scale = cph_do_scale_;
+// net_alpha = net_alpha_;
+// net_df_target = net_df_target_;
+// oobag_pred = oobag_pred_;
+// oobag_pred_type = oobag_pred_type_;
+// oobag_eval_every = oobag_eval_every_;
+// oobag_eval_counter = 0;
+// oobag_importance = oobag_importance_;
+// oobag_importance_type = oobag_importance_type_;
+// use_tree_seed = tree_seeds.length() > 0;
+// max_retry = max_retry_;
+// type_beta = type_beta_;
+// type_oobag_eval = type_oobag_eval_;
+// temp1 = 1.0 / n_rows;
+//
+// if(cph_iter_max > 1) cph_do_scale = true;
+//
+// if((type_beta == 'N') || (type_beta == 'U')) cph_do_scale = false;
+//
+// if(cph_iter_max == 1) cph_do_scale = false;
+//
+//
+// if(oobag_pred){
+//
+// time_pred = oobag_pred_horizon_;
+//
+// if(time_pred == 0) time_pred = median(y_input.col(0));
+//
+// eval_oobag.set_size(std::floor(n_tree / oobag_eval_every));
+//
+// } else {
+//
+// eval_oobag.set_size(0);
+//
+// }
+//
+// // if(verbose > 0){
+// // Rcout << "------------ input variables ------------" << std::endl;
+// // Rcout << "n_split: " << n_split << std::endl;
+// // Rcout << "mtry: " << mtry << std::endl;
+// // Rcout << "leaf_min_events: " << leaf_min_events << std::endl;
+// // Rcout << "leaf_min_obs: " << leaf_min_obs << std::endl;
+// // Rcout << "cph_method: " << cph_method << std::endl;
+// // Rcout << "cph_eps: " << cph_eps << std::endl;
+// // Rcout << "cph_iter_max: " << cph_iter_max << std::endl;
+// // Rcout << "-----------------------------------------" << std::endl;
+// // Rcout << std::endl << std::endl;
+// // }
+//
+// // ----------------------------------------------------
+// // ---- sample weights to mimic a bootstrap sample ----
+// // ----------------------------------------------------
+//
+// // s is the number of times you might get selected into
+// // a bootstrap sample. Realistically this won't be >10,
+// // but it could technically be as big as n_row.
+// IntegerVector s = seq(0, 10);
+//
+// // compute probability of being selected into the bootstrap
+// // 0 times, 1, times, ..., 9 times, or 10 times.
+// NumericVector probs = dbinom(s, n_rows, temp1, false);
+//
+// // ---------------------------------------------
+// // ---- preallocate memory for tree outputs ----
+// // ---------------------------------------------
+//
+// cols_to_sample_01.zeros(n_vars);
+// leaf_nodes.zeros(n_rows, 3);
+//
+// if(oobag_pred){
+//
+// surv_pvec.zeros(n_rows);
+// denom_pred.zeros(n_rows);
+//
+// } else {
+//
+// surv_pvec.set_size(0);
+// denom_pred.set_size(0);
+//
+// }
+//
+// // guessing the number of nodes needed to grow a tree
+// nodes_max_guess = std::ceil(0.5 * n_rows / leaf_min_events);
+//
+// betas.zeros(mtry, nodes_max_guess);
+// // x_mean.zeros(mtry, nodes_max_guess);
+// col_indices.zeros(mtry, nodes_max_guess);
+// cutpoints.zeros(nodes_max_guess);
+// children_left.zeros(nodes_max_guess);
+// leaf_indices.zeros(nodes_max_guess, 3);
+//
+// // some great variable names here
+// List forest(n_tree);
+//
+// for(tree = 0; tree < n_tree; ){
+//
+// // Abort the routine if user has pressed Ctrl + C or Escape in R.
+// Rcpp::checkUserInterrupt();
+//
+// // --------------------------------------------
+// // ---- initialize parameters to grow tree ----
+// // --------------------------------------------
+//
+// // rows_inbag = find(w_inbag != 0);
+//
+// if(use_tree_seed) set_seed_r(tree_seeds[tree]);
+//
+// w_input = as(sample(s, n_rows, true, probs));
+//
+// // if the user gives a weight vector, then each bootstrap weight
+// // should be multiplied by the corresponding user weight.
+// if(w_user.size() > 0) w_input = w_input % w_user;
+//
+// rows_oobag = find(w_input == 0);
+// rows_inbag = regspace(0, n_rows-1);
+// rows_inbag = std_setdiff(rows_inbag, rows_oobag);
+// w_inbag = w_input(rows_inbag);
+//
+// // if(verbose > 0){
+// //
+// // Rcout << "------------ boot weights ------------" << std::endl;
+// // Rcout << "pr(inbag): " << 1-pow(1-temp1,n_rows) << std::endl;
+// // Rcout << "total: " << sum(w_inbag) << std::endl;
+// // Rcout << "N > 0: " << rows_inbag.size() << std::endl;
+// // Rcout << "--------------------------------------" <<
+// // std::endl << std::endl << std::endl;
+// //
+// // }
+//
+// x_inbag = x_input.rows(rows_inbag);
+// y_inbag = y_input.rows(rows_inbag);
+//
+// if(oobag_pred){
+// x_pred = x_input.rows(rows_oobag);
+// leaf_pred.set_size(rows_oobag.size());
+// }
+//
+// // if(verbose > 0){
+// //
+// // uword temp_uword_1, temp_uword_2;
+// //
+// // if(x_inbag.n_rows < 5)
+// // temp_uword_1 = x_inbag.n_rows-1;
+// // else
+// // temp_uword_1 = 5;
+// //
+// // if(x_inbag.n_cols < 5)
+// // temp_uword_2 = x_inbag.n_cols-1;
+// // else
+// // temp_uword_2 = 4;
+// //
+// // Rcout << "x inbag: " << std::endl <<
+// // x_inbag.submat(0, 0,
+// // temp_uword_1,
+// // temp_uword_2) << std::endl;
+// //
+// // }
+//
+// if(verbose_progress){
+// Rcout << "\r growing tree no. " << tree << " of " << n_tree;
+// }
+//
+//
+// forest[tree] = ostree_fit(f_beta);
+//
+// // add 1 to tree here instead of end of loop
+// // (more convenient to compute tree % oobag_eval_every)
+// tree++;
+//
+//
+// if(oobag_pred){
+//
+// denom_pred(rows_oobag) += 1;
+// ostree_pred_leaf();
+// oobag_pred_surv_uni(oobag_pred_type);
+//
+// if(tree % oobag_eval_every == 0){
+//
+// switch(type_oobag_eval) {
+//
+// // H stands for Harrell's C-statistic
+// case 'H' :
+//
+// eval_oobag[oobag_eval_counter] = oobag_c_harrell(oobag_pred_type);
+// oobag_eval_counter++;
+//
+// break;
+//
+// // U stands for a user-supplied function
+// case 'U' :
+//
+// ww = wrap(surv_pvec);
+//
+// eval_oobag[oobag_eval_counter] = as