diff --git a/DESCRIPTION b/DESCRIPTION index f7d1b90..7448268 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -22,7 +22,6 @@ Imports: checkmate, numDeriv, MASS, - prediction, sandwich, stats Suggests: diff --git a/NAMESPACE b/NAMESPACE index a82c494..c7697c3 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,5 +1,7 @@ # Generated by roxygen2: do not edit by hand +S3method(find_data,glm) +S3method(find_data,lm) S3method(predict_counterfactual,glm) S3method(predict_counterfactual,lm) S3method(print,prediction_cf) @@ -9,6 +11,7 @@ S3method(treatment_effect,lm) S3method(treatment_effect,prediction_cf) S3method(vcovHC,prediction_cf) export(bias) +export(find_data) export(gvcov) export(h_diff) export(h_jac_diff) @@ -22,7 +25,6 @@ export(treatment_effect) import(checkmate) importFrom(MASS,negative.binomial) importFrom(numDeriv,jacobian) -importFrom(prediction,find_data) importFrom(sandwich,vcovHC) importFrom(stats,as.formula) importFrom(stats,coefficients) diff --git a/R/RobinCar2-package.R b/R/RobinCar2-package.R index 533192f..c820c27 100644 --- a/R/RobinCar2-package.R +++ b/R/RobinCar2-package.R @@ -11,5 +11,4 @@ #' gaussian terms glm var family pnorm var as.formula #' @importFrom sandwich vcovHC #' @importFrom MASS negative.binomial -#' @importFrom prediction find_data NULL diff --git a/R/find_data.R b/R/find_data.R new file mode 100644 index 0000000..5a32d12 --- /dev/null +++ b/R/find_data.R @@ -0,0 +1,15 @@ +#' Find Data in a Fit +#' @export +#' @param fit A fit object. +#' @param ... Additional arguments. +find_data <- function(fit, ...) { + UseMethod("find_data") +} +#' @export +find_data.glm <- function(fit, ...) { + fit$data +} +#' @export +find_data.lm <- function(fit, ...) { + stop("data must be provided explicitly for lm objects") +} diff --git a/R/robin_glm.R b/R/robin_glm.R index 15c7dc4..4601d7c 100644 --- a/R/robin_glm.R +++ b/R/robin_glm.R @@ -27,6 +27,8 @@ robin_glm <- function( attr(formula, ".Environment") <- environment() # check if using negative.binomial family with NA as theta. # If so, use MASS::glm.nb instead of glm. + assert_subset(all.vars(formula), names(data)) + assert_subset(all.vars(treatment), names(data)) if (identical(family$family, "Negative Binomial(NA)")) { fit <- MASS::glm.nb(formula, data = data, ...) } else { diff --git a/R/treatment_effect.R b/R/treatment_effect.R index b04ad0e..3bd35a3 100644 --- a/R/treatment_effect.R +++ b/R/treatment_effect.R @@ -53,6 +53,7 @@ treatment_effect.prediction_cf <- function( } trt_var <- trt_jac %*% inner_variance %*% t(trt_jac) } else { + inner_variance <- NULL trt_var <- diag(NULL) } @@ -63,6 +64,7 @@ treatment_effect.prediction_cf <- function( marginal_mean = object, fit = attr(object, "fit"), vartype = variance_name, + mmvariance = inner_variance, treatment = attr(object, "treatment_formula"), variance = diag(trt_var), class = "treatment_effect" @@ -183,7 +185,12 @@ print.treatment_effect <- function(x, ...) { cat("Randomization: ", deparse(attr(x, "treatment")), "\n") cat("Marginal Mean: \n") print(attr(x, "marginal_mean")) - + if (!identical(attr(x, "vartype"), "none")) { + v <- attr(x, "mmvariance") + cat("Marginal Mean Variance: \n") + print(sqrt(diag(v))) + cat("\n\n") + } cat("Variance Type: ", attr(x, "vartype"), "\n") if (identical(attr(x, "vartype"), "none")) { trt_sd <- rep(NA, length(x)) diff --git a/man/find_data.Rd b/man/find_data.Rd new file mode 100644 index 0000000..02a27d3 --- /dev/null +++ b/man/find_data.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/find_data.R +\name{find_data} +\alias{find_data} +\title{Find Data in a Fit} +\usage{ +find_data(fit, ...) +} +\arguments{ +\item{fit}{A fit object.} + +\item{...}{Additional arguments.} +} +\description{ +Find Data in a Fit +} diff --git a/tests/testthat/_snaps/treatment_effect.md b/tests/testthat/_snaps/treatment_effect.md index b27e133..e91e990 100644 --- a/tests/testthat/_snaps/treatment_effect.md +++ b/tests/testthat/_snaps/treatment_effect.md @@ -13,6 +13,11 @@ pbo trt1 trt2 0.3560965 0.5806957 0.6213865 + Marginal Mean Variance: + pbo trt1 trt2 + 0.03359913 0.03441801 0.03401864 + + Variance Type: gvcov Estimate Std.Err Z Value Pr(>|z|) trt1 - pbo 0.2246 0.0477 4.71 2.5e-06 *** @@ -24,7 +29,8 @@ --- Code - treatment_effect(fit_lm, treatment = treatment ~ s1, eff_measure = h_diff) + treatment_effect(fit_lm, treatment = treatment ~ s1, eff_measure = h_diff, + data = dummy_data) Output Treatment Effect ------------- @@ -36,6 +42,11 @@ pbo trt1 trt2 0.2003208 0.7639709 0.9712499 + Marginal Mean Variance: + pbo trt1 trt2 + 0.06768998 0.07592944 0.07654319 + + Variance Type: gvcov Estimate Std.Err Z Value Pr(>|z|) trt1 - pbo 0.564 0.101 5.60 2.2e-08 *** @@ -82,6 +93,11 @@ pbo trt1 trt2 0.3560965 0.5806957 0.6213865 + Marginal Mean Variance: + pbo trt1 + 0.03359913 0.03441801 + + Variance Type: gvcov Estimate Std.Err Z Value Pr(>|z|) trt1 - pbo 0.2246 0.0477 4.71 2.5e-06 *** diff --git a/tests/testthat/test-find_data.R b/tests/testthat/test-find_data.R new file mode 100644 index 0000000..28d76b5 --- /dev/null +++ b/tests/testthat/test-find_data.R @@ -0,0 +1,13 @@ +test_that("find_data works for glm", { + expect_identical( + find_data(fit_glm), + fit_glm$data + ) +}) + +test_that("find_data fails for lm", { + expect_error( + find_data(fit_lm), + "data must be provided explicitly for lm objects" + ) +}) diff --git a/tests/testthat/test-treatment_effect.R b/tests/testthat/test-treatment_effect.R index 8a545e1..e804bae 100644 --- a/tests/testthat/test-treatment_effect.R +++ b/tests/testthat/test-treatment_effect.R @@ -147,7 +147,7 @@ test_that("treatment_effect works as expected for custom contrast", { test_that("treatment_effect works for lm/glm object", { expect_snapshot(treatment_effect(fit_binom, treatment = treatment ~ s1, eff_measure = h_diff)) - expect_snapshot(treatment_effect(fit_lm, treatment = treatment ~ s1, eff_measure = h_diff)) + expect_snapshot(treatment_effect(fit_lm, treatment = treatment ~ s1, eff_measure = h_diff, data = dummy_data)) }) test_that("treatment_effect works if variance is not used", {