Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connect discSurv as a reduction pipeline #194

Closed
RaphaelS1 opened this issue Apr 5, 2021 · 14 comments
Closed

Connect discSurv as a reduction pipeline #194

RaphaelS1 opened this issue Apr 5, 2021 · 14 comments
Labels
Type: PipeOp Issue that wants a new PipeOp

Comments

@RaphaelS1
Copy link
Collaborator

RaphaelS1 commented Apr 5, 2021

This will require some careful thought as initial exploration (below) indicates sub-optimal results. We may want to consider discrete-survival representations, as well as discrete hazards, or think about sampling methods to solve the imbalance problem. (@adibender)

library(mlr3)
library(mlr3proba)
library(discSurv)
library(mlr3learners)
library(survival)

discSurv_redux <- function(task, n = 5, lrn = "classif.ranger") {
  out <- list()
  data <- as.data.frame(task$data())
  time <- task$target_names[[1]]
  status <- task$target_names[[2]]
  
  ## convert data to discrete time
  discData <- contToDisc(
    data, time, 
    seq.int(min(data[[time]]), max(data[[time]]) + 1, length.out = n))
  discData[[status]] <- as.integer(discData[[status]])
  ## convert data to long format
  longData <- dataLong(discData, timeColumn = "timeDisc", censColumn = status)
  ## get target and feature names
  long_data_sim <- longData[, c("timeInt", "y", task$feature_names)]
  long_data_sim$y <- factor(long_data_sim$y)
  ## create classif task
  task = TaskClassif$new("disc", long_data_sim, target = "y")
  
  ## make prediction
  p <- lrn(lrn, predict_type = "prob")$train(task)$predict(task)
  ## return brier and accuracy
  out$bbrier <- as.numeric(p$score(msr("classif.bbrier")))
  out$acc <- as.numeric(1 - p$score())
  
  ## get predictions
  pred <- cbind(longData, pred = p$prob[, 1])
  max_t <- max(table(pred$obj))
  ## convert hazards to surv as prod(1 - h(t))
  surv <- t(sapply(unique(pred$obj), function(i) {
    x <- cumprod(1 - pred[pred$obj == i, "pred"])
    if (length(x) < max_t) {
      x <- c(x, rep(tail(x, 1), max_t - length(x)))
    }
    x
  }))
  
  time <- as.numeric(as.character(discData$timeDisc))
  ## coerce to distribution and crank
  r <- .surv_return(seq(max(time)), surv = surv)
  ## create prediction object
  p <- PredictionSurv$new(
    row_ids = seq(nrow(discData)),
    crank = r$crank, distr = r$distr,
    truth = Surv(time, discData[[status]]))
  ## evaluate with Harrell's C and IGS
  out$H_C <- as.numeric(p$score())
  out$IGS <- as.numeric(p$score(msr("surv.graf")))
  
  out
}

discSurv_redux(tsk("rats"), 10, "classif.featureless")
## $bbrier
## [1] 0.01597566
## 
## $acc
## [1] 0.9840243
## 
## $H_C
## [1] 0.5
## 
## $IGS
## [1] 0.8860837
discSurv_redux(tsk("rats"), 10, "classif.ranger")
## $bbrier
## [1] 0.01058406
## 
## $acc
## [1] 0.9840243
## 
## $H_C
## [1] 0.07750325
## 
## $IGS
## [1] 0.8860771
discSurv_redux(tsk("rats"), 5, "classif.ranger")
## $bbrier
## [1] 0.01843896
## 
## $acc
## [1] 0.9709935
## 
## $H_C
## [1] 0.08337326
## 
## $IGS
## [1] 0.7862632
discSurv_redux(tsk("unemployment"), 5, "classif.ranger")
## $bbrier
## [1] 0.0647278
## 
## $acc
## [1] 0.900025
## 
## $H_C
## [1] 0.2372616
## 
## $IGS
## [1] 0.2770845
discSurv_redux(tgen("simsurv")$generate(100), 10, "classif.ranger")
## $bbrier
## [1] 0.03973285
## 
## $acc
## [1] 0.9230769
## 
## $H_C
## [1] 0.1248401
## 
## $IGS
## [1] 0.594151
@adibender
Copy link
Sponsor Collaborator

