Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New metric request: Survival PRC-AUC function #496

Open
asb2111 opened this issue Mar 13, 2024 · 1 comment
Open

New metric request: Survival PRC-AUC function #496

asb2111 opened this issue Mar 13, 2024 · 1 comment
Labels
feature a feature request or enhancement metric 📏 a new yardstick metric

Comments

@asb2111
Copy link

asb2111 commented Mar 13, 2024

In situations with major class imbalance, ROC-AUC may not be a good metric to assess model concordance. Instead, as suggested in numerous places such as the scikit-learn documentation, the area under the precision-recall curve may be preferred. Functions already exist for PRC-AUC for the standard settings, but there is currently no function available in yardstick for the survival setting.

I've attached some code here that is an adaptation of the roc_auc_survival_vec function and the functions it depends on that, I believe, implements the survival version of PRC-AUC by using the principles of Vock et al., where they provide a general recipe for incorporating inverse probability of censoring weights to any model. The final step, after estimating the weights, is:

Apply an existing prediction method to a weighted version of the training set where each member i of the training set is weighted by a factor of $\omega_i$. In other words, if $\omega_i=3$ it is as if the observation appeared three times in the data set.

# PRC ####

prc_auc_survival_vec <- function(truth,
         estimate,
         na_rm = TRUE,
         case_weights = NULL,
         ...) {
  # No checking since prc_curve_survival_vec() does checking
  curve <- prc_curve_survival_vec(
    truth = truth,
    estimate = estimate,
    na_rm = na_rm,
    case_weights = case_weights
  )
  
  curve %>%
    dplyr::group_by(.eval_time) %>%
    dplyr::summarize(.estimate = prc_trap_auc(pr, re))
}

prc_curve_survival_vec <- function(truth,
                                   estimate,
                                   na_rm = TRUE,
                                   case_weights = NULL,
                                   ...) {
  yardstick::check_dynamic_survival_metric(truth, estimate, case_weights)
  
  if (na_rm) {
    result <- yardstick_remove_missing(truth, seq_along(estimate), case_weights)
    
    truth <- result$truth
    estimate <- estimate[result$estimate]
    case_weights <- result$case_weights
  } else if (yardstick::yardstick_any_missing(truth, estimate, case_weights)) {
    cli::cli_abort(
      c(x = "Missing values were detected and {.code na_ra = FALSE}.",
        i = "Not able to perform calculations.")
    )
  }
  
  prc_curve_survival_impl(truth = truth,
                          estimate = estimate,
                          case_weights = case_weights)
}

prc_curve_survival_impl <- function(truth,
                                    estimate,
                                    case_weights) {
  event_time <- .extract_surv_time(truth)
  delta <- .extract_surv_status(truth)
  case_weights <- vctrs::vec_cast(case_weights, double())
  if (is.null(case_weights)) {
    case_weights <- rep(1, length(delta))
  }
  
  # Drop any `0` weights.
  # These shouldn't affect the result, but can result in wrong thresholds
  detect_zero_weight <- case_weights == 0
  if (any(detect_zero_weight)) {
    detect_non_zero_weight <- !detect_zero_weight
    event_time <- event_time[detect_non_zero_weight]
    delta <- delta[detect_non_zero_weight]
    case_weights <- case_weights[detect_non_zero_weight]
    estimate <- estimate[detect_non_zero_weight]
  }
  
  data <- dplyr::tibble(event_time, delta, case_weights, estimate)
  data <- tidyr::unnest(data, cols = estimate)
  
  .eval_times <- unique(data$.eval_time)
  
  not_missing_pred_survival <- !is.na(data$.pred_survival)
  
  out <- list()
  for (i in seq_along(.eval_times)) {
    .eval_time_ind <- .eval_times[[i]] == data$.eval_time & not_missing_pred_survival
    
    res <- prc_curve_survival_impl_one(
      data$event_time[.eval_time_ind],
      data$delta[.eval_time_ind],
      data[.eval_time_ind, ],
      data$case_weights[.eval_time_ind]
    )
    
    res$.eval_time <- .eval_times[[i]]
    out[[i]] <- res
  }
  
  dplyr::bind_rows(out)
}

