Skip to content

Commit

Permalink
data manipulation functions for tidymodels/parsnip#1203
Browse files Browse the repository at this point in the history
  • Loading branch information
‘topepo’ committed Sep 23, 2024
1 parent a297661 commit 2dbac85
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Imports:
dplyr (>= 0.8.0.1),
generics,
glue,
hardhat (>= 1.1.0),
hardhat (>= 1.4.0.9002),
lifecycle,
mboost,
prodlim (>= 2023.03.31),
Expand All @@ -48,6 +48,8 @@ Suggests:
rmarkdown,
rpart,
testthat (>= 3.0.0)
Remotes:
tidymodels/hardhat
Config/Needs/website:
tidymodels,
tidyverse/tidytemplate
Expand Down
45 changes: 44 additions & 1 deletion R/survival_reg-flexsurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ flexsurv_post <- function(pred, object) {
tidyr::nest(.by = .row) %>%
dplyr::select(-.row)
}
pred
pred
}

flexsurv_rename_time <- function(pred){
Expand All @@ -27,3 +27,46 @@ flexsurv_rename_time <- function(pred){
dplyr::rename(.eval_time = .time)
}
}

# ------------------------------------------------------------------------------
# Conversion of quantile predictions to the vctrs format

# For single quantile levels, flexsurv returns a data frame with column
# ".pred_quantile" and perhaps also ".pred_lower" and ".pred_upper"

# With mutiple quantile levels, flexsurv returns a data frame with a ".pred"
# column with co.lumns ".quantile" and ".pred_quantile" and perhaps
# ".pred_lower" and ".pred_upper"

flexsurv_to_quantile_pred <- function(x, object) {
# if one level, convert to nested format
if(!identical(names(x), ".pred")) {
# convert to the same format as predictions with mulitplel levels
x <- re_nest(x)
}

# Get column names to convert to vctrs encoding
nms <- names(x$.pred[[1]])
possible_cols <- c(".pred_quantile", ".pred_lower", ".pred_upper")
existing_cols <- intersect(possible_cols, nms)

# loop over prediction column names
res <- list()
for (col in existing_cols) {
res[[col]] <- purrr::map_vec(x$.pred, nested_df_iter, col = col)
}
tibble::new_tibble(res)

}

re_nest <- function(df) {
.row <- 1:nrow(df)
df <- vctrs::vec_split(df, by = .row)
df$key <- NULL
names(df) <- ".pred"
df
}

nested_df_iter <- function(df, col) {
hardhat::quantile_pred(matrix(df[[col]], nrow = 1), quantile_levels = df$.quantile)
}

0 comments on commit 2dbac85

Please sign in to comment.