diff --git a/Makefile b/Makefile index ce34dcf8..9d10accd 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ PACKAGE := $(shell grep '^Package:' DESCRIPTION | sed -E 's/^Package:[[:space:]]+//') -RSCRIPT = Rscript --no-init-file +RSCRIPT = Rscript all: ${RSCRIPT} -e 'pkgbuild::compile_dll()' diff --git a/R/runner-callr.R b/R/runner-callr.R index b71ddfc2..0ce2bb56 100644 --- a/R/runner-callr.R +++ b/R/runner-callr.R @@ -77,15 +77,15 @@ monty_runner_callr <- function(n_workers, progress = NULL) { all(env$status == "done") } - loop <- function(path, n_workers, n_chains, n_steps, progress) { - pb <- progress_bar(n_chains, n_steps, progress, show_overall = TRUE) + loop <- function(path, n_workers, n_chains, steps, progress) { + pb <- progress_bar(n_chains, steps$total, progress, show_overall = TRUE) n_workers <- min(n_chains, n_workers) env$path <- path env$sessions <- vector("list", n_workers) env$target <- rep(NA_integer_, n_workers) env$status <- rep("pending", n_chains) env$result_path <- rep(NA_character_, n_chains) - env$n_steps <- n_steps + env$n_steps <- steps$total env$n_steps_progress <- rep(0, n_chains) env$progress <- pb(seq_len(n_chains)) for (session_id in seq_len(n_workers)) { @@ -99,26 +99,26 @@ monty_runner_callr <- function(n_workers, progress = NULL) { res } - run <- function(pars, model, sampler, n_steps, rng) { + run <- function(pars, model, sampler, steps, rng) { seed <- unlist(lapply(rng, function(r) r$state())) n_chains <- length(rng) path <- tempfile() sample_manual_prepare( - model = model, sampler = sampler, n_steps = n_steps, path = path, + model = model, sampler = sampler, steps = steps, path = path, initial = pars, n_chains = n_chains, seed = seed) - loop(path, n_workers, n_chains, n_steps, progress) + loop(path, n_workers, n_chains, steps, progress) } - continue <- function(state, model, sampler, n_steps) { + continue <- function(state, model, sampler, steps) { restart <- list(state = state, model = model, sampler = sampler) n_chains <- length(state) path <- tempfile() monty_sample_manual_prepare_continue( - list(restart = restart), n_steps, path, "nothing") - loop(path, n_workers, n_chains, n_steps, progress) + list(restart = restart), steps, path, "nothing") + loop(path, n_workers, n_chains, steps, progress) } monty_runner("callr", diff --git a/R/runner-simultaneous.R b/R/runner-simultaneous.R index 73b71c99..282c88a0 100644 --- a/R/runner-simultaneous.R +++ b/R/runner-simultaneous.R @@ -32,10 +32,10 @@ monty_runner_simultaneous <- function(progress = NULL) { call = environment()) } - run <- function(pars, model, sampler, n_steps, rng) { + run <- function(pars, model, sampler, steps, rng) { validate_suitable(model) n_chains <- length(rng) - pb <- progress_bar(n_chains, n_steps, progress, show_overall = FALSE) + pb <- progress_bar(n_chains, steps$total, progress, show_overall = FALSE) progress <- pb(seq_len(n_chains)) rng_state <- lapply(rng, function(r) r$state()) ## TODO: get the rng state back into 'rng' here, or (better) look @@ -45,16 +45,16 @@ monty_runner_simultaneous <- function(progress = NULL) { ## > rng[[i]]$set_state(rng_state[, i]) # not supported! ## > } monty_run_chains_simultaneous(pars, model, sampler, - n_steps, progress, rng_state) + steps, progress, rng_state) } - continue <- function(state, model, sampler, n_steps) { + continue <- function(state, model, sampler, steps) { validate_suitable(model) n_chains <- length(state) - pb <- progress_bar(n_chains, n_steps, progress, show_overall = FALSE) + pb <- progress_bar(n_chains, steps$total, progress, show_overall = FALSE) progress <- pb(seq_len(n_chains)) monty_continue_chains_simultaneous(state, model, sampler, - n_steps, progress) + steps, progress) } monty_runner("Simultaneous", @@ -74,7 +74,7 @@ monty_runner_simultaneous <- function(progress = NULL) { ## hard to avoid. ## * there's quite a lot of churn around rng state monty_run_chains_simultaneous <- function(pars, model, sampler, - n_steps, progress, rng_state) { + steps, progress, rng_state) { r_rng_state <- get_r_rng_state() n_chains <- length(rng_state) rng <- monty_rng$new(unlist(rng_state), n_chains) @@ -82,12 +82,12 @@ monty_run_chains_simultaneous <- function(pars, model, sampler, chain_state <- sampler$initialise(pars, model, rng) monty_run_chains_simultaneous2(chain_state, model, sampler, - n_steps, progress, rng, r_rng_state) + steps, progress, rng, r_rng_state) } monty_continue_chains_simultaneous <- function(state, model, sampler, - n_steps, progress) { + steps, progress) { r_rng_state <- get_r_rng_state() n_chains <- length(state) n_pars <- length(model$parameters) @@ -111,21 +111,22 @@ monty_continue_chains_simultaneous <- function(state, model, sampler, ## Need to use model$rng_state$set to put state$model_rng into the model monty_run_chains_simultaneous2(chain_state, model, sampler, - n_steps, progress, rng, r_rng_state) + steps, progress, rng, r_rng_state) } monty_run_chains_simultaneous2 <- function(chain_state, model, sampler, - n_steps, progress, rng, + steps, progress, rng, r_rng_state) { initial <- chain_state$pars n_pars <- length(model$parameters) n_chains <- length(chain_state$density) + n_steps_record <- steps$total - history_pars <- array(NA_real_, c(n_pars, n_steps, n_chains)) - history_density <- matrix(NA_real_, n_steps, n_chains) + history_pars <- array(NA_real_, c(n_pars, n_steps_record, n_chains)) + history_density <- matrix(NA_real_, n_steps_record, n_chains) - for (i in seq_len(n_steps)) { + for (i in seq_len(steps$total)) { chain_state <- sampler$step(chain_state, model, rng) history_pars[, i, ] <- chain_state$pars history_density[i, ] <- chain_state$density diff --git a/R/runner.R b/R/runner.R index 78538cf3..b52469d6 100644 --- a/R/runner.R +++ b/R/runner.R @@ -22,24 +22,24 @@ ##' ##' @export monty_runner_serial <- function(progress = NULL) { - run <- function(pars, model, sampler, n_steps, rng) { + run <- function(pars, model, sampler, steps, rng) { n_chains <- length(rng) - pb <- progress_bar(n_chains, n_steps, progress, show_overall = TRUE) + pb <- progress_bar(n_chains, steps$total, progress, show_overall = TRUE) lapply( seq_along(rng), function(i) { - monty_run_chain(pars[, i], model, sampler, n_steps, + monty_run_chain(pars[, i], model, sampler, steps, pb(i), rng[[i]]) }) } - continue <- function(state, model, sampler, n_steps) { + continue <- function(state, model, sampler, steps) { n_chains <- length(state) - pb <- progress_bar(n_chains, n_steps, progress, show_overall = TRUE) + pb <- progress_bar(n_chains, steps$total, progress, show_overall = TRUE) lapply( seq_along(state), function(i) { - monty_continue_chain(state[[i]], model, sampler, n_steps, pb(i)) + monty_continue_chain(state[[i]], model, sampler, steps, pb(i)) }) } @@ -91,7 +91,7 @@ monty_runner_parallel <- function(n_workers) { ## get the advantage that the cluster startup happens asyncronously ## and may be ready by the time we actually pass any work onto it. - run <- function(pars, model, sampler, n_steps, rng) { + run <- function(pars, model, sampler, steps, rng) { n_chains <- length(rng) cl <- parallel::makeCluster(min(n_chains, n_workers)) on.exit(parallel::stopCluster(cl)) @@ -110,7 +110,7 @@ monty_runner_parallel <- function(n_workers) { args <- list(model = model, sampler = sampler, - n_steps = n_steps) + steps = steps) ## To debug issues in the parallel sampler, it's most efficient to ## replace this call with `Map` and drop the `cl` argument, then @@ -123,13 +123,13 @@ monty_runner_parallel <- function(n_workers) { MoreArgs = args) } - continue <- function(state, model, sampler, n_steps) { + continue <- function(state, model, sampler, steps) { n_chains <- length(state) cl <- parallel::makeCluster(min(n_chains, n_workers)) on.exit(parallel::stopCluster(cl)) args <- list(model = model, sampler = sampler, - n_steps = n_steps, + steps = steps, progress = function(i) NULL) parallel::clusterMap( cl, @@ -145,14 +145,14 @@ monty_runner_parallel <- function(n_workers) { } -monty_run_chain_parallel <- function(pars, model, sampler, n_steps, rng) { +monty_run_chain_parallel <- function(pars, model, sampler, steps, rng) { rng <- monty_rng$new(rng) progress <- function(i) NULL - monty_run_chain(pars, model, sampler, n_steps, progress, rng) + monty_run_chain(pars, model, sampler, steps, progress, rng) } -monty_run_chain <- function(pars, model, sampler, n_steps, +monty_run_chain <- function(pars, model, sampler, steps, progress, rng) { r_rng_state <- get_r_rng_state() chain_state <- sampler$initialise(pars, model, rng) @@ -177,40 +177,49 @@ monty_run_chain <- function(pars, model, sampler, n_steps, cli::cli_abort("Chain does not have finite starting density") } - monty_run_chain2(chain_state, model, sampler, n_steps, progress, + monty_run_chain2(chain_state, model, sampler, steps, progress, rng, r_rng_state) } -monty_continue_chain <- function(state, model, sampler, n_steps, - progress) { +monty_continue_chain <- function(state, model, sampler, steps, progress) { r_rng_state <- get_r_rng_state() rng <- monty_rng$new(seed = state$rng) sampler$set_internal_state(state$sampler) if (model$properties$is_stochastic) { model$rng_state$set(state$model_rng) } - monty_run_chain2(state$chain, model, sampler, n_steps, progress, + monty_run_chain2(state$chain, model, sampler, steps, progress, rng, r_rng_state) } -monty_run_chain2 <- function(chain_state, model, sampler, n_steps, +monty_run_chain2 <- function(chain_state, model, sampler, steps, progress, rng, r_rng_state) { initial <- chain_state$pars n_pars <- length(model$parameters) has_observer <- model$properties$has_observer - history_pars <- matrix(NA_real_, n_pars, n_steps) - history_density <- rep(NA_real_, n_steps) - history_observation <- if (has_observer) vector("list", n_steps) else NULL + burnin <- steps$burnin + thinning_factor <- steps$thinning_factor + n_steps <- steps$total + n_steps_record <- ceiling((steps$total - burnin) / thinning_factor) + history_pars <- matrix(NA_real_, n_pars, n_steps_record) + history_density <- rep(NA_real_, n_steps_record) + history_observation <- + if (has_observer) vector("list", n_steps_record) else NULL + + j <- 1L for (i in seq_len(n_steps)) { chain_state <- sampler$step(chain_state, model, rng) - history_pars[, i] <- chain_state$pars - history_density[[i]] <- chain_state$density - if (has_observer && !is.null(chain_state$observation)) { - history_observation[[i]] <- chain_state$observation + if (i > burnin && i %% thinning_factor == 0) { + history_pars[, j] <- chain_state$pars + history_density[[j]] <- chain_state$density + if (has_observer && !is.null(chain_state$observation)) { + history_observation[[j]] <- chain_state$observation + } + j <- j + 1L } progress(i) } diff --git a/R/sample-manual.R b/R/sample-manual.R index 5d27aaf4..baebb534 100644 --- a/R/sample-manual.R +++ b/R/sample-manual.R @@ -59,14 +59,16 @@ ##' # Clean up samples ##' monty_sample_manual_cleanup(path) monty_sample_manual_prepare <- function(model, sampler, n_steps, path, - initial = NULL, n_chains = 1L) { + initial = NULL, n_chains = 1L, + burnin = NULL, thinning_factor = NULL) { ## This break exists to hide the 'seed' argument from the public ## interface. We will use this from the callr version though. - sample_manual_prepare(model, sampler, n_steps, path, initial, n_chains) + steps <- monty_sample_steps(n_steps, burnin, thinning_factor) + sample_manual_prepare(model, sampler, steps, path, initial, n_chains) } -sample_manual_prepare <- function(model, sampler, n_steps, path, initial, +sample_manual_prepare <- function(model, sampler, steps, path, initial, n_chains, seed = NULL, call = parent.frame()) { assert_is(model, "monty_model", call = call) @@ -81,7 +83,7 @@ sample_manual_prepare <- function(model, sampler, n_steps, path, initial, sampler = sampler, rng_state = rng_state, n_chains = n_chains, - n_steps = n_steps) + steps = steps) sample_manual_path_create(path, dat, call = call) } @@ -119,25 +121,25 @@ monty_sample_manual_run <- function(chain_id, path, progress = NULL) { } n_chains <- inputs$n_chains - n_steps <- inputs$n_steps + steps <- inputs$steps restart <- inputs$restart is_continue <- is.list(restart) - pb <- progress_bar(n_chains, n_steps, progress, + pb <- progress_bar(n_chains, steps$total, progress, show_overall = FALSE, single_chain = TRUE)(chain_id) if (is_continue) { state <- restart$state model <- restart$model sampler <- restart$sampler - res <- monty_continue_chain(state[[chain_id]], model, sampler, n_steps, pb) + res <- monty_continue_chain(state[[chain_id]], model, sampler, steps, pb) } else { pars <- inputs$pars model <- inputs$model sampler <- inputs$sampler rng <- monty_rng$new(seed = inputs$rng_state[[chain_id]]) - res <- monty_run_chain(pars[, chain_id], model, sampler, n_steps, pb, rng) + res <- monty_run_chain(pars[, chain_id], model, sampler, steps, pb, rng) } saveRDS(res, path$results) @@ -186,7 +188,8 @@ monty_sample_manual_collect <- function(path, samples = NULL, } if (restartable) { - samples$restart <- restart_data(res, inputs$model, inputs$sampler, NULL) + samples$restart <- restart_data(res, inputs$model, inputs$sampler, NULL, + inputs$steps$thinning_factor) } samples } @@ -233,11 +236,12 @@ monty_sample_manual_info <- function(path) { path <- sample_manual_path(path) inputs <- readRDS(path$inputs) n_chains <- inputs$n_chains + n_steps <- inputs$steps$total path <- sample_manual_path(path$root, seq_len(inputs$n_chains)) done <- file.exists(path$results) cli::cli_h1("Manual monty sampling at {.path {path$root}}") cli::cli_alert_info("Created {format(file.info(path$inputs)$ctime)}") - cli::cli_alert_info("{inputs$n_steps} steps x {n_chains} chains") + cli::cli_alert_info("{n_steps} steps x {n_chains} chains") if (is.list(inputs$restart)) { cli::cli_alert_info("This is a restart") } @@ -289,9 +293,13 @@ monty_sample_manual_prepare_continue <- function(samples, n_steps, path, restart <- samples$restart samples <- sample_manual_prepare_check_samples(samples, save_samples) + steps <- monty_sample_steps(n_steps, + burnin = NULL, + thinning_factor = restart$thinning_factor) + dat <- list(restart = restart, n_chains = length(restart$state), - n_steps = n_steps, + steps = steps, samples = samples) sample_manual_path_create(path, dat) } diff --git a/R/sample.R b/R/sample.R index c38f47ce..79fa88fb 100644 --- a/R/sample.R +++ b/R/sample.R @@ -35,6 +35,27 @@ ##' restartable. This will add additional data to the chains ##' object. ##' +##' @param burnin Number of steps to discard as burnin. This affects +##' only the recording of steps as your chains run; we don't record +##' the first `burnin` steps. Generally you would want to do this +##' in post-processing as this data is discarded with no chance of +##' getting it back. However, if your observation process creates a +##' large amount of data, then you may prefer to apply a burnin here +##' to reduce how much memory is used. +##' +##' @param thinning_factor A thinning factor to apply while the chain +##' is running. If given, then we save every `thinning_factor`'th +##' step. So if `thinning_factor = 2` we save every second step, +##' and if 10, we'd save every 10th. Like `burnin` above, it is +##' preferable to apply this in post processing. However, for +##' slow-mixing chains that have a large observer output you can use +##' this to reduce the memory usage. Use of `thinning_factor` +##' requires that `n_steps` is an even multiple of +##' `thinning_factor`; so if `thinning_factor` is 10, then `n_steps` +##' must be a multiple of 10. This ensures that the last step is in +##' the sample. The thinning factor cannot be changed when +##' continuing a chain. +##' ##' @return A list of parameters and densities. We provide conversion ##' to formats used by other packages, notably ##' [posterior::as_draws_array], [posterior::as_draws_df] and @@ -83,7 +104,8 @@ ##' # diagnostics. monty_sample <- function(model, sampler, n_steps, initial = NULL, n_chains = 1L, runner = NULL, - restartable = FALSE) { + restartable = FALSE, burnin = NULL, + thinning_factor = NULL) { assert_is(model, "monty_model") assert_is(sampler, "monty_sampler") if (is.null(runner)) { @@ -95,12 +117,14 @@ monty_sample <- function(model, sampler, n_steps, initial = NULL, rng <- initial_rng(n_chains) pars <- initial_parameters(initial, model, rng, environment()) - res <- runner$run(pars, model, sampler, n_steps, rng) + steps <- monty_sample_steps(n_steps, burnin, thinning_factor) + res <- runner$run(pars, model, sampler, steps, rng) observer <- if (model$properties$has_observer) model$observer else NULL samples <- combine_chains(res, model$observer) if (restartable) { - samples$restart <- restart_data(res, model, sampler, runner) + samples$restart <- restart_data(res, model, sampler, runner, + thinning_factor) } samples } @@ -149,13 +173,18 @@ monty_sample_continue <- function(samples, n_steps, restartable = FALSE, model <- samples$restart$model sampler <- samples$restart$sampler - res <- runner$continue(state, model, sampler, n_steps) + burnin <- NULL + thinning_factor <- samples$restart$thinning_factor + steps <- monty_sample_steps(n_steps, burnin, thinning_factor) + + res <- runner$continue(state, model, sampler, steps) observer <- if (model$properties$has_observer) model$observer else NULL samples <- append_chains(samples, combine_chains(res, observer), observer) if (restartable) { - samples$restart <- restart_data(res, model, sampler, runner) + samples$restart <- restart_data(res, model, sampler, runner, + thinning_factor) } samples } @@ -362,7 +391,7 @@ initial_rng <- function(n_chains, seed = NULL) { } -restart_data <- function(res, model, sampler, runner) { +restart_data <- function(res, model, sampler, runner, thinning_factor) { if (is.null(names(res))) { state <- lapply(res, function(x) x$internal$state) } else { @@ -382,7 +411,8 @@ restart_data <- function(res, model, sampler, runner) { list(state = state, model = model, sampler = sampler, - runner = runner) + runner = runner, + thinning_factor = thinning_factor) } @@ -399,3 +429,36 @@ direct_sample_within_domain <- function(model, rng, max_attempts = 100) { "samples that fall outside your model's domain. Probably", "you should fix one or both of these!"))) } + + +monty_sample_steps <- function(n_steps, burnin = NULL, thinning_factor = NULL, + call = parent.frame()) { + if (inherits(n_steps, "monty_sample_steps")) { + return(n_steps) + } + assert_scalar_size(n_steps, call = call) + if (is.null(burnin)) { + burnin <- 0 + } else { + assert_scalar_size(burnin, allow_zero = TRUE, call = call) + if (burnin >= n_steps) { + cli::cli_abort("'burnin' must be smaller than 'n_steps'", + arg = "burnin", call = call) + } + } + if (is.null(thinning_factor)) { + thinning_factor <- 1 + } else { + assert_scalar_size(thinning_factor, allow_zero = FALSE, call = call) + if (n_steps %% thinning_factor != 0) { + cli::cli_abort( + "'thinning_factor' must be a divisor of 'n_steps'", + call = call) + } + } + ret <- list(total = n_steps, + burnin = burnin, + thinning_factor = thinning_factor) + class(ret) <- "monty_sample_steps" + ret +} diff --git a/man/monty_sample.Rd b/man/monty_sample.Rd index cc4dd568..02776879 100644 --- a/man/monty_sample.Rd +++ b/man/monty_sample.Rd @@ -11,7 +11,9 @@ monty_sample( initial = NULL, n_chains = 1L, runner = NULL, - restartable = FALSE + restartable = FALSE, + burnin = NULL, + thinning_factor = NULL ) } \arguments{ @@ -42,6 +44,27 @@ one chain then this argument is best left alone.} \item{restartable}{Logical, indicating if the chains should be restartable. This will add additional data to the chains object.} + +\item{burnin}{Number of steps to discard as burnin. This affects +only the recording of steps as your chains run; we don't record +the first \code{burnin} steps. Generally you would want to do this +in post-processing as this data is discarded with no chance of +getting it back. However, if your observation process creates a +large amount of data, then you may prefer to apply a burnin here +to reduce how much memory is used.} + +\item{thinning_factor}{A thinning factor to apply while the chain +is running. If given, then we save every \code{thinning_factor}'th +step. So if \code{thinning_factor = 2} we save every second step, +and if 10, we'd save every 10th. Like \code{burnin} above, it is +preferable to apply this in post processing. However, for +slow-mixing chains that have a large observer output you can use +this to reduce the memory usage. Use of \code{thinning_factor} +requires that \code{n_steps} is an even multiple of +\code{thinning_factor}; so if \code{thinning_factor} is 10, then \code{n_steps} +must be a multiple of 10. This ensures that the last step is in +the sample. The thinning factor cannot be changed when +continuing a chain.} } \value{ A list of parameters and densities. We provide conversion diff --git a/man/monty_sample_manual_prepare.Rd b/man/monty_sample_manual_prepare.Rd index 1e7d3ed2..cb111079 100644 --- a/man/monty_sample_manual_prepare.Rd +++ b/man/monty_sample_manual_prepare.Rd @@ -10,7 +10,9 @@ monty_sample_manual_prepare( n_steps, path, initial = NULL, - n_chains = 1L + n_chains = 1L, + burnin = NULL, + thinning_factor = NULL ) } \arguments{ @@ -43,6 +45,27 @@ sampling. If not given, we sample from the model (or its prior).} \item{n_chains}{Number of chains to run. The default is to run a single chain, but you will likely want to run more.} + +\item{burnin}{Number of steps to discard as burnin. This affects +only the recording of steps as your chains run; we don't record +the first \code{burnin} steps. Generally you would want to do this +in post-processing as this data is discarded with no chance of +getting it back. However, if your observation process creates a +large amount of data, then you may prefer to apply a burnin here +to reduce how much memory is used.} + +\item{thinning_factor}{A thinning factor to apply while the chain +is running. If given, then we save every \code{thinning_factor}'th +step. So if \code{thinning_factor = 2} we save every second step, +and if 10, we'd save every 10th. Like \code{burnin} above, it is +preferable to apply this in post processing. However, for +slow-mixing chains that have a large observer output you can use +this to reduce the memory usage. Use of \code{thinning_factor} +requires that \code{n_steps} is an even multiple of +\code{thinning_factor}; so if \code{thinning_factor} is 10, then \code{n_steps} +must be a multiple of 10. This ensures that the last step is in +the sample. The thinning factor cannot be changed when +continuing a chain.} } \value{ Invisibly, the path used to store files (the same as the diff --git a/tests/testthat/test-sample-manual.R b/tests/testthat/test-sample-manual.R index f7244e8b..bd0f3a59 100644 --- a/tests/testthat/test-sample-manual.R +++ b/tests/testthat/test-sample-manual.R @@ -223,3 +223,30 @@ test_that("samples, if provided, must match", { samples2), "Provided 'samples' does not match those at the start of the chain") }) + + +test_that("can use burnin/thinning_factor in manual sampling", { + model <- ex_simple_gamma1() + sampler <- monty_sampler_random_walk(vcv = diag(1) * 0.01) + + set.seed(1) + res1 <- monty_sample(model, sampler, 100, n_chains = 2, + burnin = 20, thinning_factor = 4) + + set.seed(1) + path_a <- withr::local_tempdir() + monty_sample_manual_prepare(model, sampler, 60, path_a, n_chains = 2, + burnin = 20, thinning_factor = 4) + monty_sample_manual_run(1, path_a) + monty_sample_manual_run(2, path_a) + res2a <- monty_sample_manual_collect(path_a, restartable = TRUE) + expect_equal(res2a$restart$thinning_factor, 4) + + path_b <- withr::local_tempdir() + monty_sample_manual_prepare_continue(res2a, 40, path_b) + monty_sample_manual_run(1, path_b) + monty_sample_manual_run(2, path_b) + res2b <- monty_sample_manual_collect(path_b, res2a) + + expect_equal(res2b$pars, res1$pars) +}) diff --git a/tests/testthat/test-sample.R b/tests/testthat/test-sample.R index 82117b2e..cc8e0476 100644 --- a/tests/testthat/test-sample.R +++ b/tests/testthat/test-sample.R @@ -328,3 +328,100 @@ test_that("can change runner on restart", { expect_equal(res2d, res1) }) + + +test_that("validate burnin", { + expect_equal(monty_sample_steps(100, NULL)$burnin, 0) + expect_equal(monty_sample_steps(100, 0)$burnin, 0) + expect_equal(monty_sample_steps(100, 20)$burnin, 20) + expect_error(monty_sample_steps(100, 200), + "'burnin' must be smaller than 'n_steps'") +}) + + +test_that("validate thinning rate", { + expect_equal( + monty_sample_steps(100, thinning_factor = NULL)$thinning_factor, 1) + expect_equal( + monty_sample_steps(100, thinning_factor = 1)$thinning_factor, 1) + expect_equal( + monty_sample_steps(100, thinning_factor = 20)$thinning_factor, 20) + expect_error( + monty_sample_steps(100, thinning_factor = 0), + "'thinning_factor' must be at least 1") + expect_error( + monty_sample_steps(100, thinning_factor = 17)$thinning_factor, + "'thinning_factor' must be a divisor of 'n_steps'") +}) + + +test_that("can sample with burnin", { + model <- ex_simple_gamma1() + sampler <- monty_sampler_random_walk(vcv = diag(1) * 0.01) + set.seed(1) + res1 <- monty_sample(model, sampler, 100, 3) + set.seed(1) + res2 <- monty_sample(model, sampler, 100, 3, burnin = 30) + expect_equal(res2$pars, res1$pars[, 31:100, , drop = FALSE]) +}) + + +test_that("can sample with thinning", { + model <- ex_simple_gamma1() + sampler <- monty_sampler_random_walk(vcv = diag(1) * 0.01) + set.seed(1) + res1 <- monty_sample(model, sampler, 100, 3) + set.seed(1) + res2 <- monty_sample(model, sampler, 100, 3, thinning_factor = 4) + expect_equal(res2$pars, res1$pars[, seq(4, 100, by = 4), , drop = FALSE]) +}) + + +test_that("can sample with burnin and thinning", { + model <- ex_simple_gamma1() + sampler <- monty_sampler_random_walk(vcv = diag(1) * 0.01) + set.seed(1) + res1 <- monty_sample(model, sampler, 100, 3) + ## Pick deliberately hard values here: + set.seed(1) + res2 <- monty_sample(model, sampler, 100, 3, + burnin = 25, thinning_factor = 10) + expect_equal(res2$pars, res1$pars[, seq(30, 100, by = 10), , drop = FALSE]) + set.seed(1) + res3 <- monty_sample(model, sampler, 100, 3, + burnin = 20, thinning_factor = 10) + expect_equal(res3$pars, res1$pars[, seq(30, 100, by = 10), , drop = FALSE]) +}) + + +test_that("can continue burnt in chain, does not burnin any further", { + model <- ex_simple_gamma1() + sampler <- monty_sampler_random_walk(vcv = diag(1) * 0.01) + + set.seed(1) + res1 <- monty_sample(model, sampler, 100, 1, n_chains = 3, burnin = 20) + + set.seed(1) + res2a <- monty_sample(model, sampler, 60, 1, n_chains = 3, + burnin = 20, restartable = TRUE) + res2b <- monty_sample_continue(res2a, 40) + + expect_equal(res2b, res1) +}) + + +test_that("can continue thinned chain, continues thinning", { + model <- ex_simple_gamma1() + sampler <- monty_sampler_random_walk(vcv = diag(1) * 0.01) + + set.seed(1) + res1 <- monty_sample(model, sampler, 100, 1, n_chains = 3, + thinning_factor = 4) + + set.seed(1) + res2a <- monty_sample(model, sampler, 60, 1, n_chains = 3, + thinning_factor = 4, restartable = TRUE) + res2b <- monty_sample_continue(res2a, 40) + + expect_equal(res2b, res1) +})