Skip to content

Commit

Permalink
Merge branch 'main' into fix-beta-binomial
Browse files Browse the repository at this point in the history
  • Loading branch information
edknock committed Oct 29, 2024
2 parents bdd34ce + d41c2a9 commit 0f54dfd
Show file tree
Hide file tree
Showing 10 changed files with 320 additions and 69 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
PACKAGE := $(shell grep '^Package:' DESCRIPTION | sed -E 's/^Package:[[:space:]]+//')
RSCRIPT = Rscript --no-init-file
RSCRIPT = Rscript

all:
${RSCRIPT} -e 'pkgbuild::compile_dll()'
Expand Down
18 changes: 9 additions & 9 deletions R/runner-callr.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ monty_runner_callr <- function(n_workers, progress = NULL) {
all(env$status == "done")
}

loop <- function(path, n_workers, n_chains, n_steps, progress) {
pb <- progress_bar(n_chains, n_steps, progress, show_overall = TRUE)
loop <- function(path, n_workers, n_chains, steps, progress) {
pb <- progress_bar(n_chains, steps$total, progress, show_overall = TRUE)
n_workers <- min(n_chains, n_workers)
env$path <- path
env$sessions <- vector("list", n_workers)
env$target <- rep(NA_integer_, n_workers)
env$status <- rep("pending", n_chains)
env$result_path <- rep(NA_character_, n_chains)
env$n_steps <- n_steps
env$n_steps <- steps$total
env$n_steps_progress <- rep(0, n_chains)
env$progress <- pb(seq_len(n_chains))
for (session_id in seq_len(n_workers)) {
Expand All @@ -99,26 +99,26 @@ monty_runner_callr <- function(n_workers, progress = NULL) {
res
}

run <- function(pars, model, sampler, n_steps, rng) {
run <- function(pars, model, sampler, steps, rng) {
seed <- unlist(lapply(rng, function(r) r$state()))
n_chains <- length(rng)
path <- tempfile()
sample_manual_prepare(
model = model, sampler = sampler, n_steps = n_steps, path = path,
model = model, sampler = sampler, steps = steps, path = path,
initial = pars, n_chains = n_chains,
seed = seed)
loop(path, n_workers, n_chains, n_steps, progress)
loop(path, n_workers, n_chains, steps, progress)
}

continue <- function(state, model, sampler, n_steps) {
continue <- function(state, model, sampler, steps) {
restart <- list(state = state,
model = model,
sampler = sampler)
n_chains <- length(state)
path <- tempfile()
monty_sample_manual_prepare_continue(
list(restart = restart), n_steps, path, "nothing")
loop(path, n_workers, n_chains, n_steps, progress)
list(restart = restart), steps, path, "nothing")
loop(path, n_workers, n_chains, steps, progress)
}

monty_runner("callr",
Expand Down
29 changes: 15 additions & 14 deletions R/runner-simultaneous.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ monty_runner_simultaneous <- function(progress = NULL) {
call = environment())
}

run <- function(pars, model, sampler, n_steps, rng) {
run <- function(pars, model, sampler, steps, rng) {
validate_suitable(model)
n_chains <- length(rng)
pb <- progress_bar(n_chains, n_steps, progress, show_overall = FALSE)
pb <- progress_bar(n_chains, steps$total, progress, show_overall = FALSE)
progress <- pb(seq_len(n_chains))
rng_state <- lapply(rng, function(r) r$state())
## TODO: get the rng state back into 'rng' here, or (better) look
Expand All @@ -45,16 +45,16 @@ monty_runner_simultaneous <- function(progress = NULL) {
## > rng[[i]]$set_state(rng_state[, i]) # not supported!
## > }
monty_run_chains_simultaneous(pars, model, sampler,
n_steps, progress, rng_state)
steps, progress, rng_state)
}

continue <- function(state, model, sampler, n_steps) {
continue <- function(state, model, sampler, steps) {
validate_suitable(model)
n_chains <- length(state)
pb <- progress_bar(n_chains, n_steps, progress, show_overall = FALSE)
pb <- progress_bar(n_chains, steps$total, progress, show_overall = FALSE)
progress <- pb(seq_len(n_chains))
monty_continue_chains_simultaneous(state, model, sampler,
n_steps, progress)
steps, progress)
}

monty_runner("Simultaneous",
Expand All @@ -74,20 +74,20 @@ monty_runner_simultaneous <- function(progress = NULL) {
## hard to avoid.
## * there's quite a lot of churn around rng state
monty_run_chains_simultaneous <- function(pars, model, sampler,
n_steps, progress, rng_state) {
steps, progress, rng_state) {
r_rng_state <- get_r_rng_state()
n_chains <- length(rng_state)
rng <- monty_rng$new(unlist(rng_state), n_chains)

chain_state <- sampler$initialise(pars, model, rng)

monty_run_chains_simultaneous2(chain_state, model, sampler,
n_steps, progress, rng, r_rng_state)
steps, progress, rng, r_rng_state)
}


monty_continue_chains_simultaneous <- function(state, model, sampler,
n_steps, progress) {
steps, progress) {
r_rng_state <- get_r_rng_state()
n_chains <- length(state)
n_pars <- length(model$parameters)
Expand All @@ -111,21 +111,22 @@ monty_continue_chains_simultaneous <- function(state, model, sampler,
## Need to use model$rng_state$set to put state$model_rng into the model

monty_run_chains_simultaneous2(chain_state, model, sampler,
n_steps, progress, rng, r_rng_state)
steps, progress, rng, r_rng_state)
}


monty_run_chains_simultaneous2 <- function(chain_state, model, sampler,
n_steps, progress, rng,
steps, progress, rng,
r_rng_state) {
initial <- chain_state$pars
n_pars <- length(model$parameters)
n_chains <- length(chain_state$density)
n_steps_record <- steps$total

history_pars <- array(NA_real_, c(n_pars, n_steps, n_chains))
history_density <- matrix(NA_real_, n_steps, n_chains)
history_pars <- array(NA_real_, c(n_pars, n_steps_record, n_chains))
history_density <- matrix(NA_real_, n_steps_record, n_chains)

for (i in seq_len(n_steps)) {
for (i in seq_len(steps$total)) {
chain_state <- sampler$step(chain_state, model, rng)
history_pars[, i, ] <- chain_state$pars
history_density[i, ] <- chain_state$density
Expand Down
59 changes: 34 additions & 25 deletions R/runner.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,24 @@
##'
##' @export
monty_runner_serial <- function(progress = NULL) {
run <- function(pars, model, sampler, n_steps, rng) {
run <- function(pars, model, sampler, steps, rng) {
n_chains <- length(rng)
pb <- progress_bar(n_chains, n_steps, progress, show_overall = TRUE)
pb <- progress_bar(n_chains, steps$total, progress, show_overall = TRUE)
lapply(
seq_along(rng),
function(i) {
monty_run_chain(pars[, i], model, sampler, n_steps,
monty_run_chain(pars[, i], model, sampler, steps,
pb(i), rng[[i]])
})
}

continue <- function(state, model, sampler, n_steps) {
continue <- function(state, model, sampler, steps) {
n_chains <- length(state)
pb <- progress_bar(n_chains, n_steps, progress, show_overall = TRUE)
pb <- progress_bar(n_chains, steps$total, progress, show_overall = TRUE)
lapply(
seq_along(state),
function(i) {
monty_continue_chain(state[[i]], model, sampler, n_steps, pb(i))
monty_continue_chain(state[[i]], model, sampler, steps, pb(i))
})
}

Expand Down Expand Up @@ -91,7 +91,7 @@ monty_runner_parallel <- function(n_workers) {
## get the advantage that the cluster startup happens asyncronously
## and may be ready by the time we actually pass any work onto it.

run <- function(pars, model, sampler, n_steps, rng) {
run <- function(pars, model, sampler, steps, rng) {
n_chains <- length(rng)
cl <- parallel::makeCluster(min(n_chains, n_workers))
on.exit(parallel::stopCluster(cl))
Expand All @@ -110,7 +110,7 @@ monty_runner_parallel <- function(n_workers) {

args <- list(model = model,
sampler = sampler,
n_steps = n_steps)
steps = steps)

## To debug issues in the parallel sampler, it's most efficient to
## replace this call with `Map` and drop the `cl` argument, then
Expand All @@ -123,13 +123,13 @@ monty_runner_parallel <- function(n_workers) {
MoreArgs = args)
}

continue <- function(state, model, sampler, n_steps) {
continue <- function(state, model, sampler, steps) {
n_chains <- length(state)
cl <- parallel::makeCluster(min(n_chains, n_workers))
on.exit(parallel::stopCluster(cl))
args <- list(model = model,
sampler = sampler,
n_steps = n_steps,
steps = steps,
progress = function(i) NULL)
parallel::clusterMap(
cl,
Expand All @@ -145,14 +145,14 @@ monty_runner_parallel <- function(n_workers) {
}


monty_run_chain_parallel <- function(pars, model, sampler, n_steps, rng) {
monty_run_chain_parallel <- function(pars, model, sampler, steps, rng) {
rng <- monty_rng$new(rng)
progress <- function(i) NULL
monty_run_chain(pars, model, sampler, n_steps, progress, rng)
monty_run_chain(pars, model, sampler, steps, progress, rng)
}


monty_run_chain <- function(pars, model, sampler, n_steps,
monty_run_chain <- function(pars, model, sampler, steps,
progress, rng) {
r_rng_state <- get_r_rng_state()
chain_state <- sampler$initialise(pars, model, rng)
Expand All @@ -177,40 +177,49 @@ monty_run_chain <- function(pars, model, sampler, n_steps,
cli::cli_abort("Chain does not have finite starting density")
}

monty_run_chain2(chain_state, model, sampler, n_steps, progress,
monty_run_chain2(chain_state, model, sampler, steps, progress,
rng, r_rng_state)
}


monty_continue_chain <- function(state, model, sampler, n_steps,
progress) {
monty_continue_chain <- function(state, model, sampler, steps, progress) {
r_rng_state <- get_r_rng_state()
rng <- monty_rng$new(seed = state$rng)
sampler$set_internal_state(state$sampler)
if (model$properties$is_stochastic) {
model$rng_state$set(state$model_rng)
}
monty_run_chain2(state$chain, model, sampler, n_steps, progress,
monty_run_chain2(state$chain, model, sampler, steps, progress,
rng, r_rng_state)
}


monty_run_chain2 <- function(chain_state, model, sampler, n_steps,
monty_run_chain2 <- function(chain_state, model, sampler, steps,
progress, rng, r_rng_state) {
initial <- chain_state$pars
n_pars <- length(model$parameters)
has_observer <- model$properties$has_observer

history_pars <- matrix(NA_real_, n_pars, n_steps)
history_density <- rep(NA_real_, n_steps)
history_observation <- if (has_observer) vector("list", n_steps) else NULL
burnin <- steps$burnin
thinning_factor <- steps$thinning_factor
n_steps <- steps$total
n_steps_record <- ceiling((steps$total - burnin) / thinning_factor)

history_pars <- matrix(NA_real_, n_pars, n_steps_record)
history_density <- rep(NA_real_, n_steps_record)
history_observation <-
if (has_observer) vector("list", n_steps_record) else NULL

j <- 1L
for (i in seq_len(n_steps)) {
chain_state <- sampler$step(chain_state, model, rng)
history_pars[, i] <- chain_state$pars
history_density[[i]] <- chain_state$density
if (has_observer && !is.null(chain_state$observation)) {
history_observation[[i]] <- chain_state$observation
if (i > burnin && i %% thinning_factor == 0) {
history_pars[, j] <- chain_state$pars
history_density[[j]] <- chain_state$density
if (has_observer && !is.null(chain_state$observation)) {
history_observation[[j]] <- chain_state$observation
}
j <- j + 1L
}
progress(i)
}
Expand Down
30 changes: 19 additions & 11 deletions R/sample-manual.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,16 @@
##' # Clean up samples
##' monty_sample_manual_cleanup(path)
monty_sample_manual_prepare <- function(model, sampler, n_steps, path,
initial = NULL, n_chains = 1L) {
initial = NULL, n_chains = 1L,
burnin = NULL, thinning_factor = NULL) {
## This break exists to hide the 'seed' argument from the public
## interface. We will use this from the callr version though.
sample_manual_prepare(model, sampler, n_steps, path, initial, n_chains)
steps <- monty_sample_steps(n_steps, burnin, thinning_factor)
sample_manual_prepare(model, sampler, steps, path, initial, n_chains)
}


sample_manual_prepare <- function(model, sampler, n_steps, path, initial,
sample_manual_prepare <- function(model, sampler, steps, path, initial,
n_chains, seed = NULL,
call = parent.frame()) {
assert_is(model, "monty_model", call = call)
Expand All @@ -81,7 +83,7 @@ sample_manual_prepare <- function(model, sampler, n_steps, path, initial,
sampler = sampler,
rng_state = rng_state,
n_chains = n_chains,
n_steps = n_steps)
steps = steps)

sample_manual_path_create(path, dat, call = call)
}
Expand Down Expand Up @@ -119,25 +121,25 @@ monty_sample_manual_run <- function(chain_id, path, progress = NULL) {
}

n_chains <- inputs$n_chains
n_steps <- inputs$n_steps
steps <- inputs$steps

restart <- inputs$restart
is_continue <- is.list(restart)

pb <- progress_bar(n_chains, n_steps, progress,
pb <- progress_bar(n_chains, steps$total, progress,
show_overall = FALSE, single_chain = TRUE)(chain_id)

if (is_continue) {
state <- restart$state
model <- restart$model
sampler <- restart$sampler
res <- monty_continue_chain(state[[chain_id]], model, sampler, n_steps, pb)
res <- monty_continue_chain(state[[chain_id]], model, sampler, steps, pb)
} else {
pars <- inputs$pars
model <- inputs$model
sampler <- inputs$sampler
rng <- monty_rng$new(seed = inputs$rng_state[[chain_id]])
res <- monty_run_chain(pars[, chain_id], model, sampler, n_steps, pb, rng)
res <- monty_run_chain(pars[, chain_id], model, sampler, steps, pb, rng)
}

saveRDS(res, path$results)
Expand Down Expand Up @@ -186,7 +188,8 @@ monty_sample_manual_collect <- function(path, samples = NULL,
}

if (restartable) {
samples$restart <- restart_data(res, inputs$model, inputs$sampler, NULL)
samples$restart <- restart_data(res, inputs$model, inputs$sampler, NULL,
inputs$steps$thinning_factor)
}
samples
}
Expand Down Expand Up @@ -233,11 +236,12 @@ monty_sample_manual_info <- function(path) {
path <- sample_manual_path(path)
inputs <- readRDS(path$inputs)
n_chains <- inputs$n_chains
n_steps <- inputs$steps$total
path <- sample_manual_path(path$root, seq_len(inputs$n_chains))
done <- file.exists(path$results)
cli::cli_h1("Manual monty sampling at {.path {path$root}}")
cli::cli_alert_info("Created {format(file.info(path$inputs)$ctime)}")
cli::cli_alert_info("{inputs$n_steps} steps x {n_chains} chains")
cli::cli_alert_info("{n_steps} steps x {n_chains} chains")
if (is.list(inputs$restart)) {
cli::cli_alert_info("This is a restart")
}
Expand Down Expand Up @@ -289,9 +293,13 @@ monty_sample_manual_prepare_continue <- function(samples, n_steps, path,
restart <- samples$restart
samples <- sample_manual_prepare_check_samples(samples, save_samples)

steps <- monty_sample_steps(n_steps,
burnin = NULL,
thinning_factor = restart$thinning_factor)

dat <- list(restart = restart,
n_chains = length(restart$state),
n_steps = n_steps,
steps = steps,
samples = samples)
sample_manual_path_create(path, dat)
}
Expand Down
Loading

0 comments on commit 0f54dfd

Please sign in to comment.