Skip to content

Commit

Permalink
add print functions
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkliming committed Aug 13, 2024
1 parent 419c99d commit b0c4d25
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 11 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
S3method(predict_counterfactual,glm)
S3method(predict_counterfactual,lm)
S3method(print,prediction_cf)
S3method(print,treatment_effect)
S3method(treatment_effect,glm)
S3method(treatment_effect,lm)
S3method(treatment_effect,prediction_cf)
Expand Down
23 changes: 12 additions & 11 deletions R/predict_couterfactual.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ predict_counterfactual <- function(fit, treatment, data, unbiased) {

#' @export
predict_counterfactual.lm <- function(fit, treatment, data, unbiased = TRUE) {
treatment <- h_get_vars(treatment)
trt_vars <- h_get_vars(treatment)
assert_data_frame(data)
assert_subset(unlist(treatment), colnames(data))
assert_subset(unlist(trt_vars), colnames(data))
formula <- formula(fit)
assert_subset(treatment$treatment, all.vars(formula[[3]]))
assert_subset(trt_vars$treatment, all.vars(formula[[3]]))
assert(
test_character(data[[treatment$treatment]]),
test_factor(data[[treatment$treatment]])
test_character(data[[trt_vars$treatment]]),
test_factor(data[[trt_vars$treatment]])
)
data[[treatment$treatment]] <- as.factor(data[[treatment$treatment]])
data[[trt_vars$treatment]] <- as.factor(data[[trt_vars$treatment]])
assert_flag(unbiased)

trt_lvls <- levels(data[[treatment$treatment]])
trt_lvls <- levels(data[[trt_vars$treatment]])
n_lvls <- length(trt_lvls)

df <- lapply(
Expand All @@ -38,7 +38,7 @@ predict_counterfactual.lm <- function(fit, treatment, data, unbiased = TRUE) {
}
)

df[[treatment$treatment]] <- rep(trt_lvls, each = nrow(data))
df[[trt_vars$treatment]] <- rep(trt_lvls, each = nrow(data))

mm <- model.matrix(fit, data = df)
pred_linear <- mm %*% coefficients(fit)
Expand All @@ -48,14 +48,14 @@ predict_counterfactual.lm <- function(fit, treatment, data, unbiased = TRUE) {
y <- model.response(fit$model)
residual <- y - fitted(fit)

strata <- data[treatment$strata]
strata <- data[trt_vars$strata]
if (ncol(strata) == 0L) {
strata <- integer(nrow(strata))
}
group_idx <- split(seq_len(nrow(data)), strata)

if (unbiased) {
ret <- ret - bias(residual, data[[treatment$treatment]], group_idx)
ret <- ret - bias(residual, data[[trt_vars$treatment]], group_idx)
}
structure(
.Data = colMeans(ret),
Expand All @@ -65,7 +65,8 @@ predict_counterfactual.lm <- function(fit, treatment, data, unbiased = TRUE) {
response = y,
fit = fit,
model_matrix = mm,
treatment = data[[treatment$treatment]],
treatment_formula = treatment,
treatment = data[[trt_vars$treatment]],
group_idx = group_idx,
class = "prediction_cf"
)
Expand Down
31 changes: 31 additions & 0 deletions R/treatment_effect.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ treatment_effect.prediction_cf <- function(
structure(
.Data = trt_effect,
name = pair_names[lower.tri(pair_names)],
fit = attr(object, "fit"),
vartype = deparse(substitute(variance)),
treatment = attr(object, "treatment_formula"),
variance = diag(trt_var),
class = "treatment_effect"
)
Expand Down Expand Up @@ -165,3 +168,31 @@ h_lower_tri_idx <- function(n) {
rc <- c(n, n)
which(.row(rc) > .col(rc), arr.ind = TRUE)
}

#' @export
print.treatment_effect <- function(obj, digits = 3, ...) {
cat("Treatment Effect\n")
cat("-------------\n")
cat("Model : ", deparse(attr(obj,"fit")$formula), "\n")

Check warning on line 176 in R/treatment_effect.R

View workflow job for this annotation

GitHub Actions / SuperLinter 🦸‍♀️ / Lint R code 🧶

file=R/treatment_effect.R,line=176,col=43,[commas_linter] Commas should always have a space after.
cat("Randomization: ", deparse(attr(obj, "treatment")), "\n")
cat("Variance Type: ", attr(obj, "vartype"), "\n")
trt_sd <- sqrt(attr(obj, "variance"))
z_value <- obj / trt_sd
p <- pnorm(abs(z_value), lower.tail = FALSE)
coef_mat <- matrix(
c(
obj,
trt_sd,
z_value,
p
),
nrow = length(obj)
)
colnames(coef_mat) <- c("Estimate", "Std.Err", "Z Value", "Pr(>z)")
row.names(coef_mat) <- attr(obj, "name")
stats::printCoefmat(
coef_mat,
zap.ind = 3,
digits = digits
)
}

Check warning on line 198 in R/treatment_effect.R

View workflow job for this annotation

GitHub Actions / SuperLinter 🦸‍♀️ / Lint R code 🧶

file=R/treatment_effect.R,line=198,col=2,[trailing_blank_lines_linter] Missing terminal newline.

0 comments on commit b0c4d25

Please sign in to comment.