Skip to content

Commit

Permalink
Issue #507: Add weighting support to marginal model class (#509)
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs authored Jan 28, 2025
1 parent dd738bc commit e1a8fa0
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 29 deletions.
5 changes: 2 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- id: deps-in-desc
args: [--allow_private_imports]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-added-large-files
args: ['--maxkb=200']
Expand All @@ -28,13 +28,12 @@ repos:
entry: Cannot commit .Rhistory, .RData, .Rds or .rds.
language: fail
files: '\.Rhistory|\.RData|\.Rds|\.rds$'
exclude: '^inst/extdata/(fit|fit_gamma)\.rds$'
- repo: meta
hooks:
- id: check-hooks-apply
- id: check-useless-excludes
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-yaml
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Development version of `epidist`.
## Models

- Added a marginalised likelihood model based on `primarycensored`. This can be specified using `as_epidist_marginal_model()`. This is currently limited to Weibull, log-normal, and gamma distributions with uniform primary censoring but this will be generalised in future releases. See #426.
- Added a `weight` argument to `as_epidist_marginal_model()` to allow for weighted data (for example count data) to be used in the marginal model. See #509.
- Added user settable primary event priors to the latent model. See #474.
- Added a marginalised likelihood to the latent model. See #474.

Expand Down
17 changes: 14 additions & 3 deletions R/marginal_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@ as_epidist_marginal_model <- function(data, ...) {
#' `obs_time_threshold` times the maximum delay will be set to Inf to improve
#' model efficiency by reducing the number of unique observation times.
#' Default is 2.
#' @param weight A column name to use for weighting the data in the
#' likelihood. Default is NULL. Internally this is used to define the 'n'
#' column of the returned object.
#' @param ... Not used in this method.
#' @method as_epidist_marginal_model epidist_linelist_data
#' @family marginal_model
#' @autoglobal
#' @export
as_epidist_marginal_model.epidist_linelist_data <- function(
data, obs_time_threshold = 2, ...) {
data, obs_time_threshold = 2, weight = NULL, ...) {
assert_epidist(data)

data <- data |>
Expand All @@ -32,10 +35,18 @@ as_epidist_marginal_model.epidist_linelist_data <- function(
relative_obs_time = .data$obs_time - .data$ptime_lwr,
orig_relative_obs_time = .data$obs_time - .data$ptime_lwr,
delay_lwr = .data$stime_lwr - .data$ptime_lwr,
delay_upr = .data$stime_upr - .data$ptime_lwr,
n = 1
delay_upr = .data$stime_upr - .data$ptime_lwr
)

if (!is.null(weight)) {
assert_names(names(data), must.include = weight)
data <- data |>
mutate(n = .data[[weight]])
} else {
data <- data |>
mutate(n = 1)
}

# Calculate maximum delay
max_delay <- max(data$delay_upr, na.rm = TRUE)
threshold <- max_delay * obs_time_threshold
Expand Down
21 changes: 10 additions & 11 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -157,34 +157,33 @@ All contributions to this project are gratefully acknowledged using the [`allcon
### Code


<a href="https://github.com/epinowcast/epidist/commits?author=seabbs">seabbs</a>,
<a href="https://github.com/epinowcast/epidist/commits?author=athowes">athowes</a>,
<a href="https://github.com/epinowcast/epidist/commits?author=parksw3">parksw3</a>,
<a href="https://github.com/epinowcast/epidist/commits?author=damonbayer">damonbayer</a>,
<a href="https://github.com/epinowcast/epidist/commits?author=seabbs">seabbs</a>,
<a href="https://github.com/epinowcast/epidist/commits?author=athowes">athowes</a>,
<a href="https://github.com/epinowcast/epidist/commits?author=parksw3">parksw3</a>,
<a href="https://github.com/epinowcast/epidist/commits?author=damonbayer">damonbayer</a>,
<a href="https://github.com/epinowcast/epidist/commits?author=medewitt">medewitt</a>



### Issue Authors


<a href="https://github.com/epinowcast/epidist/issues?q=is%3Aissue+author%3Akgostic">kgostic</a>,
<a href="https://github.com/epinowcast/epidist/issues?q=is%3Aissue+author%3ATimTaylor">TimTaylor</a>,
<a href="https://github.com/epinowcast/epidist/issues?q=is%3Aissue+author%3Akgostic">kgostic</a>,
<a href="https://github.com/epinowcast/epidist/issues?q=is%3Aissue+author%3ATimTaylor">TimTaylor</a>,
<a href="https://github.com/epinowcast/epidist/issues?q=is%3Aissue+author%3Ajamesmbaazam">jamesmbaazam</a>



### Issue Contributors


<a href="https://github.com/epinowcast/epidist/issues?q=is%3Aissue+commenter%3Apearsonca">pearsonca</a>,
<a href="https://github.com/epinowcast/epidist/issues?q=is%3Aissue+commenter%3Asbfnk">sbfnk</a>,
<a href="https://github.com/epinowcast/epidist/issues?q=is%3Aissue+commenter%3ASamuelBrand1">SamuelBrand1</a>,
<a href="https://github.com/epinowcast/epidist/issues?q=is%3Aissue+commenter%3Azsusswein">zsusswein</a>,
<a href="https://github.com/epinowcast/epidist/issues?q=is%3Aissue+commenter%3Apearsonca">pearsonca</a>,
<a href="https://github.com/epinowcast/epidist/issues?q=is%3Aissue+commenter%3Asbfnk">sbfnk</a>,
<a href="https://github.com/epinowcast/epidist/issues?q=is%3Aissue+commenter%3ASamuelBrand1">SamuelBrand1</a>,
<a href="https://github.com/epinowcast/epidist/issues?q=is%3Aissue+commenter%3Azsusswein">zsusswein</a>,
<a href="https://github.com/epinowcast/epidist/issues?q=is%3Aissue+commenter%3Akcharniga">kcharniga</a>


<!-- markdownlint-enable -->
<!-- prettier-ignore-end -->
<!-- ALL-CONTRIBUTORS-LIST:END -->

6 changes: 5 additions & 1 deletion man/as_epidist_marginal_model.epidist_linelist_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 11 additions & 7 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,38 +125,42 @@ prep_marginal_obs_sex <- as_epidist_marginal_model(sim_obs_sex)

if (not_on_cran()) {
set.seed(1)
cli::cli_inform("Compiling the latent model with cmdstanr")
cli::cli_alert_info("Compiling the latent model with cmdstanr")
fit <- epidist(
data = prep_obs, seed = 1, chains = 2, cores = 2, silent = 2, refresh = 0,
backend = "cmdstanr"
)

cli::cli_inform("Compiling the latent model with rstan")
cli::cli_alert_info("Compiling the latent model with rstan")
fit_rstan <- epidist(
data = prep_obs, seed = 1, chains = 2, cores = 2, silent = 2, refresh = 0
)

cli::cli_inform("Compiling the marginal model with cmdstanr")
cli::cli_alert_info("Compiling the marginal model with cmdstanr")
fit_marginal <- suppressMessages(epidist(
data = prep_marginal_obs, seed = 1, chains = 2, cores = 2, silent = 2,
refresh = 0, backend = "cmdstanr"
))

cli::cli_inform("Compiling the latent model with cmdstanr and a gamma dist")
cli::cli_alert_info(
"Compiling the latent model with cmdstanr and a gamma dist"
)
fit_gamma <- epidist(
data = prep_obs_gamma, family = Gamma(link = "log"),
seed = 1, chains = 2, cores = 2, silent = 2, refresh = 0,
backend = "cmdstanr"
)

cli::cli_inform("Compiling the marginal model with cmdstanr and a gamma dist")
cli::cli_alert_info(
"Compiling the marginal model with cmdstanr and a gamma dist"
)
fit_marginal_gamma <- suppressMessages(epidist(
data = prep_marginal_obs_gamma, family = Gamma(link = "log"),
seed = 1, chains = 2, cores = 2, silent = 2, refresh = 0,
backend = "cmdstanr"
))

cli::cli_inform(
cli::cli_alert_info(
"Compiling the latent model with cmdstanr and a sex stratification"
)
fit_sex <- epidist(
Expand All @@ -166,7 +170,7 @@ if (not_on_cran()) {
cores = 2, chains = 2, backend = "cmdstanr"
)

cli::cli_inform(
cli::cli_alert_info(
"Compiling the marginal model with cmdstanr and a sex stratification"
)
fit_marginal_sex <- suppressMessages(epidist(
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-gen.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ test_that("epidist_gen_posterior_predict returns a function that errors for i ou
prep <- brms::prepare_predictions(fit)
i_out_of_bounds <- length(prep$data$Y) + 1
predict_fn <- epidist_gen_posterior_predict(family)
expect_warning(
suppressMessages(expect_warning(
expect_error(
predict_fn(i = i_out_of_bounds, prep)
)
)
))
}

# Test lognormal - latent and marginal
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-int-direct_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ test_that("epidist.epidist_naive_model Stan code has no syntax errors in the def
mod <- cmdstanr::cmdstan_model(
stan_file = cmdstanr::write_stan_file(stancode), compile = FALSE
)
expect_true(mod$check_syntax())
suppressMessages(expect_true(mod$check_syntax()))
})

test_that("epidist.epidist_naive_model fits and the MCMC converges in the default case", { # nolint: line_length_linter.
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-int-latent_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ test_that("epidist.epidist_latent_model Stan code has no syntax errors in the de
mod <- cmdstanr::cmdstan_model(
stan_file = cmdstanr::write_stan_file(stancode), compile = FALSE
)
expect_true(mod$check_syntax())
suppressMessages(expect_true(mod$check_syntax()))
})

test_that("epidist.epidist_latent_model samples from the prior according to marginal Kolmogorov-Smirnov tests in the default case.", { # nolint: line_length_linter.
Expand Down
27 changes: 27 additions & 0 deletions tests/testthat/test-marginal_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,33 @@ test_that("as_epidist_marginal_model.epidist_linelist_data errors when passed in
expect_error(as_epidist_marginal_model(sim_obs[, 1]))
})

test_that("as_epidist_marginal_model.epidist_linelist_data respects weight variable", { # nolint: line_length_linter.
# Create test data with a weight column
weighted_data <- sim_obs
weighted_data$counts <- rep(c(1, 2), length.out = nrow(weighted_data))

# Check weighted model has correct n values
weighted_model <- as_epidist_marginal_model(
weighted_data,
weight = "counts"
)
expect_identical(weighted_model$n, weighted_data$counts)

# Check unweighted model has n=1
unweighted_model <- as_epidist_marginal_model(sim_obs)
expect_true(all(unweighted_model$n == 1))
})

test_that(
"as_epidist_marginal_model.epidist_linelist_data errors with invalid weight column", # nolint: line_length_linter.
{
expect_error(
as_epidist_marginal_model(sim_obs, weight = "nonexistent_column"),
regexp = "Names must include the elements"
)
}
)

# Make this data available for other tests
family_lognormal <- epidist_family(prep_marginal_obs, family = lognormal())

Expand Down

0 comments on commit e1a8fa0

Please sign in to comment.