Skip to content

Commit

Permalink
linting and format check
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Feb 26, 2025
1 parent a2e2cab commit ff48990
Show file tree
Hide file tree
Showing 25 changed files with 273 additions and 171 deletions.
45 changes: 29 additions & 16 deletions R/aggregate_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,15 @@ as_epidist_aggregate_data <- function(data, ...) {
#' n = c(1, 2, 3)
#' )
as_epidist_aggregate_data.default <- function(
data, n = NULL, ptime_upr = NULL, stime_lwr = NULL,
stime_upr = NULL, obs_time = NULL, ...) {
data,
n = NULL,
ptime_upr = NULL,
stime_lwr = NULL,
stime_upr = NULL,
obs_time = NULL,
...) {
# Create linelist data first
df <- as_epidist_linelist_data.default(
linelist_data <- as_epidist_linelist_data.default(
data = data,
ptime_upr = ptime_upr,
stime_lwr = stime_lwr,
Expand All @@ -55,13 +60,13 @@ as_epidist_aggregate_data.default <- function(
)

if (!is.null(n)) {
df$n <- n
linelist_data$n <- n
} else {
cli::cli_abort("{.var n} is NULL but must be provided.")
}
df <- new_epidist_aggregate_data(df)
assert_epidist(df)
return(df)
aggregate_data <- new_epidist_aggregate_data(linelist_data)
assert_epidist(aggregate_data)
return(aggregate_data)
}

#' Create an epidist_aggregate_data object from a data.frame
Expand Down Expand Up @@ -91,10 +96,16 @@ as_epidist_aggregate_data.default <- function(
#' n = "n"
#' )
as_epidist_aggregate_data.data.frame <- function(
data, n = NULL, pdate_lwr = NULL, sdate_lwr = NULL,
pdate_upr = NULL, sdate_upr = NULL, obs_date = NULL, ...) {
data,
n = NULL,
pdate_lwr = NULL,
sdate_lwr = NULL,
pdate_upr = NULL,
sdate_upr = NULL,
obs_date = NULL,
...) {
# First convert to linelist data
df <- as_epidist_linelist_data.data.frame(
linelist_data <- as_epidist_linelist_data.data.frame(
data = data,
pdate_lwr = pdate_lwr,
sdate_lwr = sdate_lwr,
Expand All @@ -112,11 +123,11 @@ as_epidist_aggregate_data.data.frame <- function(
n <- "n"
}

df$n <- data[[n]]
linelist_data$n <- data[[n]]

df <- new_epidist_aggregate_data(df)
assert_epidist(df)
return(df)
aggregate_data <- new_epidist_aggregate_data(linelist_data)
assert_epidist(aggregate_data)
return(aggregate_data)
}

#' Convert linelist data to aggregate format
Expand Down Expand Up @@ -156,7 +167,9 @@ as_epidist_aggregate_data.data.frame <- function(
#' ) |>
#' as_epidist_aggregate_data(by = "age")
as_epidist_aggregate_data.epidist_linelist_data <- function(
data, by = NULL, ...) {
data,
by = NULL,
...) {
assert_epidist.epidist_linelist_data(data)

# Required variables for epidist objects
Expand Down Expand Up @@ -206,7 +219,7 @@ new_epidist_aggregate_data <- function(data) {
#' @family aggregate_data
#' @export
is_epidist_aggregate_data <- function(data, ...) {
inherits(data, "epidist_aggregate_data")
return(inherits(data, "epidist_aggregate_data"))
}

#' Assert validity of `epidist_aggregate_data` objects
Expand Down
8 changes: 5 additions & 3 deletions R/diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ epidist_diagnostics <- function(fit) {
if (fit$algorithm %in% c("laplace", "meanfield", "fullrank", "pathfinder")) {
cli_abort(c(
"!" = paste0(
"Diagnostics not yet supported for the algorithm: ", fit$algorithm
"Diagnostics not yet supported for the algorithm: ",
fit$algorithm
)
))
}
Expand All @@ -57,8 +58,9 @@ epidist_diagnostics <- function(fit) {
max_treedepth = max(np[treedepth_ind, ]$Value)
) |>
mutate(
no_at_max_treedepth =
sum(np[treedepth_ind, ]$Value == .data$max_treedepth),
no_at_max_treedepth = sum(
np[treedepth_ind, ]$Value == .data$max_treedepth
),
per_at_max_treedepth = .data$no_at_max_treedepth / samples
)
} else {
Expand Down
38 changes: 27 additions & 11 deletions R/epidist.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,29 +47,45 @@
#' epidist(chains = 2, cores = 2, refresh = ifelse(interactive(), 250, 0))
#'
#' summary(fit)
epidist <- function(data, formula = mu ~ 1,
family = lognormal(), prior = NULL,
merge_priors = TRUE,
fn = brms::brm, ...) {
epidist <- function(
data,
formula = mu ~ 1,
family = lognormal(),
prior = NULL,
merge_priors = TRUE,
fn = brms::brm,
...) {
assert_epidist(data)
epidist_family <- epidist_family(data, family)
epidist_formula <- epidist_formula(
data = data, family = epidist_family, formula = formula
data = data,
family = epidist_family,
formula = formula
)
transformed_data <- epidist_transform_data(
data, epidist_family, epidist_formula
data,
epidist_family,
epidist_formula
)
epidist_prior <- epidist_prior(
data = transformed_data, family = epidist_family,
formula = epidist_formula, prior,
data = transformed_data,
family = epidist_family,
formula = epidist_formula,
prior,
merge = merge_priors
)
epidist_stancode <- epidist_stancode(
data = transformed_data, family = epidist_family, formula = epidist_formula
data = transformed_data,
family = epidist_family,
formula = epidist_formula
)
fit <- fn(
formula = epidist_formula, family = epidist_family, prior = epidist_prior,
stanvars = epidist_stancode, data = transformed_data, ...
formula = epidist_formula,
family = epidist_family,
prior = epidist_prior,
stanvars = epidist_stancode,
data = transformed_data,
...
)
class(fit) <- c(class(fit), "epidist_fit")
return(fit)
Expand Down
3 changes: 2 additions & 1 deletion R/formula.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ epidist_formula_model <- function(data, formula, ...) {
#' @export
epidist_formula_model.default <- function(data, formula, ...) {
formula <- stats::update(
formula, delay ~ .
formula,
delay ~ .

Check warning on line 46 in R/formula.R

View check run for this annotation

Codecov / codecov/patch

R/formula.R#L45-L46

Added lines #L45 - L46 were not covered by tests
)
return(formula)
}
44 changes: 33 additions & 11 deletions R/latent_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ new_epidist_latent_model <- function(data, ...) {
#' @family latent_model
#' @export
is_epidist_latent_model <- function(data) {
inherits(data, "epidist_latent_model")
return(inherits(data, "epidist_latent_model"))
}

#' @method assert_epidist epidist_latent_model
Expand Down Expand Up @@ -152,7 +152,9 @@ assert_epidist.epidist_latent_model <- function(data, ...) {
#' @family latent_model
#' @export
epidist_family_model.epidist_latent_model <- function(
data, family, ...) {
data,
family,
...) {
# Really the name and vars are the "model-specific" parts here
custom_family <- brms::custom_family(
paste0("latent_", family$family),
Expand All @@ -162,8 +164,13 @@ epidist_family_model.epidist_latent_model <- function(
ub = c(NA, as.numeric(lapply(family$other_bounds, "[[", "ub"))),
type = family$type,
vars = c(
"vreal1", "vreal2", "vreal3", "pwindow_raw", "swindow_raw",
"woverlap", "wN"
"vreal1",
"vreal2",
"vreal3",
"pwindow_raw",
"swindow_raw",
"woverlap",
"wN"
),
loop = FALSE,
log_lik = epidist_gen_log_lik(family),
Expand All @@ -186,9 +193,12 @@ epidist_family_model.epidist_latent_model <- function(
#' @family latent_model
#' @export
epidist_formula_model.epidist_latent_model <- function(
data, formula, ...) {
data,
formula,
...) {
formula <- stats::update(
formula, delay | vreal(relative_obs_time, pwindow, swindow) ~ .
formula,
delay | vreal(relative_obs_time, pwindow, swindow) ~ .
)
return(formula)
}
Expand Down Expand Up @@ -228,7 +238,8 @@ epidist_model_prior.epidist_latent_model <- function(data, formula, ...) {
epidist_stancode.epidist_latent_model <- function(
data,
family = epidist_family(data),
formula = epidist_formula(data), ...) {
formula = epidist_formula(data),
...) {
assert_epidist(data)

stanvars_version <- .version_stanvar()
Expand All @@ -241,7 +252,9 @@ epidist_stancode.epidist_latent_model <- function(
family_name <- gsub("latent_", "", family$name, fixed = TRUE)

stanvars_functions[[1]]$scode <- gsub(
"family", family_name, stanvars_functions[[1]]$scode,
"family",
family_name,
stanvars_functions[[1]]$scode,
fixed = TRUE
)

Expand All @@ -258,7 +271,9 @@ epidist_stancode.epidist_latent_model <- function(
)

stanvars_functions[[1]]$scode <- gsub(
"dpars_B", family$param, stanvars_functions[[1]]$scode,
"dpars_B",
family$param,
stanvars_functions[[1]]$scode,
fixed = TRUE
)

Expand All @@ -284,7 +299,9 @@ epidist_stancode.epidist_latent_model <- function(
scode = "vector<lower=0,upper=1>[N] swindow_raw;"
)

stanvars_all <- stanvars_version + stanvars_functions + stanvars_data +
stanvars_all <- stanvars_version +
stanvars_functions +
stanvars_data +
stanvars_parameters

return(stanvars_all)
Expand All @@ -293,6 +310,11 @@ epidist_stancode.epidist_latent_model <- function(
.latent_required_cols <- function() {
return(c(
.linelist_required_cols(),
"relative_obs_time", "pwindow", "woverlap", "swindow", "delay", ".row_id"
"relative_obs_time",
"pwindow",
"woverlap",
"swindow",
"delay",
".row_id"
))
}
16 changes: 12 additions & 4 deletions R/naive_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ as_epidist_naive_model <- function(data, ...) {
#' ) |>
#' as_epidist_naive_model()
as_epidist_naive_model.epidist_linelist_data <- function(
data, weight = NULL, ...) {
data,
weight = NULL,
...) {
assert_epidist.epidist_linelist_data(data)

data <- data |>

Check warning on line 61 in R/naive_model.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/naive_model.R,line=61,col=11,[one_call_pipe_linter] Avoid pipe |> for expressions with only a single call.
Expand Down Expand Up @@ -144,10 +146,13 @@ assert_epidist.epidist_naive_model <- function(data, ...) {
#' @family naive_model
#' @export
epidist_formula_model.epidist_naive_model <- function(
data, formula, ...) {
data,
formula,
...) {
# data is only used to dispatch on
formula <- stats::update(
formula, delay | weights(n) ~ .
formula,
delay | weights(n) ~ .
)
return(formula)
}
Expand Down Expand Up @@ -177,7 +182,10 @@ epidist_formula_model.epidist_naive_model <- function(
#' @importFrom purrr map_chr
#' @export
epidist_transform_data_model.epidist_naive_model <- function(
data, family, formula, ...) {
data,
family,
formula,
...) {
required_cols <- .naive_required_cols()
trans_data <- data |>
.summarise_n_by_formula(by = required_cols, formula = formula) |>
Expand Down
21 changes: 15 additions & 6 deletions R/prior.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@
#' @rdname epidist_prior
#' @family prior
#' @export
epidist_prior <- function(data, family, formula, prior, merge = TRUE,
enforce_presence = FALSE) {
epidist_prior <- function(
data,
family,
formula,
prior,
merge = TRUE,
enforce_presence = FALSE) {
assert_epidist(data)
default <- brms::default_prior(formula, data = data)
model <- epidist_model_prior(data, formula)
Expand All @@ -48,13 +53,17 @@ epidist_prior <- function(data, family, formula, prior, merge = TRUE,
family$source <- "family"
}
custom <- .replace_prior(
family, model,
merge = TRUE, enforce_presence = FALSE
family,
model,
merge = TRUE,
enforce_presence = FALSE
)
internal <- .replace_prior(default, custom, merge = TRUE)
prior <- .replace_prior(
internal, prior,
warn = TRUE, merge = merge,
internal,
prior,
warn = TRUE,
merge = merge,
enforce_presence = enforce_presence
)

Expand Down
Loading

0 comments on commit ff48990

Please sign in to comment.