Skip to content

Commit

Permalink
handling R function passage
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Aug 11, 2023
1 parent 70b2e6b commit c1820fe
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 248 deletions.
8 changes: 2 additions & 6 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,11 @@ node_fill_group_exported <- function(group, XB_sorted, start, stop, value) {
invisible(.Call(`_aorsf_node_fill_group_exported`, group, XB_sorted, start, stop, value))
}

which_cols_valid_exported <- function(y_inbag, x_inbag, rows_node, mtry) {
.Call(`_aorsf_which_cols_valid_exported`, y_inbag, x_inbag, rows_node, mtry)
}

lrt_multi_exported <- function(y_, w_, XB_, n_split_, split_min_stat, leaf_min_events, leaf_min_obs) {
.Call(`_aorsf_lrt_multi_exported`, y_, w_, XB_, n_split_, split_min_stat, leaf_min_events, leaf_min_obs)
}

orsf_cpp <- function(x, y, w, tree_seeds, f_beta, f_oobag_eval, n_tree, mtry, vi_type_R, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, pred_mode, pred_type_R, pred_horizon, oobag_pred, oobag_eval_every) {
.Call(`_aorsf_orsf_cpp`, x, y, w, tree_seeds, f_beta, f_oobag_eval, n_tree, mtry, vi_type_R, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, pred_mode, pred_type_R, pred_horizon, oobag_pred, oobag_eval_every)
orsf_cpp <- function(x, y, w, tree_seeds, lincomb_R_function, f_oobag_eval, n_tree, mtry, vi_type_R, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, pred_mode, pred_type_R, pred_horizon, oobag_pred, oobag_eval_every) {
.Call(`_aorsf_orsf_cpp`, x, y, w, tree_seeds, lincomb_R_function, f_oobag_eval, n_tree, mtry, vi_type_R, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, pred_mode, pred_type_R, pred_horizon, oobag_pred, oobag_eval_every)
}

6 changes: 3 additions & 3 deletions src/Data.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

Data(arma::mat& x,
arma::mat& y,
arma::vec& w) {
arma::uvec& w) {

this->x = x;
this->y = y;
Expand Down Expand Up @@ -72,15 +72,15 @@
return(y.submat(vector_of_row_indices, vector_of_column_indices));
}

arma::vec w_subvec(arma::uvec& vector_of_indices){
arma::uvec w_subvec(arma::uvec& vector_of_indices){
return(w(vector_of_indices));
}

// member variables

arma::uword n_cols;
arma::uword n_rows;
arma::vec w;
arma::uvec w;

bool has_weights;

Expand Down
47 changes: 23 additions & 24 deletions src/Forest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,27 @@ Forest::Forest(){ }

void Forest::init(std::unique_ptr<Data> input_data,
Rcpp::IntegerVector& tree_seeds,
int n_tree,
int mtry,
arma::uword n_tree,
arma::uword mtry,
VariableImportance vi_type,
double leaf_min_events,
double leaf_min_obs,
SplitRule split_rule,
double split_min_events,
double split_min_obs,
double split_min_stat,
int split_max_retry,
arma::uword split_max_retry,
LinearCombo lincomb_type,
double lincomb_eps,
int lincomb_iter_max,
arma::uword lincomb_iter_max,
bool lincomb_scale,
double lincomb_alpha,
int lincomb_df_target,
arma::uword lincomb_df_target,
PredType pred_type,
bool pred_mode,
double pred_horizon,
bool oobag_pred,
int oobag_eval_every){
arma::uword oobag_eval_every){

this->data = std::move(input_data);
this->tree_seeds = tree_seeds;
Expand Down Expand Up @@ -65,35 +65,23 @@ void Forest::init(std::unique_ptr<Data> input_data,
Rcout << std::endl << std::endl;
}

// sample weights to mimic a bootstrap sample
this->bootstrap_select_times = seq(0, 10);

uword n_rows = data->get_n_rows();

// compute probability of being selected into the bootstrap
// 0 times, 1, times, ..., 9 times, or 10 times.
this->bootstrap_select_probs = dbinom(bootstrap_select_times,
n_rows,
1.0 / n_rows,
false);

}

// growInternal() in ranger
void Forest::plant() {

trees.reserve(n_tree);

for (int i = 0; i < n_tree; ++i) {
for (arma::uword i = 0; i < n_tree; ++i) {
trees.push_back(std::make_unique<Tree>());
}

}

void Forest::grow(){
void Forest::grow(Function& lincomb_R_function){


for(int i = 0; i < n_tree; ++i){
for(uword i = 0; i < n_tree; ++i){

trees[i]->init(data.get(),
tree_seeds[i],
Expand All @@ -110,15 +98,26 @@ void Forest::grow(){
lincomb_iter_max,
lincomb_scale,
lincomb_alpha,
lincomb_df_target,
&bootstrap_select_times,
&bootstrap_select_probs);
lincomb_df_target);

trees[i]->grow();


}

double x_dbl = 1.0;

NumericMatrix test_mat = lincomb_R_function(x_dbl);

arma::mat test_mat_arma(test_mat.begin(),
test_mat.nrow(),
test_mat.ncol(), false);

Rcout << "--- test R function output ---" << std::endl << std::endl;
Rcout << test_mat_arma << std::endl;

// result.push_back(test_mat_arma, "test");



}
Expand Down
58 changes: 31 additions & 27 deletions src/Forest.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class Forest {

void init(std::unique_ptr<Data> input_data,
Rcpp::IntegerVector& tree_seeds,
int n_tree,
int mtry,
arma::uword n_tree,
arma::uword mtry,
VariableImportance vi_type,
// leaves
double leaf_min_events,
Expand All @@ -38,46 +38,50 @@ class Forest {
double split_min_events,
double split_min_obs,
double split_min_stat,
int split_max_retry,
arma::uword split_max_retry,
// linear combinations
LinearCombo lincomb_type,
LinearCombo lincomb_type,
double lincomb_eps,
int lincomb_iter_max,
bool lincomb_scale,
arma::uword lincomb_iter_max,
bool lincomb_scale,
double lincomb_alpha,
int lincomb_df_target,
arma::uword lincomb_df_target,
// predictions
PredType pred_type,
bool pred_mode,
bool pred_mode,
double pred_horizon,
bool oobag_pred,
int oobag_eval_every);
bool oobag_pred,
arma::uword oobag_eval_every);

// virtual void initInternal() = 0;
// virtual void initarma::uwordernal() = 0;

// Grow or predict
void run();

void grow();
void grow(Function& lincomb_R_function);

void plant();

Rcpp::IntegerVector get_bootstrap_select_times(){
return bootstrap_select_times;
}

Rcpp::NumericVector get_bootstrap_select_probs(){
return bootstrap_select_probs;
}
std::vector<std::vector<arma::uvec>> get_coef_indices() {

std::vector<std::vector<arma::uvec>> result;

for (auto& tree : trees) {
result.push_back(tree->get_coef_indices());
}

return result;

}

// Member variables

Rcpp::IntegerVector bootstrap_select_times;
Rcpp::NumericVector bootstrap_select_probs;

int n_tree;
int mtry;
arma::uword n_tree;
arma::uword mtry;

Rcpp::IntegerVector tree_seeds;

Expand All @@ -93,24 +97,24 @@ class Forest {

// node splitting
SplitRule split_rule;
double split_min_events;
double split_min_obs;
double split_min_stat;
int split_max_retry;
double split_min_events;
double split_min_obs;
double split_min_stat;
arma::uword split_max_retry;

// linear combinations
LinearCombo lincomb_type;
double lincomb_eps;
int lincomb_iter_max;
bool lincomb_scale;
double lincomb_alpha;
int lincomb_df_target;
arma::uword lincomb_iter_max;
arma::uword lincomb_df_target;

// predictions
PredType pred_type;
double pred_horizon;
bool oobag_pred;
int oobag_eval_every;
arma::uword oobag_eval_every;


};
Expand Down
45 changes: 15 additions & 30 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,6 @@ BEGIN_RCPP
return R_NilValue;
END_RCPP
}
// which_cols_valid_exported
arma::uvec which_cols_valid_exported(const arma::mat& y_inbag, const arma::mat& x_inbag, arma::uvec& rows_node, const arma::uword mtry);
RcppExport SEXP _aorsf_which_cols_valid_exported(SEXP y_inbagSEXP, SEXP x_inbagSEXP, SEXP rows_nodeSEXP, SEXP mtrySEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const arma::mat& >::type y_inbag(y_inbagSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type x_inbag(x_inbagSEXP);
Rcpp::traits::input_parameter< arma::uvec& >::type rows_node(rows_nodeSEXP);
Rcpp::traits::input_parameter< const arma::uword >::type mtry(mtrySEXP);
rcpp_result_gen = Rcpp::wrap(which_cols_valid_exported(y_inbag, x_inbag, rows_node, mtry));
return rcpp_result_gen;
END_RCPP
}
// lrt_multi_exported
List lrt_multi_exported(NumericMatrix& y_, NumericVector& w_, NumericVector& XB_, int n_split_, double split_min_stat, double leaf_min_events, double leaf_min_obs);
RcppExport SEXP _aorsf_lrt_multi_exported(SEXP y_SEXP, SEXP w_SEXP, SEXP XB_SEXP, SEXP n_split_SEXP, SEXP split_min_statSEXP, SEXP leaf_min_eventsSEXP, SEXP leaf_min_obsSEXP) {
Expand All @@ -113,39 +99,39 @@ BEGIN_RCPP
END_RCPP
}
// orsf_cpp
List orsf_cpp(arma::mat& x, arma::mat& y, arma::vec& w, Rcpp::IntegerVector& tree_seeds, Rcpp::Function f_beta, Rcpp::Function f_oobag_eval, int n_tree, int mtry, int vi_type_R, double leaf_min_events, double leaf_min_obs, int split_rule_R, double split_min_events, double split_min_obs, double split_min_stat, int split_max_retry, int lincomb_type_R, double lincomb_eps, int lincomb_iter_max, bool lincomb_scale, double lincomb_alpha, int lincomb_df_target, bool pred_mode, int pred_type_R, double pred_horizon, bool oobag_pred, int oobag_eval_every);
RcppExport SEXP _aorsf_orsf_cpp(SEXP xSEXP, SEXP ySEXP, SEXP wSEXP, SEXP tree_seedsSEXP, SEXP f_betaSEXP, SEXP f_oobag_evalSEXP, SEXP n_treeSEXP, SEXP mtrySEXP, SEXP vi_type_RSEXP, SEXP leaf_min_eventsSEXP, SEXP leaf_min_obsSEXP, SEXP split_rule_RSEXP, SEXP split_min_eventsSEXP, SEXP split_min_obsSEXP, SEXP split_min_statSEXP, SEXP split_max_retrySEXP, SEXP lincomb_type_RSEXP, SEXP lincomb_epsSEXP, SEXP lincomb_iter_maxSEXP, SEXP lincomb_scaleSEXP, SEXP lincomb_alphaSEXP, SEXP lincomb_df_targetSEXP, SEXP pred_modeSEXP, SEXP pred_type_RSEXP, SEXP pred_horizonSEXP, SEXP oobag_predSEXP, SEXP oobag_eval_everySEXP) {
List orsf_cpp(arma::mat& x, arma::mat& y, arma::uvec& w, Rcpp::IntegerVector& tree_seeds, Rcpp::Function& lincomb_R_function, Rcpp::Function f_oobag_eval, arma::uword n_tree, arma::uword mtry, arma::uword vi_type_R, double leaf_min_events, double leaf_min_obs, arma::uword split_rule_R, double split_min_events, double split_min_obs, double split_min_stat, arma::uword split_max_retry, arma::uword lincomb_type_R, double lincomb_eps, arma::uword lincomb_iter_max, bool lincomb_scale, double lincomb_alpha, arma::uword lincomb_df_target, bool pred_mode, arma::uword pred_type_R, double pred_horizon, bool oobag_pred, arma::uword oobag_eval_every);
RcppExport SEXP _aorsf_orsf_cpp(SEXP xSEXP, SEXP ySEXP, SEXP wSEXP, SEXP tree_seedsSEXP, SEXP lincomb_R_functionSEXP, SEXP f_oobag_evalSEXP, SEXP n_treeSEXP, SEXP mtrySEXP, SEXP vi_type_RSEXP, SEXP leaf_min_eventsSEXP, SEXP leaf_min_obsSEXP, SEXP split_rule_RSEXP, SEXP split_min_eventsSEXP, SEXP split_min_obsSEXP, SEXP split_min_statSEXP, SEXP split_max_retrySEXP, SEXP lincomb_type_RSEXP, SEXP lincomb_epsSEXP, SEXP lincomb_iter_maxSEXP, SEXP lincomb_scaleSEXP, SEXP lincomb_alphaSEXP, SEXP lincomb_df_targetSEXP, SEXP pred_modeSEXP, SEXP pred_type_RSEXP, SEXP pred_horizonSEXP, SEXP oobag_predSEXP, SEXP oobag_eval_everySEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< arma::mat& >::type x(xSEXP);
Rcpp::traits::input_parameter< arma::mat& >::type y(ySEXP);
Rcpp::traits::input_parameter< arma::vec& >::type w(wSEXP);
Rcpp::traits::input_parameter< arma::uvec& >::type w(wSEXP);
Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type tree_seeds(tree_seedsSEXP);
Rcpp::traits::input_parameter< Rcpp::Function >::type f_beta(f_betaSEXP);
Rcpp::traits::input_parameter< Rcpp::Function& >::type lincomb_R_function(lincomb_R_functionSEXP);
Rcpp::traits::input_parameter< Rcpp::Function >::type f_oobag_eval(f_oobag_evalSEXP);
Rcpp::traits::input_parameter< int >::type n_tree(n_treeSEXP);
Rcpp::traits::input_parameter< int >::type mtry(mtrySEXP);
Rcpp::traits::input_parameter< int >::type vi_type_R(vi_type_RSEXP);
Rcpp::traits::input_parameter< arma::uword >::type n_tree(n_treeSEXP);
Rcpp::traits::input_parameter< arma::uword >::type mtry(mtrySEXP);
Rcpp::traits::input_parameter< arma::uword >::type vi_type_R(vi_type_RSEXP);
Rcpp::traits::input_parameter< double >::type leaf_min_events(leaf_min_eventsSEXP);
Rcpp::traits::input_parameter< double >::type leaf_min_obs(leaf_min_obsSEXP);
Rcpp::traits::input_parameter< int >::type split_rule_R(split_rule_RSEXP);
Rcpp::traits::input_parameter< arma::uword >::type split_rule_R(split_rule_RSEXP);
Rcpp::traits::input_parameter< double >::type split_min_events(split_min_eventsSEXP);
Rcpp::traits::input_parameter< double >::type split_min_obs(split_min_obsSEXP);
Rcpp::traits::input_parameter< double >::type split_min_stat(split_min_statSEXP);
Rcpp::traits::input_parameter< int >::type split_max_retry(split_max_retrySEXP);
Rcpp::traits::input_parameter< int >::type lincomb_type_R(lincomb_type_RSEXP);
Rcpp::traits::input_parameter< arma::uword >::type split_max_retry(split_max_retrySEXP);
Rcpp::traits::input_parameter< arma::uword >::type lincomb_type_R(lincomb_type_RSEXP);
Rcpp::traits::input_parameter< double >::type lincomb_eps(lincomb_epsSEXP);
Rcpp::traits::input_parameter< int >::type lincomb_iter_max(lincomb_iter_maxSEXP);
Rcpp::traits::input_parameter< arma::uword >::type lincomb_iter_max(lincomb_iter_maxSEXP);
Rcpp::traits::input_parameter< bool >::type lincomb_scale(lincomb_scaleSEXP);
Rcpp::traits::input_parameter< double >::type lincomb_alpha(lincomb_alphaSEXP);
Rcpp::traits::input_parameter< int >::type lincomb_df_target(lincomb_df_targetSEXP);
Rcpp::traits::input_parameter< arma::uword >::type lincomb_df_target(lincomb_df_targetSEXP);
Rcpp::traits::input_parameter< bool >::type pred_mode(pred_modeSEXP);
Rcpp::traits::input_parameter< int >::type pred_type_R(pred_type_RSEXP);
Rcpp::traits::input_parameter< arma::uword >::type pred_type_R(pred_type_RSEXP);
Rcpp::traits::input_parameter< double >::type pred_horizon(pred_horizonSEXP);
Rcpp::traits::input_parameter< bool >::type oobag_pred(oobag_predSEXP);
Rcpp::traits::input_parameter< int >::type oobag_eval_every(oobag_eval_everySEXP);
rcpp_result_gen = Rcpp::wrap(orsf_cpp(x, y, w, tree_seeds, f_beta, f_oobag_eval, n_tree, mtry, vi_type_R, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, pred_mode, pred_type_R, pred_horizon, oobag_pred, oobag_eval_every));
Rcpp::traits::input_parameter< arma::uword >::type oobag_eval_every(oobag_eval_everySEXP);
rcpp_result_gen = Rcpp::wrap(orsf_cpp(x, y, w, tree_seeds, lincomb_R_function, f_oobag_eval, n_tree, mtry, vi_type_R, leaf_min_events, leaf_min_obs, split_rule_R, split_min_events, split_min_obs, split_min_stat, split_max_retry, lincomb_type_R, lincomb_eps, lincomb_iter_max, lincomb_scale, lincomb_alpha, lincomb_df_target, pred_mode, pred_type_R, pred_horizon, oobag_pred, oobag_eval_every));
return rcpp_result_gen;
END_RCPP
}
Expand All @@ -156,7 +142,6 @@ static const R_CallMethodDef CallEntries[] = {
{"_aorsf_node_find_cps_exported", (DL_FUNC) &_aorsf_node_find_cps_exported, 5},
{"_aorsf_node_compute_lrt_exported", (DL_FUNC) &_aorsf_node_compute_lrt_exported, 3},
{"_aorsf_node_fill_group_exported", (DL_FUNC) &_aorsf_node_fill_group_exported, 5},
{"_aorsf_which_cols_valid_exported", (DL_FUNC) &_aorsf_which_cols_valid_exported, 4},
{"_aorsf_lrt_multi_exported", (DL_FUNC) &_aorsf_lrt_multi_exported, 7},
{"_aorsf_orsf_cpp", (DL_FUNC) &_aorsf_orsf_cpp, 27},
{NULL, NULL, 0}
Expand Down
Loading

0 comments on commit c1820fe

Please sign in to comment.