Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Summarizing HTE Outputs in a Multi-Arm Experiment #1479

shafayetShafee opened this issue Jan 12, 2025 · 0 comments

Summarizing HTE Outputs in a Multi-Arm Experiment #1479

shafayetShafee opened this issue Jan 12, 2025 · 0 comments


Copy link

shafayetShafee commented Jan 12, 2025

Hello, I need some guidance/direction/suggestions on how can I use the estimated HTE outputs from the multi_arm_causal_forest to create insightful summary. After going through this paper, I can think of some approaches. But I am a bit confused, since these resources discussed about binary treatment only, whereas my usecase is “multi-arm treatment”.

Lets consider a reproducible example to discuss the approaches,




Helper fns

predict_effect_and_ci <- function(multi_arm_causal_forest_model, newdata = NULL) {

  if (!inherits(multi_arm_causal_forest_model, "multi_arm_causal_forest")) {
    stop('This function only supports model objects of class "multi_arm_causal_forest".')

  tau_hat <- predict(
    newdata = newdata,
    estimate.variance = TRUE,
    drop = TRUE

  effect_estimate_df <-$predictions)
  contrasts_name <- colnames(effect_estimate_df)
  contrast_generic_name <- paste0("contrast_", seq(1, length(contrasts_name)))
  contrast_info <- setNames(contrasts_name, contrast_generic_name)
  colnames(effect_estimate_df) <- paste0(contrast_generic_name, "_estimate")

  effect_estimate_var_df <-$variance.estimates)
  colnames(effect_estimate_var_df) <- paste0(contrast_generic_name, "_var")

  effect_est_df <- bind_cols(effect_estimate_df, effect_estimate_var_df)

    contrast_info = contrast_info,
    data = effect_est_df

get_top_n_vars <- function(forest, X, n = 3) {
  varimp <- grf::variable_importance(forest)
  ranked_variables <- order(varimp, decreasing = TRUE)
  top_varnames <- colnames(X)[ranked_variables[1:n]]
n <- 3000
p <- 10
X <- matrix(rnorm(n * p), n, p)
W <- as.factor(sample(c("A", "B", "C"), n, replace = TRUE))
Y <- X[, 1] + X[, 2] * (W == "B") - 1.5 * X[, 2] * (W == "C") + rnorm(n)

exp_df <- data.frame(Y = Y, W = W, X)

Splitting Data into Train-Test

train = sample(nrow(X), 0.6 * nrow(X))
test = -train

Fit Forest Model on Training Set

mc.forest <- multi_arm_causal_forest(X[train, ], Y[train], W[train], seed = 1344)

Predict HTEs on Test Set

tau_hat_est <- predict_effect_and_ci(mc.forest, newdata = X[test, ])
tau_hat_est_df <- bind_cols(tau_hat_est$data, exp_df[test, ]) %>% 
    c1_ci_low = contrast_1_estimate - 1.96 * sqrt(contrast_1_var),
    c1_ci_high = contrast_1_estimate + 1.96 * sqrt(contrast_1_var),
    c2_ci_low = contrast_2_estimate - 1.96 * sqrt(contrast_2_var),
    c2_ci_high = contrast_2_estimate + 1.96 * sqrt(contrast_2_var),

head(tau_hat_est_df, 3)
  contrast_1_estimate contrast_2_estimate contrast_1_var contrast_2_var
1         -0.06310611          -0.1661815     0.01636501     0.01322940
2         -0.56899703           0.9466801     0.02781945     0.05492243
3         -0.49811420           0.9974250     0.02565718     0.06470612
           Y W          X1         X2          X3         X4          X5
1  0.9849406 A  0.54756844 -0.1014569  0.21716754 -2.0556520 -0.04809347
2 -2.1574977 A  0.08431498 -0.6160837 -0.46033781 -0.1537932  0.08784540
3  1.0177696 A -0.50059754 -0.6376908 -0.08594392  0.4529726 -1.98854317
          X6          X7         X8         X9        X10  c1_ci_low c1_ci_high
1  1.8194223 -0.04598789 -0.3885001 0.45111597 -1.9751646 -0.3138407  0.1876284
2 -0.8248944 -1.42140442 -0.8348958 0.06918902  0.8410156 -0.8959086 -0.2420854
3 -0.5929628  0.08853166  0.1790741 0.92633845  0.8261464 -0.8120642 -0.1841642
   c2_ci_low c2_ci_high
1 -0.3916190 0.05925604
2  0.4873436 1.40601655
3  0.4988520 1.49599794

Creating HTE Quartile Groups

The tau_hat_est_df contains two HTE estimates, $\hat{\tau}{b-a}$ comparing treatment “B” with “A” and $\hat{\tau}{c-a}$ comparing treatment “C” with “A”. We can create quartile groups based on $\hat{\tau}_{b-a}$, at first.

num.groups = 4

quartile = cut(
  quantile(tau_hat_est_df$contrast_1_estimate, seq(0, 1, by = 1 / num.groups)),
  labels = 1:num.groups,
  include.lowest = TRUE
) = split(seq_along(quartile), quartile)

eval.forest = multi_arm_causal_forest(X[test, ], Y[test], W[test], seed = 1345) = lapply(, function(samples) {
  average_treatment_effect(eval.forest, subset = samples)

df.plot.ate = bind_rows(, .id = "group") %>% 
    group = paste0("Q", group)
  ) %>% 
  select(group, contrast, estimate, std.err)
rownames(df.plot.ate) <- NULL

head(df.plot.ate, 10)
  group contrast   estimate   std.err
1    Q1    B - A -1.1571475 0.1582629
2    Q1    C - A  2.0199825 0.1652687
3    Q2    B - A -0.3894375 0.1472359
4    Q2    C - A  0.5131891 0.1584275
5    Q3    B - A  0.3887451 0.1407779
6    Q3    C - A -0.4980314 0.1409159
7    Q4    B - A  1.2435839 0.1512048
8    Q4    C - A -1.9012597 0.1586695
tau_BA_ate <- df.plot.ate %>% 
  filter(contrast == "B - A")

tau_BA_ate %>% 
ggplot(aes(x = group, y = estimate)) +
  geom_hline(yintercept = 0, linetype = 2, linewidth = 0.5) +
      ymin = estimate - 1.96 * std.err, 
      ymax = estimate + 1.96 * std.err
    width = 0.09, color = "#4E79A7", linewidth = 0.7
  ) +
  geom_point(color = "#E15759", size = 3) +
  xlab("Estimated CATE Quartile") +
  ylab("Average treatment effect") + 
  theme_minimal() +
    plot.title = element_text(size = 12, face = "bold", lineheight = 1.1),
    axis.text = element_text(size = 11),
    axis.title.x = element_text(margin = margin(t = 10))


Note that, since I have created the quartile groups based on $\hat{\tau}{b-a}$, I have only used the ATE estimates (and its SE) for the “B - A” contrast and plotted them, ignoring the values for “C - A” contrast. But when $\hat{\tau}{c-a}$ will be used to create the quartile groups, only the ATE estimates for “C - A” contrast will be shown. At least, that what I am thinking. So my question is, Am I on the right track? Are there any better ways ?

Covariate Profiles for Quartile Groups

top_2_vars <- get_top_n_vars(
  exp_df %>% select(starts_with("X")), 
  n = 2

[1] "X2" "X5"
tau_hat_est_df %>% 
    Q = quartile,
    group = paste0("Q", Q)
  ) %>% 
  group_by(group) %>% 
    across(.cols = all_of(top_2_vars), .fns = mean, .names = "mean_{.col}")
  ) %>% 
  left_join(tau_BA_ate, by = "group")
# A tibble: 4 × 6
  group mean_X2 mean_X5 contrast estimate std.err
  <chr>   <dbl>   <dbl> <chr>       <dbl>   <dbl>
1 Q1     -1.28   0.124  B - A      -1.16    0.158
2 Q2     -0.301 -0.0852 B - A      -0.389   0.147
3 Q3      0.366 -0.138  B - A       0.389   0.141
4 Q4      1.30  -0.0303 B - A       1.24    0.151

Is the above summary representation valid? Are there any better ways?

Additional Questions

  1. Is it incorrect to average the $\hat{\tau}_{b-a}$ for each quartile, rather than fitting eval.forest to each quartile group separately to get the ATE estimates?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
None yet
None yet

No branches or pull requests

1 participant