stupid question - what precisely is the "discrete survival" task for you? Is there a summary/description somewhere?

You discretize the follow up into intervals, and create pseudo data with new "status" variable that is 0 if subject survived the interval and 1 if subject experienced event in that interval. Then fit any classifier to the pseudo status variable (with the interval as covariate). The event probability in each interval = the discrete time hazard, from which you can also calculate survival probabilities. If the grid is fine enough, you can approximate any survival time distribution reasonably well.

@adibender
Copy link
Sponsor Collaborator

@RaphaelS1 Looks legit, however,

  • n should probably be increased. As you assume that baseline hazard is constant in each interval, 5 or 10 intervals might be to crude
  • carrying last value forward could also be an issue here:
if (length(x) < max_t) {
      x <- c(x, rep(tail(x, 1), max_t - length(x)))
    }

Its better to create a new data with all intervals for each id (covariates of the ids remain constant in all intervals). Then predict the hazard for each interval and calculate survival probability accordingly.

@RaphaelS1
Copy link
Collaborator Author

stupid question - what precisely is the "discrete survival" task for you? Is there a summary/description somewhere?

https://www.springer.com/gp/book/9783319281568

n should probably be increased. As you assume that baseline hazard is constant in each interval, 5 or 10 intervals might be to crude

Suggestions for a better default?

Its better to create a new data with all intervals for each id (covariates of the ids remain constant in all intervals). Then predict the hazard for each interval and calculate survival probability accordingly.

Not sure I understand this, could you provide code/pseudo-code/math example?

@adibender
Copy link
Sponsor Collaborator

