Skip to content

Commit

Permalink
Merge pull request #94 from mrc-ide/mrc-5933
Browse files Browse the repository at this point in the history
Allow continuing chain without appending
  • Loading branch information
edknock authored Oct 31, 2024
2 parents 72b5670 + 4b2b909 commit c3f1e79
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 20 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.2.26
Version: 0.2.27
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
19 changes: 13 additions & 6 deletions R/sample-manual.R
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ monty_sample_manual_run <- function(chain_id, path, progress = NULL) {
##' @title Collect manually run samples
##'
##' @inheritParams monty_sample_manual_run
##' @inheritParams monty_sample_continue
##'
##' @param samples Samples from the parent run. You need to provide
##' these where `save_samples` was set to anything other than "value"
Expand All @@ -168,23 +169,24 @@ monty_sample_manual_run <- function(chain_id, path, progress = NULL) {
##' @export
##' @inherit monty_sample_manual_prepare examples
monty_sample_manual_collect <- function(path, samples = NULL,
restartable = FALSE) {
restartable = FALSE,
append = TRUE) {
inputs <- readRDS(sample_manual_path(path)$inputs)
path <- sample_manual_path(path, seq_len(inputs$n_chains))
assert_scalar_logical(append)

msg <- !file.exists(path$results)
if (any(msg)) {
cli::cli_abort("Results missing for chain{?s} {as.character(which(msg))}")
}

prev <- sample_manual_collect_check_samples(inputs, samples)
prev <- sample_manual_collect_check_samples(inputs, samples, append)

observer <- inputs$model$observer
res <- lapply(path$results, readRDS)
samples <- combine_chains(res, observer)
if (!is.null(prev)) {
samples <- append_chains(prev, combine_chains(res, observer),
observer)
samples <- append_chains(prev, samples, observer)
}

if (restartable) {
Expand Down Expand Up @@ -372,7 +374,7 @@ sample_manual_prepare_check_samples <- function(samples, save_samples,
}


sample_manual_collect_check_samples <- function(inputs, samples,
sample_manual_collect_check_samples <- function(inputs, samples, append,
call = parent.frame()) {
if (!is.list(inputs$restart)) {
if (!is.null(samples)) {
Expand All @@ -385,7 +387,9 @@ sample_manual_collect_check_samples <- function(inputs, samples,
arg = "samples", call = call)
}
} else if (is.null(samples)) {
if (!is.null(inputs$samples$value)) {
if (!append) {
samples <- NULL
} else if (!is.null(inputs$samples$value)) {
samples <- inputs$samples$value
} else {
cli::cli_abort(
Expand All @@ -401,6 +405,9 @@ sample_manual_collect_check_samples <- function(inputs, samples,
"Provided 'samples' does not match those at the start of the chain",
arg = "samples", call = call)
}
if (!append) {
samples <- NULL
}
}
samples
}
15 changes: 12 additions & 3 deletions R/sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,18 @@ monty_sample <- function(model, sampler, n_steps, initial = NULL,
##' well as the type of runner (e.g., changing the number of
##' allocated cores).
##'
##' @param append Logical, indicating if we should append the results
##' of the resumed chain together with the original chain.
##'
##' @inheritParams monty_sample
##'
##' @return A list of parameters and densities
##' @export
monty_sample_continue <- function(samples, n_steps, restartable = FALSE,
runner = NULL) {
runner = NULL, append = TRUE) {
check_can_continue_samples(samples)
assert_scalar_logical(restartable)
assert_scalar_logical(append)

if (is.null(runner)) {
runner <- samples$restart$runner
Expand All @@ -178,9 +182,14 @@ monty_sample_continue <- function(samples, n_steps, restartable = FALSE,
steps <- monty_sample_steps(n_steps, burnin, thinning_factor)

res <- runner$continue(state, model, sampler, steps)

observer <- if (model$properties$has_observer) model$observer else NULL
samples <- append_chains(samples, combine_chains(res, observer), observer)
samples_new <- combine_chains(res, observer)

if (append) {
samples <- append_chains(samples, samples_new, observer)
} else {
samples <- samples_new
}

if (restartable) {
samples$restart <- restart_data(res, model, sampler, runner,
Expand Down
11 changes: 10 additions & 1 deletion man/monty_sample_continue.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion man/monty_sample_manual_collect.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

67 changes: 59 additions & 8 deletions tests/testthat/test-sample-manual.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,32 @@ test_that("can continue a manually sampled chain", {
})


test_that("can continue a manually sampled chain without appending", {
model <- ex_simple_gamma1()
sampler <- monty_sampler_random_walk(vcv = diag(1) * 0.01)

set.seed(1)
res1a <- monty_sample(model, sampler, 100, n_chains = 2, restartable = TRUE)
res1b <- monty_sample_continue(res1a, 50)

set.seed(1)
path_a <- withr::local_tempdir()
monty_sample_manual_prepare(model, sampler, 100, path_a, n_chains = 2)
monty_sample_manual_run(1, path_a)
monty_sample_manual_run(2, path_a)
res2a <- monty_sample_manual_collect(path_a, restartable = TRUE)
expect_equal(res1a$pars, res2a$pars)

path_b <- withr::local_tempdir()
monty_sample_manual_prepare_continue(res2a, 50, path_b)
monty_sample_manual_run(1, path_b)
monty_sample_manual_run(2, path_b)
res2b <- monty_sample_manual_collect(path_b, res2a, append = FALSE)

expect_equal(res2b$pars, res1b$pars[, 101:150, , drop = FALSE])
})


test_that("path used for manual sampling must be empty", {
tmp <- withr::local_tempdir()
file.create(file.path(tmp, "other"))
Expand Down Expand Up @@ -135,7 +161,7 @@ test_that("can print information about chain completeness", {
})


test_that("...", {
test_that("can validate previously provided samples", {
path <- withr::local_tempdir()
model <- ex_simple_gamma1()
sampler <- monty_sampler_random_walk(vcv = diag(1) * 0.01)
Expand Down Expand Up @@ -173,22 +199,28 @@ test_that("can validate samples on collect for non-restarts", {
inputs <- NULL
expect_null(sample_manual_collect_check_samples(inputs, NULL))
expect_error(
sample_manual_collect_check_samples(inputs, list()),
sample_manual_collect_check_samples(inputs, list(), TRUE),
"'samples' provided, but this was not a restarted sample")
expect_error(
sample_manual_collect_check_samples(inputs, list(), FALSE),
"'samples' provided, but this was not a restarted sample")
})


test_that("continuation without passing samples allowed if samples saved", {
inputs <- list(restart = list(1), samples = list(value = 2))
expect_equal(sample_manual_collect_check_samples(inputs, NULL), 2)
expect_equal(sample_manual_collect_check_samples(inputs, NULL, TRUE), 2)
expect_null(sample_manual_collect_check_samples(inputs, NULL, FALSE))
})


test_that("continuation without passing samples not allowed otherwise", {
inputs <- list(restart = list(1))
expect_error(
sample_manual_collect_check_samples(inputs, NULL),
sample_manual_collect_check_samples(inputs, NULL, TRUE),
"Expected 'samples' to be provided, as this chain is a continuation")
expect_null(
sample_manual_collect_check_samples(inputs, NULL, FALSE))
})


Expand All @@ -204,23 +236,42 @@ test_that("samples, if provided, must match", {
expect_equal(
sample_manual_collect_check_samples(
list(restart = list(), samples = list(value = samples1, hash = NULL)),
samples1),
samples1, TRUE),
samples1)
expect_equal(
sample_manual_collect_check_samples(
list(restart = list(), samples = list(value = NULL, hash = hash1)),
samples1),
samples1, TRUE),
samples1)

expect_null(
sample_manual_collect_check_samples(
list(restart = list(), samples = list(value = samples1, hash = NULL)),
samples1, FALSE))
expect_null(
sample_manual_collect_check_samples(
list(restart = list(), samples = list(value = NULL, hash = hash1)),
samples1, FALSE))

expect_error(
sample_manual_collect_check_samples(
list(restart = list(), samples = list(value = samples1, hash = NULL)),
samples2, TRUE),
"Provided 'samples' does not match those at the start of the chain")
expect_error(
sample_manual_collect_check_samples(
list(restart = list(), samples = list(value = NULL, hash = hash1)),
samples2, TRUE),
"Provided 'samples' does not match those at the start of the chain")
expect_error(
sample_manual_collect_check_samples(
list(restart = list(), samples = list(value = samples1, hash = NULL)),
samples2),
samples2, FALSE),
"Provided 'samples' does not match those at the start of the chain")
expect_error(
sample_manual_collect_check_samples(
list(restart = list(), samples = list(value = NULL, hash = hash1)),
samples2),
samples2, FALSE),
"Provided 'samples' does not match those at the start of the chain")
})

Expand Down
16 changes: 16 additions & 0 deletions tests/testthat/test-sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,19 @@ test_that("can continue thinned chain, continues thinning", {

expect_equal(res2b, res1)
})


test_that("can choose not to append when continuing samples", {
model <- ex_simple_gamma1()
sampler <- monty_sampler_random_walk(vcv = diag(1) * 0.01)

set.seed(1)
res1 <- monty_sample(model, sampler, 100, 1, n_chains = 3)

set.seed(1)
res2a <- monty_sample(model, sampler, 60, 1, n_chains = 3,
restartable = TRUE)
res2b <- monty_sample_continue(res2a, 40, append = FALSE)

expect_equal(res2b$pars, res1$pars[, 61:100, , drop = FALSE])
})

0 comments on commit c3f1e79

Please sign in to comment.