diff --git a/R/orsf.R b/R/orsf.R index c3607c23..1235b147 100644 --- a/R/orsf.R +++ b/R/orsf.R @@ -410,7 +410,7 @@ orsf <- function(data, if(sample_fraction == 1 && oobag_pred){ stop( "cannot compute out-of-bag predictions if no samples are out-of-bag.", - "To resolve this, set sample_fraction < 1 or oobag_pred_type = 'none'.", + " Try setting sample_fraction < 1 or oobag_pred_type = 'none'.", call. = FALSE ) } @@ -473,9 +473,9 @@ orsf <- function(data, type_oobag_eval <- 'user' if(oobag_pred_type == 'leaf'){ - stop("a user-supplied oobag function cannot be", - "applied when oobag_pred_type = 'leaf'", - call. = FALSE) + warning("a user-supplied oobag function cannot be", + "applied when oobag_pred_type = 'leaf'", + call. = FALSE) } } diff --git a/R/orsf_pd.R b/R/orsf_pd.R index b7c38fbe..0a539ba3 100644 --- a/R/orsf_pd.R +++ b/R/orsf_pd.R @@ -322,6 +322,7 @@ orsf_pred_dependence <- function(object, expand_grid = expand_grid, prob_values = prob_values, prob_labels = prob_labels, + oobag = oobag, boundary_checks = boundary_checks, new_data = pd_data, pred_type = pred_type, diff --git a/tests/testthat/helper-orsf.R b/tests/testthat/helper-orsf.R index a3d4f82a..0877d3c2 100644 --- a/tests/testthat/helper-orsf.R +++ b/tests/testthat/helper-orsf.R @@ -121,6 +121,16 @@ oobag_fun_bad_name_2 <- function(y_mat, w_vec, nope){ } +oobag_fun_bad_name_3 <- function(y_mat, nope, s_vec){ + + # risk = 1 - survival + r_vec <- 1 - s_vec + + # mean of the squared differences between predicted and observed risk + mean( (y_mat[, 2L] - r_vec)^2 ) + +} + oobag_fun_bad_out <- function(y_mat, w_vec, s_vec){ # risk = 1 - survival diff --git a/tests/testthat/test-orsf.R b/tests/testthat/test-orsf.R index 4746ce78..89648731 100644 --- a/tests/testthat/test-orsf.R +++ b/tests/testthat/test-orsf.R @@ -1,18 +1,22 @@ -f <- time + status ~ . - id +f <- time + status ~ . test_that( desc = 'non-formula inputs are vetted', code = { - expect_error(orsf(pbc_orsf, f, n_tree = 0), "should be >= 1") - expect_error(orsf(pbc_orsf, f, n_split = "3"), "should have type") - expect_error(orsf(pbc_orsf, f, mtry = 5000), 'should be <=') - expect_error(orsf(pbc_orsf, f, leaf_min_events = 5000), 'should be <=') - expect_error(orsf(pbc_orsf, f, leaf_min_obs = 5000), 'should be <=') - expect_error(orsf(pbc_orsf, f, attachData = TRUE), 'attach_data?') - expect_error(orsf(pbc_orsf, f, Control = 0), 'control?') + expect_error(orsf(pbc, f, n_tree = 0), "should be >= 1") + expect_error(orsf(pbc, f, n_split = "3"), "should have type") + expect_error(orsf(pbc, f, mtry = 5000), 'should be <=') + expect_error(orsf(pbc, f, leaf_min_events = 5000), 'should be <=') + expect_error(orsf(pbc, f, leaf_min_obs = 5000), 'should be <=') + expect_error(orsf(pbc, f, attachData = TRUE), 'attach_data?') + expect_error(orsf(pbc, f, Control = 0), 'control?') + expect_error(orsf(pbc, f, sample_fraction = 1, oobag_pred_type = 'risk'), + 'no samples are out-of-bag') + expect_error(orsf(pbc, f, split_rule = 'cstat', split_min_stat = 1), + 'must be < 1') pbc_orsf$date_var <- Sys.Date() expect_error(orsf(pbc_orsf, f), 'unsupported type') @@ -403,45 +407,40 @@ test_that( ) -if(Sys.getenv("run_all_aorsf_tests") == 'yes'){ - - test_that( - desc = 'orsf_time_to_train is reasonable at approximating time to train', - code = { +test_that( + desc = 'orsf_time_to_train is reasonable at approximating time to train', + code = { - # testing the seed behavior when no_fit is TRUE. You should get the same - # forest whether you train with orsf() or with orsf_train(). + # testing the seed behavior when no_fit is TRUE. You should get the same + # forest whether you train with orsf() or with orsf_train(). - for(.n_tree in c(100, 250, 1000)){ + for(.n_tree in c(100, 250, 1000)){ - object <- orsf(pbc_orsf, Surv(time, status) ~ . - id, - n_tree = .n_tree, no_fit = TRUE, - importance = 'anova') - set.seed(89) - time_estimated <- orsf_time_to_train(object, n_tree_subset = 50) + object <- orsf(pbc_orsf, Surv(time, status) ~ . - id, + n_tree = .n_tree, no_fit = TRUE, + importance = 'anova') + set.seed(89) + time_estimated <- orsf_time_to_train(object, n_tree_subset = 50) - set.seed(89) - time_true_start <- Sys.time() - fit_orsf_3 <- orsf_train(object) - time_true_stop <- Sys.time() + set.seed(89) + time_true_start <- Sys.time() + fit_orsf_3 <- orsf_train(object) + time_true_stop <- Sys.time() - time_true <- time_true_stop - time_true_start + time_true <- time_true_stop - time_true_start - diff_abs <- abs(as.numeric(time_true - time_estimated)) - diff_rel <- diff_abs / as.numeric(time_true) + diff_abs <- abs(as.numeric(time_true - time_estimated)) + diff_rel <- diff_abs / as.numeric(time_true) - # expect the difference between estimated and true time is < 5 second. - expect_lt(diff_abs, 5) - # expect that the difference is not greater than 5x the - # magnitude of the actual time it took to fit the forest - expect_lt(diff_rel, 5) + # expect the difference between estimated and true time is < 5 second. + expect_lt(diff_abs, 5) + # expect that the difference is not greater than 5x the + # magnitude of the actual time it took to fit the forest + expect_lt(diff_rel, 5) - } } - ) - -} - + } +) test_that( desc = 'orsf_train does not accept bad inputs', diff --git a/tests/testthat/test-orsf_control.R b/tests/testthat/test-orsf_control.R index 47e4adcd..8d7c8c88 100644 --- a/tests/testthat/test-orsf_control.R +++ b/tests/testthat/test-orsf_control.R @@ -15,61 +15,50 @@ test_that("inputs are vetted", { f_bad_1 <- function(a_node, y_node, w_node){ 1 } f_bad_2 <- function(x_node, a_node, w_node){ 1 } f_bad_3 <- function(x_node, y_node, a_node){ 1 } + f_bad_4 <- function(x_node, y_node){ 1 } - expect_error(orsf_control_custom(f_bad_1), 'x_node') - expect_error(orsf_control_custom(f_bad_2), 'y_node') - expect_error(orsf_control_custom(f_bad_3), 'w_node') - - f_bad_4 <- function(x_node, y_node, w_node) {runif(n = ncol(x_node))} - - expect_error(orsf_control_custom(f_bad_4), 'matrix output') - - # seems like this one can throw off github actions? - if (Sys.getenv("run_all_aorsf_tests") == 'yes') { - - f_bad_5 <- function(x_node, y_node, w_node){ - stop("IDK WHAT TO DO", call. = FALSE) - } - - expect_error(orsf_control_custom(f_bad_5), "encountered an error") + f_bad_5 <- function(x_node, y_node, w_node) { + stop("an expected error occurred") + } + f_bad_6 <- function(x_node, y_node, w_node){ + return(matrix(0, ncol = 2, nrow = ncol(x_node))) } + f_bad_7 <- function(x_node, y_node, w_node){ + return(matrix(0, ncol = 1, nrow = 2)) + } + f_bad_8 <- function(x_node, y_node, w_node) {runif(n = ncol(x_node))} + expect_error(orsf_control_custom(f_bad_1), 'x_node') + expect_error(orsf_control_custom(f_bad_2), 'y_node') + expect_error(orsf_control_custom(f_bad_3), 'w_node') + expect_error(orsf_control_custom(f_bad_4), 'should have 3') + expect_error(orsf_control_custom(f_bad_5), 'encountered an error') + expect_error(orsf_control_custom(f_bad_6), 'with 1 column') + expect_error(orsf_control_custom(f_bad_7), 'with 1 row for each') + expect_error(orsf_control_custom(f_bad_8), 'matrix output') - f <- function(x_node, y_node, w_node) { matrix(runif(ncol(x_node)), ncol=1) } + f_rando <- function(x_node, y_node, w_node) { matrix(runif(ncol(x_node)), ncol=1) } - expect_s3_class(orsf_control_custom(f), 'orsf_control') + expect_s3_class(orsf_control_custom(f_rando), 'orsf_control') }) test_that( - desc = 'outputs meet expectations on prediction accuracy', + desc = 'custom orsf_control predictions are good', code = { + fit_pca = orsf(pbc_orsf, + Surv(time, status) ~ ., + tree_seeds = seeds_standard, + control = orsf_control_custom(beta_fun = f_pca), + n_tree = n_tree_test) - f <- function(x_node, y_node, w_node) { matrix(runif(ncol(x_node)), ncol=1) } - - fit_cph = orsf(pbc_orsf, - Surv(time, status) ~ ., - tree_seeds = seq(500), - control = orsf_control_cph(), - n_tree = 500) - - - fit_rando = orsf(pbc_orsf, - Surv(time, status) ~ ., - tree_seeds = seq(500), - control = orsf_control_custom(beta_fun = f), - n_tree = 500) - - expect_lt(fit_rando$eval_oobag$stat_values, - fit_cph$eval_oobag$stat_values) - - expect_gt(fit_rando$eval_oobag$stat_values, .6) + expect_gt(fit_pca$eval_oobag$stat_values, .65) } ) diff --git a/tests/testthat/test-orsf_vi.R b/tests/testthat/test-orsf_vi.R index bf3929cb..6f259da6 100644 --- a/tests/testthat/test-orsf_vi.R +++ b/tests/testthat/test-orsf_vi.R @@ -188,6 +188,11 @@ test_that( regexp = 's_vec' ) + expect_error( + orsf_vi_negate(fit_no_vi, oobag_fun = oobag_fun_bad_name_3), + regexp = 'w_vec' + ) + expect_error( orsf_vi_negate(fit_no_vi, oobag_fun = oobag_fun_bad_out), regexp = 'length 1'