prc_curve_survival_impl_one <- function(event_time, delta, data, case_weights) {
  res <- dplyr::tibble(.threshold = sort(unique(c(-Inf, data$.pred_survival, Inf)), decreasing = TRUE))
  
  obs_time_le_time <- event_time <= data$.eval_time
  obs_time_gt_time <- event_time > data$.eval_time
  n <- nrow(data)
  
  re_denom <- sum(obs_time_le_time * delta * data$.weight_censored * case_weights, na.rm = TRUE)

  data_df <- data.frame(
    le_time = obs_time_le_time,
    ge_time = obs_time_gt_time,
    delta = data$delta,
    weight_censored = data$.weight_censored,
    case_weights = case_weights
  )
  
  data_split <- vctrs::vec_split(data_df, data$.pred_survival)
  data_split <- data_split$val[order(data_split$key)]
  
  re <- vapply(
    data_split,
    function(x) sum(x$le_time * x$delta * x$weight_censored * x$case_weights, na.rm = TRUE),
    FUN.VALUE = numeric(1)
  )
  
  re <- cumsum(re)
  re <- re / re_denom
  re <- dplyr::if_else(re > 1, 1, re)
  re <- dplyr::if_else(re < 0, 0, re)
  re <- c(0, re, 1)
  res$re <- re

  pr_num <- vapply(
    data_split,
    function(x) sum(x$le_time * x$delta * x$weight_censored * x$case_weights, na.rm = TRUE),
    FUN.VALUE = numeric(1)
  )

  pr_den <- vapply(
    data_split,
    function(x) sum(x$case_weights * x$weight_censored, na.rm = TRUE),
    FUN.VALUE = numeric(1)
  )
  
  pr_den <- cumsum(pr_den)
  pr_num <- cumsum(pr_num)
  pr <- pr_num / pr_den
  pr <- dplyr::if_else(pr > 1, 1, pr)
  pr <- dplyr::if_else(pr < 0 | is.na(pr), 0, pr)
  pr <- c(min(pr, na.rm = T), pr, max(pr, na.rm = T))
  res$pr <- pr

  res
}

prc_trap_auc <- function(pr, re) {
  not_na <- !is.na(pr) & !is.na(re)
  pr <- pr[not_na]
  re <- re[not_na]

  yardstick:::auc(re, pr)
}

prc_curve_survival <- function(data, ...){
  UseMethod("prc_curve_survival")
}

prc_curve_survival.data.frame <- function(data,
                                          truth,
                                          ...,
                                          na_rm = TRUE,
                                          case_weights = NULL){
  
  result <- curve_survival_metric_summarizer(
    name = "prc_curve_survival",
    fn = prc_curve_survival_vec,
    data = data,
    truth = !!enquo(truth),
    ...,
    na_rm = na_rm,
    case_weights = !!enquo(case_weights)
  )
  
  yardstick:::curve_finalize(result, data, "prc_survival_df", "grouped_prc_survival_df")
}

autoplot.prc_survival_df <- function(object, ...) {
  `%+%` <- ggplot2::`%+%`
  object$.eval_time <- format(object$.eval_time)
  
  # Base chart
  prc_chart <- ggplot2::ggplot(data = object)
  
  # create aesthetic
  prc_aes <- ggplot2::aes(
    x = re,
    y = pr,
    color = .eval_time,
    group = .eval_time
  )
  
  # build the graph
  prc_chart <- prc_chart %+%
    ggplot2::geom_step(mapping = prc_aes, direction = "hv") %+%
    # ggplot2::geom_abline(lty = 3) %+%
    ggplot2::coord_equal() %+%
    ggplot2::theme_bw() %+%
    ggplot2::xlab("Recall") %+%
    ggplot2::ylab("Precision")
  
  prc_chart
}

prc_auc_survival <- function(data, ...){
  UseMethod("prc_auc_survival")
}

prc_auc_survival <- yardstick::new_dynamic_survival_metric(prc_auc_survival, direction = "maximize")

prc_auc_survival.data.frame <- function(data,
                                      truth,
                                      ...,
                                      na_rm = TRUE,
                                      case_weights = NULL) {
  yardstick::dynamic_survival_metric_summarizer(
    name = "prc_auc_survival",
    fn = prc_auc_survival_vec,
    data = data,
    truth = !!enquo(truth),
    ...,
    na_rm = na_rm,
    case_weights = !!enquo(case_weights)
  )
}
@asb2111 asb2111 changed the title Survival PRC-AUC function New metrick request: Survival PRC-AUC function Mar 13, 2024
@asb2111 asb2111 changed the title New metrick request: Survival PRC-AUC function New metric request: Survival PRC-AUC function Mar 13, 2024
@EmilHvitfeldt EmilHvitfeldt added feature a feature request or enhancement metric 📏 a new yardstick metric labels Mar 16, 2024
@EmilHvitfeldt
Copy link
Member

Thank you for the suggestion!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature a feature request or enhancement metric 📏 a new yardstick metric
Projects
None yet
Development

No branches or pull requests

2 participants