Skip to content

Commit

Permalink
opening some tests up to see if they still cause hiccups
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Oct 8, 2023
1 parent 1f5a78f commit de2b9f0
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 80 deletions.
8 changes: 4 additions & 4 deletions R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ orsf <- function(data,
if(sample_fraction == 1 && oobag_pred){
stop(
"cannot compute out-of-bag predictions if no samples are out-of-bag.",
"To resolve this, set sample_fraction < 1 or oobag_pred_type = 'none'.",
" Try setting sample_fraction < 1 or oobag_pred_type = 'none'.",
call. = FALSE
)
}
Expand Down Expand Up @@ -473,9 +473,9 @@ orsf <- function(data,
type_oobag_eval <- 'user'

if(oobag_pred_type == 'leaf'){
stop("a user-supplied oobag function cannot be",
"applied when oobag_pred_type = 'leaf'",
call. = FALSE)
warning("a user-supplied oobag function cannot be",
"applied when oobag_pred_type = 'leaf'",
call. = FALSE)
}

}
Expand Down
1 change: 1 addition & 0 deletions R/orsf_pd.R
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ orsf_pred_dependence <- function(object,
expand_grid = expand_grid,
prob_values = prob_values,
prob_labels = prob_labels,
oobag = oobag,
boundary_checks = boundary_checks,
new_data = pd_data,
pred_type = pred_type,
Expand Down
10 changes: 10 additions & 0 deletions tests/testthat/helper-orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ oobag_fun_bad_name_2 <- function(y_mat, w_vec, nope){

}

oobag_fun_bad_name_3 <- function(y_mat, nope, s_vec){

# risk = 1 - survival
r_vec <- 1 - s_vec

# mean of the squared differences between predicted and observed risk
mean( (y_mat[, 2L] - r_vec)^2 )

}

oobag_fun_bad_out <- function(y_mat, w_vec, s_vec){

# risk = 1 - survival
Expand Down
75 changes: 37 additions & 38 deletions tests/testthat/test-orsf.R
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@


f <- time + status ~ . - id
f <- time + status ~ .

test_that(
desc = 'non-formula inputs are vetted',
code = {

expect_error(orsf(pbc_orsf, f, n_tree = 0), "should be >= 1")
expect_error(orsf(pbc_orsf, f, n_split = "3"), "should have type")
expect_error(orsf(pbc_orsf, f, mtry = 5000), 'should be <=')
expect_error(orsf(pbc_orsf, f, leaf_min_events = 5000), 'should be <=')
expect_error(orsf(pbc_orsf, f, leaf_min_obs = 5000), 'should be <=')
expect_error(orsf(pbc_orsf, f, attachData = TRUE), 'attach_data?')
expect_error(orsf(pbc_orsf, f, Control = 0), 'control?')
expect_error(orsf(pbc, f, n_tree = 0), "should be >= 1")
expect_error(orsf(pbc, f, n_split = "3"), "should have type")
expect_error(orsf(pbc, f, mtry = 5000), 'should be <=')
expect_error(orsf(pbc, f, leaf_min_events = 5000), 'should be <=')
expect_error(orsf(pbc, f, leaf_min_obs = 5000), 'should be <=')
expect_error(orsf(pbc, f, attachData = TRUE), 'attach_data?')
expect_error(orsf(pbc, f, Control = 0), 'control?')
expect_error(orsf(pbc, f, sample_fraction = 1, oobag_pred_type = 'risk'),
'no samples are out-of-bag')
expect_error(orsf(pbc, f, split_rule = 'cstat', split_min_stat = 1),
'must be < 1')

pbc_orsf$date_var <- Sys.Date()
expect_error(orsf(pbc_orsf, f), 'unsupported type')
Expand Down Expand Up @@ -403,45 +407,40 @@ test_that(
)


if(Sys.getenv("run_all_aorsf_tests") == 'yes'){

test_that(
desc = 'orsf_time_to_train is reasonable at approximating time to train',
code = {
test_that(
desc = 'orsf_time_to_train is reasonable at approximating time to train',
code = {

# testing the seed behavior when no_fit is TRUE. You should get the same
# forest whether you train with orsf() or with orsf_train().
# testing the seed behavior when no_fit is TRUE. You should get the same
# forest whether you train with orsf() or with orsf_train().

for(.n_tree in c(100, 250, 1000)){
for(.n_tree in c(100, 250, 1000)){

object <- orsf(pbc_orsf, Surv(time, status) ~ . - id,
n_tree = .n_tree, no_fit = TRUE,
importance = 'anova')
set.seed(89)
time_estimated <- orsf_time_to_train(object, n_tree_subset = 50)
object <- orsf(pbc_orsf, Surv(time, status) ~ . - id,
n_tree = .n_tree, no_fit = TRUE,
importance = 'anova')
set.seed(89)
time_estimated <- orsf_time_to_train(object, n_tree_subset = 50)

set.seed(89)
time_true_start <- Sys.time()
fit_orsf_3 <- orsf_train(object)
time_true_stop <- Sys.time()
set.seed(89)
time_true_start <- Sys.time()
fit_orsf_3 <- orsf_train(object)
time_true_stop <- Sys.time()

time_true <- time_true_stop - time_true_start
time_true <- time_true_stop - time_true_start

diff_abs <- abs(as.numeric(time_true - time_estimated))
diff_rel <- diff_abs / as.numeric(time_true)
diff_abs <- abs(as.numeric(time_true - time_estimated))
diff_rel <- diff_abs / as.numeric(time_true)

# expect the difference between estimated and true time is < 5 second.
expect_lt(diff_abs, 5)
# expect that the difference is not greater than 5x the
# magnitude of the actual time it took to fit the forest
expect_lt(diff_rel, 5)
# expect the difference between estimated and true time is < 5 second.
expect_lt(diff_abs, 5)
# expect that the difference is not greater than 5x the
# magnitude of the actual time it took to fit the forest
expect_lt(diff_rel, 5)

}
}
)

}

}
)

test_that(
desc = 'orsf_train does not accept bad inputs',
Expand Down
65 changes: 27 additions & 38 deletions tests/testthat/test-orsf_control.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,61 +15,50 @@ test_that("inputs are vetted", {
f_bad_1 <- function(a_node, y_node, w_node){ 1 }
f_bad_2 <- function(x_node, a_node, w_node){ 1 }
f_bad_3 <- function(x_node, y_node, a_node){ 1 }
f_bad_4 <- function(x_node, y_node){ 1 }

expect_error(orsf_control_custom(f_bad_1), 'x_node')
expect_error(orsf_control_custom(f_bad_2), 'y_node')
expect_error(orsf_control_custom(f_bad_3), 'w_node')

f_bad_4 <- function(x_node, y_node, w_node) {runif(n = ncol(x_node))}

expect_error(orsf_control_custom(f_bad_4), 'matrix output')

# seems like this one can throw off github actions?
if (Sys.getenv("run_all_aorsf_tests") == 'yes') {

f_bad_5 <- function(x_node, y_node, w_node){
stop("IDK WHAT TO DO", call. = FALSE)
}

expect_error(orsf_control_custom(f_bad_5), "encountered an error")
f_bad_5 <- function(x_node, y_node, w_node) {
stop("an expected error occurred")
}

f_bad_6 <- function(x_node, y_node, w_node){
return(matrix(0, ncol = 2, nrow = ncol(x_node)))
}

f_bad_7 <- function(x_node, y_node, w_node){
return(matrix(0, ncol = 1, nrow = 2))
}

f_bad_8 <- function(x_node, y_node, w_node) {runif(n = ncol(x_node))}

expect_error(orsf_control_custom(f_bad_1), 'x_node')
expect_error(orsf_control_custom(f_bad_2), 'y_node')
expect_error(orsf_control_custom(f_bad_3), 'w_node')
expect_error(orsf_control_custom(f_bad_4), 'should have 3')
expect_error(orsf_control_custom(f_bad_5), 'encountered an error')
expect_error(orsf_control_custom(f_bad_6), 'with 1 column')
expect_error(orsf_control_custom(f_bad_7), 'with 1 row for each')
expect_error(orsf_control_custom(f_bad_8), 'matrix output')

f <- function(x_node, y_node, w_node) { matrix(runif(ncol(x_node)), ncol=1) }
f_rando <- function(x_node, y_node, w_node) { matrix(runif(ncol(x_node)), ncol=1) }

expect_s3_class(orsf_control_custom(f), 'orsf_control')
expect_s3_class(orsf_control_custom(f_rando), 'orsf_control')


})


test_that(
desc = 'outputs meet expectations on prediction accuracy',
desc = 'custom orsf_control predictions are good',
code = {

fit_pca = orsf(pbc_orsf,
Surv(time, status) ~ .,
tree_seeds = seeds_standard,
control = orsf_control_custom(beta_fun = f_pca),
n_tree = n_tree_test)

f <- function(x_node, y_node, w_node) { matrix(runif(ncol(x_node)), ncol=1) }

fit_cph = orsf(pbc_orsf,
Surv(time, status) ~ .,
tree_seeds = seq(500),
control = orsf_control_cph(),
n_tree = 500)


fit_rando = orsf(pbc_orsf,
Surv(time, status) ~ .,
tree_seeds = seq(500),
control = orsf_control_custom(beta_fun = f),
n_tree = 500)

expect_lt(fit_rando$eval_oobag$stat_values,
fit_cph$eval_oobag$stat_values)

expect_gt(fit_rando$eval_oobag$stat_values, .6)
expect_gt(fit_pca$eval_oobag$stat_values, .65)

}
)
Expand Down
5 changes: 5 additions & 0 deletions tests/testthat/test-orsf_vi.R
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ test_that(
regexp = 's_vec'
)

expect_error(
orsf_vi_negate(fit_no_vi, oobag_fun = oobag_fun_bad_name_3),
regexp = 'w_vec'
)

expect_error(
orsf_vi_negate(fit_no_vi, oobag_fun = oobag_fun_bad_out),
regexp = 'length 1'
Expand Down

0 comments on commit de2b9f0

Please sign in to comment.