Skip to content

Commit

Permalink
more changes for tidymodels/parsnip#1203
Browse files Browse the repository at this point in the history
  • Loading branch information
‘topepo’ committed Sep 30, 2024
1 parent 2dbac85 commit 30a7936
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 100 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ Imports:
rlang (>= 1.0.0),
stats,
tibble (>= 3.1.3),
tidyr (>= 1.0.0)
tidyr (>= 1.0.0),
vctrs
Suggests:
aorsf (>= 0.1.2),
coin,
Expand Down
4 changes: 2 additions & 2 deletions R/censored-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ utils::globalVariables(
".id", ".tmp", "engine", "predictor_indicators", ".strata", "group",
".pred_quantile", ".quantile", "interval", "level", ".pred_linear_pred",
".pred_link", ".pred_time", ".pred_survival", "next_event_time",
"sum_component", "time_interval"
"sum_component", "time_interval", "quantile_levels"
)
)

# quiet R-CMD-check NOTEs that prodlim is unused
# (parsnip uses it for all censored regression models
# (parsnip uses it for all censored regression models
# but only has it in Suggests)
#' @importFrom prodlim prodlim
NULL
12 changes: 6 additions & 6 deletions R/survival_reg-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ make_survival_reg_survival <- function() {
type = "quantile",
value = list(
pre = NULL,
post = survreg_quant,
post = parsnip::matrix_to_quantile_pred,
func = c(fun = "predict"),
args =
list(
object = expr(object$fit),
newdata = expr(new_data),
type = "quantile",
p = expr(quantile)
p = expr(quantile_levels)
)
)
)
Expand Down Expand Up @@ -236,14 +236,14 @@ make_survival_reg_flexsurv <- function() {
type = "quantile",
value = list(
pre = NULL,
post = NULL,
post = flexsurv_to_quantile_pred,
func = c(fun = "predict"),
args =
list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data),
type = "quantile",
p = rlang::expr(quantile),
p = rlang::expr(quantile_levels),
conf.int = rlang::expr(interval == "confidence"),
conf.level = rlang::expr(level)
)
Expand Down Expand Up @@ -393,14 +393,14 @@ make_survival_reg_flexsurvspline <- function() {
type = "quantile",
value = list(
pre = NULL,
post = NULL,
post = flexsurv_to_quantile_pred,
func = c(fun = "predict"),
args =
list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data),
type = "quantile",
p = rlang::expr(quantile),
p = rlang::expr(quantile_levels),
conf.int = rlang::expr(interval == "confidence"),
conf.level = rlang::expr(level)
)
Expand Down
55 changes: 21 additions & 34 deletions tests/testthat/test-survival_reg-flexsurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ library(testthat)

