Skip to content

Commit

Permalink
Merge pull request #17 from ropensci/oop
Browse files Browse the repository at this point in the history
object oriented re-write
  • Loading branch information
bcjaeger authored Oct 3, 2023
2 parents e82c397 + 5788982 commit 786ff2b
Show file tree
Hide file tree
Showing 60 changed files with 11,891 additions and 5,831 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
^Rmd/
lastMiKTeXException
^\.zenodo\.json$
^scratch\.R$
10 changes: 9 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# aorsf 0.0.8
# aorsf 0.1.0 (unreleased)

* Re-worked `aorsf`'s C++, code following the design of `ranger`, to set it up for classification and regression trees.

* Allowed multi-threading to be performed in `orsf()`, `predict.orsf_fit()`, and functions in the `orsf_vi()` and `orsf_pd()` family.

* Allowed for sampling without replacement and sampling a specific fraction of observations in `orsf()`

* Included Harrell's C-statistic as an option for assessing goodness of splits while growing trees.

* Fixed an issue where an uninformative error message would occur when `pred_horizon` was > max(time) for `orsf_summarize_uni`. Thanks to @JyHao1 and @DustinMLong for finding this!

Expand Down
64 changes: 12 additions & 52 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,67 +1,27 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

std_setdiff <- function(x, y) {
.Call(`_aorsf_std_setdiff`, x, y)
coxph_fit_exported <- function(x_node, y_node, w_node, method, cph_eps, cph_iter_max) {
.Call(`_aorsf_coxph_fit_exported`, x_node, y_node, w_node, method, cph_eps, cph_iter_max)
}

x_node_scale_exported <- function(x_, w_) {
.Call(`_aorsf_x_node_scale_exported`, x_, w_)
compute_cstat_exported_vec <- function(y, w, p, pred_is_risklike) {
.Call(`_aorsf_compute_cstat_exported_vec`, y, w, p, pred_is_risklike)
}

leaf_kaplan_testthat <- function(y, w) {
.Call(`_aorsf_leaf_kaplan_testthat`, y, w)
compute_cstat_exported_uvec <- function(y, w, g, pred_is_risklike) {
.Call(`_aorsf_compute_cstat_exported_uvec`, y, w, g, pred_is_risklike)
}

newtraph_cph_testthat <- function(x_in, y_in, w_in, method, cph_eps_, iter_max) {
.Call(`_aorsf_newtraph_cph_testthat`, x_in, y_in, w_in, method, cph_eps_, iter_max)
compute_logrank_exported <- function(y, w, g) {
.Call(`_aorsf_compute_logrank_exported`, y, w, g)
}

lrt_multi_testthat <- function(y_node_, w_node_, XB_, n_split_, leaf_min_events_, leaf_min_obs_) {
.Call(`_aorsf_lrt_multi_testthat`, y_node_, w_node_, XB_, n_split_, leaf_min_events_, leaf_min_obs_)
cph_scale <- function(x, w) {
.Call(`_aorsf_cph_scale`, x, w)
}

oobag_c_harrell_testthat <- function(y_mat, s_vec) {
.Call(`_aorsf_oobag_c_harrell_testthat`, y_mat, s_vec)
}

ostree_pred_leaf_testthat <- function(tree, x_pred_) {
.Call(`_aorsf_ostree_pred_leaf_testthat`, tree, x_pred_)
}

orsf_fit <- function(x, y, weights, n_tree, n_split_, mtry_, leaf_min_events_, leaf_min_obs_, split_min_events_, split_min_obs_, split_min_stat_, cph_method_, cph_eps_, cph_iter_max_, cph_do_scale_, net_alpha_, net_df_target_, oobag_pred_, oobag_pred_type_, oobag_pred_horizon_, oobag_eval_every_, oobag_importance_, oobag_importance_type_, tree_seeds, max_retry_, f_beta, type_beta_, f_oobag_eval, type_oobag_eval_, verbose_progress) {
.Call(`_aorsf_orsf_fit`, x, y, weights, n_tree, n_split_, mtry_, leaf_min_events_, leaf_min_obs_, split_min_events_, split_min_obs_, split_min_stat_, cph_method_, cph_eps_, cph_iter_max_, cph_do_scale_, net_alpha_, net_df_target_, oobag_pred_, oobag_pred_type_, oobag_pred_horizon_, oobag_eval_every_, oobag_importance_, oobag_importance_type_, tree_seeds, max_retry_, f_beta, type_beta_, f_oobag_eval, type_oobag_eval_, verbose_progress)
}

orsf_oob_negate_vi <- function(x, y, forest, last_eval_stat, time_pred_, f_oobag_eval, pred_type_, type_oobag_eval_) {
.Call(`_aorsf_orsf_oob_negate_vi`, x, y, forest, last_eval_stat, time_pred_, f_oobag_eval, pred_type_, type_oobag_eval_)
}

orsf_oob_permute_vi <- function(x, y, forest, last_eval_stat, time_pred_, f_oobag_eval, pred_type_, type_oobag_eval_) {
.Call(`_aorsf_orsf_oob_permute_vi`, x, y, forest, last_eval_stat, time_pred_, f_oobag_eval, pred_type_, type_oobag_eval_)
}

orsf_pred_uni <- function(forest, x_new, time_dbl, pred_type) {
.Call(`_aorsf_orsf_pred_uni`, forest, x_new, time_dbl, pred_type)
}

orsf_pred_multi <- function(forest, x_new, time_vec, pred_type) {
.Call(`_aorsf_orsf_pred_multi`, forest, x_new, time_vec, pred_type)
}

pd_new_smry <- function(forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type) {
.Call(`_aorsf_pd_new_smry`, forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type)
}

pd_oob_smry <- function(forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type) {
.Call(`_aorsf_pd_oob_smry`, forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type)
}

pd_new_ice <- function(forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type) {
.Call(`_aorsf_pd_new_ice`, forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type)
}

pd_oob_ice <- function(forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type) {
.Call(`_aorsf_pd_oob_ice`, forest, x_new_, x_cols_, x_vals_, probs_, time_dbl, pred_type)
orsf_cpp <- function(x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, sample_with_replacement, sample_fraction, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_x_vals, pd_x_cols, pd_probs, n_thread, write_forest, run_forest, verbosity) {
.Call(`_aorsf_orsf_cpp`, x, y, w, tree_type_R, tree_seeds, loaded_forest, lincomb_R_function, oobag_R_function, n_tree, mtry, sample_with_replacement, sample_fraction, vi_type_R, vi_max_pvalue, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_cuts, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, lincomb_ties_method, pred_mode, pred_type_R, pred_horizon, pred_aggregate, oobag, oobag_eval_type_R, oobag_eval_every, pd_type_R, pd_x_vals, pd_x_cols, pd_probs, n_thread, write_forest, run_forest, verbosity)
}

125 changes: 103 additions & 22 deletions R/check.R
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,13 @@ check_orsf_inputs <- function(data = NULL,
n_tree = NULL,
n_split = NULL,
n_retry = NULL,
n_thread = NULL,
mtry = NULL,
sample_with_replacement = NULL,
sample_fraction = NULL,
leaf_min_events = NULL,
leaf_min_obs = NULL,
split_rule = NULL,
split_min_events = NULL,
split_min_obs = NULL,
split_min_stat = NULL,
Expand Down Expand Up @@ -796,6 +800,25 @@ check_orsf_inputs <- function(data = NULL,

}

if(!is.null(n_thread)){

check_arg_type(arg_value = n_thread,
arg_name = 'n_thread',
expected_type = 'numeric')

check_arg_is_integer(arg_name = 'n_thread',
arg_value = n_thread)

check_arg_gteq(arg_name = 'n_thread',
arg_value = n_thread,
bound = 0)

check_arg_length(arg_name = 'n_thread',
arg_value = n_thread,
expected_length = 1)

}

if(!is.null(mtry)){

check_arg_type(arg_value = mtry,
Expand All @@ -815,6 +838,38 @@ check_orsf_inputs <- function(data = NULL,

}

if(!is.null(sample_with_replacement)){

check_arg_type(arg_value = sample_with_replacement,
arg_name = 'sample_with_replacement',
expected_type = 'logical')

check_arg_length(arg_name = 'sample_with_replacement',
arg_value = sample_with_replacement,
expected_length = 1)

}

if(!is.null(sample_fraction)){

check_arg_type(arg_value = sample_fraction,
arg_name = 'sample_fraction',
expected_type = 'numeric')

check_arg_gt(arg_value = sample_fraction,
arg_name = 'sample_fraction',
bound = 0)

check_arg_lteq(arg_value = sample_fraction,
arg_name = 'sample_fraction',
bound = 1)

check_arg_length(arg_value = sample_fraction,
arg_name = 'sample_fraction',
expected_length = 1)

}

if(!is.null(leaf_min_events)){

check_arg_type(arg_value = leaf_min_events,
Expand Down Expand Up @@ -852,6 +907,22 @@ check_orsf_inputs <- function(data = NULL,

}

if(!is.null(split_rule)){

check_arg_type(arg_value = split_rule,
arg_name = 'split_rule',
expected_type = 'character')

check_arg_length(arg_value = split_rule,
arg_name = 'split_rule',
expected_length = 1)

check_arg_is_valid(arg_value = split_rule,
arg_name = 'split_rule',
valid_options = c("logrank", "cstat"))

}

if(!is.null(split_min_events)){

check_arg_type(arg_value = split_min_events,
Expand Down Expand Up @@ -916,19 +987,14 @@ check_orsf_inputs <- function(data = NULL,
arg_name = 'oobag_pred_type',
expected_length = 1)

if(oobag_pred_type == 'mort') stop(
"Out-of-bag mortality predictions aren't supported yet. ",
" Sorry for the inconvenience - we plan on including this option",
" in a future update.",
call. = FALSE
)

check_arg_is_valid(arg_value = oobag_pred_type,
arg_name = 'oobag_pred_type',
valid_options = c("none",
"surv",
"risk",
"chf"))
"chf",
"mort",
"leaf"))

}

Expand All @@ -938,13 +1004,17 @@ check_orsf_inputs <- function(data = NULL,
arg_name = 'oobag_pred_horizon',
expected_type = 'numeric')

check_arg_length(arg_value = oobag_pred_horizon,
arg_name = 'oobag_pred_horizon',
expected_length = 1)
# check_arg_length(arg_value = oobag_pred_horizon,
# arg_name = 'oobag_pred_horizon',
# expected_length = 1)

check_arg_gteq(arg_value = oobag_pred_horizon,
arg_name = 'oobag_pred_horizon',
bound = 0)
for(i in seq_along(oobag_pred_horizon)){

check_arg_gteq(arg_value = oobag_pred_horizon[i],
arg_name = 'oobag_pred_horizon',
bound = 0)

}

}

Expand Down Expand Up @@ -1000,7 +1070,7 @@ check_orsf_inputs <- function(data = NULL,

check_arg_is_integer(tree_seeds, arg_name = 'tree_seeds')

if(length(tree_seeds) != n_tree){
if(length(tree_seeds) > 1 && length(tree_seeds) != n_tree){

stop('tree_seeds should have length <', n_tree,
"> (the number of trees) but instead has length <",
Expand Down Expand Up @@ -1550,7 +1620,8 @@ check_predict <- function(object,
valid_options = c("risk",
"surv",
"chf",
"mort"))
"mort",
"leaf"))

}

Expand Down Expand Up @@ -1615,8 +1686,8 @@ check_oobag_fun <- function(oobag_fun){

oobag_fun_args <- names(formals(oobag_fun))

if(length(oobag_fun_args) != 2) stop(
"oobag_fun should have 2 input arguments but instead has ",
if(length(oobag_fun_args) != 3) stop(
"oobag_fun should have 3 input arguments but instead has ",
length(oobag_fun_args),
call. = FALSE
)
Expand All @@ -1627,8 +1698,14 @@ check_oobag_fun <- function(oobag_fun){
call. = FALSE
)

if(oobag_fun_args[2] != 's_vec') stop(
"the second input argument of oobag_fun should be named 's_vec' ",
if(oobag_fun_args[2] != 'w_vec') stop(
"the second input argument of oobag_fun should be named 'w_vec' ",
"but is instead named '", oobag_fun_args[1], "'",
call. = FALSE
)

if(oobag_fun_args[3] != 's_vec') stop(
"the third input argument of oobag_fun should be named 's_vec' ",
"but is instead named '", oobag_fun_args[2], "'",
call. = FALSE
)
Expand All @@ -1637,9 +1714,12 @@ check_oobag_fun <- function(oobag_fun){
test_status <- rep(c(0,1), each = 50)

.y_mat <- cbind(time = test_time, status = test_status)
.w_vec <- rep(1, times = 100)
.s_vec <- seq(0.9, 0.1, length.out = 100)

test_output <- try(oobag_fun(y_mat = .y_mat, s_vec = .s_vec),
test_output <- try(oobag_fun(y_mat = .y_mat,
w_vec = .w_vec,
s_vec = .s_vec),
silent = FALSE)

if(is_error(test_output)){
Expand All @@ -1649,8 +1729,9 @@ check_oobag_fun <- function(oobag_fun){
"test_time <- seq(from = 1, to = 5, length.out = 100)\n",
"test_status <- rep(c(0,1), each = 50)\n\n",
"y_mat <- cbind(time = test_time, status = test_status)\n",
"w_vec <- rep(1, times = 100)\n",
"s_vec <- seq(0.9, 0.1, length.out = 100)\n\n",
"test_output <- oobag_fun(y_mat = y_mat, s_vec = s_vec)\n\n",
"test_output <- oobag_fun(y_mat = y_mat, w_vec = w_vec, s_vec = s_vec)\n\n",
"test_output should be a numeric value of length 1",
call. = FALSE)

Expand Down
15 changes: 15 additions & 0 deletions R/compute_mean_leaves.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@


compute_mean_leaves <- function(forest){

if(is.null(forest$leaf_summary)){
return(0)
}

collapse::fmean(
vapply(forest$leaf_summary,
function(leaf_smry) sum(leaf_smry != 0),
FUN.VALUE = integer(1))
)

}
49 changes: 49 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,55 @@ paste_collapse <- function(x, sep=', ', last = ' or '){

}

#' Find cut-point boundaries (R version)
#'
#' Used to test the cpp version for finding cutpoints
#'
#' @param y_node outcome matrix
#' @param w_node weight vector
#' @param XB linear combination of predictors
#' @param xb_uni unique values in XB
#' @param leaf_min_events min no. of events in a leaf
#' @param leaf_min_obs min no. of observations in a leaf
#'
#' @noRd
#'
#' @return data.frame with description of valid cutpoints
cp_find_bounds_R <- function(y_node,
w_node,
XB,
xb_uni,
leaf_min_events,
leaf_min_obs){

status = y_node[, 'status']

cp_stats <-
sapply(
X = xb_uni,
FUN = function(x){
c(
cp = x,
e_right = sum(status[XB > x]),
e_left = sum(status[XB <= x]),
n_right = sum(XB > x),
n_left = sum(XB <= x)
)
}
)

cp_stats <- as.data.frame(t(cp_stats))

cp_stats$valid_cp = with(
cp_stats,
e_right >= leaf_min_events & e_left >= leaf_min_events &
n_right >= leaf_min_obs & n_left >= leaf_min_obs
)

cp_stats

}

has_units <- function(x){
inherits(x, 'units')
}
Expand Down
Loading

0 comments on commit 786ff2b

Please sign in to comment.