From 75f9fe08e50de372bf923dbd7653febe68236a5d Mon Sep 17 00:00:00 2001 From: Liming <36079400+clarkliming@users.noreply.github.com> Date: Tue, 15 Oct 2024 10:11:55 +0800 Subject: [PATCH] add randomization schemas (#44) * update programs * [skip style] [skip vbump] Restyle files --------- Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com> --- NAMESPACE | 2 +- R/predict_couterfactual.R | 20 ++++---- R/robin_glm.R | 8 +-- R/treatment_effect.R | 14 ++--- R/utils.R | 51 ++++++++++--------- R/variance_anhecova.R | 8 ++- man/{vcovANHECOVA.Rd => gvcov.Rd} | 8 ++- man/h_get_vars.Rd | 4 +- man/predict_counterfactual.Rd | 7 +-- man/robin_glm.Rd | 4 +- .../testthat/_snaps/predict_counterfactual.md | 6 +-- tests/testthat/_snaps/treatment_effect.md | 12 ++--- tests/testthat/_snaps/variance.md | 6 +-- tests/testthat/test-predict_counterfactual.R | 6 +-- tests/testthat/test-robin_glm.R | 4 +- tests/testthat/test-utils.R | 37 +++++++------- tests/testthat/test-variance.R | 6 +-- 17 files changed, 101 insertions(+), 102 deletions(-) rename man/{vcovANHECOVA.Rd => gvcov.Rd} (70%) diff --git a/NAMESPACE b/NAMESPACE index 25f6914..d875952 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -9,6 +9,7 @@ S3method(treatment_effect,lm) S3method(treatment_effect,prediction_cf) S3method(vcovHC,prediction_cf) export(bias) +export(gvcov) export(h_diff) export(h_jac_diff) export(h_jac_odds_ratio) @@ -18,7 +19,6 @@ export(h_ratio) export(predict_counterfactual) export(robin_glm) export(treatment_effect) -export(vcovANHECOVA) import(checkmate) importFrom(numDeriv,jacobian) importFrom(prediction,find_data) diff --git a/R/predict_couterfactual.R b/R/predict_couterfactual.R index 803d10d..16c5438 100644 --- a/R/predict_couterfactual.R +++ b/R/predict_couterfactual.R @@ -2,23 +2,21 @@ #' @description Obtain counterfactual prediction of a fit. #' #' @param fit fitted object. -#' @param treatment (`string` or `formula`) treatment variable in string, or a formula of form -#' treatment ~ strata(s). +#' @param treatment (`formula`) formula of form treatment ~ strata(s). #' @param data (`data.frame`) raw dataset. -#' @param unbiased (`flag`) indicator of whether to remove potential bias of the prediction. #' #' @return Numeric matrix of counter factual prediction. #' #' @export -predict_counterfactual <- function(fit, treatment, data, unbiased) { +predict_counterfactual <- function(fit, treatment, data) { UseMethod("predict_counterfactual") } #' @export -predict_counterfactual.lm <- function(fit, treatment, data = find_data(fit), unbiased = TRUE) { +predict_counterfactual.lm <- function(fit, treatment, data = find_data(fit)) { trt_vars <- h_get_vars(treatment) assert_data_frame(data) - assert_subset(unlist(trt_vars), colnames(data)) + assert_subset(c(trt_vars$treatment, trt_vars$strata), colnames(data)) formula <- formula(fit) assert_subset(trt_vars$treatment, all.vars(formula[[3]])) assert( @@ -26,7 +24,6 @@ predict_counterfactual.lm <- function(fit, treatment, data = find_data(fit), unb test_factor(data[[trt_vars$treatment]]) ) data[[trt_vars$treatment]] <- as.factor(data[[trt_vars$treatment]]) - assert_flag(unbiased) trt_lvls <- levels(data[[trt_vars$treatment]]) n_lvls <- length(trt_lvls) @@ -54,13 +51,16 @@ predict_counterfactual.lm <- function(fit, treatment, data = find_data(fit), unb } group_idx <- split(seq_len(nrow(data)), strata) - if (unbiased) { + if (identical(trt_vars$schema, "ps")) { ret <- ret - bias(residual, data[[trt_vars$treatment]], group_idx) + } else { + ret <- ret - bias(residual, data[[trt_vars$treatment]], list(seq_len(nrow(ret)))) } structure( .Data = colMeans(ret), residual = residual, predictions = ret, + schema = trt_vars$schema, predictions_linear = pred_linear, response = y, fit = fit, @@ -73,6 +73,6 @@ predict_counterfactual.lm <- function(fit, treatment, data = find_data(fit), unb } #' @export -predict_counterfactual.glm <- function(fit, treatment, data = find_data(fit), unbiased = TRUE) { - predict_counterfactual.lm(fit = fit, data = data, treatment = treatment, unbiased = unbiased) +predict_counterfactual.glm <- function(fit, treatment, data = find_data(fit)) { + predict_counterfactual.lm(fit = fit, data = data, treatment = treatment) } diff --git a/R/robin_glm.R b/R/robin_glm.R index cd4d24e..e604ba5 100644 --- a/R/robin_glm.R +++ b/R/robin_glm.R @@ -8,7 +8,7 @@ #' @param contrast_jac (`function`) A function to calculate the Jacobian of the contrast function. Ignored if using #' default contrasts. #' @param vcov (`function`) A function to calculate the variance-covariance matrix of the treatment effect, -#' including `vcovHC` and `vcovANHECOVA`. +#' including `vcovHC` and `gvcov`. #' @param family (`family`) A family object of the glm model. #' @param ... Additional arguments passed to `vcov`. For finer control of glm, refer to usage of `treatment_effect`, #' `difference`, `risk_ratio`, `odds_ratio`. @@ -21,14 +21,14 @@ #' ) robin_glm <- function( formula, data, treatment, contrast = "difference", - contrast_jac = NULL, vcov = vcovANHECOVA, family = gaussian, ...) { + contrast_jac = NULL, vcov = gvcov, family = gaussian, ...) { attr(formula, ".Environment") <- environment() fit <- glm(formula, family = family, data = data) - pc <- predict_counterfactual(fit, treatment, data, unbiased = TRUE) + pc <- predict_counterfactual(fit, treatment, data) has_interaction <- h_interaction(formula, treatment) if (has_interaction && identical(vcov, vcovHC) && !identical(contrast, "difference")) { stop( - "Huber-White standard error only works for difference contrasts in models without interaction term." + "Huber-White variance estimator is ONLY supported when the expected outcome difference is estimated using a linear model without treatment-covariate interactions; see the 2023 FDA guidance." ) } if (identical(contrast, "difference")) { diff --git a/R/treatment_effect.R b/R/treatment_effect.R index 8dd731b..0405dd2 100644 --- a/R/treatment_effect.R +++ b/R/treatment_effect.R @@ -15,7 +15,7 @@ treatment_effect <- function(object, pair, variance, eff_measure, eff_jacobian, #' @export treatment_effect.prediction_cf <- function( - object, pair = names(object), variance = vcovANHECOVA, eff_measure, eff_jacobian, ...) { + object, pair = names(object), variance = gvcov, eff_measure, eff_jacobian, ...) { assert( test_function(variance), test_null(variance) @@ -59,9 +59,9 @@ treatment_effect.prediction_cf <- function( #' @export #' @inheritParams predict_counterfactual treatment_effect.lm <- function( - object, pair = names(object), variance = vcovANHECOVA, eff_measure, eff_jacobian, - treatment, data = find_data(object), unbiased = TRUE, ...) { - pc <- predict_counterfactual(object, data = data, treatment, unbiased) + object, pair = names(object), variance = gvcov, eff_measure, eff_jacobian, + treatment, data = find_data(object), ...) { + pc <- predict_counterfactual(object, data = data, treatment) if (missing(pair)) { treatment_effect(pc, pair = , , variance = variance, eff_measure = eff_measure, eff_jacobian = eff_jacobian, ...) } else { @@ -71,9 +71,9 @@ treatment_effect.lm <- function( #' @export treatment_effect.glm <- function( - object, pair, variance = vcovANHECOVA, eff_measure, eff_jacobian, - treatment, data = find_data(object), unbiased = TRUE, ...) { - pc <- predict_counterfactual(object, treatment, data, unbiased) + object, pair, variance = gvcov, eff_measure, eff_jacobian, + treatment, data = find_data(object), ...) { + pc <- predict_counterfactual(object, treatment, data) if (missing(pair)) { treatment_effect(pc, pair = , , variance = variance, eff_measure = eff_measure, eff_jacobian = eff_jacobian, ...) } else { diff --git a/R/utils.R b/R/utils.R index 891c827..c2c1520 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1,33 +1,38 @@ +randomization_schema <- data.frame( + schema = c("Pocock-Simon", "permuted-block", "simple"), + id = c("ps", "pb", "sp"), + stringsAsFactors = FALSE +) + #' Extract Variable Names #' #' @param treatment (`string` or `formula`) string name of the treatment, or a formula. #' -#' @details Extract the formula elements, including `treatment` and `strata`. +#' @details Extract the formula elements, including `treatment`, `schema` and `strata`. #' -#' @return A list of two elements, `treatmetn` and `strata`. +#' @return A list of three elements, `treatment`, `schema` and `strata`. h_get_vars <- function(treatment) { - if (test_string(treatment)) { - ret <- list( - treatment = treatment, - strata = character(0) - ) - } else if (test_formula(treatment)) { - if (!identical(length(treatment), 3L)) { - stop("treatment formula must be of type treatment ~ strata") - } - if (!is.name(treatment[[2]])) { - stop("left hand side of the treatment formula should be a single name!") - } - treatvar <- as.character(treatment[[2]]) - strata <- setdiff(all.vars(treatment[[3]]), ".") - ret <- list( - treatment = treatvar, - strata = strata - ) - } else { - stop("treatment must be a length 1 character or a formula of form treatment ~ strata") + assert_formula(treatment) + if (!identical(length(treatment), 3L)) { + stop("treatment formula must be of type treatment ~ strata") + } + if (!is.name(treatment[[2]])) { + stop("left hand side of the treatment formula should be a single name!") + } + treatvar <- as.character(treatment[[2]]) + tms <- terms(treatment, specials = randomization_schema$id) + schema <- names(Filter(Negate(is.null), attr(tms, "specials"))) + if (length(schema) > 1) { + stop("only one randomization schema is allowed!") + } else if (length(schema) == 0) { + schema <- "sp" } - ret + strata <- setdiff(all.vars(treatment[[3]]), ".") + list( + treatment = treatvar, + schema = schema, + strata = strata + ) } block_sum <- function(x, n) { diff --git a/R/variance_anhecova.R b/R/variance_anhecova.R index 5e4ec09..99faa37 100644 --- a/R/variance_anhecova.R +++ b/R/variance_anhecova.R @@ -2,15 +2,13 @@ #' #' @param x (`prediction_cf`) Counter-factual prediction. #' @param decompose (`flag`) whether to use decompose method to calculate the variance. -#' @param randomization (`string`) randomization method. #' @param ... Not used. #' #' @return Named covariance matrix. #' @export -vcovANHECOVA <- function(x, decompose = TRUE, randomization = "simple", ...) { # nolint +gvcov <- function(x, decompose = TRUE, ...) { # nolint assert_class(x, "prediction_cf") assert_flag(decompose) - assert_string(randomization) resi <- attr(x, "residual") est <- as.numeric(x) preds <- attr(x, "predictions") @@ -31,7 +29,7 @@ vcovANHECOVA <- function(x, decompose = TRUE, randomization = "simple", ...) { # } v <- diag(vcov_sr) + cov_ymu + t(cov_ymu) - var_preds - v <- v - h_get_erb(resi, group_idx, trt, pi_t, randomization) + v <- v - h_get_erb(resi, group_idx, trt, pi_t, attr(x, "schema")) ret <- v / length(resi) dimnames(ret) <- list(trt_lvls, trt_lvls) return(ret) @@ -59,7 +57,7 @@ h_get_erb <- function(resi, group_idx, trt, pi, randomization) { return(0) } assert_string(randomization) - if (randomization %in% c("simple", "pocock-simon")) { + if (randomization %in% c("sp", "ps")) { return(0) } assert_numeric(resi) diff --git a/man/vcovANHECOVA.Rd b/man/gvcov.Rd similarity index 70% rename from man/vcovANHECOVA.Rd rename to man/gvcov.Rd index 37b859d..b0ed059 100644 --- a/man/vcovANHECOVA.Rd +++ b/man/gvcov.Rd @@ -1,18 +1,16 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/variance_anhecova.R -\name{vcovANHECOVA} -\alias{vcovANHECOVA} +\name{gvcov} +\alias{gvcov} \title{ANHECOVA Covariance} \usage{ -vcovANHECOVA(x, decompose = TRUE, randomization = "simple", ...) +gvcov(x, decompose = TRUE, ...) } \arguments{ \item{x}{(\code{prediction_cf}) Counter-factual prediction.} \item{decompose}{(\code{flag}) whether to use decompose method to calculate the variance.} -\item{randomization}{(\code{string}) randomization method.} - \item{...}{Not used.} } \value{ diff --git a/man/h_get_vars.Rd b/man/h_get_vars.Rd index 7a63c73..c00dcf1 100644 --- a/man/h_get_vars.Rd +++ b/man/h_get_vars.Rd @@ -10,11 +10,11 @@ h_get_vars(treatment) \item{treatment}{(\code{string} or \code{formula}) string name of the treatment, or a formula.} } \value{ -A list of two elements, \code{treatmetn} and \code{strata}. +A list of three elements, \code{treatment}, \code{schema} and \code{strata}. } \description{ Extract Variable Names } \details{ -Extract the formula elements, including \code{treatment} and \code{strata}. +Extract the formula elements, including \code{treatment}, \code{schema} and \code{strata}. } diff --git a/man/predict_counterfactual.Rd b/man/predict_counterfactual.Rd index 581f5b8..1076408 100644 --- a/man/predict_counterfactual.Rd +++ b/man/predict_counterfactual.Rd @@ -4,17 +4,14 @@ \alias{predict_counterfactual} \title{Counterfactual Prediction} \usage{ -predict_counterfactual(fit, treatment, data, unbiased) +predict_counterfactual(fit, treatment, data) } \arguments{ \item{fit}{fitted object.} -\item{treatment}{(\code{string} or \code{formula}) treatment variable in string, or a formula of form -treatment ~ strata(s).} +\item{treatment}{(\code{formula}) formula of form treatment ~ strata(s).} \item{data}{(\code{data.frame}) raw dataset.} - -\item{unbiased}{(\code{flag}) indicator of whether to remove potential bias of the prediction.} } \value{ Numeric matrix of counter factual prediction. diff --git a/man/robin_glm.Rd b/man/robin_glm.Rd index abc8c3b..01271b9 100644 --- a/man/robin_glm.Rd +++ b/man/robin_glm.Rd @@ -10,7 +10,7 @@ robin_glm( treatment, contrast = "difference", contrast_jac = NULL, - vcov = vcovANHECOVA, + vcov = gvcov, family = gaussian, ... ) @@ -30,7 +30,7 @@ or a string name of treatment assignment.} default contrasts.} \item{vcov}{(\code{function}) A function to calculate the variance-covariance matrix of the treatment effect, -including \code{vcovHC} and \code{vcovANHECOVA}.} +including \code{vcovHC} and \code{gvcov}.} \item{family}{(\code{family}) A family object of the glm model.} diff --git a/tests/testthat/_snaps/predict_counterfactual.md b/tests/testthat/_snaps/predict_counterfactual.md index ce7b71f..d5f2513 100644 --- a/tests/testthat/_snaps/predict_counterfactual.md +++ b/tests/testthat/_snaps/predict_counterfactual.md @@ -1,7 +1,7 @@ # predict_counterfactual works for guassian Code - predict_counterfactual(fit_glm, "treatment") + predict_counterfactual(fit_glm, treatment ~ 1) Output counter-factual prediction @@ -12,7 +12,7 @@ # predict_counterfactual works for guassian with lm Code - predict_counterfactual(fit_lm, "treatment", data = dummy_data) + predict_counterfactual(fit_lm, treatment ~ 1, data = dummy_data) Output counter-factual prediction @@ -23,7 +23,7 @@ # predict_counterfactual works for binomial Code - predict_counterfactual(fit_binom, "treatment") + predict_counterfactual(fit_binom, treatment ~ 1) Output counter-factual prediction diff --git a/tests/testthat/_snaps/treatment_effect.md b/tests/testthat/_snaps/treatment_effect.md index f73636b..4ce0541 100644 --- a/tests/testthat/_snaps/treatment_effect.md +++ b/tests/testthat/_snaps/treatment_effect.md @@ -9,9 +9,9 @@ Randomization: treatment ~ s1 Variance Type: variance Estimate Std.Err Z Value Pr(>z) - trt1 - pbo 0.2246 0.0477 4.71 1.3e-06 *** - trt2 - pbo 0.2653 0.0475 5.58 1.2e-08 *** - trt2 - trt1 0.0407 0.0479 0.85 0.2 + trt1 - pbo 0.2246 0.0477 4.71 2.5e-06 *** + trt2 - pbo 0.2653 0.0475 5.58 2.4e-08 *** + trt2 - trt1 0.0407 0.0479 0.85 0.4 --- Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 @@ -26,9 +26,9 @@ Randomization: treatment ~ s1 Variance Type: variance Estimate Std.Err Z Value Pr(>z) - trt1 - pbo 0.564 0.101 5.60 1.1e-08 *** - trt2 - pbo 0.771 0.101 7.61 1.4e-14 *** - trt2 - trt1 0.207 0.107 1.94 0.026 * + trt1 - pbo 0.564 0.101 5.60 2.2e-08 *** + trt2 - pbo 0.771 0.101 7.61 2.8e-14 *** + trt2 - trt1 0.207 0.107 1.94 0.052 . --- Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 diff --git a/tests/testthat/_snaps/variance.md b/tests/testthat/_snaps/variance.md index 5c5c8c3..c3eeba1 100644 --- a/tests/testthat/_snaps/variance.md +++ b/tests/testthat/_snaps/variance.md @@ -8,10 +8,10 @@ trt1 4.523445e-07 1.164889e-03 -7.709031e-07 trt2 -9.709004e-06 -7.709031e-07 1.170214e-03 -# vcovANHECOVA works +# gvcov works Code - vcovANHECOVA(pc) + gvcov(pc) Output pbo trt1 trt2 pbo 1.128902e-03 1.856234e-05 1.333885e-05 @@ -21,7 +21,7 @@ --- Code - vcovANHECOVA(pc, randomization = "permute_block") + gvcov(pc, randomization = "permute_block") Output pbo trt1 trt2 pbo 1.128902e-03 1.856234e-05 1.333885e-05 diff --git a/tests/testthat/test-predict_counterfactual.R b/tests/testthat/test-predict_counterfactual.R index 8812488..aa7d5e7 100644 --- a/tests/testthat/test-predict_counterfactual.R +++ b/tests/testthat/test-predict_counterfactual.R @@ -1,13 +1,13 @@ test_that("predict_counterfactual works for guassian", { - expect_snapshot(predict_counterfactual(fit_glm, "treatment")) + expect_snapshot(predict_counterfactual(fit_glm, treatment ~ 1)) }) test_that("predict_counterfactual works for guassian with lm", { - expect_snapshot(predict_counterfactual(fit_lm, "treatment", data = dummy_data)) + expect_snapshot(predict_counterfactual(fit_lm, treatment ~ 1, data = dummy_data)) }) test_that("predict_counterfactual works for binomial", { - expect_snapshot(predict_counterfactual(fit_binom, "treatment")) + expect_snapshot(predict_counterfactual(fit_binom, treatment ~ 1)) }) test_that("predict_counterfactual works if contrast are non-standard", { diff --git a/tests/testthat/test-robin_glm.R b/tests/testthat/test-robin_glm.R index 85b42d5..4a124b8 100644 --- a/tests/testthat/test-robin_glm.R +++ b/tests/testthat/test-robin_glm.R @@ -4,7 +4,7 @@ test_that("h_interaction works correctly", { expect_false(h_interaction(y ~ trt + z, treatment = trt ~ x)) expect_true(h_interaction(y ~ trt:z, treatment = trt ~ x)) expect_true(h_interaction(trt * y ~ trt:z, treatment = trt ~ x)) - expect_true(h_interaction(y ~ trt:z, treatment = "trt")) + expect_true(h_interaction(y ~ trt:z, treatment = trt ~ 1)) }) # robin_glm ---- @@ -26,7 +26,7 @@ test_that("robin_glm works correctly", { data = dummy_data, treatment = treatment ~ s1, contrast = "odds_ratio", vcov = vcovHC ), - "Huber-White standard error only works for difference contrasts in models without interaction term." + "Huber-White variance estimator is ONLY" ) expect_silent(robin_glm(y_b ~ treatment * s1, data = dummy_data, treatment = treatment ~ s1, contrast = h_diff)) }) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index ef50a98..92a29f5 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -1,34 +1,35 @@ -test_that("h_get_vars works for single character", { - res <- expect_silent(h_get_vars("abc")) - expect_identical(res, list(treatment = "abc", strata = character(0))) +test_that("h_get_vars works for formula", { + res <- expect_silent(h_get_vars(abc ~ 1)) + expect_identical(res, list(treatment = "abc", schema = "sp", strata = character(0))) expect_error( - h_get_vars(c("abc", "def")), - "length 1 character" + h_get_vars("treatment"), + "Must be a formula, not character" ) expect_error( h_get_vars(NULL), - "length 1 character" + "Must be a formula, not 'NULL'" ) -}) - -test_that("h_get_vars works for formula", { - res <- expect_silent(h_get_vars(a ~ b + c)) - expect_identical(res, list(treatment = "a", strata = c("b", "c"))) - - res <- expect_silent(h_get_vars(`~`(a, ))) - expect_identical(res, list(treatment = "a", strata = character(0))) - expect_error( - h_get_vars(log(a) ~ strata(b)), + h_get_vars(log(a) ~ b), "left hand side of the treatment formula should be a single name!" ) - expect_error( h_get_vars(a + b ~ strata(b)), "left hand side of the treatment formula should be a single name!" ) +}) + +test_that("h_get_vars works for formula with schemas", { + res <- expect_silent(h_get_vars(a ~ b + c)) + expect_identical(res, list(treatment = "a", schema = "sp", strata = c("b", "c"))) + + res <- expect_silent(h_get_vars(a ~ pb(1) + b)) + expect_identical(res, list(treatment = "a", schema = "pb", strata = "b")) + + res <- expect_silent(h_get_vars(a ~ ps(b) + c)) + expect_identical(res, list(treatment = "a", schema = "ps", strata = c("b", "c"))) res <- expect_silent(h_get_vars(a ~ strata(b))) - expect_identical(res, list(treatment = "a", strata = "b")) + expect_identical(res, list(treatment = "a", schema = "sp", strata = "b")) }) diff --git a/tests/testthat/test-variance.R b/tests/testthat/test-variance.R index 217f294..a6d5136 100644 --- a/tests/testthat/test-variance.R +++ b/tests/testthat/test-variance.R @@ -5,12 +5,12 @@ test_that("vcovHC works", { ) }) -test_that("vcovANHECOVA works", { +test_that("gvcov works", { pc <- predict_counterfactual(fit_binom, treatment ~ s1) expect_snapshot( - vcovANHECOVA(pc) + gvcov(pc) ) expect_snapshot( - vcovANHECOVA(pc, randomization = "permute_block") + gvcov(pc, randomization = "permute_block") ) })