From 2dbac8509d91b5f6e4a48fe99fc0ec4fc322f777 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Mon, 23 Sep 2024 18:19:48 -0400 Subject: [PATCH] data manipulation functions for tidymodels/parsnip#1203 --- DESCRIPTION | 4 +++- R/survival_reg-flexsurv.R | 45 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 9986f5c..dce6cfa 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -27,7 +27,7 @@ Imports: dplyr (>= 0.8.0.1), generics, glue, - hardhat (>= 1.1.0), + hardhat (>= 1.4.0.9002), lifecycle, mboost, prodlim (>= 2023.03.31), @@ -48,6 +48,8 @@ Suggests: rmarkdown, rpart, testthat (>= 3.0.0) +Remotes: + tidymodels/hardhat Config/Needs/website: tidymodels, tidyverse/tidytemplate diff --git a/R/survival_reg-flexsurv.R b/R/survival_reg-flexsurv.R index 77895f2..ee1fbc0 100644 --- a/R/survival_reg-flexsurv.R +++ b/R/survival_reg-flexsurv.R @@ -11,7 +11,7 @@ flexsurv_post <- function(pred, object) { tidyr::nest(.by = .row) %>% dplyr::select(-.row) } - pred + pred } flexsurv_rename_time <- function(pred){ @@ -27,3 +27,46 @@ flexsurv_rename_time <- function(pred){ dplyr::rename(.eval_time = .time) } } + +# ------------------------------------------------------------------------------ +# Conversion of quantile predictions to the vctrs format + +# For single quantile levels, flexsurv returns a data frame with column +# ".pred_quantile" and perhaps also ".pred_lower" and ".pred_upper" + +# With mutiple quantile levels, flexsurv returns a data frame with a ".pred" +# column with co.lumns ".quantile" and ".pred_quantile" and perhaps +# ".pred_lower" and ".pred_upper" + +flexsurv_to_quantile_pred <- function(x, object) { + # if one level, convert to nested format + if(!identical(names(x), ".pred")) { + # convert to the same format as predictions with mulitplel levels + x <- re_nest(x) + } + + # Get column names to convert to vctrs encoding + nms <- names(x$.pred[[1]]) + possible_cols <- c(".pred_quantile", ".pred_lower", ".pred_upper") + existing_cols <- intersect(possible_cols, nms) + + # loop over prediction column names + res <- list() + for (col in existing_cols) { + res[[col]] <- purrr::map_vec(x$.pred, nested_df_iter, col = col) + } + tibble::new_tibble(res) + +} + +re_nest <- function(df) { + .row <- 1:nrow(df) + df <- vctrs::vec_split(df, by = .row) + df$key <- NULL + names(df) <- ".pred" + df +} + +nested_df_iter <- function(df, col) { + hardhat::quantile_pred(matrix(df[[col]], nrow = 1), quantile_levels = df$.quantile) +}