From d7adc1fdb68d851055cfbcd54f7215797e611f6a Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 15 Oct 2024 11:05:53 +0100 Subject: [PATCH] Alternative sir implementation --- tests/testthat/helper-monty.R | 103 +++--------------------- tests/testthat/helper-sir-filter.R | 123 +++++++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 91 deletions(-) create mode 100644 tests/testthat/helper-sir-filter.R diff --git a/tests/testthat/helper-monty.R b/tests/testthat/helper-monty.R index 2d9fca72..4739beb3 100644 --- a/tests/testthat/helper-monty.R +++ b/tests/testthat/helper-monty.R @@ -81,16 +81,6 @@ ex_simple_nested_with_base <- function(n_groups) { } -ex_dust_sir <- function(...) { - testthat::skip_if_not_installed("dust") - prior <- monty_dsl({ - beta ~ Gamma(shape = 1, rate = 1 / 0.5) - gamma ~ Gamma(shape = 1, rate = 1 / 0.5) - }) - ex_dust_sir_likelihood(...) + prior -} - - random_array <- function(dim, named = FALSE) { if (named) { dn <- lapply(seq_along(dim), function(i) { @@ -104,91 +94,22 @@ random_array <- function(dim, named = FALSE) { } -ex_dust_sir_likelihood <- function(n_particles = 100, n_threads = 1, +ex_dust_sir_likelihood <- function(n_particles = 100, deterministic = FALSE, save_trajectories = FALSE) { - testthat::skip_if_not_installed("dust") - sir <- dust::dust_example("sir") - - np <- 10 - end <- 150 * 4 - times <- seq(0, end, by = 4) - ans <- sir$new(list(), 0, np, seed = 1L)$simulate(times) - dat <- data.frame(time = times[-1], incidence = ans[5, 1, -1]) - - ## TODO: an upshot here is that our dust models are always going to - ## need to be initialisable; we might need to sample from the - ## statistical parameters, or set things up to allow two-phases of - ## initialsation (which is I think where we are heading, so that's - ## fine). - model <- sir$new(list(), 0, n_particles, seed = 1L, n_threads = n_threads, - deterministic = deterministic) - model$set_data(dust::dust_data(dat)) - model$set_index(c(2, 4)) - - trajectories <- NULL - - density <- function(x) { - beta <- x[[1]] - gamma <- x[[2]] - model$update_state( - pars = list(beta = x[[1]], gamma = x[[2]]), - time = 0, - set_initial_state = TRUE) - res <- model$filter(save_trajectories = save_trajectories) - if (save_trajectories) { - trajectories <<- res$trajectories - } - res$log_likelihood - } - - set_rng_state <- function(rng_state) { - n_streams <- n_particles + 1 - if (length(rng_state) != 32 * n_streams) { - ## Expand the state by short jumps; we'll make this nicer once - ## we refactor the RNG interface and dust. - rng_state <- monty_rng$new(rng_state, n_streams)$state() - } - model$set_rng_state(rng_state) - } - - get_rng_state <- function() { - model$rng_state() - } + data <- data.frame(time = c( 4, 8, 12, 16, 20, 24, 28, 32, 36), + incidence = c( 1, 0, 3, 5, 2, 4, 3, 7, 2)) + sir_filter_monty(data, n_particles, deterministic, save_trajectories) +} - if (save_trajectories) { - observer <- monty_observer( - function() { - ## TODO: It's not really clear to me (Rich) that we want the - ## rng coming in here. In dust2 we'll correctly use the - ## *filter* rng. however, even there we end up with the act - ## of observation changing the behaviour of the sampler and I - ## am not sure that is really desirable. We could probably - ## improve that in dust, but that would require that we do not - ## pass the rng here too. So for now we take the first - ## particle. - ## i <- floor(rng$random_real(1) * model$model$n_particles()) + 1L - i <- 1L - if (save_trajectories) { - traj <- trajectories[, i, , drop = FALSE] - dim(traj) <- dim(traj)[-2] - } else { - traj <- NULL - } - list(trajectories = traj, state = model$state()[, i]) - }) - } else { - observer <- NULL - } - monty_model( - list(model = model, - density = density, - parameters = c("beta", "gamma"), - observer = observer, - set_rng_state = set_rng_state, - get_rng_state = get_rng_state), - monty_model_properties(is_stochastic = !deterministic)) +ex_dust_sir <- function(...) { + testthat::skip_if_not_installed("dust") + prior <- monty_dsl({ + beta ~ Gamma(shape = 1, rate = 1 / 0.5) + gamma ~ Gamma(shape = 1, rate = 1 / 0.5) + }) + ex_dust_sir_likelihood(...) + prior } diff --git a/tests/testthat/helper-sir-filter.R b/tests/testthat/helper-sir-filter.R new file mode 100644 index 00000000..c5cce7a5 --- /dev/null +++ b/tests/testthat/helper-sir-filter.R @@ -0,0 +1,123 @@ +sir_filter_monty <- function(data, n_particles, deterministic = FALSE, + save_trajectories = FALSE, seed = NULL) { + parameters <- c("beta", "gamma") + env <- new.env() + base <- list(N = 1000, I0 = 10, beta = 0.2, gamma = 0.1, exp_noise = 1e6) + + get_rng_state <- function() { + c(monty_rng_state(env$rng$filter$ptr, FALSE), + monty_rng_state(env$rng$system$ptr, FALSE)) + } + + set_rng_state <- function(rng_state) { + n_streams <- n_particles + 1 + r <- matrix( + monty_rng$new(n_streams = n_streams, seed = rng_state)$state(), + ncol = n_streams) + env$rng <- list( + filter = monty_random_alloc(1, r[, 1], deterministic), + system = monty_random_alloc(n_particles, c(r[, -1]), deterministic)) + } + + set_rng_state(seed) + + density <- function(x) { + pars <- base + pars[parameters] <- x + res <- sir_filter(pars, data, n_particles, env$rng, save_trajectories) + if (save_trajectories) { + env$trajectories <- res$trajectories + } + res$log_likelihood + } + + if (save_trajectories) { + observer <- monty_observer( + function() { + i <- 1L + trajectories <- env$trajectories[c(2, 4), i, , drop = FALSE] + dim(trajectories) <- dim(trajectories)[-2] + list(trajectories = trajectories) + }) + } else { + observer <- NULL + } + + monty_model( + list(density = density, + parameters = parameters, + observer = observer, + set_rng_state = set_rng_state, + get_rng_state = get_rng_state), + monty_model_properties(is_stochastic = !deterministic)) +} + + +sir_filter <- function(pars, data, n_particles, rng, + save_trajectories = FALSE) { + y <- list(S = pars$N - pars$I0, I = pars$I0, R = 0, cases = 0) + packer <- monty_packer(names(y)) + state <- matrix(packer$pack(y), 4, n_particles) + time <- 0 + dt <- 1 + ll <- 0 + exp_noise <- rep_len(pars$exp_noise, n_particles) + i_cases <- packer$index()$cases + if (save_trajectories) { + trajectories <- array(NA_real_, c(length(y), n_particles, nrow(data))) + } else { + trajectories <- NULL + } + + for (i in seq_len(nrow(data))) { + from <- time + to <- data$time[[i]] + state <- sir_run(from, to, dt, state, packer, pars, rng$system) + noise <- monty_random_exponential_rate(exp_noise, rng$system) + lambda <- state[i_cases, ] + noise + tmp <- dpois(data$incidence[[i]], lambda, log = TRUE) + w <- exp(tmp - max(tmp)) + ll <- ll + log(mean(w)) + max(tmp) + u <- monty_random_real(rng$filter) + k <- dust_resample_weight(w, u) + + if (save_trajectories) { + trajectories[, , i] <- state + trajectories <- trajectories[, k, , drop = FALSE] # slow but correct... + } + + state <- state[, k, drop = FALSE] + time <- to + } + + list(log_likelihood = ll, trajectories = trajectories) +} + + +sir_run <- function(from, to, dt, state, packer, pars, rng) { + state <- packer$unpack(state) + for (time in seq(from, to, by = dt)[-1]) { + state <- sir_step(time, dt, state, pars, rng) + } + packer$pack(state) +} + + +sir_step <- function(time, dt, state, pars, rng) { + p_SI <- 1 - exp(-pars$beta * state$I / pars$N * dt) + p_IR <- rep(1 - exp(-pars$gamma * dt), length(state$I)) + n_SI <- monty_random_binomial(state$S, p_SI, rng) + n_IR <- monty_random_binomial(state$I, p_IR, rng) + cases <- if (time %% 1 == 0) 0 else state$cases + list(S = state$S - n_SI, + I = state$I + n_SI - n_IR, + R = state$R + n_IR, + cases = cases + n_SI) +} + + +dust_resample_weight <- function(w, u) { + n <- length(w) + uu <- u / n + seq(0, by = 1 / n, length.out = n) + findInterval(uu, cumsum(w / sum(w))) + 1L +}