Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow chain thinning/burnin while running #91

Merged
merged 8 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.2.23
Version: 0.2.24
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
PACKAGE := $(shell grep '^Package:' DESCRIPTION | sed -E 's/^Package:[[:space:]]+//')
RSCRIPT = Rscript --no-init-file
RSCRIPT = Rscript

all:
${RSCRIPT} -e 'pkgbuild::compile_dll()'
Expand Down
18 changes: 9 additions & 9 deletions R/runner-callr.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ monty_runner_callr <- function(n_workers, progress = NULL) {
all(env$status == "done")
}

loop <- function(path, n_workers, n_chains, n_steps, progress) {
pb <- progress_bar(n_chains, n_steps, progress, show_overall = TRUE)
loop <- function(path, n_workers, n_chains, steps, progress) {
pb <- progress_bar(n_chains, steps$total, progress, show_overall = TRUE)
n_workers <- min(n_chains, n_workers)
env$path <- path
env$sessions <- vector("list", n_workers)
env$target <- rep(NA_integer_, n_workers)
env$status <- rep("pending", n_chains)
env$result_path <- rep(NA_character_, n_chains)
env$n_steps <- n_steps
env$n_steps <- steps$total
env$n_steps_progress <- rep(0, n_chains)
env$progress <- pb(seq_len(n_chains))
for (session_id in seq_len(n_workers)) {
Expand All @@ -99,26 +99,26 @@ monty_runner_callr <- function(n_workers, progress = NULL) {
res
}

run <- function(pars, model, sampler, n_steps, rng) {
run <- function(pars, model, sampler, steps, rng) {
seed <- unlist(lapply(rng, function(r) r$state()))
n_chains <- length(rng)
path <- tempfile()
sample_manual_prepare(
model = model, sampler = sampler, n_steps = n_steps, path = path,
model = model, sampler = sampler, steps = steps, path = path,
initial = pars, n_chains = n_chains,
seed = seed)
loop(path, n_workers, n_chains, n_steps, progress)
loop(path, n_workers, n_chains, steps, progress)
}

continue <- function(state, model, sampler, n_steps) {
continue <- function(state, model, sampler, steps) {
restart <- list(state = state,
model = model,
sampler = sampler)
n_chains <- length(state)
path <- tempfile()
monty_sample_manual_prepare_continue(
list(restart = restart), n_steps, path, "nothing")
loop(path, n_workers, n_chains, n_steps, progress)
list(restart = restart), steps, path, "nothing")
loop(path, n_workers, n_chains, steps, progress)
}

monty_runner("callr",
Expand Down
29 changes: 15 additions & 14 deletions R/runner-simultaneous.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ monty_runner_simultaneous <- function(progress = NULL) {
call = environment())
}

run <- function(pars, model, sampler, n_steps, rng) {
run <- function(pars, model, sampler, steps, rng) {
validate_suitable(model)
n_chains <- length(rng)
pb <- progress_bar(n_chains, n_steps, progress, show_overall = FALSE)
pb <- progress_bar(n_chains, steps$total, progress, show_overall = FALSE)
progress <- pb(seq_len(n_chains))
rng_state <- lapply(rng, function(r) r$state())
## TODO: get the rng state back into 'rng' here, or (better) look
Expand All @@ -45,16 +45,16 @@ monty_runner_simultaneous <- function(progress = NULL) {
## > rng[[i]]$set_state(rng_state[, i]) # not supported!
## > }
monty_run_chains_simultaneous(pars, model, sampler,
n_steps, progress, rng_state)
steps, progress, rng_state)
}

continue <- function(state, model, sampler, n_steps) {
continue <- function(state, model, sampler, steps) {
validate_suitable(model)
n_chains <- length(state)
pb <- progress_bar(n_chains, n_steps, progress, show_overall = FALSE)
pb <- progress_bar(n_chains, steps$total, progress, show_overall = FALSE)
progress <- pb(seq_len(n_chains))
monty_continue_chains_simultaneous(state, model, sampler,
n_steps, progress)
steps, progress)
}

monty_runner("Simultaneous",
Expand All @@ -74,20 +74,20 @@ monty_runner_simultaneous <- function(progress = NULL) {
## hard to avoid.
## * there's quite a lot of churn around rng state
monty_run_chains_simultaneous <- function(pars, model, sampler,
n_steps, progress, rng_state) {
steps, progress, rng_state) {
r_rng_state <- get_r_rng_state()
n_chains <- length(rng_state)
rng <- monty_rng$new(unlist(rng_state), n_chains)

chain_state <- sampler$initialise(pars, model, rng)

monty_run_chains_simultaneous2(chain_state, model, sampler,
n_steps, progress, rng, r_rng_state)
steps, progress, rng, r_rng_state)
}


monty_continue_chains_simultaneous <- function(state, model, sampler,
n_steps, progress) {
steps, progress) {
r_rng_state <- get_r_rng_state()
n_chains <- length(state)
n_pars <- length(model$parameters)
Expand All @@ -111,21 +111,22 @@ monty_continue_chains_simultaneous <- function(state, model, sampler,
## Need to use model$rng_state$set to put state$model_rng into the model

monty_run_chains_simultaneous2(chain_state, model, sampler,
n_steps, progress, rng, r_rng_state)
steps, progress, rng, r_rng_state)
}


monty_run_chains_simultaneous2 <- function(chain_state, model, sampler,
n_steps, progress, rng,
steps, progress, rng,
r_rng_state) {
initial <- chain_state$pars
n_pars <- length(model$parameters)
n_chains <- length(chain_state$density)
n_steps_record <- steps$total

history_pars <- array(NA_real_, c(n_pars, n_steps, n_chains))
history_density <- matrix(NA_real_, n_steps, n_chains)
history_pars <- array(NA_real_, c(n_pars, n_steps_record, n_chains))
history_density <- matrix(NA_real_, n_steps_record, n_chains)

for (i in seq_len(n_steps)) {
for (i in seq_len(steps$total)) {
chain_state <- sampler$step(chain_state, model, rng)
history_pars[, i, ] <- chain_state$pars
history_density[i, ] <- chain_state$density
Expand Down
59 changes: 34 additions & 25 deletions R/runner.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,24 @@
##'
##' @export
monty_runner_serial <- function(progress = NULL) {
run <- function(pars, model, sampler, n_steps, rng) {
run <- function(pars, model, sampler, steps, rng) {
n_chains <- length(rng)
pb <- progress_bar(n_chains, n_steps, progress, show_overall = TRUE)
pb <- progress_bar(n_chains, steps$total, progress, show_overall = TRUE)
lapply(
seq_along(rng),
function(i) {
monty_run_chain(pars[, i], model, sampler, n_steps,
monty_run_chain(pars[, i], model, sampler, steps,
pb(i), rng[[i]])
})
}

continue <- function(state, model, sampler, n_steps) {
continue <- function(state, model, sampler, steps) {
n_chains <- length(state)
pb <- progress_bar(n_chains, n_steps, progress, show_overall = TRUE)
pb <- progress_bar(n_chains, steps$total, progress, show_overall = TRUE)
lapply(
seq_along(state),
function(i) {
monty_continue_chain(state[[i]], model, sampler, n_steps, pb(i))
monty_continue_chain(state[[i]], model, sampler, steps, pb(i))
})
}

Expand Down Expand Up @@ -91,7 +91,7 @@ monty_runner_parallel <- function(n_workers) {
## get the advantage that the cluster startup happens asyncronously
## and may be ready by the time we actually pass any work onto it.

run <- function(pars, model, sampler, n_steps, rng) {
run <- function(pars, model, sampler, steps, rng) {
n_chains <- length(rng)
cl <- parallel::makeCluster(min(n_chains, n_workers))
on.exit(parallel::stopCluster(cl))
Expand All @@ -110,7 +110,7 @@ monty_runner_parallel <- function(n_workers) {

args <- list(model = model,
sampler = sampler,
n_steps = n_steps)
steps = steps)

## To debug issues in the parallel sampler, it's most efficient to
## replace this call with `Map` and drop the `cl` argument, then
Expand All @@ -123,13 +123,13 @@ monty_runner_parallel <- function(n_workers) {
MoreArgs = args)
}

continue <- function(state, model, sampler, n_steps) {
continue <- function(state, model, sampler, steps) {
n_chains <- length(state)
cl <- parallel::makeCluster(min(n_chains, n_workers))
on.exit(parallel::stopCluster(cl))
args <- list(model = model,
sampler = sampler,
n_steps = n_steps,
steps = steps,
progress = function(i) NULL)
parallel::clusterMap(
cl,
Expand All @@ -145,14 +145,14 @@ monty_runner_parallel <- function(n_workers) {
}


monty_run_chain_parallel <- function(pars, model, sampler, n_steps, rng) {
monty_run_chain_parallel <- function(pars, model, sampler, steps, rng) {
rng <- monty_rng$new(rng)
progress <- function(i) NULL
monty_run_chain(pars, model, sampler, n_steps, progress, rng)
monty_run_chain(pars, model, sampler, steps, progress, rng)
}


monty_run_chain <- function(pars, model, sampler, n_steps,
monty_run_chain <- function(pars, model, sampler, steps,
progress, rng) {
r_rng_state <- get_r_rng_state()
chain_state <- sampler$initialise(pars, model, rng)
Expand All @@ -177,40 +177,49 @@ monty_run_chain <- function(pars, model, sampler, n_steps,
cli::cli_abort("Chain does not have finite starting density")
}

monty_run_chain2(chain_state, model, sampler, n_steps, progress,
monty_run_chain2(chain_state, model, sampler, steps, progress,
rng, r_rng_state)
}


monty_continue_chain <- function(state, model, sampler, n_steps,
progress) {
monty_continue_chain <- function(state, model, sampler, steps, progress) {
r_rng_state <- get_r_rng_state()
rng <- monty_rng$new(seed = state$rng)
sampler$set_internal_state(state$sampler)
if (model$properties$is_stochastic) {
model$rng_state$set(state$model_rng)
}
monty_run_chain2(state$chain, model, sampler, n_steps, progress,
monty_run_chain2(state$chain, model, sampler, steps, progress,
rng, r_rng_state)
}


monty_run_chain2 <- function(chain_state, model, sampler, n_steps,
monty_run_chain2 <- function(chain_state, model, sampler, steps,
progress, rng, r_rng_state) {
initial <- chain_state$pars
n_pars <- length(model$parameters)
has_observer <- model$properties$has_observer

history_pars <- matrix(NA_real_, n_pars, n_steps)
history_density <- rep(NA_real_, n_steps)
history_observation <- if (has_observer) vector("list", n_steps) else NULL
burnin <- steps$burnin
thinning_factor <- steps$thinning_factor
n_steps <- steps$total
n_steps_record <- ceiling((steps$total - burnin) / thinning_factor)

history_pars <- matrix(NA_real_, n_pars, n_steps_record)
history_density <- rep(NA_real_, n_steps_record)
history_observation <-
if (has_observer) vector("list", n_steps_record) else NULL

j <- 1L
for (i in seq_len(n_steps)) {
chain_state <- sampler$step(chain_state, model, rng)
history_pars[, i] <- chain_state$pars
history_density[[i]] <- chain_state$density
if (has_observer && !is.null(chain_state$observation)) {
history_observation[[i]] <- chain_state$observation
if (i > burnin && i %% thinning_factor == 0) {
history_pars[, j] <- chain_state$pars
history_density[[j]] <- chain_state$density
if (has_observer && !is.null(chain_state$observation)) {
history_observation[[j]] <- chain_state$observation
}
j <- j + 1L
}
progress(i)
}
Expand Down
30 changes: 19 additions & 11 deletions R/sample-manual.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,16 @@
##' # Clean up samples
##' monty_sample_manual_cleanup(path)
monty_sample_manual_prepare <- function(model, sampler, n_steps, path,
initial = NULL, n_chains = 1L) {
initial = NULL, n_chains = 1L,
burnin = NULL, thinning_factor = NULL) {
## This break exists to hide the 'seed' argument from the public
## interface. We will use this from the callr version though.
sample_manual_prepare(model, sampler, n_steps, path, initial, n_chains)
steps <- monty_sample_steps(n_steps, burnin, thinning_factor)
sample_manual_prepare(model, sampler, steps, path, initial, n_chains)
}


sample_manual_prepare <- function(model, sampler, n_steps, path, initial,
sample_manual_prepare <- function(model, sampler, steps, path, initial,
n_chains, seed = NULL,
call = parent.frame()) {
assert_is(model, "monty_model", call = call)
Expand All @@ -81,7 +83,7 @@ sample_manual_prepare <- function(model, sampler, n_steps, path, initial,
sampler = sampler,
rng_state = rng_state,
n_chains = n_chains,
n_steps = n_steps)
steps = steps)

sample_manual_path_create(path, dat, call = call)
}
Expand Down Expand Up @@ -119,25 +121,25 @@ monty_sample_manual_run <- function(chain_id, path, progress = NULL) {
}

n_chains <- inputs$n_chains
n_steps <- inputs$n_steps
steps <- inputs$steps

restart <- inputs$restart
is_continue <- is.list(restart)

pb <- progress_bar(n_chains, n_steps, progress,
pb <- progress_bar(n_chains, steps$total, progress,
show_overall = FALSE, single_chain = TRUE)(chain_id)

if (is_continue) {
state <- restart$state
model <- restart$model
sampler <- restart$sampler
res <- monty_continue_chain(state[[chain_id]], model, sampler, n_steps, pb)
res <- monty_continue_chain(state[[chain_id]], model, sampler, steps, pb)
} else {
pars <- inputs$pars
model <- inputs$model
sampler <- inputs$sampler
rng <- monty_rng$new(seed = inputs$rng_state[[chain_id]])
res <- monty_run_chain(pars[, chain_id], model, sampler, n_steps, pb, rng)
res <- monty_run_chain(pars[, chain_id], model, sampler, steps, pb, rng)
}

saveRDS(res, path$results)
Expand Down Expand Up @@ -186,7 +188,8 @@ monty_sample_manual_collect <- function(path, samples = NULL,
}

if (restartable) {
samples$restart <- restart_data(res, inputs$model, inputs$sampler, NULL)
samples$restart <- restart_data(res, inputs$model, inputs$sampler, NULL,
inputs$steps$thinning_factor)
}
samples
}
Expand Down Expand Up @@ -233,11 +236,12 @@ monty_sample_manual_info <- function(path) {
path <- sample_manual_path(path)
inputs <- readRDS(path$inputs)
n_chains <- inputs$n_chains
n_steps <- inputs$steps$total
path <- sample_manual_path(path$root, seq_len(inputs$n_chains))
done <- file.exists(path$results)
cli::cli_h1("Manual monty sampling at {.path {path$root}}")
cli::cli_alert_info("Created {format(file.info(path$inputs)$ctime)}")
cli::cli_alert_info("{inputs$n_steps} steps x {n_chains} chains")
cli::cli_alert_info("{n_steps} steps x {n_chains} chains")
if (is.list(inputs$restart)) {
cli::cli_alert_info("This is a restart")
}
Expand Down Expand Up @@ -289,9 +293,13 @@ monty_sample_manual_prepare_continue <- function(samples, n_steps, path,
restart <- samples$restart
samples <- sample_manual_prepare_check_samples(samples, save_samples)

steps <- monty_sample_steps(n_steps,
burnin = NULL,
thinning_factor = restart$thinning_factor)

dat <- list(restart = restart,
n_chains = length(restart$state),
n_steps = n_steps,
steps = steps,
samples = samples)
sample_manual_path_create(path, dat)
}
Expand Down
Loading
Loading