Skip to content

Commit

Permalink
Alternative sir implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Oct 15, 2024
1 parent 08e394b commit d7adc1f
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 91 deletions.
103 changes: 12 additions & 91 deletions tests/testthat/helper-monty.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,6 @@ ex_simple_nested_with_base <- function(n_groups) {
}


ex_dust_sir <- function(...) {
testthat::skip_if_not_installed("dust")
prior <- monty_dsl({
beta ~ Gamma(shape = 1, rate = 1 / 0.5)
gamma ~ Gamma(shape = 1, rate = 1 / 0.5)
})
ex_dust_sir_likelihood(...) + prior
}


random_array <- function(dim, named = FALSE) {
if (named) {
dn <- lapply(seq_along(dim), function(i) {
Expand All @@ -104,91 +94,22 @@ random_array <- function(dim, named = FALSE) {
}


ex_dust_sir_likelihood <- function(n_particles = 100, n_threads = 1,
ex_dust_sir_likelihood <- function(n_particles = 100,
deterministic = FALSE,
save_trajectories = FALSE) {
testthat::skip_if_not_installed("dust")
sir <- dust::dust_example("sir")

np <- 10
end <- 150 * 4
times <- seq(0, end, by = 4)
ans <- sir$new(list(), 0, np, seed = 1L)$simulate(times)
dat <- data.frame(time = times[-1], incidence = ans[5, 1, -1])

## TODO: an upshot here is that our dust models are always going to
## need to be initialisable; we might need to sample from the
## statistical parameters, or set things up to allow two-phases of
## initialsation (which is I think where we are heading, so that's
## fine).
model <- sir$new(list(), 0, n_particles, seed = 1L, n_threads = n_threads,
deterministic = deterministic)
model$set_data(dust::dust_data(dat))
model$set_index(c(2, 4))

trajectories <- NULL

density <- function(x) {
beta <- x[[1]]
gamma <- x[[2]]
model$update_state(
pars = list(beta = x[[1]], gamma = x[[2]]),
time = 0,
set_initial_state = TRUE)
res <- model$filter(save_trajectories = save_trajectories)
if (save_trajectories) {
trajectories <<- res$trajectories
}
res$log_likelihood
}

set_rng_state <- function(rng_state) {
n_streams <- n_particles + 1
if (length(rng_state) != 32 * n_streams) {
## Expand the state by short jumps; we'll make this nicer once
## we refactor the RNG interface and dust.
rng_state <- monty_rng$new(rng_state, n_streams)$state()
}
model$set_rng_state(rng_state)
}

get_rng_state <- function() {
model$rng_state()
}
data <- data.frame(time = c( 4, 8, 12, 16, 20, 24, 28, 32, 36),
incidence = c( 1, 0, 3, 5, 2, 4, 3, 7, 2))
sir_filter_monty(data, n_particles, deterministic, save_trajectories)
}

if (save_trajectories) {
observer <- monty_observer(
function() {
## TODO: It's not really clear to me (Rich) that we want the
## rng coming in here. In dust2 we'll correctly use the
## *filter* rng. however, even there we end up with the act
## of observation changing the behaviour of the sampler and I
## am not sure that is really desirable. We could probably
## improve that in dust, but that would require that we do not
## pass the rng here too. So for now we take the first
## particle.
## i <- floor(rng$random_real(1) * model$model$n_particles()) + 1L
i <- 1L
if (save_trajectories) {
traj <- trajectories[, i, , drop = FALSE]
dim(traj) <- dim(traj)[-2]
} else {
traj <- NULL
}
list(trajectories = traj, state = model$state()[, i])
})
} else {
observer <- NULL
}

monty_model(
list(model = model,
density = density,
parameters = c("beta", "gamma"),
observer = observer,
set_rng_state = set_rng_state,
get_rng_state = get_rng_state),
monty_model_properties(is_stochastic = !deterministic))
ex_dust_sir <- function(...) {
testthat::skip_if_not_installed("dust")
prior <- monty_dsl({
beta ~ Gamma(shape = 1, rate = 1 / 0.5)
gamma ~ Gamma(shape = 1, rate = 1 / 0.5)
})
ex_dust_sir_likelihood(...) + prior
}


Expand Down
123 changes: 123 additions & 0 deletions tests/testthat/helper-sir-filter.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
sir_filter_monty <- function(data, n_particles, deterministic = FALSE,
save_trajectories = FALSE, seed = NULL) {
parameters <- c("beta", "gamma")
env <- new.env()
base <- list(N = 1000, I0 = 10, beta = 0.2, gamma = 0.1, exp_noise = 1e6)

get_rng_state <- function() {
c(monty_rng_state(env$rng$filter$ptr, FALSE),
monty_rng_state(env$rng$system$ptr, FALSE))
}

set_rng_state <- function(rng_state) {
n_streams <- n_particles + 1
r <- matrix(
monty_rng$new(n_streams = n_streams, seed = rng_state)$state(),
ncol = n_streams)
env$rng <- list(
filter = monty_random_alloc(1, r[, 1], deterministic),
system = monty_random_alloc(n_particles, c(r[, -1]), deterministic))
}

set_rng_state(seed)

density <- function(x) {
pars <- base
pars[parameters] <- x
res <- sir_filter(pars, data, n_particles, env$rng, save_trajectories)
if (save_trajectories) {
env$trajectories <- res$trajectories
}
res$log_likelihood
}

if (save_trajectories) {
observer <- monty_observer(
function() {
i <- 1L
trajectories <- env$trajectories[c(2, 4), i, , drop = FALSE]
dim(trajectories) <- dim(trajectories)[-2]
list(trajectories = trajectories)
})
} else {
observer <- NULL
}

monty_model(
list(density = density,
parameters = parameters,
observer = observer,
set_rng_state = set_rng_state,
get_rng_state = get_rng_state),
monty_model_properties(is_stochastic = !deterministic))
}


sir_filter <- function(pars, data, n_particles, rng,
save_trajectories = FALSE) {
y <- list(S = pars$N - pars$I0, I = pars$I0, R = 0, cases = 0)
packer <- monty_packer(names(y))
state <- matrix(packer$pack(y), 4, n_particles)
time <- 0
dt <- 1
ll <- 0
exp_noise <- rep_len(pars$exp_noise, n_particles)
i_cases <- packer$index()$cases
if (save_trajectories) {
trajectories <- array(NA_real_, c(length(y), n_particles, nrow(data)))
} else {
trajectories <- NULL
}

for (i in seq_len(nrow(data))) {
from <- time
to <- data$time[[i]]
state <- sir_run(from, to, dt, state, packer, pars, rng$system)
noise <- monty_random_exponential_rate(exp_noise, rng$system)
lambda <- state[i_cases, ] + noise
tmp <- dpois(data$incidence[[i]], lambda, log = TRUE)
w <- exp(tmp - max(tmp))
ll <- ll + log(mean(w)) + max(tmp)
u <- monty_random_real(rng$filter)
k <- dust_resample_weight(w, u)

if (save_trajectories) {
trajectories[, , i] <- state
trajectories <- trajectories[, k, , drop = FALSE] # slow but correct...
}

state <- state[, k, drop = FALSE]
time <- to
}

list(log_likelihood = ll, trajectories = trajectories)
}


sir_run <- function(from, to, dt, state, packer, pars, rng) {
state <- packer$unpack(state)
for (time in seq(from, to, by = dt)[-1]) {
state <- sir_step(time, dt, state, pars, rng)
}
packer$pack(state)
}


sir_step <- function(time, dt, state, pars, rng) {
p_SI <- 1 - exp(-pars$beta * state$I / pars$N * dt)
p_IR <- rep(1 - exp(-pars$gamma * dt), length(state$I))
n_SI <- monty_random_binomial(state$S, p_SI, rng)
n_IR <- monty_random_binomial(state$I, p_IR, rng)
cases <- if (time %% 1 == 0) 0 else state$cases
list(S = state$S - n_SI,
I = state$I + n_SI - n_IR,
R = state$R + n_IR,
cases = cases + n_SI)
}


dust_resample_weight <- function(w, u) {
n <- length(w)
uu <- u / n + seq(0, by = 1 / n, length.out = n)
findInterval(uu, cumsum(w / sum(w))) + 1L
}

0 comments on commit d7adc1f

Please sign in to comment.