Skip to content

Commit

Permalink
new function
Browse files Browse the repository at this point in the history
  • Loading branch information
gperrett committed May 24, 2024
1 parent 9c74cb5 commit ae24ca5
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 11 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# plotBart 0.1.27
- new function `table_moderator_c_bin()`

# plotBART 0.1.26
- new features for `plot_balance()`

Expand Down
98 changes: 98 additions & 0 deletions R/plot_moderators.R
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ plot_moderator_c_bin <- function(.model, moderator,type = c('density', 'histogra
summarise_all(mean) %>%
pivot_longer(cols = 2:ncol(posterior))



# plot it
p <- ggplot(posterior, aes(value, fill = moderator))

Expand Down Expand Up @@ -638,3 +640,99 @@ rpart_ggplot_ <- function(.model){

return(p)
}



#' @title Auto-Bin a table of a continuous moderating variable into a discrete moderating variable
#' @description Use a regression tree to optimally bin a continuous variable, this function will print out a table with estimates and 95% ci
#'
#' @param .model a model produced by `bartCause::bartc()`
#' @param moderator the moderator as a vector
#'@param .name sting representing the name of the moderating variable
#'
#' @author George Perrett
#'
#'
#' @return a data.frame object
#' @export
#'
#' @import dplyr
#' @importFrom bartCause extract
#' @importFrom rpart rpart
#'
#' @examples
#' \donttest{
#' data(lalonde)
#' confounders <- c('age', 'educ', 'black', 'hisp', 'married', 'nodegr')
#' model_results <- bartCause::bartc(
#' response = lalonde[['re78']],
#' treatment = lalonde[['treat']],
#' confounders = as.matrix(lalonde[, confounders]),
#' estimand = 'ate',
#' commonSuprule = 'none'
#' )
#' table_moderator_c_bin(model_results, lalonde$age, .name = 'age')
#' }
table_moderator_c_bin <- function(.model, moderator, .name = 'bin'){

validate_model_(.model)
is_numeric_vector_(moderator)
type <- type[1]

# adjust moderator to match estimand
moderator <- adjust_for_estimand_(.model, moderator)
estimand <- switch (.model$estimand,
ate = 'CATE',
att = 'CATT',
atc = 'CATC'
)

# extract the posterior
posterior <- bartCause::extract(.model, 'icate')

# get icate point est
icate.m <- apply(posterior, 2, mean)

# fit regression tree
tree <- rpart::rpart(icate.m ~ moderator)

# get bins from regression tree
bins <- dplyr::tibble(splits = tree$where,
x = moderator)

subgroups <- dplyr::tibble(splits = tree$where,
x = moderator) %>%
dplyr::group_by(splits) %>%
dplyr::summarise(min = min(x), max = max(x)) %>%
dplyr::arrange(min) %>%
dplyr::mutate(subgroup = paste0(.name,':', round(min, 2) ,'-', round(max, 2)))

bins <- bins %>% dplyr::left_join(subgroups)

# roatate posterior
posterior <- posterior %>%
t() %>%
as.data.frame() %>%
as_tibble() %>%
mutate(moderator = bins$subgroup)

# marginalize
posterior <- posterior %>%
group_by(moderator) %>%
summarise_all(mean) %>%
pivot_longer(cols = 2:ncol(posterior))



posterior %>%
dplyr::group_by(moderator) |>
dplyr::mutate(
est = mean (value),
sd = sd(value),
lci = quantile(value, prob = .025),
uci = quantile(value, prob = .975)) %>%
dplyr::select(est, sd, lci, uci) %>%
dplyr::distinct()

}

4 changes: 2 additions & 2 deletions R/plot_overlap_pScores.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ plot_overlap_pScores <- function(.data, treatment, confounders, plot_type = c("h
p <- ggplot() +
geom_hline(yintercept = 0, linetype = 'dashed', color = 'grey60') +
geom_histogram(data = filter(dat, Z == 1),
aes(x = pscores, y = ..count.., fill = Z),
aes(x = pscores, y = after_stat(count), fill = Z),
alpha = 0.8, color = 'black') +
geom_histogram(data = filter(dat, Z == 0),
aes(x = pscores, y = -..count.., fill = Z),
aes(x = pscores, y = -after_stat(count), fill = Z),
alpha = 0.8, color = 'black') +
scale_y_continuous(labels = function(lbl) abs(lbl)) +
scale_fill_manual(values = c(4,2)) +
Expand Down
8 changes: 4 additions & 4 deletions R/plot_overlap_vars.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,17 @@ plot_overlap_vars <- function(.data, treatment, confounders, plot_type = c("hist

if(is.numeric(dat_pivoted$value)){
p <- p + geom_histogram(data = filter(dat_pivoted, Z_treat == 1),
aes(x = value, y = ..count.., fill = Z_treat),
aes(x = value, y = after_stat(count), fill = Z_treat),
alpha = 0.8, color = 'black') +
geom_histogram(data = filter(dat_pivoted, Z_treat == 0),
aes(x = value, y = -..count.., fill = Z_treat),
aes(x = value, y = after_stat(count), fill = Z_treat),
alpha = 0.8, color = 'black')
}else{
p <- p + geom_bar(data = filter(dat_pivoted, Z_treat == 1),
aes(x = value, y = ..count.., fill = Z_treat),
aes(x = value, y = after_stat(count), fill = Z_treat),
alpha = 0.8, color = 'black') +
geom_bar(data = filter(dat_pivoted, Z_treat == 0),
aes(x = value, y = -..count.., fill = Z_treat),
aes(x = value, y = -after_stat(count), fill = Z_treat),
alpha = 0.8, color = 'black')
}
p <- p + scale_y_continuous(labels = function(lbl) abs(lbl)) +
Expand Down
10 changes: 5 additions & 5 deletions renv.lock
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,10 @@
},
"fastmap": {
"Package": "fastmap",
"Version": "1.1.0",
"Version": "1.2.0",
"Source": "Repository",
"Repository": "CRAN",
"Hash": "77bd60a6157420d4ffa93b27cf6a58b8",
"Hash": "aa5e1cd11c2d15497494c5292d7ffcc8",
"Requirements": []
},
"forcats": {
Expand Down Expand Up @@ -478,10 +478,10 @@
},
"htmltools": {
"Package": "htmltools",
"Version": "0.5.3",
"Version": "0.5.8.1",
"Source": "Repository",
"Repository": "CRAN",
"Hash": "6496090a9e00f8354b811d1a2d47b566",
"Repository": "RSPM",
"Hash": "81d371a9cc60640e74e4ab6ac46dcedc",
"Requirements": [
"base64enc",
"digest",
Expand Down

0 comments on commit ae24ca5

Please sign in to comment.