-
Notifications
You must be signed in to change notification settings - Fork 2
/
bcf.qmd
298 lines (244 loc) · 8.89 KB
/
bcf.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
---
title: "Bayesian Causal Forest (BCF)"
share:
permalink: "https://book.martinez.fyi/bcf.html"
description: "Business Data Science: What Does it Mean to Be Data-Driven?"
linkedin: true
email: true
mastodon: true
---
## Introduction to Bayesian Causal Forest
While BART has proven to be a powerful tool for causal inference, it has some
limitations when applied to heterogeneous treatment effect estimation. To
address these shortcomings, @hahn2020bayesian introduced the
Bayesian Causal Forest (BCF) model. BCF builds upon BART's foundation but
incorporates key modifications that make it particularly well-suited for
estimating heterogeneous treatment effects.
## The BCF Model
The BCF model can be expressed as:
$$
Y_i = \mu(x_i, \hat{\pi}(x_i)) + \tau(x_i)z_i + \epsilon_i, \quad \epsilon_i \sim N(0, \sigma^2)
$$
where:
- $Y_i$ is the outcome for individual $i$
- $x_i$ are the covariates
- $z_i$ is the treatment indicator
- $\hat{\pi}(x_i)$ is an estimate of the propensity score
- $\mu(\cdot)$ is the prognostic function
- $\tau(\cdot)$ is the treatment effect function
The key innovation of BCF lies in its separation of the prognostic function
$\mu(\cdot)$ and the treatment effect function $\tau(\cdot)$. Both functions are
modeled using BART, but with different priors that reflect their distinct roles.
## Key Features of BCF
1. **Separation of Prognostic and Treatment Effects**: By modeling $\mu(\cdot)$
and $\tau(\cdot)$ separately, BCF allows for different levels of
regularization for each component. This is particularly useful when the
treatment effect is expected to be simpler or more homogeneous than the
overall prognostic effect.
2. **Inclusion of Propensity Score**: The inclusion of $\hat{\pi}(x_i)$ in the
prognostic function helps to mitigate issues related to
regularization-induced confounding, which can occur when strong confounding
is present.
3. **Targeted Regularization**: BCF employs a prior on the treatment effect
function that encourages shrinkage towards homogeneous effects. This can
lead to more stable and accurate estimates, especially when the true
treatment effect heterogeneity is modest.
4. **Handling of Targeted Selection**: BCF is designed to perform well in
scenarios where treatment assignment is based on expected outcomes under
control, a phenomenon referred to as "targeted selection".
5. **Improved Treatment Effect Estimation**: BCF often yields more accurate and
stable estimates of conditional average treatment effects (CATEs),
especially in scenarios with strong confounding.
## Examples
Let's consider the following data generating process from @hahn2020bayesian, which is also covered in [one of the vignettes](https://stochastictree.github.io/stochtree-r/articles/CausalInference.html#demo-1-nonlinear-outcome-model-heterogeneous-treatment-effect) from @stochastictree.
$$
\begin{aligned}
y &= \mu(X) + \tau(X) Z + \epsilon\\
\epsilon &\sim N\left(0,\sigma^2\right)\\
\mu(X) &= 1 + g(X) + 6 \lvert X_3 - 1 \rvert\\
\tau(X) &= 1 + 2 X_2 X_4\\
g(X) &= \mathbb{I}(X_5=1) \times 2 - \mathbb{I}(X_5=2) \times 1 - \mathbb{I}(X_5=3) \times 4\\
s_{\mu} &= \sqrt{\mathbb{V}(\mu(X))}\\
\pi(X) &= 0.8 \phi\left(\frac{3\mu(X)}{s_{\mu}}\right) - \frac{X_1}{2} + \frac{2U+1}{20}\\
X_1,X_2,X_3 &\sim N\left(0,1\right)\\
X_4 &\sim \text{Bernoulli}(1/2)\\
X_5 &\sim \text{Categorical}(1/3,1/3,1/3)\\
U &\sim \text{Uniform}\left(0,1\right)\\
Z &\sim \text{Bernoulli}\left(\pi(X)\right)
\end{aligned}
$$
Let's generate data from this DGP and fit a BCF model using the {stochtree} package:
```{r data}
#| label: data
#| warning: false
#| message: false
set.seed(1982)
# Define the functions based on the provided model
g <- function(x5) {
return(ifelse(x5 == 1, 2, ifelse(x5 == 2, -1, ifelse(x5 == 3, -4, 0))))
}
mu <- function(X) {
return(1 + g(X[, 5]) + 6 * abs(X[, 3] - 1))
}
tau <- function(X) {
return(1 + 2 * X[, 2] * X[, 4])
}
# Set parameters
n <- 500
snr <- 3
# Generate covariates
x1 <- rnorm(n)
x2 <- rnorm(n)
x3 <- rnorm(n)
x4 <- as.numeric(rbinom(n, 1, 0.5))
x5 <- as.numeric(sample(1:3, n, replace = TRUE))
X <- cbind(x1, x2, x3, x4, x5)
colnames(X) <- paste0("X", 1:5)
# Calculate mu(X) and tau(X)
mu_x <- mu(X)
tau_x <- tau(X)
# Calculate s_mu
s_mu <- sd(mu_x)
# Calculate pi(X)
pi_x <- 0.8 * pnorm((3 * mu_x / s_mu) - 0.5 * X[, 1]) + 0.05 + runif(n) / 10
# Generate treatment assignment Z
Z <- rbinom(n, 1, pi_x)
# Generate the outcome y
E_XZ <- mu_x + Z * tau_x
y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr)
# Convert X to data frame and factorize categorical variables
X <- as.data.frame(X)
X$X4 <- factor(X$X4, ordered = TRUE)
X$X5 <- factor(X$X5, ordered = TRUE)
# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct * n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- setdiff(1:n, test_inds)
X_test <- X[test_inds, ]
X_train <- X[train_inds, ]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]
```
To sample using {stochtree} we can run the following code:
```{r fit1}
library(stochtree)
num_gfr <- 10
num_burnin <- 500
num_mcmc <- 1500
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_warmstart <- bcf(
X_train = X_train,
Z_train = Z_train,
y_train = y_train,
pi_train = pi_train,
X_test = X_test,
Z_test = Z_test,
pi_test = pi_test,
num_gfr = num_gfr,
num_burnin = num_burnin,
num_mcmc = num_mcmc,
sample_sigma_leaf_mu = F,
sample_sigma_leaf_tau = F
)
```
After fitting the model, it's crucial to assess convergence. One way to do this is by examining the traceplot for $\sigma^2$:
```{r traceplot}
library(ggplot2)
df <- tibble::tibble(
sample = 1:length(bcf_model_warmstart$sigma2_samples),
sigma2_samples = bcf_model_warmstart$sigma2_samples
)
# Create the plot
ggplot(df, aes(x = sample, y = sigma2_samples)) +
geom_line() +
labs(
x = "Sample",
y = expression(sigma^2),
title = "Global Variance Parameter"
) +
theme_minimal()
```
The traceplot shows no obvious trends, suggesting that the MCMC chain has likely converged.
Next, we can evaluate the model's performance in predicting the prognostic function:
```{r prognostic_plot}
df <- tibble::tibble(
predicted = rowMeans(bcf_model_warmstart$mu_hat_test),
actual = mu_test
)
# Create the plot
ggplot(df, aes(x = predicted, y = actual)) +
geom_point() +
geom_abline(
slope = 1,
intercept = 0,
color = "red",
linetype = "dashed",
linewidth = 1
) +
labs(x = "Predicted",
y = "Actual",
title = "Prognostic Function") +
theme_minimal()
```
The plot shows a strong correlation between predicted and actual values of the
prognostic function, indicating that the BCF model has captured the nonlinear
relationships in the data well.
Given that we know the true data generating process, we can assess how well the
model estimates the true treatment effects:
```{r effect_plot}
df <- tibble::tibble(
predicted = rowMeans(bcf_model_warmstart$tau_hat_test),
actual = tau_test
)
# Calculate the limits for the axes
limits <- range(c(df$predicted, df$actual))
# Create the plot
ggplot(df, aes(x = predicted, y = actual)) +
geom_point() +
geom_abline(slope = 1, intercept = 0, color = "red", linetype = "dashed", linewidth = 1) +
labs(
x = "Predicted",
y = "Actual",
title = "Treatment Effect"
) +
coord_fixed(ratio = 1, xlim = limits, ylim = limits) +
theme_minimal()
```
Finally, let's check our coverage:
```{r coverage}
#| label: coverage
test_lb <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.975)
cover <- (
(test_lb <= tau_x[test_inds]) &
(test_ub >= tau_x[test_inds])
)
cat("95% Credible Interval Coverage Rate:", round(mean(cover) * 100, 2), "%\n")
```
## Conclusion
Bayesian Causal Forest represents a significant advancement in the application
of tree-based methods to causal inference. By building upon the strengths of
BART and addressing some of its limitations, BCF offers a powerful tool for
estimating heterogeneous treatment effects. Its ability to handle strong
confounding and targeted selection, coupled with its interpretability, makes it
a valuable addition to the causal inference toolkit.
However, like all methods, BCF has its limitations. It assumes that all relevant
confounders are observed, and its performance can degrade in scenarios with
limited overlap between treated and control units. As always in causal
inference, careful consideration of the problem at hand and the assumptions of
the method is crucial.
::: {.callout-tip}
## Learn more
@hahn2020bayesian Bayesian regression tree models for causal inference:
Regularization, confounding, and heterogeneous effects (with discussion).
:::