Skip to content

Commit

Permalink
add print functions (#36)
Browse files Browse the repository at this point in the history
* add print functions

* [skip style] [skip vbump] Restyle files

* Empty

* update docs

* update namespace

* fix check issues

* [skip style] [skip vbump] Restyle files

---------

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
clarkliming and github-actions[bot] authored Aug 19, 2024
1 parent a55a208 commit 374a3f9
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 14 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
S3method(predict_counterfactual,glm)
S3method(predict_counterfactual,lm)
S3method(print,prediction_cf)
S3method(print,treatment_effect)
S3method(treatment_effect,glm)
S3method(treatment_effect,lm)
S3method(treatment_effect,prediction_cf)
Expand All @@ -29,6 +30,7 @@ importFrom(stats,gaussian)
importFrom(stats,glm)
importFrom(stats,model.matrix)
importFrom(stats,model.response)
importFrom(stats,pnorm)
importFrom(stats,predict)
importFrom(stats,residuals)
importFrom(stats,terms)
Expand Down
2 changes: 1 addition & 1 deletion R/RobinCar2-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#' @import checkmate
#' @importFrom numDeriv jacobian
#' @importFrom stats predict residuals fitted model.response model.matrix coefficients family
#' gaussian terms glm var
#' gaussian terms glm var family pnorm var
#' @importFrom sandwich vcovHC
#' @importFrom prediction find_data
NULL
23 changes: 12 additions & 11 deletions R/predict_couterfactual.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ predict_counterfactual <- function(fit, treatment, data, unbiased) {

#' @export
predict_counterfactual.lm <- function(fit, treatment, data = find_data(fit), unbiased = TRUE) {
treatment <- h_get_vars(treatment)
trt_vars <- h_get_vars(treatment)
assert_data_frame(data)
assert_subset(unlist(treatment), colnames(data))
assert_subset(unlist(trt_vars), colnames(data))
formula <- formula(fit)
assert_subset(treatment$treatment, all.vars(formula[[3]]))
assert_subset(trt_vars$treatment, all.vars(formula[[3]]))
assert(
test_character(data[[treatment$treatment]]),
test_factor(data[[treatment$treatment]])
test_character(data[[trt_vars$treatment]]),
test_factor(data[[trt_vars$treatment]])
)
data[[treatment$treatment]] <- as.factor(data[[treatment$treatment]])
data[[trt_vars$treatment]] <- as.factor(data[[trt_vars$treatment]])
assert_flag(unbiased)

trt_lvls <- levels(data[[treatment$treatment]])
trt_lvls <- levels(data[[trt_vars$treatment]])
n_lvls <- length(trt_lvls)

df <- lapply(
Expand All @@ -38,7 +38,7 @@ predict_counterfactual.lm <- function(fit, treatment, data = find_data(fit), unb
}
)

df[[treatment$treatment]] <- rep(trt_lvls, each = nrow(data))
df[[trt_vars$treatment]] <- rep(trt_lvls, each = nrow(data))

mm <- model.matrix(fit, data = df)
pred_linear <- mm %*% coefficients(fit)
Expand All @@ -48,14 +48,14 @@ predict_counterfactual.lm <- function(fit, treatment, data = find_data(fit), unb
y <- model.response(fit$model)
residual <- y - fitted(fit)

strata <- data[treatment$strata]
strata <- data[trt_vars$strata]
if (ncol(strata) == 0L) {
strata <- integer(nrow(strata))
}
group_idx <- split(seq_len(nrow(data)), strata)

if (unbiased) {
ret <- ret - bias(residual, data[[treatment$treatment]], group_idx)
ret <- ret - bias(residual, data[[trt_vars$treatment]], group_idx)
}
structure(
.Data = colMeans(ret),
Expand All @@ -65,7 +65,8 @@ predict_counterfactual.lm <- function(fit, treatment, data = find_data(fit), unb
response = y,
fit = fit,
model_matrix = mm,
treatment = data[[treatment$treatment]],
treatment_formula = treatment,
treatment = data[[trt_vars$treatment]],
group_idx = group_idx,
class = "prediction_cf"
)
Expand Down
6 changes: 5 additions & 1 deletion R/robin_glm.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
#' `difference`, `risk_ratio`, `odds_ratio`.
#' @export
#' @examples
#' robin_glm(y ~ treatment * s1, data = dummy_data, treatment = treatment ~ s1, contrast = "difference")
#' robin_glm(
#' y ~ treatment * s1,
#' data = dummy_data,
#' treatment = treatment ~ s1, contrast = "difference"
#' )
robin_glm <- function(
formula, data, treatment, contrast = "difference",
contrast_jac = NULL, vcov = vcovANHECOVA, family = gaussian, ...) {
Expand Down
31 changes: 31 additions & 0 deletions R/treatment_effect.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ treatment_effect.prediction_cf <- function(
structure(
.Data = trt_effect,
name = pair_names[lower.tri(pair_names)],
fit = attr(object, "fit"),
vartype = deparse(substitute(variance)),
treatment = attr(object, "treatment_formula"),
variance = diag(trt_var),
class = "treatment_effect"
)
Expand Down Expand Up @@ -165,3 +168,31 @@ h_lower_tri_idx <- function(n) {
rc <- c(n, n)
which(.row(rc) > .col(rc), arr.ind = TRUE)
}

#' @export
print.treatment_effect <- function(x, ...) {
cat("Treatment Effect\n")
cat("-------------\n")
cat("Model : ", deparse(attr(x, "fit")$formula), "\n")
cat("Randomization: ", deparse(attr(x, "treatment")), "\n")
cat("Variance Type: ", attr(x, "vartype"), "\n")
trt_sd <- sqrt(attr(x, "variance"))
z_value <- x / trt_sd
p <- pnorm(abs(z_value), lower.tail = FALSE)
coef_mat <- matrix(
c(
x,
trt_sd,
z_value,
p
),
nrow = length(x)
)
colnames(coef_mat) <- c("Estimate", "Std.Err", "Z Value", "Pr(>z)")
row.names(coef_mat) <- attr(x, "name")
stats::printCoefmat(
coef_mat,
zap.ind = 3,
digits = 3
)
}
5 changes: 4 additions & 1 deletion man/robin_glm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 374a3f9

Please sign in to comment.