Skip to content

Commit

Permalink
add randomization schemas (#44)
Browse files Browse the repository at this point in the history
* update programs

* [skip style] [skip vbump] Restyle files

---------

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
clarkliming and github-actions[bot] authored Oct 15, 2024
1 parent defae34 commit 75f9fe0
Show file tree
Hide file tree
Showing 17 changed files with 101 additions and 102 deletions.
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ S3method(treatment_effect,lm)
S3method(treatment_effect,prediction_cf)
S3method(vcovHC,prediction_cf)
export(bias)
export(gvcov)
export(h_diff)
export(h_jac_diff)
export(h_jac_odds_ratio)
Expand All @@ -18,7 +19,6 @@ export(h_ratio)
export(predict_counterfactual)
export(robin_glm)
export(treatment_effect)
export(vcovANHECOVA)
import(checkmate)
importFrom(numDeriv,jacobian)
importFrom(prediction,find_data)
Expand Down
20 changes: 10 additions & 10 deletions R/predict_couterfactual.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,28 @@
#' @description Obtain counterfactual prediction of a fit.
#'
#' @param fit fitted object.
#' @param treatment (`string` or `formula`) treatment variable in string, or a formula of form
#' treatment ~ strata(s).
#' @param treatment (`formula`) formula of form treatment ~ strata(s).
#' @param data (`data.frame`) raw dataset.
#' @param unbiased (`flag`) indicator of whether to remove potential bias of the prediction.
#'
#' @return Numeric matrix of counter factual prediction.
#'
#' @export
predict_counterfactual <- function(fit, treatment, data, unbiased) {
predict_counterfactual <- function(fit, treatment, data) {
UseMethod("predict_counterfactual")
}

#' @export
predict_counterfactual.lm <- function(fit, treatment, data = find_data(fit), unbiased = TRUE) {
predict_counterfactual.lm <- function(fit, treatment, data = find_data(fit)) {
trt_vars <- h_get_vars(treatment)
assert_data_frame(data)
assert_subset(unlist(trt_vars), colnames(data))
assert_subset(c(trt_vars$treatment, trt_vars$strata), colnames(data))
formula <- formula(fit)
assert_subset(trt_vars$treatment, all.vars(formula[[3]]))
assert(
test_character(data[[trt_vars$treatment]]),
test_factor(data[[trt_vars$treatment]])
)
data[[trt_vars$treatment]] <- as.factor(data[[trt_vars$treatment]])
assert_flag(unbiased)

trt_lvls <- levels(data[[trt_vars$treatment]])
n_lvls <- length(trt_lvls)
Expand Down Expand Up @@ -54,13 +51,16 @@ predict_counterfactual.lm <- function(fit, treatment, data = find_data(fit), unb
}
group_idx <- split(seq_len(nrow(data)), strata)

if (unbiased) {
if (identical(trt_vars$schema, "ps")) {
ret <- ret - bias(residual, data[[trt_vars$treatment]], group_idx)
} else {
ret <- ret - bias(residual, data[[trt_vars$treatment]], list(seq_len(nrow(ret))))
}
structure(
.Data = colMeans(ret),
residual = residual,
predictions = ret,
schema = trt_vars$schema,
predictions_linear = pred_linear,
response = y,
fit = fit,
Expand All @@ -73,6 +73,6 @@ predict_counterfactual.lm <- function(fit, treatment, data = find_data(fit), unb
}

#' @export
predict_counterfactual.glm <- function(fit, treatment, data = find_data(fit), unbiased = TRUE) {
predict_counterfactual.lm(fit = fit, data = data, treatment = treatment, unbiased = unbiased)
predict_counterfactual.glm <- function(fit, treatment, data = find_data(fit)) {
predict_counterfactual.lm(fit = fit, data = data, treatment = treatment)
}
8 changes: 4 additions & 4 deletions R/robin_glm.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#' @param contrast_jac (`function`) A function to calculate the Jacobian of the contrast function. Ignored if using
#' default contrasts.
#' @param vcov (`function`) A function to calculate the variance-covariance matrix of the treatment effect,
#' including `vcovHC` and `vcovANHECOVA`.
#' including `vcovHC` and `gvcov`.
#' @param family (`family`) A family object of the glm model.
#' @param ... Additional arguments passed to `vcov`. For finer control of glm, refer to usage of `treatment_effect`,
#' `difference`, `risk_ratio`, `odds_ratio`.
Expand All @@ -21,14 +21,14 @@
#' )
robin_glm <- function(
formula, data, treatment, contrast = "difference",
contrast_jac = NULL, vcov = vcovANHECOVA, family = gaussian, ...) {
contrast_jac = NULL, vcov = gvcov, family = gaussian, ...) {
attr(formula, ".Environment") <- environment()
fit <- glm(formula, family = family, data = data)
pc <- predict_counterfactual(fit, treatment, data, unbiased = TRUE)
pc <- predict_counterfactual(fit, treatment, data)
has_interaction <- h_interaction(formula, treatment)
if (has_interaction && identical(vcov, vcovHC) && !identical(contrast, "difference")) {
stop(
"Huber-White standard error only works for difference contrasts in models without interaction term."
"Huber-White variance estimator is ONLY supported when the expected outcome difference is estimated using a linear model without treatment-covariate interactions; see the 2023 FDA guidance."
)
}
if (identical(contrast, "difference")) {
Expand Down
14 changes: 7 additions & 7 deletions R/treatment_effect.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ treatment_effect <- function(object, pair, variance, eff_measure, eff_jacobian,

#' @export
treatment_effect.prediction_cf <- function(
object, pair = names(object), variance = vcovANHECOVA, eff_measure, eff_jacobian, ...) {
object, pair = names(object), variance = gvcov, eff_measure, eff_jacobian, ...) {
assert(
test_function(variance),
test_null(variance)
Expand Down Expand Up @@ -59,9 +59,9 @@ treatment_effect.prediction_cf <- function(
#' @export
#' @inheritParams predict_counterfactual
treatment_effect.lm <- function(
object, pair = names(object), variance = vcovANHECOVA, eff_measure, eff_jacobian,
treatment, data = find_data(object), unbiased = TRUE, ...) {
pc <- predict_counterfactual(object, data = data, treatment, unbiased)
object, pair = names(object), variance = gvcov, eff_measure, eff_jacobian,
treatment, data = find_data(object), ...) {
pc <- predict_counterfactual(object, data = data, treatment)
if (missing(pair)) {
treatment_effect(pc, pair = , , variance = variance, eff_measure = eff_measure, eff_jacobian = eff_jacobian, ...)
} else {
Expand All @@ -71,9 +71,9 @@ treatment_effect.lm <- function(

#' @export
treatment_effect.glm <- function(
object, pair, variance = vcovANHECOVA, eff_measure, eff_jacobian,
treatment, data = find_data(object), unbiased = TRUE, ...) {
pc <- predict_counterfactual(object, treatment, data, unbiased)
object, pair, variance = gvcov, eff_measure, eff_jacobian,
treatment, data = find_data(object), ...) {
pc <- predict_counterfactual(object, treatment, data)
if (missing(pair)) {
treatment_effect(pc, pair = , , variance = variance, eff_measure = eff_measure, eff_jacobian = eff_jacobian, ...)
} else {
Expand Down
51 changes: 28 additions & 23 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -1,33 +1,38 @@
randomization_schema <- data.frame(
schema = c("Pocock-Simon", "permuted-block", "simple"),
id = c("ps", "pb", "sp"),
stringsAsFactors = FALSE
)

#' Extract Variable Names
#'
#' @param treatment (`string` or `formula`) string name of the treatment, or a formula.
#'
#' @details Extract the formula elements, including `treatment` and `strata`.
#' @details Extract the formula elements, including `treatment`, `schema` and `strata`.
#'
#' @return A list of two elements, `treatmetn` and `strata`.
#' @return A list of three elements, `treatment`, `schema` and `strata`.
h_get_vars <- function(treatment) {
if (test_string(treatment)) {
ret <- list(
treatment = treatment,
strata = character(0)
)
} else if (test_formula(treatment)) {
if (!identical(length(treatment), 3L)) {
stop("treatment formula must be of type treatment ~ strata")
}
if (!is.name(treatment[[2]])) {
stop("left hand side of the treatment formula should be a single name!")
}
treatvar <- as.character(treatment[[2]])
strata <- setdiff(all.vars(treatment[[3]]), ".")
ret <- list(
treatment = treatvar,
strata = strata
)
} else {
stop("treatment must be a length 1 character or a formula of form treatment ~ strata")
assert_formula(treatment)
if (!identical(length(treatment), 3L)) {
stop("treatment formula must be of type treatment ~ strata")
}
if (!is.name(treatment[[2]])) {
stop("left hand side of the treatment formula should be a single name!")
}
treatvar <- as.character(treatment[[2]])
tms <- terms(treatment, specials = randomization_schema$id)
schema <- names(Filter(Negate(is.null), attr(tms, "specials")))
if (length(schema) > 1) {
stop("only one randomization schema is allowed!")
} else if (length(schema) == 0) {
schema <- "sp"
}
ret
strata <- setdiff(all.vars(treatment[[3]]), ".")
list(
treatment = treatvar,
schema = schema,
strata = strata
)
}

block_sum <- function(x, n) {
Expand Down
8 changes: 3 additions & 5 deletions R/variance_anhecova.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
#'
#' @param x (`prediction_cf`) Counter-factual prediction.
#' @param decompose (`flag`) whether to use decompose method to calculate the variance.
#' @param randomization (`string`) randomization method.
#' @param ... Not used.
#'
#' @return Named covariance matrix.
#' @export
vcovANHECOVA <- function(x, decompose = TRUE, randomization = "simple", ...) { # nolint
gvcov <- function(x, decompose = TRUE, ...) { # nolint
assert_class(x, "prediction_cf")
assert_flag(decompose)
assert_string(randomization)
resi <- attr(x, "residual")
est <- as.numeric(x)
preds <- attr(x, "predictions")
Expand All @@ -31,7 +29,7 @@ vcovANHECOVA <- function(x, decompose = TRUE, randomization = "simple", ...) { #
}

v <- diag(vcov_sr) + cov_ymu + t(cov_ymu) - var_preds
v <- v - h_get_erb(resi, group_idx, trt, pi_t, randomization)
v <- v - h_get_erb(resi, group_idx, trt, pi_t, attr(x, "schema"))
ret <- v / length(resi)
dimnames(ret) <- list(trt_lvls, trt_lvls)
return(ret)
Expand Down Expand Up @@ -59,7 +57,7 @@ h_get_erb <- function(resi, group_idx, trt, pi, randomization) {
return(0)
}
assert_string(randomization)
if (randomization %in% c("simple", "pocock-simon")) {
if (randomization %in% c("sp", "ps")) {
return(0)
}
assert_numeric(resi)
Expand Down
8 changes: 3 additions & 5 deletions man/vcovANHECOVA.Rd → man/gvcov.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/h_get_vars.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 2 additions & 5 deletions man/predict_counterfactual.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/robin_glm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions tests/testthat/_snaps/predict_counterfactual.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# predict_counterfactual works for guassian

Code
predict_counterfactual(fit_glm, "treatment")
predict_counterfactual(fit_glm, treatment ~ 1)
Output
counter-factual prediction
Expand All @@ -12,7 +12,7 @@
# predict_counterfactual works for guassian with lm

Code
predict_counterfactual(fit_lm, "treatment", data = dummy_data)
predict_counterfactual(fit_lm, treatment ~ 1, data = dummy_data)
Output
counter-factual prediction
Expand All @@ -23,7 +23,7 @@
# predict_counterfactual works for binomial

Code
predict_counterfactual(fit_binom, "treatment")
predict_counterfactual(fit_binom, treatment ~ 1)
Output
counter-factual prediction
Expand Down
12 changes: 6 additions & 6 deletions tests/testthat/_snaps/treatment_effect.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
Randomization: treatment ~ s1
Variance Type: variance
Estimate Std.Err Z Value Pr(>z)
trt1 - pbo 0.2246 0.0477 4.71 1.3e-06 ***
trt2 - pbo 0.2653 0.0475 5.58 1.2e-08 ***
trt2 - trt1 0.0407 0.0479 0.85 0.2
trt1 - pbo 0.2246 0.0477 4.71 2.5e-06 ***
trt2 - pbo 0.2653 0.0475 5.58 2.4e-08 ***
trt2 - trt1 0.0407 0.0479 0.85 0.4
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Expand All @@ -26,9 +26,9 @@
Randomization: treatment ~ s1
Variance Type: variance
Estimate Std.Err Z Value Pr(>z)
trt1 - pbo 0.564 0.101 5.60 1.1e-08 ***
trt2 - pbo 0.771 0.101 7.61 1.4e-14 ***
trt2 - trt1 0.207 0.107 1.94 0.026 *
trt1 - pbo 0.564 0.101 5.60 2.2e-08 ***
trt2 - pbo 0.771 0.101 7.61 2.8e-14 ***
trt2 - trt1 0.207 0.107 1.94 0.052 .
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

6 changes: 3 additions & 3 deletions tests/testthat/_snaps/variance.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
trt1 4.523445e-07 1.164889e-03 -7.709031e-07
trt2 -9.709004e-06 -7.709031e-07 1.170214e-03

# vcovANHECOVA works
# gvcov works

Code
vcovANHECOVA(pc)
gvcov(pc)
Output
pbo trt1 trt2
pbo 1.128902e-03 1.856234e-05 1.333885e-05
Expand All @@ -21,7 +21,7 @@
---

Code
vcovANHECOVA(pc, randomization = "permute_block")
gvcov(pc, randomization = "permute_block")
Output
pbo trt1 trt2
pbo 1.128902e-03 1.856234e-05 1.333885e-05
Expand Down
Loading

0 comments on commit 75f9fe0

Please sign in to comment.