Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow specifying tolerance at to dist_spec upon definition #724

Merged
merged 33 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4b562de
flexibly specify bounds on distributions
sbfnk Jul 12, 2024
5a51071
make max/tolerance arguments explicit
sbfnk Jul 17, 2024
e620220
fix S3 documentation
sbfnk Jul 17, 2024
6c331c7
update new_dist_spec use
sbfnk Jul 17, 2024
455c1fd
update plot documentation
sbfnk Jul 17, 2024
2332c15
create man files
sbfnk Jul 17, 2024
70a2f5c
correctly name var
sbfnk Jul 17, 2024
284c07d
remove obsolete return statements
sbfnk Jul 17, 2024
729babd
add global variable
sbfnk Jul 17, 2024
d65dc22
add news item
sbfnk Jul 17, 2024
d045521
simplify map syntax
sbfnk Jul 17, 2024
f59ed81
remove unused variables
sbfnk Jul 17, 2024
66f26e2
add distribution to globals
sbfnk Jul 17, 2024
0065efe
fix x-axis
sbfnk Jul 25, 2024
f723cab
add examples
sbfnk Jul 25, 2024
90ca08f
add tests for specifying tolerance
sbfnk Jul 25, 2024
b9a0774
add comment
sbfnk Jul 25, 2024
49a6f54
improve error message [ci skip]
sbfnk Jul 31, 2024
6b35983
remove superseded comment
sbfnk Jul 31, 2024
07eb4fb
clarify NA return values
sbfnk Jul 31, 2024
1825fe4
add informative error message
sbfnk Jul 31, 2024
2aadbbd
remove superseded comment
sbfnk Jul 31, 2024
0208f60
add [is_constrained()] function
sbfnk Jul 31, 2024
504dfce
remove unneeded checks
sbfnk Jul 31, 2024
4df83f1
update pkgdown
sbfnk Jul 31, 2024
ff2a8c8
set default tolerance in _opts functions
sbfnk Jul 31, 2024
3f6838d
make deprecated functions internal
sbfnk Jul 31, 2024
bc8e0b2
update tolerance argument
sbfnk Jul 31, 2024
48caadf
remove whitespace
sbfnk Jul 31, 2024
903c756
improve printout
sbfnk Jul 31, 2024
13393e2
improve documentation
sbfnk Jul 31, 2024
2b7279d
clarify id doc
sbfnk Aug 1, 2024
dc187ac
use correct function name
sbfnk Aug 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,25 @@

S3method("+",dist_spec)
S3method(c,dist_spec)
S3method(discretise,dist_spec)
S3method(discretise,multi_dist_spec)
S3method(fix_dist,dist_spec)
S3method(fix_dist,multi_dist_spec)
S3method(is_constrained,dist_spec)
S3method(is_constrained,multi_dist_spec)
S3method(max,dist_spec)
S3method(max,multi_dist_spec)
S3method(mean,dist_spec)
S3method(mean,multi_dist_spec)
S3method(plot,dist_spec)
S3method(plot,epinow)
S3method(plot,estimate_infections)
S3method(plot,estimate_secondary)
S3method(plot,estimate_truncation)
S3method(print,dist_spec)
S3method(sd,default)
S3method(sd,dist_spec)
S3method(sd,multi_dist_spec)
S3method(summary,epinow)
S3method(summary,estimate_infections)
export(Fixed)
Expand All @@ -19,9 +30,9 @@ export(NonParametric)
export(Normal)
export(R_to_growth)
export(adjust_infection_to_report)
export(apply_tolerance)
export(backcalc_opts)
export(bootstrapped_dist_fit)
export(bound_dist)
export(calc_CrI)
export(calc_CrIs)
export(calc_summary_measures)
Expand All @@ -35,8 +46,6 @@ export(delay_opts)
export(discretise)
export(discretize)
export(dist_fit)
export(dist_skel)
export(dist_spec)
export(epinow)
export(epinow2_cmdstan_model)
export(estimate_delay)
Expand All @@ -63,6 +72,7 @@ export(get_regional_results)
export(gp_opts)
export(growth_to_R)
export(gt_opts)
export(is_constrained)
export(lognorm_dist_def)
export(make_conf)
export(map_prob_change)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
## Model changes

- `epinow()` now returns the "timing" output in a "time difference"" format that is easier to understand and work with. By @jamesmbaazam in #688 and reviewed by @sbfnk.
- The interface for defining delay distributions has been generalised to also cater for continuous distributions
- When defining probability distributions these can now be truncated using the `tolerance` argument

