-
Notifications
You must be signed in to change notification settings - Fork 0
/
Weibull AFT model in Stan.Rmd
156 lines (118 loc) · 3.73 KB
/
Weibull AFT model in Stan.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
---
title: "Weibull AFT Model in Stan"
author: "Arman Oganisian"
date: "3/9/2019"
output: html_document
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
library(rstan)
library(survival)
```
```{r simulate_data, }
set.seed(1)
n <- 1000
A <- rbinom(n, 1, .5)
X <- model.matrix(~ A)
true_beta <- (1/2)*matrix(c(-1/3, 2), ncol=1)
true_mu <- X %*% true_beta
true_sigma <- 1
true_alpha <- 1/true_sigma
true_lambda <- exp(-1*true_mu*true_alpha)
# simulate censoring and survival times
survt = rweibull(n, shape=true_alpha, scale = true_lambda)
cent = rweibull(n, shape=true_alpha, scale = true_lambda)
## observed data:
#censoring indicator
delta <- cent < survt
survt[delta==1] <- cent[delta==1] # censor survival time.
# count number of missing/censored survival times
n_miss <- sum(delta)
d_list <- list(N_m = n_miss, N_o = n - n_miss, P=2, # number of betas
# data for censored subjects
y_m=survt[delta==1], X_m=X[delta==1,],
# data for uncensored subjects
y_o=survt[delta==0], X_o=X[delta==0,])
```
```{stan specify_stan_mod, output.var="weibull_mod"}
data {
int<lower=0> P; // number of beta parameters
// data for censored subjects
int<lower=0> N_m;
matrix[N_m,P] X_m;
vector[N_m] y_m;
// data for observed subjects
int<lower=0> N_o;
matrix[N_o,P] X_o;
real y_o[N_o];
}
parameters {
vector[P] beta;
real<lower=0> alpha; // Weibull Shape
}
transformed parameters{
// model Weibull rate as function of covariates
vector[N_m] lambda_m;
vector[N_o] lambda_o;
// standard weibull AFT re-parameterization
lambda_m = exp((X_m*beta)*alpha);
lambda_o = exp((X_o*beta)*alpha);
}
model {
beta ~ normal(0, 100);
alpha ~ exponential(1);
// evaluate likelihood for censored and uncensored subjects
target += weibull_lpdf(y_o | alpha, lambda_o);
target += weibull_lccdf(y_m | alpha, lambda_m);
}
// generate posterior quantities of interest
generated quantities{
vector[1000] post_pred_trt;
vector[1000] post_pred_pbo;
real lambda_trt;
real lambda_pbo;
real hazard_ratio;
// generate hazard ratio
lambda_trt = exp((beta[1] + beta[2])*alpha ) ;
lambda_pbo = exp((beta[1])*alpha ) ;
hazard_ratio = exp(beta[2]*alpha ) ;
// generate survival times (for plotting survival curves)
for(i in 1:1000){
post_pred_trt[i] = weibull_rng(alpha, lambda_trt);
post_pred_pbo[i] = weibull_rng(alpha, lambda_pbo);
}
}
```
```{r run_stan_mod, }
weibull_fit <- sampling(weibull_mod,
data = d_list,
chains = 1, iter=6000, warmup=1000,
pars= c('hazard_ratio','post_pred_trt','post_pred_pbo'))
post_draws<-extract(weibull_fit)
```
```{r plot_hazard_ratio, }
hist(post_draws$hazard_ratio,
xlab='Hazard Ratio', main='Hazard Ratio Posterior Distribution')
abline(v=.367, col='red')
mean(post_draws$hazard_ratio)
quantile(post_draws$hazard_ratio, probs = c(.025, .975))
```
```{r plot_survival,}
library(survival)
plot(survfit(Surv(survt, 1-delta) ~ A ), col=c('black','blue'),
xlab='Time',ylab='Survival Probability')
for(i in 1:5000){
trt_ecdf <- ecdf(post_draws$post_pred_trt[i,])
curve(1 - trt_ecdf(x), from = 0, to=4, add=T, col='gray')
pbo_ecdf <- ecdf(post_draws$post_pred_pbo[i,])
curve(1 - pbo_ecdf(x), from = 0, to=4, add=T, col='lightblue')
}
lines(survfit(Surv(survt, 1-delta) ~ A ), col=c('black','blue'), add=T )
legend('topright',
legend = c('KM Curve and Intervals (TRT)',
'Posterior Survival Draws (TRT)',
'KM Curve and Intervals (PBO)',
'Posterior Survival Draws (PBO)'),
col=c('black','gray','blue','lightblue'),
lty=c(1,0,1,0), pch=c(NA,15,NA,15), bty='n')
```