Skip to content

Commit

Permalink
getting to coxph in tree grow
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Aug 12, 2023
1 parent c1820fe commit 41504b7
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 58 deletions.
12 changes: 6 additions & 6 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

coxph_scale_exported <- function(x_, w_) {
.Call(`_aorsf_coxph_scale_exported`, x_, w_)
coxph_scale_exported <- function(x_node, w_node) {
.Call(`_aorsf_coxph_scale_exported`, x_node, w_node)
}

coxph_fit_exported <- function(x_, y_, w_, method, cph_eps, cph_iter_max) {
.Call(`_aorsf_coxph_fit_exported`, x_, y_, w_, method, cph_eps, cph_iter_max)
coxph_fit_exported <- function(x_node, y_node, w_node, method, cph_eps, cph_iter_max) {
.Call(`_aorsf_coxph_fit_exported`, x_node, y_node, w_node, method, cph_eps, cph_iter_max)
}

node_find_cps_exported <- function(y_node, w_node, XB, leaf_min_events, leaf_min_obs) {
Expand All @@ -25,7 +25,7 @@ lrt_multi_exported <- function(y_, w_, XB_, n_split_, split_min_stat, leaf_min_e
.Call(`_aorsf_lrt_multi_exported`, y_, w_, XB_, n_split_, split_min_stat, leaf_min_events, leaf_min_obs)
}

orsf_cpp <- function(x, y, w, tree_seeds, lincomb_R_function, f_oobag_eval, n_tree, mtry, vi_type_R, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, pred_mode, pred_type_R, pred_horizon, oobag_pred, oobag_eval_every) {
.Call(`_aorsf_orsf_cpp`, x, y, w, tree_seeds, lincomb_R_function, f_oobag_eval, n_tree, mtry, vi_type_R, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, pred_mode, pred_type_R, pred_horizon, oobag_pred, oobag_eval_every)
orsf_cpp <- function(x, y, w, tree_seeds, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, pred_mode, pred_type_R, pred_horizon, oobag_pred, oobag_eval_every) {
.Call(`_aorsf_orsf_cpp`, x, y, w, tree_seeds, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, pred_mode, pred_type_R, pred_horizon, oobag_pred, oobag_eval_every)
}

23 changes: 23 additions & 0 deletions src/Coxph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <RcppArmadillo.h>
#include "globals.h"
#include "Coxph.h"
#include "utility.h"

using namespace arma;
using namespace Rcpp;
Expand Down Expand Up @@ -675,6 +676,18 @@

}

if(VERBOSITY > 1){

Rcout << "--------- Newt-Raph algo; before rescale " << std::endl;
Rcout << "beta: " << beta_new.t() << std::endl;
Rcout << std::endl;

}


print_mat(x_transforms, "x_transforms", 10, 10);


// invert vmat
cholesky_invert(vmat);

Expand All @@ -683,6 +696,7 @@
beta_current[i] = beta_new[i];

if(std::isinf(beta_current[i]) || std::isnan(beta_current[i])){
Rcout << beta_current[i] << std::endl;
beta_current[i] = 0;
}

Expand Down Expand Up @@ -715,6 +729,15 @@

}

if(VERBOSITY > 1){

Rcout << "--------- Newt-Raph algo; after rescale " << std::endl;
Rcout << "beta: " << beta_current.t() << std::endl;
Rcout << std::endl;

}


// if(verbose > 1) Rcout << std::endl;

return(beta_current);
Expand Down
1 change: 1 addition & 0 deletions src/Forest.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "Data.h"
#include "globals.h"
#include "utility.h"
#include "Tree.h"

namespace aorsf {
Expand Down
1 change: 1 addition & 0 deletions src/NodeSplitStats.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <armadillo>
#include <Rcpp.h>


namespace aorsf {

arma::uvec node_find_cps(const arma::mat& y_node,
Expand Down
30 changes: 15 additions & 15 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,30 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// coxph_scale_exported
List coxph_scale_exported(NumericMatrix& x_, NumericVector& w_);
RcppExport SEXP _aorsf_coxph_scale_exported(SEXP x_SEXP, SEXP w_SEXP) {
List coxph_scale_exported(arma::vec& x_node, arma::vec& w_node);
RcppExport SEXP _aorsf_coxph_scale_exported(SEXP x_nodeSEXP, SEXP w_nodeSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericMatrix& >::type x_(x_SEXP);
Rcpp::traits::input_parameter< NumericVector& >::type w_(w_SEXP);
rcpp_result_gen = Rcpp::wrap(coxph_scale_exported(x_, w_));
Rcpp::traits::input_parameter< arma::vec& >::type x_node(x_nodeSEXP);
Rcpp::traits::input_parameter< arma::vec& >::type w_node(w_nodeSEXP);
rcpp_result_gen = Rcpp::wrap(coxph_scale_exported(x_node, w_node));
return rcpp_result_gen;
END_RCPP
}
// coxph_fit_exported
List coxph_fit_exported(NumericMatrix& x_, NumericMatrix& y_, NumericVector& w_, int method, double cph_eps, int cph_iter_max);
RcppExport SEXP _aorsf_coxph_fit_exported(SEXP x_SEXP, SEXP y_SEXP, SEXP w_SEXP, SEXP methodSEXP, SEXP cph_epsSEXP, SEXP cph_iter_maxSEXP) {
List coxph_fit_exported(arma::vec& x_node, arma::vec& y_node, arma::vec& w_node, int method, double cph_eps, int cph_iter_max);
RcppExport SEXP _aorsf_coxph_fit_exported(SEXP x_nodeSEXP, SEXP y_nodeSEXP, SEXP w_nodeSEXP, SEXP methodSEXP, SEXP cph_epsSEXP, SEXP cph_iter_maxSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericMatrix& >::type x_(x_SEXP);
Rcpp::traits::input_parameter< NumericMatrix& >::type y_(y_SEXP);
Rcpp::traits::input_parameter< NumericVector& >::type w_(w_SEXP);
Rcpp::traits::input_parameter< arma::vec& >::type x_node(x_nodeSEXP);
Rcpp::traits::input_parameter< arma::vec& >::type y_node(y_nodeSEXP);
Rcpp::traits::input_parameter< arma::vec& >::type w_node(w_nodeSEXP);
Rcpp::traits::input_parameter< int >::type method(methodSEXP);
Rcpp::traits::input_parameter< double >::type cph_eps(cph_epsSEXP);
Rcpp::traits::input_parameter< int >::type cph_iter_max(cph_iter_maxSEXP);
rcpp_result_gen = Rcpp::wrap(coxph_fit_exported(x_, y_, w_, method, cph_eps, cph_iter_max));
rcpp_result_gen = Rcpp::wrap(coxph_fit_exported(x_node, y_node, w_node, method, cph_eps, cph_iter_max));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -99,8 +99,8 @@ BEGIN_RCPP
END_RCPP
}
// orsf_cpp
List orsf_cpp(arma::mat& x, arma::mat& y, arma::uvec& w, Rcpp::IntegerVector& tree_seeds, Rcpp::Function& lincomb_R_function, Rcpp::Function f_oobag_eval, arma::uword n_tree, arma::uword mtry, arma::uword vi_type_R, double leaf_min_events, double leaf_min_obs, arma::uword split_rule_R, double split_min_events, double split_min_obs, double split_min_stat, arma::uword split_max_retry, arma::uword lincomb_type_R, double lincomb_eps, arma::uword lincomb_iter_max, bool lincomb_scale, double lincomb_alpha, arma::uword lincomb_df_target, bool pred_mode, arma::uword pred_type_R, double pred_horizon, bool oobag_pred, arma::uword oobag_eval_every);
RcppExport SEXP _aorsf_orsf_cpp(SEXP xSEXP, SEXP ySEXP, SEXP wSEXP, SEXP tree_seedsSEXP, SEXP lincomb_R_functionSEXP, SEXP f_oobag_evalSEXP, SEXP n_treeSEXP, SEXP mtrySEXP, SEXP vi_type_RSEXP, SEXP leaf_min_eventsSEXP, SEXP leaf_min_obsSEXP, SEXP split_rule_RSEXP, SEXP split_min_eventsSEXP, SEXP split_min_obsSEXP, SEXP split_min_statSEXP, SEXP split_max_retrySEXP, SEXP lincomb_type_RSEXP, SEXP lincomb_epsSEXP, SEXP lincomb_iter_maxSEXP, SEXP lincomb_scaleSEXP, SEXP lincomb_alphaSEXP, SEXP lincomb_df_targetSEXP, SEXP pred_modeSEXP, SEXP pred_type_RSEXP, SEXP pred_horizonSEXP, SEXP oobag_predSEXP, SEXP oobag_eval_everySEXP) {
List orsf_cpp(arma::mat& x, arma::mat& y, arma::uvec& w, Rcpp::IntegerVector& tree_seeds, Rcpp::Function& lincomb_R_function, Rcpp::Function& oobag_R_function, arma::uword n_tree, arma::uword mtry, arma::uword vi_type_R, double leaf_min_events, double leaf_min_obs, arma::uword split_rule_R, double split_min_events, double split_min_obs, double split_min_stat, arma::uword split_max_retry, arma::uword lincomb_type_R, double lincomb_eps, arma::uword lincomb_iter_max, bool lincomb_scale, double lincomb_alpha, arma::uword lincomb_df_target, bool pred_mode, arma::uword pred_type_R, double pred_horizon, bool oobag_pred, arma::uword oobag_eval_every);
RcppExport SEXP _aorsf_orsf_cpp(SEXP xSEXP, SEXP ySEXP, SEXP wSEXP, SEXP tree_seedsSEXP, SEXP lincomb_R_functionSEXP, SEXP oobag_R_functionSEXP, SEXP n_treeSEXP, SEXP mtrySEXP, SEXP vi_type_RSEXP, SEXP leaf_min_eventsSEXP, SEXP leaf_min_obsSEXP, SEXP split_rule_RSEXP, SEXP split_min_eventsSEXP, SEXP split_min_obsSEXP, SEXP split_min_statSEXP, SEXP split_max_retrySEXP, SEXP lincomb_type_RSEXP, SEXP lincomb_epsSEXP, SEXP lincomb_iter_maxSEXP, SEXP lincomb_scaleSEXP, SEXP lincomb_alphaSEXP, SEXP lincomb_df_targetSEXP, SEXP pred_modeSEXP, SEXP pred_type_RSEXP, SEXP pred_horizonSEXP, SEXP oobag_predSEXP, SEXP oobag_eval_everySEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand All @@ -109,7 +109,7 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< arma::uvec& >::type w(wSEXP);
Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type tree_seeds(tree_seedsSEXP);
Rcpp::traits::input_parameter< Rcpp::Function& >::type lincomb_R_function(lincomb_R_functionSEXP);
Rcpp::traits::input_parameter< Rcpp::Function >::type f_oobag_eval(f_oobag_evalSEXP);
Rcpp::traits::input_parameter< Rcpp::Function& >::type oobag_R_function(oobag_R_functionSEXP);
Rcpp::traits::input_parameter< arma::uword >::type n_tree(n_treeSEXP);
Rcpp::traits::input_parameter< arma::uword >::type mtry(mtrySEXP);
Rcpp::traits::input_parameter< arma::uword >::type vi_type_R(vi_type_RSEXP);
Expand All @@ -131,7 +131,7 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< double >::type pred_horizon(pred_horizonSEXP);
Rcpp::traits::input_parameter< bool >::type oobag_pred(oobag_predSEXP);
Rcpp::traits::input_parameter< arma::uword >::type oobag_eval_every(oobag_eval_everySEXP);
rcpp_result_gen = Rcpp::wrap(orsf_cpp(x, y, w, tree_seeds, lincomb_R_function, f_oobag_eval, n_tree, mtry, vi_type_R, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, pred_mode, pred_type_R, pred_horizon, oobag_pred, oobag_eval_every));
rcpp_result_gen = Rcpp::wrap(orsf_cpp(x, y, w, tree_seeds, lincomb_R_function, oobag_R_function, n_tree, mtry, vi_type_R, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, pred_mode, pred_type_R, pred_horizon, oobag_pred, oobag_eval_every));
return rcpp_result_gen;
END_RCPP
}
Expand Down
75 changes: 53 additions & 22 deletions src/Tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <RcppArmadillo.h>
#include "Tree.h"
#include "Coxph.h"

using namespace arma;
using namespace Rcpp;
Expand Down Expand Up @@ -33,7 +34,7 @@


this->data = data;
this->n_cols = data->n_cols;
this->n_cols_total = data->n_cols;

this->seed = seed;
this->mtry = mtry;
Expand Down Expand Up @@ -82,6 +83,12 @@
this->w_inbag = data->w_subvec(rows_inbag);
this->rows_oobag = find(boot_wts == 0);

if(VERBOSITY > 0){

print_mat(x_inbag, "x_inbag", 5, 5);

}

// all observations start in node 0
this->rows_node = linspace<uvec>(0, x_inbag.n_rows-1, x_inbag.n_rows);

Expand Down Expand Up @@ -131,14 +138,14 @@

// Start empty
std::vector<uword> cols_assessed, cols_accepted;
cols_assessed.reserve(n_cols);
cols_assessed.reserve(n_cols_total);
cols_accepted.reserve(mtry);

std::uniform_int_distribution<uword> unif_dist(0, n_cols - 1);
std::uniform_int_distribution<uword> unif_dist(0, n_cols_total - 1);

uword i, draw;

for (i = 0; i < n_cols; ++i) {
for (i = 0; i < n_cols_total; ++i) {

draw = unif_dist(random_number_generator);

Expand Down Expand Up @@ -171,39 +178,63 @@

void Tree::grow(){

// create inbag views of x, y, and w,
bootstrap();

sample_cols();
// assign all inbag observations to node 0
node_assignments.zeros(x_inbag.n_rows);

coef_indices.push_back(cols_node);
// coordinate the order that nodes are grown.
uvec nodes_open(1, fill::zeros);
uvec nodes_queued;

node_assignments.zeros(x_inbag.n_rows);
// iterate through nodes to be grown
uvec::iterator node;

uvec nodes_to_grow(1, fill::zeros);
uvec rows_node;
for(node = nodes_open.begin(); node != nodes_open.end(); ++node){

if(VERBOSITY > 0){
if(nodes_open[0] == 0){

// when growing the first node, there is no need to find
// which rows are in the node.
rows_node = linspace<uvec>(0,
x_inbag.n_rows-1,
x_inbag.n_rows);

} else {

uword temp_uword_1, temp_uword_2;
// identify which rows are in the current node.
rows_node = find(node_assignments == *node);

if(x_inbag.n_rows < 5)
temp_uword_1 = x_inbag.n_rows-1;
else
temp_uword_1 = 5;
}

y_node = y_inbag.rows(rows_node);
w_node = w_inbag(rows_node);

sample_cols();

x_node = x_inbag(rows_node, cols_node);

Rcout << x_node << std::endl;

if(x_inbag.n_cols < 5)
temp_uword_2 = x_inbag.n_cols-1;
else
temp_uword_2 = 4;
print_mat(x_node, "x_node", 5, 5);

Rcout << " ---- view of x_inbag ---- " << std::endl << std::endl;
Rcout << round(x_inbag.submat(0, 0, temp_uword_1, temp_uword_2));
Rcout << std::endl << std::endl;
vec w_node_doubles = conv_to<vec>::from(w_node);
vec beta = coxph_fit(x_node, y_node, w_node_doubles, 1, 1e-9, 20, 'A');

Rcout << beta << std::endl;

print_mat(beta, "beta", 5, 5);

coef_values.push_back(beta);

coef_indices.push_back(cols_node);

}




} // Tree::grow

} // namespace aorsf
Expand Down
9 changes: 5 additions & 4 deletions src/Tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "Data.h"
#include "globals.h"
#include "utility.h"

namespace aorsf {

Expand Down Expand Up @@ -58,7 +59,7 @@
// Pointer to original data
Data* data;

arma::uword n_cols;
arma::uword n_cols_total;

int seed;

Expand Down Expand Up @@ -91,13 +92,13 @@
double split_min_events;
double split_min_obs;
double split_min_stat;
arma::uword split_max_retry;
arma::uword split_max_retry;
LinearCombo lincomb_type;
double lincomb_eps;
arma::uword lincomb_iter_max;
arma::uword lincomb_iter_max;
bool lincomb_scale;
double lincomb_alpha;
arma::uword lincomb_df_target;
arma::uword lincomb_df_target;
double pred_horizon;


Expand Down
19 changes: 8 additions & 11 deletions src/orsf_oop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@
// Same as x_node_scale, but this can be called from R

// [[Rcpp::export]]
List coxph_scale_exported(NumericMatrix& x_,
NumericVector& w_){
List coxph_scale_exported(arma::vec& x_node,
arma::vec& w_node){

mat x_node = mat(x_.begin(), x_.nrow(), x_.ncol(), false);
vec w_node = vec(w_.begin(), w_.length(), false);
mat x_transforms = coxph_scale(x_node, w_node);

return(
Expand All @@ -59,16 +57,16 @@
}

// [[Rcpp::export]]
List coxph_fit_exported(NumericMatrix& x_,
NumericMatrix& y_,
NumericVector& w_,
List coxph_fit_exported(arma::vec& x_node,
arma::vec& y_node,
arma::vec& w_node,
int method,
double cph_eps,
int cph_iter_max){

mat x_node = mat(x_.begin(), x_.nrow(), x_.ncol(), false);
mat y_node = mat(y_.begin(), y_.nrow(), y_.ncol(), false);
vec w_node = vec(w_.begin(), w_.length(), false);
// mat x_node = mat(x_.begin(), x_.nrow(), x_.ncol(), false);
// mat y_node = mat(y_.begin(), y_.nrow(), y_.ncol(), false);
// uvec w_node = uvec(w_.begin(), w_.length(), false);

uword cph_iter_max_ = cph_iter_max;

Expand Down Expand Up @@ -178,7 +176,6 @@

}


// [[Rcpp::plugins("cpp17")]]
// [[Rcpp::export]]
List orsf_cpp(arma::mat& x,
Expand Down
Loading

0 comments on commit 41504b7

Please sign in to comment.