From 30a793677c048410f1c7de8854fd52d65d5ff51f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Mon, 30 Sep 2024 09:47:24 -0400 Subject: [PATCH] more changes for tidymodels/parsnip#1203 --- DESCRIPTION | 3 +- R/censored-package.R | 4 +- R/survival_reg-data.R | 12 ++-- tests/testthat/test-survival_reg-flexsurv.R | 55 +++++++--------- .../test-survival_reg-flexsurvspline.R | 63 +++++-------------- tests/testthat/test-survival_reg-survival.R | 24 ++++--- 6 files changed, 61 insertions(+), 100 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index dce6cfa..8ca55a3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -35,7 +35,8 @@ Imports: rlang (>= 1.0.0), stats, tibble (>= 3.1.3), - tidyr (>= 1.0.0) + tidyr (>= 1.0.0), + vctrs Suggests: aorsf (>= 0.1.2), coin, diff --git a/R/censored-package.R b/R/censored-package.R index d1b2383..663cc12 100644 --- a/R/censored-package.R +++ b/R/censored-package.R @@ -62,12 +62,12 @@ utils::globalVariables( ".id", ".tmp", "engine", "predictor_indicators", ".strata", "group", ".pred_quantile", ".quantile", "interval", "level", ".pred_linear_pred", ".pred_link", ".pred_time", ".pred_survival", "next_event_time", - "sum_component", "time_interval" + "sum_component", "time_interval", "quantile_levels" ) ) # quiet R-CMD-check NOTEs that prodlim is unused -# (parsnip uses it for all censored regression models +# (parsnip uses it for all censored regression models # but only has it in Suggests) #' @importFrom prodlim prodlim NULL diff --git a/R/survival_reg-data.R b/R/survival_reg-data.R index fa43298..2b5bc9f 100644 --- a/R/survival_reg-data.R +++ b/R/survival_reg-data.R @@ -86,14 +86,14 @@ make_survival_reg_survival <- function() { type = "quantile", value = list( pre = NULL, - post = survreg_quant, + post = parsnip::matrix_to_quantile_pred, func = c(fun = "predict"), args = list( object = expr(object$fit), newdata = expr(new_data), type = "quantile", - p = expr(quantile) + p = expr(quantile_levels) ) ) ) @@ -236,14 +236,14 @@ make_survival_reg_flexsurv <- function() { type = "quantile", value = list( pre = NULL, - post = NULL, + post = flexsurv_to_quantile_pred, func = c(fun = "predict"), args = list( object = rlang::expr(object$fit), newdata = rlang::expr(new_data), type = "quantile", - p = rlang::expr(quantile), + p = rlang::expr(quantile_levels), conf.int = rlang::expr(interval == "confidence"), conf.level = rlang::expr(level) ) @@ -393,14 +393,14 @@ make_survival_reg_flexsurvspline <- function() { type = "quantile", value = list( pre = NULL, - post = NULL, + post = flexsurv_to_quantile_pred, func = c(fun = "predict"), args = list( object = rlang::expr(object$fit), newdata = rlang::expr(new_data), type = "quantile", - p = rlang::expr(quantile), + p = rlang::expr(quantile_levels), conf.int = rlang::expr(interval == "confidence"), conf.level = rlang::expr(level) ) diff --git a/tests/testthat/test-survival_reg-flexsurv.R b/tests/testthat/test-survival_reg-flexsurv.R index cce2b41..701f065 100644 --- a/tests/testthat/test-survival_reg-flexsurv.R +++ b/tests/testthat/test-survival_reg-flexsurv.R @@ -2,7 +2,7 @@ library(testthat) test_that("model object", { skip_if_not_installed("flexsurv") - + set.seed(1234) exp_f_fit <- flexsurv::flexsurvreg( Surv(time, status) ~ age + ph.ecog, @@ -149,7 +149,7 @@ test_that("survival probabilities for single eval time point", { test_that("can predict for out-of-domain timepoints", { skip_if_not_installed("flexsurv") - + eval_time_obs_max_and_ood <- c(1022, 2000) obs_without_NA <- lung[2,] @@ -236,41 +236,28 @@ test_that("quantile predictions", { ) expect_s3_class(pred, "tbl_df") - expect_equal(names(pred), ".pred") + expect_equal(names(pred), ".pred_quantile") expect_equal(nrow(pred), 3) - expect_true( - all(purrr::map_lgl( - pred$.pred, - ~ all(dim(.x) == c(9, 2)) - )) - ) - expect_true( - all(purrr::map_lgl( - pred$.pred, - ~ all(names(.x) == c(".quantile", ".pred_quantile")) - )) - ) - expect_equal( - tidyr::unnest(pred, cols = .pred)$.pred_quantile, - do.call(rbind, exp_pred)$est - ) + expect_s3_class(pred$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list")) + + for (.row in 1:nrow(pred)) { + expect_equal( + unclass(pred$.pred_quantile[.row])[[1]], + exp_pred[[.row]]$est + ) + } # add confidence interval - pred <- predict(fit_s, + pred_ci <- predict(fit_s, new_data = bladder[1:3, ], type = "quantile", interval = "confidence", level = 0.7 ) - expect_true( - all(purrr::map_lgl( - pred$.pred, - ~ all(names(.x) == c( - ".quantile", - ".pred_quantile", - ".pred_lower", - ".pred_upper" - )) - )) - ) + expect_s3_class(pred_ci, "tbl_df") + expect_equal(names(pred_ci), c(".pred_quantile", ".pred_lower", ".pred_upper")) + expect_equal(nrow(pred_ci), 3) + expect_s3_class(pred_ci$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list")) + expect_s3_class(pred_ci$.pred_lower, c("quantile_pred", "vctrs_vctr", "list")) + expect_s3_class(pred_ci$.pred_upper, c("quantile_pred", "vctrs_vctr", "list")) # single observation f_pred_1 <- predict(fit_s, bladder[2,], type = "quantile") @@ -354,7 +341,7 @@ test_that("hazard for single eval time point", { test_that("`fix_xy()` works", { skip_if_not_installed("flexsurv") - + lung_x <- as.matrix(lung[, c("age", "ph.ecog")]) lung_y <- Surv(lung$time, lung$status) lung_pred <- lung[1:5, ] @@ -401,13 +388,13 @@ test_that("`fix_xy()` works", { f_fit, new_data = lung_pred, type = "quantile", - quantile = c(0.2, 0.8) + quantile_levels = c(0.2, 0.8) ) xy_pred_quantile <- predict( xy_fit, new_data = lung_pred, type = "quantile", - quantile = c(0.2, 0.8) + quantile_levels = c(0.2, 0.8) ) expect_equal(f_pred_quantile, xy_pred_quantile) diff --git a/tests/testthat/test-survival_reg-flexsurvspline.R b/tests/testthat/test-survival_reg-flexsurvspline.R index 6d837c1..8c2ba13 100644 --- a/tests/testthat/test-survival_reg-flexsurvspline.R +++ b/tests/testthat/test-survival_reg-flexsurvspline.R @@ -61,7 +61,7 @@ test_that("survival probability prediction", { head(lung), type = "survival", times = c(0, 500, 1000) - ) + ) if (packageVersion("flexsurv") < "2.3") { exp_pred <- exp_pred %>% dplyr::rowwise() %>% @@ -211,59 +211,26 @@ test_that("quantile predictions", { set_mode("censored regression") %>% fit(Surv(stop, event) ~ rx + size + enum, data = bladder) pred <- predict(fit_s, new_data = bladder[1:3, ], type = "quantile") - - set.seed(1) - exp_fit <- flexsurv::flexsurvspline( - Surv(stop, event) ~ rx + size + enum, - data = bladder, - k = 1 - ) - exp_pred <- summary( - exp_fit, - newdata = bladder[1:3, ], - type = "quantile", - quantiles = (1:9) / 10 - ) - expect_s3_class(pred, "tbl_df") - expect_equal(names(pred), ".pred") + expect_equal(names(pred), ".pred_quantile") expect_equal(nrow(pred), 3) - expect_true( - all(purrr::map_lgl( - pred$.pred, - ~ all(dim(.x) == c(9, 2)) - )) - ) - expect_true( - all(purrr::map_lgl( - pred$.pred, - ~ all(names(.x) == c(".quantile", ".pred_quantile")) - )) - ) - expect_equal( - tidyr::unnest(pred, cols = .pred)$.pred_quantile, - do.call(rbind, exp_pred)$est - ) + expect_s3_class(pred$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list")) + # add confidence interval - pred <- predict( + pred_ci <- predict( fit_s, new_data = bladder[1:3, ], type = "quantile", interval = "confidence", level = 0.7 ) - expect_true( - all(purrr::map_lgl( - pred$.pred, - ~ all(names(.x) == c( - ".quantile", - ".pred_quantile", - ".pred_lower", - ".pred_upper" - )) - )) - ) + expect_s3_class(pred_ci, "tbl_df") + expect_equal(names(pred_ci), c(".pred_quantile", ".pred_lower", ".pred_upper")) + expect_equal(nrow(pred_ci), 3) + expect_s3_class(pred_ci$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list")) + expect_s3_class(pred_ci$.pred_lower, c("quantile_pred", "vctrs_vctr", "list")) + expect_s3_class(pred_ci$.pred_upper, c("quantile_pred", "vctrs_vctr", "list")) # single observation f_pred_1 <- predict(fit_s, bladder[2,], type = "quantile") @@ -284,7 +251,7 @@ test_that("hazard prediction", { head(lung), type = "hazard", times = c(0, 500, 1000) - ) + ) if (packageVersion("flexsurv") < "2.3") { exp_pred <- exp_pred %>% dplyr::rowwise() %>% @@ -409,13 +376,13 @@ test_that("`fix_xy()` works", { f_fit, new_data = lung_pred, type = "quantile", - quantile = c(0.2, 0.8) + quantile_levels = c(0.2, 0.8) ) xy_pred_quantile <- predict( xy_fit, new_data = lung_pred, type = "quantile", - quantile = c(0.2, 0.8) + quantile_levels = c(0.2, 0.8) ) expect_equal(f_pred_quantile, xy_pred_quantile) @@ -438,7 +405,7 @@ test_that("`fix_xy()` works", { test_that("can handle case weights", { skip_if_not_installed("flexsurv") - + # flexsurv engine can only take weights > 0 set.seed(1) wts <- runif(nrow(lung)) diff --git a/tests/testthat/test-survival_reg-survival.R b/tests/testthat/test-survival_reg-survival.R index 958d337..4d1782f 100644 --- a/tests/testthat/test-survival_reg-survival.R +++ b/tests/testthat/test-survival_reg-survival.R @@ -122,16 +122,22 @@ test_that("prediction of survival time quantile", { fit(Surv(time, status) ~ age + sex, data = lung) exp_quant <- predict(res$fit, head(lung), p = (2:4) / 5, type = "quantile") - exp_quant <- apply(exp_quant, 1, function(x) { - tibble::tibble(.quantile = (2:4) / 5, .pred_quantile = x) - }) - exp_quant <- tibble::tibble(.pred = exp_quant) - obs_quant <- predict(res, head(lung), type = "quantile", quantile = (2:4) / 5) + obs_quant <- predict(res, head(lung), type = "quantile", quantile_levels = (2:4) / 5) - expect_equal(as.data.frame(exp_quant), as.data.frame(obs_quant)) + expect_s3_class(obs_quant, "tbl_df") + expect_equal(names(obs_quant), ".pred_quantile") + expect_equal(nrow(obs_quant), 6) + expect_s3_class(obs_quant$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list")) + + for (.row in 1:nrow(obs_quant)) { + expect_equal( + unclass(obs_quant$.pred_quantile[.row])[[1]], + exp_quant[.row,] + ) + } # single observation - f_pred_1 <- predict(res, lung[1, ], type = "quantile") + f_pred_1 <- predict(res, lung[1, ], type = "quantile", quantile_levels = .5) expect_identical(nrow(f_pred_1), 1L) }) @@ -213,13 +219,13 @@ test_that("`fix_xy()` works", { f_fit, new_data = lung_pred, type = "quantile", - quantile = c(0.2, 0.8) + quantile_levels = c(0.2, 0.8) ) xy_pred_quantile <- predict( xy_fit, new_data = lung_pred, type = "quantile", - quantile = c(0.2, 0.8) + quantile_levels = c(0.2, 0.8) ) expect_equal(f_pred_quantile, xy_pred_quantile)