Skip to content

Commit

Permalink
store full posterior survival array (testing version)
Browse files Browse the repository at this point in the history
* will work with latest version of distr6
  • Loading branch information
bblodfon committed Sep 8, 2023
1 parent d52c6fa commit c8041b4
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions R/learner_BART_surv_bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,23 +180,37 @@ delayedAssign(
pred = pred_fun()
}

# Build survival matrix using the mean posterior estimates of the survival
# function, see page 34-35 in Sparapani (2021) for more details

# Number of test observations
N = task$nrow
# Number of unique times
K = pred$K
times = pred$times
# Number of posterior draws
M = nrow(pred$surv.test)

# save the full posterior survival matrix
# save the full posterior survival matrix and the mean for checking
# TODO: remove next two lines
self$model$surv.test = pred$surv.test
self$model$surv.test.mean = pred$surv.test.mean

# Convert full posterior survival matrix to 3D survival array
# See page 34-35 in Sparapani (2021) for more details
surv.array = aperm(
array(pred$surv.test, dim = c(M, K, N), dimnames = list(NULL, times, NULL)),
c(3, 2, 1)
)

# create mean posterior survival matrix (N obs x K times)
# Mean posterior survival matrix (N obs x K times)
surv = matrix(pred$surv.test.mean, nrow = N, ncol = K, byrow = TRUE)

mlr3proba::.surv_return(times = pred$times, surv = surv)
# get crank as expected mortality using mean posterior
pred_list = mlr3proba::.surv_return(times = times, surv = surv)

# replace with the full survival posterior
pred_list$distr = surv.array

# return list with crank and distr
pred_list
}
)
)
Expand Down

0 comments on commit c8041b4

Please sign in to comment.