Here is a proof of concept (programming is terrible). Note that I use pammtools for various data trafo just because I'm more familiar with it. Also, this would generalize to the continuous case, we would just need to use regr.* instead of classif.* with poisson loss and integrate the offset, which is also created by the as_ped function. Later, we would need to generalize a little bit to allow custom times where Brier Score is evaluated, etc.
Btw, after rewriting with pammtools I noticed that it should be pred = p$prob[,2] in your original code, because the hazard is the prob for 1s not 0s, so your previous code probably also works better, once this is fixed (but its still better to use actual prediction for intervals where subjects weren't observed, rather than carry last value forward):

library(mlr3)
library(mlr3proba)
library(discSurv)
library(mlr3learners)
library(survival)
library(pammtools)

discSurv_redux <- function(task, cut = NULL, lrn = "classif.ranger") {

  out <- list()
  data <- as.data.frame(task$data())
  time <- task$target_names[[1]]
  status <- task$target_names[[2]]

  ## convert data to discrete time
  longData <- as_ped(data=data, Surv(time, status)~., cut = cut)
  ## get target and feature names
  long_data_sim <- longData[, c("tend", "ped_status", task$feature_names[1:3])]
  long_data_sim$ped_status <- factor(long_data_sim$ped_status)
  ## create classif task
  task = TaskClassif$new("disc", long_data_sim, target = "ped_status")

  data2 <- data
  data2$time <- max(data$time)
  new_data <- as_ped(longData, newdata = data2)
  new_data <- new_data[, c( "id", "tend", "ped_status", task$feature_names[1:3])]

  ## make prediction
  p <- lrn(lrn, predict_type = "prob")$train(task)$predict_newdata(new_data)
  ## get predictions
  pred <- cbind(new_data, pred = p$prob[, 2])
  max_t <- max(data$time)
  ## convert hazards to surv as prod(1 - h(t))
  surv <- t(sapply(unique(pred$id), function(i) {
    x <- cumprod((1 - pred[pred$id == i, "pred"]))
    x
  }))

  time <- sort(unique(new_data$tend))
  ## coerce to distribution and crank
  r <- .surv_return(time, surv = surv)
  ## create prediction object
  p <- PredictionSurv$new(
    row_ids = seq(nrow(data)),
    crank = r$crank, distr = r$distr,
    truth = Surv(data[["time"]], data[[status]]))
  ## evaluate with Harrell's C and IGS
  out$H_C <- as.numeric(p$score())
  out$IGS <- as.numeric(p$score(msr("surv.graf", proper = TRUE)))

  out

}
set.seed(18452505)
discSurv_redux(tsk("rats"), cut = seq(0, max(rats$time), length.out = 10), "classif.featureless")
# $H_C
# [1] 0.5

# $IGS
# [1] 0.05894565
discSurv_redux(tsk("rats"), cut = seq(0, max(rats$time), length.out = 100), "classif.featureless")
# $H_C
# [1] 0.5

# $IGS
# [1] 0.05894565
discSurv_redux(tsk("rats"), cut = seq(0, max(rats$time), length.out = 10), "classif.ranger")
# $H_C
# [1] 0.9465666

# $IGS
# [1] 0.07711993
discSurv_redux(tsk("rats"), cut = seq(0, max(rats$time), length.out = 100), "classif.ranger")
# $H_C
# [1] 0.9563922

# $IGS
# [1] 0.0356116
discSurv_redux(tsk("rats"), cut = NULL, "classif.ranger") # cuts at unique event times
# $H_C
# [1] 0.9568337

# $IGS
# [1] 0.04004635

Results make sense to me, although we might think about what featureless means in this context. The way it is used now, you estimate one constant baseline hazard. When you include the variable that indicates intervals, you get piece-wise constant baseline hazard (still without traditional features).

@adibender
Copy link
Sponsor Collaborator

adibender commented May 25, 2021

Suggestions for a better default?

maybe number of unique event times (or square root of them, if number large).
or use the unique event times as cut points. At least it works well for PAMMs, not sure about discrete survival.

Not sure I understand this, could you provide code/pseudo-code/math example?

Note new_data in the code above, contains one row per subject and interval. The number of intervals, however, should be dynamic and depend on the times at which we want to predict survival, but that's a little more programming effort

@adibender
Copy link
Sponsor Collaborator

predicting individual bins that are non-overlapping, or predicting bins that are overlapping. In the latter case you may need to assume compatibility between predicted probabilities

I'm not aware of cases where bins are overlapping. See example below for usual workflow.

not just 0/1, but you have a third possible outcome, "censored" (!) - how do you deal with this?

For each subject only bins/intervals are included where subject was at risk. In the last interval where they are at risk, if they were still alive, all pseudo status values (ped_status below) are 0, otherwise all ped_status values are 0, except the last in which the event occurred.

@adibender, in the distribution prediction context, featureless predictions = distribution/density estimation. For example, Kaplan-Meier (or the discrete version thereof). See section 3.3 of https://arxiv.org/abs/1801.00753 why I think this is the case.

Its definitely more sensible to use estimate a more flexible baseline hazard, but we would have to think how this model is called from within mlr3. Because if we want to use the reduction directly, i.e. call classification learners without modification, classif.featureless means intercept only (model m0 below, but what we want is model m1, which is baseline hazard without covariates; however, in the scheme of discrete time anylsis this means estimating the probability of event in each interval, conditional on feature interval.

library(pammtools)
#> 
#> Attaching package: 'pammtools'
#> The following object is masked from 'package:stats':
#> 
#>     filter
library(mgcv)
#> Loading required package: nlme
#> This is mgcv 1.8-33. For overview type 'help("mgcv-package")'.
library(ggplot2)
theme_set(theme_bw())
library(survival)

set.seed(128)
data <- tumor[sample.int(nrow(tumor), 300, replace = FALSE), ]
data$id <- seq_len(nrow(data))
data <- data[, c("id", "days", "status", "age", "sex", "complications")]
data[c(1,3), ]
#> # A tibble: 2 x 6
#>      id  days status   age sex   complications
#>   <int> <dbl>  <int> <int> <fct> <fct>        
#> 1     1  1402      1    64 male  no           
#> 2     3   645      0    45 male  yes

### Data preparation
# discretize follow up: use unique event times as cut points
discretized_data <- data %>% as_ped(Surv(days, status)~., cut = NULL)
nrow(discretized_data)
#> [1] 25513
# show first and last two observations for subjects 1 and 200
discretized_data %>%
  filter(id %in% c(1, 3)) %>%
  group_by(id) %>%
  slice(1:2, (dplyr::n()-1):dplyr::n())
#> # A tibble: 8 x 9
#> # Groups:   id [2]
#>      id tstart  tend interval    offset ped_status   age sex   complications
#>   <int>  <dbl> <dbl> <fct>        <dbl>      <dbl> <int> <fct> <fct>        
#> 1     1      0     2 (0,2]        0.693          0    64 male  no           
#> 2     1      2     3 (2,3]        0              0    64 male  no           
#> 3     1   1383  1393 (1383,1393]  2.30           0    64 male  no           
#> 4     1   1393  1402 (1393,1402]  2.20           1    64 male  no           
#> 5     3      0     2 (0,2]        0.693          0    45 male  yes          
#> 6     3      2     3 (2,3]        0              0    45 male  yes          
#> 7     3    586   613 (586,613]    3.30           0    45 male  yes          
#> 8     3    613   646 (613,646]    3.47           0    45 male  yes
length(unique(discretized_data$interval))
#> [1] 142
# -> subject 1 experienced event at 1402 -> ped_status = 1 in this interval, otherwise 0
# -> subject 3 was censored at 645 -> last interval in risk set is (613, 646] -> ped_status = 0 always


### Model estimation

# constant baseline hazard

m0 <- glm(ped_status ~ 1, data = discretized_data, family = binomial())

# interval specific baseline hazard
# not so good, hazards volatile, many parameters to estimate
m1 <- glm(ped_status ~ interval, data = discretized_data, family = binomial())

# better: penaize differences between neighboring hazards
# tend is a representation of time in the j-th interval, here interval end point
m2 <- gam(ped_status ~ s(tend), data = discretized_data, family = binomial())

### Visualization

ndf <- discretized_data %>% make_newdata(tend = unique(tend))
head(ndf)
#>   tstart tend intlen interval       id    offset ped_status      age  sex
#> 1      0    2      2    (0,2] 151.1807 0.6931472          0 61.06683 male
#> 2      2    3      1    (2,3] 151.1807 0.0000000          0 61.06683 male
#> 3      3    5      2    (3,5] 151.1807 0.6931472          0 61.06683 male
#> 4      5    6      1    (5,6] 151.1807 0.0000000          0 61.06683 male
#> 5      6    7      1    (6,7] 151.1807 0.0000000          0 61.06683 male
#> 6      7    8      1    (7,8] 151.1807 0.0000000          0 61.06683 male
#>   complications
#> 1            no
#> 2            no
#> 3            no
#> 4            no
#> 5            no
#> 6            no

ndf$hazard0 <- predict(m0, newdata = ndf, type = "response")
ndf$hazard1 <- predict(m1, newdata = ndf, type = "response")
ndf$hazard2 <- predict(m2, newdata = ndf, type = "response")

ggplot(ndf, aes(x = tend)) +
  geom_step(aes(y = hazard0, col = "m0")) +
  geom_step(aes(y = hazard1, col = "m1")) +
  geom_step(aes(y = hazard2, col = "m2"))

### Survival Probability

ndf <- ndf %>%
  mutate(
    surv0 = cumprod(1 - hazard0),
    surv1 = cumprod(1 - hazard1),
    surv2 = cumprod(1 - hazard2),
  )

# cox for comparison
cox <- coxph(Surv(days, status)~1, data = data)
bh <- basehaz(cox)
bh$surv_cox <- as.numeric(exp(-bh$hazard))

ggplot(ndf, aes(x = tend)) +
  geom_step(aes(y = surv0, col = "m0"))+
  geom_step(aes(y = surv1, col = "m1"))+
  geom_step(aes(y = surv2, col = "m2")) +
  geom_step(data = bh, aes(x=time, y = surv_cox, col = "cox")) +
  ylim(c(0, 1))

# cox, m1 and m2 basically identical


## Stratified baseline hazard

cox_strata <- coxph(Surv(days, status)~strata(complications), data = data)
bh_strata <- basehaz(cox_strata)
bh_strata <- bh_strata %>%
  group_by(strata) %>%
  mutate(surv = exp(-hazard))

discrete_strata <- gam(ped_status ~ complications + s(tend, by = complications),
  data = discretized_data, family = binomial())

ndf <- discretized_data %>%
  make_newdata(tend = unique(tend), complications = unique(complications))
ndf$hazard_strata <- predict(discrete_strata, ndf, type = "response")
ndf <- ndf %>%
  group_by(complications) %>%
  mutate(surv_strata = cumprod(1 - hazard_strata))

ggplot(ndf, aes(x = tend)) +
  geom_step(aes(lty = complications, y = surv_strata, col = "discrete")) +
  geom_line(data = bh_strata, aes(x = time, y = surv, lty = strata, col = "cox"))

Created on 2021-06-01 by the reprex package (v0.3.0)

@RaphaelS1
Copy link
Collaborator Author

This is great, let's use pammtools not discSurv in the reduction directly

@RaphaelS1
Copy link
Collaborator Author

Especially with the examples you show above it's very promising!!

@RaphaelS1
Copy link
Collaborator Author

So available reductions/ones we have implemented

  • discreteSurv/pammtools
  • surv -> det regr
  • surv -> prob regr
  • multi-class classif (my method)
  • stacking classif (Zhong/Tibshirani) - should be relatively easy to implement

Comparing these would make a very nice paper

@adibender
Copy link
Sponsor Collaborator

This is great, let's use pammtools not discSurv in the reduction directly

Ideally, we would extract some of the functionality re data transformation into a separate package. Otherwise you get a lot of unnecessary dependencies. the data trafo in pammtools is itself a wrapper around survival::survSplit.
The other question is whether those reductions go into proba or separate packages?

@adibender
Copy link
Sponsor Collaborator

So available reductions/ones we have implemented

  • discreteSurv/pammtools
  • surv -> det regr
  • surv -> prob regr
  • multi-class classif (my method)
  • stacking classif (Zhong/Tibshirani) - should be relatively easy to implement

Comparing these would make a very nice paper

Unless I'm missing something, I think Zhong/Tibshirani is discrete time to event analysis (transformation of survival task into a binomial likelihood optimization/classification task), just framed differently, so there is not difference to discSurv/pammtools. The latter are just implementations of the data trafo before applying some classification algorithm. For pammtools, this is essentially a by-product, because originally, the data trafo was intended for piece-wise exponential models (i.e. transformation of survival task into poisson likelihood optimization task), but the data trafo for discrete time to event and piece-wise exponential models is essentially the same.

@RaphaelS1
Copy link
Collaborator Author

The other question is whether those reductions go into proba or separate packages?

This could go into a different package but we still want the interface to work with mlr easily otherwise it's still a lot of glue code

Unless I'm missing something, I think Zhong/Tibshirani is discrete time to event analysis (transformation of survival task into a binomial likelihood optimization/classification task), just framed differently, so there is not difference to discSurv/pammtools.

You're the expect here but my understanding is that their method is different because of how the predictions are made (not necessarily in the transformation) and is therefore sufficiently different to mean it's fair to include in a comparison. So in discSurv we would fit a single model at all unique times with T as a covariate? Whereas in the stacking approach it's one model per unique event time. Obviously these may come to the same thing but analytically it's still a different approach

@adibender
Copy link
Sponsor Collaborator

what do we need to do for this reduction to be fully integrated in mlr3?

@mlr-org mlr-org deleted a comment from fkiraly Jan 13, 2024
@mlr-org mlr-org deleted a comment from fkiraly Jan 13, 2024
@mlr-org mlr-org deleted a comment from fkiraly Jan 13, 2024
@mlr-org mlr-org deleted a comment from fkiraly Jan 13, 2024
@mlr-org mlr-org deleted a comment from fkiraly Jan 13, 2024
@bblodfon bblodfon added the Type: PipeOp Issue that wants a new PipeOp label Jun 10, 2024
bblodfon added a commit that referenced this issue Jul 25, 2024
@bblodfon
Copy link
Collaborator

Finished in #391

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Type: PipeOp Issue that wants a new PipeOp
Projects
None yet
Development

No branches or pull requests

3 participants