-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add helper for bridging causal fits #955
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
Package: parsnip | ||
Title: A Common API to Modeling and Analysis Functions | ||
Version: 1.1.0.9000 | ||
Version: 1.1.0.9001 | ||
Authors@R: c( | ||
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")), | ||
person("Davis", "Vaughan", , "[email protected]", role = "aut"), | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,90 @@ | ||||||||||||||||||
#' Helper for bridging two-stage causal fits | ||||||||||||||||||
#' | ||||||||||||||||||
#' @description | ||||||||||||||||||
#' `weight_propensity()` is a helper function to more easily link the | ||||||||||||||||||
#' propensity and outcome models in causal workflows. **The main documentation | ||||||||||||||||||
#' for this function lives in the tune package at** `?tune::weight_propensity`. | ||||||||||||||||||
#' | ||||||||||||||||||
#' @param object The object containing the model fit(s) that will generate | ||||||||||||||||||
#' predictions used to calculate propensity weights. Currently, either a | ||||||||||||||||||
#' [parsnip model fit][parsnip::fit.model_spec()], fitted | ||||||||||||||||||
#' [workflow][workflows::workflow()], or | ||||||||||||||||||
#' tuning results (`?tune::fit_resamples`) object. If a tuning result, the | ||||||||||||||||||
#' object must have been generated with the control argument | ||||||||||||||||||
#' (`?tune::control_resamples`) `extract = identity`. | ||||||||||||||||||
#' @param wt_fn A function used to calculate the propensity weights. The first | ||||||||||||||||||
#' argument gives the predicted probability of exposure, the true value for | ||||||||||||||||||
#' which is provided in the second argument. See `?propensity::wt_ate()` for | ||||||||||||||||||
#' an example. | ||||||||||||||||||
Comment on lines
+15
to
+18
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't find that second sentence the easiest to read with the "the true value for which". Is the following correct?
Suggested change
|
||||||||||||||||||
#' @param .treated The level of the exposure corresponding to the treatment, as | ||||||||||||||||||
#' a string. Additionally passed as `.treated` to `wt_fn`. | ||||||||||||||||||
#' @param ... Additional arguments passed to `wt_fn`. | ||||||||||||||||||
#' @param data The data supplied as the `data` argument to `fit()` the `object`. | ||||||||||||||||||
#' This argument is only required for the `model_fit` and `workflow` methods---the | ||||||||||||||||||
#' needed data for the `tune_results` method lives inside of `object`. | ||||||||||||||||||
#' | ||||||||||||||||||
#' @return | ||||||||||||||||||
#' For `model_fit` and fitted `workflow` input, a modified version of the data | ||||||||||||||||||
#' set supplied in `data` that contains a `.wts` column with class | ||||||||||||||||||
#' `importance_weights`. For `tune_results` input, a modified version of the | ||||||||||||||||||
#' resampling object underlying the tuning results containing a new `.wts` column | ||||||||||||||||||
#' with propensity values corresponding to each element of the analysis set. | ||||||||||||||||||
#' | ||||||||||||||||||
#' @references Barrett M & D'Agostino McGowan L (forthcoming). | ||||||||||||||||||
#' _Causal Inference in R_. \url{https://www.r-causal.org/} | ||||||||||||||||||
#' @name weight_propensity.model_fit | ||||||||||||||||||
NULL | ||||||||||||||||||
|
||||||||||||||||||
#' @rdname weight_propensity.model_fit | ||||||||||||||||||
#' @export | ||||||||||||||||||
weight_propensity <- function(object, wt_fn, ...) { | ||||||||||||||||||
UseMethod("weight_propensity") | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
#' @rdname weight_propensity.model_fit | ||||||||||||||||||
#' @method weight_propensity default | ||||||||||||||||||
#' @export | ||||||||||||||||||
weight_propensity.default <- function(object, wt_fn, ...) { | ||||||||||||||||||
abort("No known `weight_propensity()` method for this type of object.") | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
#' @noRd | ||||||||||||||||||
#' @method weight_propensity model_spec | ||||||||||||||||||
#' @export | ||||||||||||||||||
weight_propensity.model_spec <- function(object, wt_fn, ...) { | ||||||||||||||||||
abort(c( | ||||||||||||||||||
"`weight_propensity()` is not well-defined for a model specification.", | ||||||||||||||||||
"i" = "Supply `object` to `fit()` before generating propensity weights." | ||||||||||||||||||
)) | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
#' @rdname weight_propensity.model_fit | ||||||||||||||||||
#' @method weight_propensity model_fit | ||||||||||||||||||
#' @export | ||||||||||||||||||
weight_propensity.model_fit <- function(object, | ||||||||||||||||||
wt_fn, | ||||||||||||||||||
.treated = object$lvl[2], | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to pass I've tried out a few different interfaces to this argument and don't feel strongly on how we can best handle this. We could alternatively add an Note that the current form of that argument is not checked / tested, pending a decision on how we want it to feel. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's worth discussing with Lucy and Malcolm There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @LucyMcGowan and @malcolmbarrett: We're currently working on a set of PRs to better accommodate causal workflows in tidymodels via a helper, fit_resamples(
propensity_workflow,
resamples = bootstraps(data),
control = control_resample(extract = identity)
) %>%
weight_propensity(wt_ate, ...) %>%
fit_resamples(
outcome_workflow,
resamples = .
) where the second argument to There's surely lots to digest here, but do you have opinions on how we should open up the interface to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is all awesome! A few comments:
|
||||||||||||||||||
..., | ||||||||||||||||||
data) { | ||||||||||||||||||
if (rlang::is_missing(wt_fn) || !is.function(wt_fn)) { | ||||||||||||||||||
abort("`wt_fn` must be a function.") | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
if (rlang::is_missing(data) || !is.data.frame(data)) { | ||||||||||||||||||
abort("`data` must be the data supplied as the data argument to `fit()`.") | ||||||||||||||||||
} | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only checks this PR makes on inputted data is that it's a data frame. This gives a window for folks to supply different data to |
||||||||||||||||||
|
||||||||||||||||||
# TODO: I'm not sure we have a way to identify `y` via a model | ||||||||||||||||||
# spec fitted with `fit_xy()`---this will error in that case. | ||||||||||||||||||
Comment on lines
+77
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When we started out with the "censored regression" mode, we required models to be fit via the formula interface, i.e., In that spirit, you could add an error here to point people towards |
||||||||||||||||||
outcome_name <- object$preproc$y_var | ||||||||||||||||||
|
||||||||||||||||||
preds <- predict(object, data, type = "prob") | ||||||||||||||||||
preds <- preds[[paste0(".pred_", .treated)]] | ||||||||||||||||||
|
||||||||||||||||||
data$.wts <- | ||||||||||||||||||
hardhat::importance_weights( | ||||||||||||||||||
wt_fn(preds, data[[outcome_name]], .treated = .treated, ...) | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
data | ||||||||||||||||||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# errors informatively with bad input | ||
|
||
Code | ||
weight_propensity(spec, silly_wt_fn, data = two_class_dat) | ||
Condition | ||
Error in `weight_propensity()`: | ||
! `weight_propensity()` is not well-defined for a model specification. | ||
i Supply `object` to `fit()` before generating propensity weights. | ||
|
||
--- | ||
|
||
Code | ||
weight_propensity("boop", silly_wt_fn, data = two_class_dat) | ||
Condition | ||
Error in `weight_propensity()`: | ||
! No known `weight_propensity()` method for this type of object. | ||
|
||
--- | ||
|
||
Code | ||
weight_propensity(spec_fit, two_class_dat) | ||
Condition | ||
Error in `weight_propensity()`: | ||
! `wt_fn` must be a function. | ||
|
||
--- | ||
|
||
Code | ||
weight_propensity(spec_fit, "boop", data = two_class_dat) | ||
Condition | ||
Error in `weight_propensity()`: | ||
! `wt_fn` must be a function. | ||
|
||
--- | ||
|
||
Code | ||
weight_propensity(spec_fit, function(...) { | ||
-1L | ||
}, data = two_class_dat) | ||
Condition | ||
Error in `hardhat::importance_weights()`: | ||
! `x` can't contain negative weights. | ||
|
||
--- | ||
|
||
Code | ||
weight_propensity(spec_fit, silly_wt_fn) | ||
Condition | ||
Error in `weight_propensity()`: | ||
! `data` must be the data supplied as the data argument to `fit()`. | ||
|
||
--- | ||
|
||
Code | ||
weight_propensity(spec_fit, silly_wt_fn, data = "boop") | ||
Condition | ||
Error in `weight_propensity()`: | ||
! `data` must be the data supplied as the data argument to `fit()`. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
test_that("basic functionality", { | ||
skip_if_not_installed("modeldata") | ||
library(modeldata) | ||
library(parsnip) | ||
|
||
silly_wt_fn <- function(.propensity, .exposure = NULL, ...) { | ||
seq(1, 2, length.out = length(.propensity)) | ||
} | ||
|
||
lr_fit <- fit(logistic_reg(), Class ~ A + B, two_class_dat) | ||
|
||
lr_res1 <- weight_propensity(lr_fit, silly_wt_fn, data = two_class_dat) | ||
expect_s3_class(lr_res1, "tbl_df") | ||
expect_true(all(names(lr_res1) %in% c(names(two_class_dat), ".wts"))) | ||
expect_equal(lr_res1$.wts, importance_weights(seq(1, 2, length.out = nrow(two_class_dat)))) | ||
}) | ||
|
||
test_that("errors informatively with bad input", { | ||
skip_if_not_installed("modeldata") | ||
library(modeldata) | ||
library(parsnip) | ||
|
||
silly_wt_fn <- function(.propensity, .exposure = NULL, ...) { | ||
seq(1, 2, length.out = length(.propensity)) | ||
} | ||
|
||
# bad `object` | ||
spec <- logistic_reg() | ||
|
||
expect_snapshot( | ||
error = TRUE, | ||
weight_propensity(spec, silly_wt_fn, data = two_class_dat) | ||
) | ||
|
||
expect_snapshot( | ||
error = TRUE, | ||
weight_propensity("boop", silly_wt_fn, data = two_class_dat) | ||
) | ||
|
||
# bad `wt_fn` | ||
spec_fit <- fit(spec, Class ~ A + B, data = two_class_dat) | ||
|
||
expect_snapshot( | ||
error = TRUE, | ||
weight_propensity(spec_fit, two_class_dat) | ||
) | ||
|
||
expect_snapshot( | ||
error = TRUE, | ||
weight_propensity(spec_fit, "boop", data = two_class_dat) | ||
) | ||
|
||
expect_snapshot( | ||
error = TRUE, | ||
weight_propensity(spec_fit, function(...) {-1L}, data = two_class_dat) | ||
) | ||
|
||
# bad `data` | ||
expect_snapshot( | ||
error = TRUE, | ||
weight_propensity(spec_fit, silly_wt_fn) | ||
) | ||
|
||
expect_snapshot( | ||
error = TRUE, | ||
weight_propensity(spec_fit, silly_wt_fn, data = "boop") | ||
) | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option would be to instead require
save_pred = TRUE
, but we couldn't make use ofweight_propensity.workflow
in that case. This approach is a bit more DRY.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels like a little bit of a rough edge to me. I'm not sure we need to sand over it right now in terms of the interface but I would add more documentation, especially on the "main" doc page in tune, which currently only mentions this in an example. What about adding a sentence to the Details section, explaining why this needs to be set like this? I think that would help people remember.