Skip to content

Commit

Permalink
Merge pull request #98 from mrc-ide/mrc-5941
Browse files Browse the repository at this point in the history
Allow using previous samples as initial conditions
  • Loading branch information
richfitz authored Nov 6, 2024
2 parents edc84a3 + 2fcadce commit 2dc1e64
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 5 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.2.33
Version: 0.2.34
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
36 changes: 34 additions & 2 deletions R/sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
##' @param n_steps The number of steps to run the sampler for.
##'
##' @param initial Optionally, initial parameter values for the
##' sampling. If not given, we sample from the model (or its prior).
##' sampling. If not given, we sample from the model (or its
##' prior). Alternatively, you can provide a `monty_samples`
##' object here -- the result of a previous call to this function --
##' and we will sample some starting points from the final portion
##' of the chains (the exact details here are subject to change, but
##' we'll sample from the last 20 points or 5% of the chain, which
##' ever smaller, with replacement, pooled across all chains in the
##' previous sample).
##'
##' @param n_chains Number of chains to run. The default is to run a
##' single chain, but you will likely want to run more.
Expand Down Expand Up @@ -249,7 +256,23 @@ initial_parameters <- function(initial, model, rng, call = NULL) {
## sample from the posterior!
initial <- lapply(rng, function(r) direct_sample_within_domain(model, r))
}
if (is.list(initial)) {
if (inherits(initial, "monty_samples")) {
## Heuristic here; sample from the last 5% of the chain or 20
## points, whichever is smaller - hopefully a reasonable
## heuristic - pooled across chains.
pars <- tail_and_pool(initial$pars, 0.05, 20)
if (nrow(pars) != n_pars) {
cli::cli_abort(
c(paste("Unexpected parameter length in 'monty_samples' object",
"'initial'; expected {n_pars}"),
i = paste("Your model has {n_pars} parameter{?s}, so the 'initial'",
"object must have this many rows within its 'pars'",
"element, but yours had {nrow(pars)} row{?s}")),
arg = "initial", call = call)
}
i <- ceiling(vnapply(rng, function(r) r$random_real(1)) * ncol(pars))
initial <- pars[, i, drop = FALSE]
} else if (is.list(initial)) {
if (length(initial) != n_chains) {
cli::cli_abort(
c(paste("Unexpected length for list 'initial'",
Expand Down Expand Up @@ -471,3 +494,12 @@ monty_sample_steps <- function(n_steps, burnin = NULL, thinning_factor = NULL,
class(ret) <- "monty_sample_steps"
ret
}


tail_and_pool <- function(pars, p, n) {
n_samples <- ncol(pars)
n_keep <- min(ceiling(n_samples * p), n)
ret <- pars[, seq(to = n_samples, length.out = n_keep), , drop = FALSE]
dim(ret) <- c(nrow(ret), prod(dim(ret)[-1]))
ret
}
9 changes: 8 additions & 1 deletion man/monty_sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 8 additions & 1 deletion man/monty_sample_manual_prepare.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 47 additions & 0 deletions tests/testthat/test-sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,37 @@ test_that("validate that initial have correct size for vector inputs", {
})


test_that("sample from previous samples", {
model <- ex_simple_gamma1()
sampler <- monty_sampler_random_walk(vcv = diag(1) * 0.01)
set.seed(1)
samples <- monty_sample(model, sampler, 100, 1, n_chains = 2)
r1 <- initial_rng(6, seed = 42)

initial <- initial_parameters(samples, model, r1)
expect_equal(dim(initial), c(1, 6))

cmp <- tail_and_pool(samples$pars, 0.05, 20)
r2 <- initial_rng(6, seed = 42)
i <- ceiling(vnapply(r2, function(r) r$random_real(1)) * ncol(cmp))
expect_equal(initial, cmp[, i, drop = FALSE])
})


test_that("validate parameter size when using previous samples", {
model <- ex_simple_gamma1()
sampler <- monty_sampler_random_walk(vcv = diag(1) * 0.01)
set.seed(1)
samples <- monty_sample(model, sampler, 100, 1, n_chains = 2)

samples$pars <- array(samples$pars, c(3, 50, 2))
r <- initial_rng(6, seed = 42)
expect_error(
initial_parameters(samples, model, r),
"Unexpected parameter length in 'monty_samples' object 'initial'")
})


test_that("can run more than one chain, in parallel", {
model <- ex_simple_gamma1()
sampler <- monty_sampler_random_walk(vcv = diag(1) * 0.01)
Expand Down Expand Up @@ -441,3 +472,19 @@ test_that("can choose not to append when continuing samples", {

expect_equal(res2b$pars, res1$pars[, 61:100, , drop = FALSE])
})


test_that("can use samples as initial conditions", {
set.seed(1)
m <- ex_sir_filter_posterior(n_particles = 20)
vcv <- matrix(c(0.0006405, 0.0005628, 0.0005628, 0.0006641), 2, 2)
sampler <- monty_sampler_random_walk(vcv = vcv)
initial <- c(0.2, 0.1)

set.seed(1)
res1 <- monty_sample(m, sampler, 50, initial, n_chains = 2)
res2 <- monty_sample(m, sampler, 20, res1, n_chains = 3)

expect_equal(dim(res1$pars), c(2, 50, 2))
expect_equal(dim(res2$pars), c(2, 20, 3))
})

0 comments on commit 2dc1e64

Please sign in to comment.