test_that("model object", {
skip_if_not_installed("flexsurv")

set.seed(1234)
exp_f_fit <- flexsurv::flexsurvreg(
Surv(time, status) ~ age + ph.ecog,
Expand Down Expand Up @@ -149,7 +149,7 @@ test_that("survival probabilities for single eval time point", {

test_that("can predict for out-of-domain timepoints", {
skip_if_not_installed("flexsurv")

eval_time_obs_max_and_ood <- c(1022, 2000)
obs_without_NA <- lung[2,]

Expand Down Expand Up @@ -236,41 +236,28 @@ test_that("quantile predictions", {
)

expect_s3_class(pred, "tbl_df")
expect_equal(names(pred), ".pred")
expect_equal(names(pred), ".pred_quantile")
expect_equal(nrow(pred), 3)
expect_true(
all(purrr::map_lgl(
pred$.pred,
~ all(dim(.x) == c(9, 2))
))
)
expect_true(
all(purrr::map_lgl(
pred$.pred,
~ all(names(.x) == c(".quantile", ".pred_quantile"))
))
)
expect_equal(
tidyr::unnest(pred, cols = .pred)$.pred_quantile,
do.call(rbind, exp_pred)$est
)
expect_s3_class(pred$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))

for (.row in 1:nrow(pred)) {
expect_equal(
unclass(pred$.pred_quantile[.row])[[1]],
exp_pred[[.row]]$est
)
}

# add confidence interval
pred <- predict(fit_s,
pred_ci <- predict(fit_s,
new_data = bladder[1:3, ], type = "quantile",
interval = "confidence", level = 0.7
)
expect_true(
all(purrr::map_lgl(
pred$.pred,
~ all(names(.x) == c(
".quantile",
".pred_quantile",
".pred_lower",
".pred_upper"
))
))
)
expect_s3_class(pred_ci, "tbl_df")
expect_equal(names(pred_ci), c(".pred_quantile", ".pred_lower", ".pred_upper"))
expect_equal(nrow(pred_ci), 3)
expect_s3_class(pred_ci$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))
expect_s3_class(pred_ci$.pred_lower, c("quantile_pred", "vctrs_vctr", "list"))
expect_s3_class(pred_ci$.pred_upper, c("quantile_pred", "vctrs_vctr", "list"))

# single observation
f_pred_1 <- predict(fit_s, bladder[2,], type = "quantile")
Expand Down Expand Up @@ -354,7 +341,7 @@ test_that("hazard for single eval time point", {

test_that("`fix_xy()` works", {
skip_if_not_installed("flexsurv")

lung_x <- as.matrix(lung[, c("age", "ph.ecog")])
lung_y <- Surv(lung$time, lung$status)
lung_pred <- lung[1:5, ]
Expand Down Expand Up @@ -401,13 +388,13 @@ test_that("`fix_xy()` works", {
f_fit,
new_data = lung_pred,
type = "quantile",
quantile = c(0.2, 0.8)
quantile_levels = c(0.2, 0.8)
)
xy_pred_quantile <- predict(
xy_fit,
new_data = lung_pred,
type = "quantile",
quantile = c(0.2, 0.8)
quantile_levels = c(0.2, 0.8)
)
expect_equal(f_pred_quantile, xy_pred_quantile)

Expand Down
63 changes: 15 additions & 48 deletions tests/testthat/test-survival_reg-flexsurvspline.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ test_that("survival probability prediction", {
head(lung),
type = "survival",
times = c(0, 500, 1000)
)
)
if (packageVersion("flexsurv") < "2.3") {
exp_pred <- exp_pred %>%
dplyr::rowwise() %>%
Expand Down Expand Up @@ -211,59 +211,26 @@ test_that("quantile predictions", {
set_mode("censored regression") %>%
fit(Surv(stop, event) ~ rx + size + enum, data = bladder)
pred <- predict(fit_s, new_data = bladder[1:3, ], type = "quantile")

set.seed(1)
exp_fit <- flexsurv::flexsurvspline(
Surv(stop, event) ~ rx + size + enum,
data = bladder,
k = 1
)
exp_pred <- summary(
exp_fit,
newdata = bladder[1:3, ],
type = "quantile",
quantiles = (1:9) / 10
)

expect_s3_class(pred, "tbl_df")
expect_equal(names(pred), ".pred")
expect_equal(names(pred), ".pred_quantile")
expect_equal(nrow(pred), 3)
expect_true(
all(purrr::map_lgl(
pred$.pred,
~ all(dim(.x) == c(9, 2))
))
)
expect_true(
all(purrr::map_lgl(
pred$.pred,
~ all(names(.x) == c(".quantile", ".pred_quantile"))
))
)
expect_equal(
tidyr::unnest(pred, cols = .pred)$.pred_quantile,
do.call(rbind, exp_pred)$est
)
expect_s3_class(pred$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))


# add confidence interval
pred <- predict(
pred_ci <- predict(
fit_s,
new_data = bladder[1:3, ],
type = "quantile",
interval = "confidence",
level = 0.7
)
expect_true(
all(purrr::map_lgl(
pred$.pred,
~ all(names(.x) == c(
".quantile",
".pred_quantile",
".pred_lower",
".pred_upper"
))
))
)
expect_s3_class(pred_ci, "tbl_df")
expect_equal(names(pred_ci), c(".pred_quantile", ".pred_lower", ".pred_upper"))
expect_equal(nrow(pred_ci), 3)
expect_s3_class(pred_ci$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))
expect_s3_class(pred_ci$.pred_lower, c("quantile_pred", "vctrs_vctr", "list"))
expect_s3_class(pred_ci$.pred_upper, c("quantile_pred", "vctrs_vctr", "list"))

# single observation
f_pred_1 <- predict(fit_s, bladder[2,], type = "quantile")
Expand All @@ -284,7 +251,7 @@ test_that("hazard prediction", {
head(lung),
type = "hazard",
times = c(0, 500, 1000)
)
)
if (packageVersion("flexsurv") < "2.3") {
exp_pred <- exp_pred %>%
dplyr::rowwise() %>%
Expand Down Expand Up @@ -409,13 +376,13 @@ test_that("`fix_xy()` works", {
f_fit,
new_data = lung_pred,
type = "quantile",
quantile = c(0.2, 0.8)
quantile_levels = c(0.2, 0.8)
)
xy_pred_quantile <- predict(
xy_fit,
new_data = lung_pred,
type = "quantile",
quantile = c(0.2, 0.8)
quantile_levels = c(0.2, 0.8)
)
expect_equal(f_pred_quantile, xy_pred_quantile)

Expand All @@ -438,7 +405,7 @@ test_that("`fix_xy()` works", {

test_that("can handle case weights", {
skip_if_not_installed("flexsurv")

# flexsurv engine can only take weights > 0
set.seed(1)
wts <- runif(nrow(lung))
Expand Down
24 changes: 15 additions & 9 deletions tests/testthat/test-survival_reg-survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,22 @@ test_that("prediction of survival time quantile", {
fit(Surv(time, status) ~ age + sex, data = lung)

exp_quant <- predict(res$fit, head(lung), p = (2:4) / 5, type = "quantile")
exp_quant <- apply(exp_quant, 1, function(x) {
tibble::tibble(.quantile = (2:4) / 5, .pred_quantile = x)
})
exp_quant <- tibble::tibble(.pred = exp_quant)
obs_quant <- predict(res, head(lung), type = "quantile", quantile = (2:4) / 5)
obs_quant <- predict(res, head(lung), type = "quantile", quantile_levels = (2:4) / 5)

expect_equal(as.data.frame(exp_quant), as.data.frame(obs_quant))
expect_s3_class(obs_quant, "tbl_df")
expect_equal(names(obs_quant), ".pred_quantile")
expect_equal(nrow(obs_quant), 6)
expect_s3_class(obs_quant$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))

for (.row in 1:nrow(obs_quant)) {
expect_equal(
unclass(obs_quant$.pred_quantile[.row])[[1]],
exp_quant[.row,]
)
}

# single observation
f_pred_1 <- predict(res, lung[1, ], type = "quantile")
f_pred_1 <- predict(res, lung[1, ], type = "quantile", quantile_levels = .5)
expect_identical(nrow(f_pred_1), 1L)
})

Expand Down Expand Up @@ -213,13 +219,13 @@ test_that("`fix_xy()` works", {
f_fit,
new_data = lung_pred,
type = "quantile",
quantile = c(0.2, 0.8)
quantile_levels = c(0.2, 0.8)
)
xy_pred_quantile <- predict(
xy_fit,
new_data = lung_pred,
type = "quantile",
quantile = c(0.2, 0.8)
quantile_levels = c(0.2, 0.8)
)
expect_equal(f_pred_quantile, xy_pred_quantile)

Expand Down

0 comments on commit 30a7936

Please sign in to comment.