## Bug fixes

Expand Down
28 changes: 15 additions & 13 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ check_stan_delay <- function(dist) {
# Check that `dist` is a `dist_spec`
assert_class(dist, "dist_spec")
# Check that `dist` is lognormal or gamma or nonparametric
distributions <- vapply(dist, function(x) x$distribution, character(1))
distributions <- vapply(
seq_len(ndist(dist)), get_distribution, x = dist, FUN.VALUE = character(1)
)
if (
!all(distributions %in% c("lognormal", "gamma", "fixed", "nonparametric"))
) {
Expand All @@ -78,24 +80,24 @@ check_stan_delay <- function(dist) {
}
# Check that `dist` has parameters that are either numeric or normal
# distributions with numeric parameters and infinite maximum
numeric_parameters <- vapply(dist$parameters, is.numeric, logical(1))
normal_parameters <- vapply(
dist$parameters,
function(x) {
is(x, "dist_spec") &&
x$distribution == "normal" &&
all(vapply(x$parameters, is.numeric, logical(1))) &&
is.infinite(x$max)
},
logical(1)
)
if (!all(numeric_parameters | normal_parameters)) {
numeric_or_normal <- unlist(lapply(seq_len(ndist(dist)), function(id) {
params <- get_parameters(dist, id)
vapply(params, function(x) {
is.numeric(x) ||
(is(x, "dist_spec") && get_distribution(x) == "normal" &&
is.infinite(max(x)))
}, logical(1))
}))
if (!all(numeric_or_normal)) {
stop(
"Delay distributions passed to the model need to have parameters that ",
"are either numeric or normally distributed with numeric parameters ",
"and infinite maximum."
)
}
if (is.null(attr(dist, "tolerance"))) {
attr(dist, "tolerance") <- 0
}
assert_numeric(attr(dist, "tolerance"), lower = 0, upper = 1)
# Check that `dist` has a finite maximum
if (any(is.infinite(max(dist))) && !(attr(dist, "tolerance") > 0)) {
Expand Down
31 changes: 16 additions & 15 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -741,33 +741,33 @@ create_stan_delays <- function(..., time_points = 1L) {
delays <- list(...)
## discretise
delays <- map(delays, discretise, strict = FALSE)
## convolve where appropriate
delays <- map(delays, collapse)
## apply tolerance
delays <- map(delays, function(x) {
apply_tolerance(x, tolerance = attr(x, "tolerance"))
})
## get maximum delays
max_delay <- unname(as.numeric(flatten(map(delays, max))))
## number of different non-empty types
type_n <- lengths(delays)
type_n <- vapply(delays, ndist, integer(1))
## assign ID values to each type
ids <- rep(0L, length(type_n))
ids[type_n > 0] <- seq_len(sum(type_n > 0))
names(ids) <- paste(names(type_n), "id", sep = "_")

flat_delays <- flatten(delays)
## create "flat version" of delays, i.e. a list of all the delays (including
## elements of composite delays)
if (length(delays) > 1) {
flat_delays <- do.call(c, delays)
} else {
flat_delays <- delays
}
parametric <- unname(vapply(
flat_delays, function(x) x$distribution != "nonparametric", logical(1)
flat_delays, function(x) get_distribution(x) != "nonparametric", logical(1)
))
param_length <- unname(vapply(flat_delays[parametric], function(x) {
length(x$parameters)
length(get_parameters(x))
}, numeric(1)))
nonparam_length <- unname(vapply(flat_delays[!parametric], function(x) {
length(x$pmf)
}, numeric(1)))
distributions <- unname(as.character(
map(flat_delays[parametric], ~ .x$distribution)
map(flat_delays[parametric], get_distribution)
))

## create stan object
Expand All @@ -788,15 +788,16 @@ create_stan_delays <- function(..., time_points = 1L) {
ret$types_groups <- array(c(0, cumsum(unname(type_n[type_n > 0]))) + 1)

ret$params_mean <- array(unname(as.numeric(
map(flatten(map(flat_delays[parametric], ~ .x$parameters)), mean)
map(flatten(map(flat_delays[parametric], get_parameters)), mean)
)))
ret$params_sd <- array(unname(as.numeric(
map(flatten(map(flat_delays[parametric], ~ .x$parameters)), sd_dist)
map(flatten(map(flat_delays[parametric], get_parameters)), sd)
)))
ret$params_sd[is.na(ret$params_sd)] <- 0
ret$max <- array(max_delay[parametric])

ret$np_pmf <- array(unname(as.numeric(
flatten(map(flat_delays[!parametric], ~ .x$pmf))
flatten(map(flat_delays[!parametric], get_pmf))
)))
## get non zero length delay pmf lengths
ret$np_pmf_groups <- array(c(0, cumsum(nonparam_length)) + 1)
Expand All @@ -809,7 +810,7 @@ create_stan_delays <- function(..., time_points = 1L) {
## set lower bounds
ret$params_lower <- array(unname(as.numeric(flatten(
map(flat_delays[parametric], function(x) {
lower_bounds(x$distribution)[names(x$parameters)]
lower_bounds(get_distribution(x))[names(get_parameters(x))]
})
))))
## assign prior weights
Expand Down
85 changes: 35 additions & 50 deletions R/deprecated.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ adjust_infection_to_report <- function(infections, delay_defs,
#' @param fixed Deprecated, use [fix_dist()] instead.
#' @return A list of distribution options.
#' @importFrom rlang warn arg_match
#' @export
#' @keywords internal
dist_spec <- function(distribution = c(
"lognormal", "normal", "gamma", "fixed", "empty"
Expand Down Expand Up @@ -485,55 +484,7 @@ rstan_opts <- function(object = NULL,
#' Samples outside of this range are resampled.
#'
#' @return A vector of samples or a probability distribution.
#' @export
#' @examples
#'
#' ## Exponential model
#' # sample
#' dist_skel(10, model = "exp", params = list(rate = 1))
#'
#' # cumulative prob density
#' dist_skel(1:10, model = "exp", dist = TRUE, params = list(rate = 1))
#'
#' # probability density
#' dist_skel(1:10,
#' model = "exp", dist = TRUE,
#' cum = FALSE, params = list(rate = 1)
#' )
#'
#' ## Gamma model
#' # sample
#' dist_skel(10, model = "gamma", params = list(shape = 1, rate = 0.5))
#'
#' # cumulative prob density
#' dist_skel(0:10,
#' model = "gamma", dist = TRUE,
#' params = list(shape = 1, rate = 0.5)
#' )
#'
#' # probability density
#' dist_skel(0:10,
#' model = "gamma", dist = TRUE,
#' cum = FALSE, params = list(shape = 2, rate = 0.5)
#' )
#'
#' ## Log normal model
#' # sample
#' dist_skel(10,
#' model = "lognormal", params = list(meanlog = log(5), sdlog = log(2))
#' )
#'
#' # cumulative prob density
#' dist_skel(0:10,
#' model = "lognormal", dist = TRUE,
#' params = list(meanlog = log(5), sdlog = log(2))
#' )
#'
#' # probability density
#' dist_skel(0:10,
#' model = "lognormal", dist = TRUE, cum = FALSE,
#' params = list(meanlog = log(5), sdlog = log(2))
#' )
#' @keywords internal
dist_skel <- function(n, dist = FALSE, cum = TRUE, model,
discrete = FALSE, params, max_value = 120) {
lifecycle::deprecate_warn(
Expand Down Expand Up @@ -633,3 +584,37 @@ dist_skel <- function(n, dist = FALSE, cum = TRUE, model,
sample <- truncated_skel(n, dist = dist, cum = cum, max_value = max_value)
return(sample)
}

#' Applies a threshold to all nonparametric distributions in a <dist_spec>
#'
#' @description `r lifecycle::badge("deprecated")`
#' This function is deprecated. Use `bound_dist()` instead.
#' @param x A `<dist_spec>`
#' @param tolerance Numeric; the desired tolerance level. Any part of the
#' cumulative distribution function beyond 1 minus this tolerance level is
#' removed.
#' @return A `<dist_spec>` where probability masses below the threshold level
#' have been removed
#' @keywords internal
apply_tolerance <- function(x, tolerance) {
lifecycle::deprecate_warn(
"1.6.0", "apply_tolerance()", "bound_dist()"
)
if (!is(x, "dist_spec")) {
stop("Can only apply tolerance to distributions in a <dist_spec>.")
}
y <- lapply(x, function(x) {
if (x$distribution == "nonparametric") {
cmf <- cumsum(x$pmf)
new_pmf <- x$pmf[c(TRUE, (1 - cmf[-length(cmf)]) >= tolerance)]
x$pmf <- new_pmf / sum(new_pmf)
return(x)
} else {
return(x)
}
})

## preserve attributes
attributes(y) <- attributes(x)
return(y)
}
Loading
Loading