Skip to content

Commit

Permalink
Merge pull request #96 from mrc-ide/mrc-5937
Browse files Browse the repository at this point in the history
Proper closing of progress bars on failure
  • Loading branch information
weshinsley authored Nov 5, 2024
2 parents 8c96fc0 + 6520b77 commit 14ffdec
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 116 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.2.28
Version: 0.2.29
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
85 changes: 59 additions & 26 deletions R/progress.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ progress_bar <- function(n_chains, n_steps, progress, show_overall,

show_progress_bar <- function(progress, call = NULL) {
if (is.null(progress)) {
progress <- getOption("monty.progress", TRUE)
progress <- getOption("monty.progress", !is_testing())
}
## Errors here are not great if we get this from an option, probably
## needs its own error path.
Expand All @@ -33,48 +33,64 @@ show_progress_bar <- function(progress, call = NULL) {
##
## These are currently not tuneable from user-facing code.
progress_bar_simple <- function(n_steps, every_s = 1, min_updates = 20) {
function(chain_index) {
env <- new.env(parent = emptyenv())
env$t_next <- Sys.time()
freq <- ceiling(n_steps / min_updates)
function(at) {
now <- Sys.time()
show_progress <- at == n_steps || at %% freq == 0 || now > env$t_next
if (show_progress) {
env$t_next <- now + every_s
message(sprintf("MONTY-PROGRESS: chain: %s, step: %s",
chain_index, at))
}
env <- new.env(parent = emptyenv())
env$t_next <- Sys.time()
freq <- ceiling(n_steps / min_updates)
update <- function(chain_id, at) {
now <- Sys.time()
show_progress <- at == n_steps || at %% freq == 0 || now > env$t_next
if (show_progress) {
env$t_next <- now + every_s
message(sprintf("MONTY-PROGRESS: chain: %s, step: %s",
chain_id, at))
}
}
list(update = update, fail = fail_no_action)
}


progress_bar_fancy <- function(n_chains, n_steps, show_overall,
single_chain = FALSE) {
## We're expecting to take a while, so we show immediately, if enabled:
oo <- options(cli.progress_show_after = 0)
on.exit(options(oo))

e <- new.env()
e$n <- rep(0, n_chains)
overall <- progress_overall(n_chains, n_steps, show_overall, single_chain)
fmt <- paste("Sampling {overall(e$n)} {cli::pb_bar} |",
"{cli::pb_percent} ETA: {cli::pb_eta}")
fmt_done <- paste(
"{cli::col_green(cli::symbol$tick)} Sampled {cli::pb_total} steps",
"across {n_chains} chains in {cli::pb_elapsed}")
fmt_failed <- paste(
"{cli::col_red(cli::symbol$cross)} Sampling stopped at {cli::pb_current}",
"step{?s} after {cli::pb_elapsed}")
n_steps_total <- if (single_chain) n_steps else n_chains * n_steps
id <- cli::cli_progress_bar(
total = n_steps_total,
format = fmt,
format_done = fmt_done,
format_failed = fmt_failed,
clear = FALSE,
.auto_close = FALSE)

function(chain_index) {
function(at) {
## Avoid writing into a closed progress bar, it will cause an
## error. We do this by checking to see if progress has changed
## from last time we tried updating.
changed <- any(e$n[chain_index] != at, na.rm = TRUE)
if (changed) {
e$n[chain_index] <- at
cli::cli_progress_update(id = id, set = sum(e$n))
}
update <- function(chain_id, at) {
## Avoid writing into a closed progress bar, it will cause an
## error. We do this by checking to see if progress has changed
## from last time we tried updating.
changed <- any(e$n[chain_id] != at, na.rm = TRUE)
if (changed) {
e$n[chain_id] <- at
cli::cli_progress_update(id = id, set = sum(e$n))
}
}

fail <- function() {
cli::cli_progress_done(id, result = "failed")
}

list(update = update, fail = fail)
}


Expand All @@ -93,10 +109,13 @@ parse_progress_bar_simple <- function(txt) {

## Dummy version that can be used where no progress bar is wanted.
progress_bar_none <- function(...) {
function(chain_index) {
function(at) {
}
update <- function(chain_id, at) {
}
list(update = update, fail = fail_no_action)
}


fail_no_action <- function() {
}


Expand All @@ -122,3 +141,17 @@ progress_overall <- function(n_chains, n_steps, show_overall, single_chain) {
paste0(c("[", ret, "]"), collapse = "")
}
}


## Sets up some erorr handlers that close out the progress bar where
## we exit uncleanly; this works for everything other than closing out
## progress bars that were exited from a browser call, which we can't
## really help with. If we don't do this then an ugly partial
## progress bar will be printed a the completion of every subsequent
## progress bar, because we use .auto_close = FALSE.
with_progress_fail_on_error <- function(progress, code) {
withCallingHandlers(
code,
error = function(e) progress$fail(),
interrupt = function(e) progress$fail())
}
10 changes: 6 additions & 4 deletions R/runner-callr.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ monty_runner_callr <- function(n_workers, progress = NULL) {
}
}
}
env$progress(env$n_steps_progress)
env$progress$update(env$n_steps_progress)
}

all(env$status == "done")
Expand All @@ -87,12 +87,14 @@ monty_runner_callr <- function(n_workers, progress = NULL) {
env$result_path <- rep(NA_character_, n_chains)
env$n_steps <- steps$total
env$n_steps_progress <- rep(0, n_chains)
env$progress <- pb(seq_len(n_chains))
env$progress <- pb
for (session_id in seq_len(n_workers)) {
launch(session_id)
}
while (!step()) {
}
with_progress_fail_on_error(
pb,
while (!step()) {
})

res <- lapply(env$result_path, readRDS)
unlink(env$path, recursive = TRUE)
Expand Down
18 changes: 11 additions & 7 deletions R/runner-simultaneous.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,27 @@ monty_runner_simultaneous <- function(progress = NULL) {
validate_suitable(model)
n_chains <- length(rng)
pb <- progress_bar(n_chains, steps$total, progress, show_overall = FALSE)
progress <- pb(seq_len(n_chains))
rng_state <- lapply(rng, function(r) r$state())
## TODO: get the rng state back into 'rng' here, or (better) look
## at if we should just be using seed instead here perhaps?
## > rng_state <- matrix(res$internal$state$rng, ncol = n_chains)
## > for (i in seq_len(n_chains)) {
## > rng[[i]]$set_state(rng_state[, i]) # not supported!
## > }
monty_run_chains_simultaneous(pars, model, sampler,
steps, progress, rng_state)
with_progress_fail_on_error(
pb,
monty_run_chains_simultaneous(pars, model, sampler, steps, pb$update,
rng_state))
}

continue <- function(state, model, sampler, steps) {
validate_suitable(model)
n_chains <- length(state)
pb <- progress_bar(n_chains, steps$total, progress, show_overall = FALSE)
progress <- pb(seq_len(n_chains))
monty_continue_chains_simultaneous(state, model, sampler,
steps, progress)
with_progress_fail_on_error(
pb,
monty_continue_chains_simultaneous(state, model, sampler, steps,
pb$update))
}

monty_runner("Simultaneous",
Expand Down Expand Up @@ -126,12 +128,14 @@ monty_run_chains_simultaneous2 <- function(chain_state, model, sampler,
history_pars <- array(NA_real_, c(n_pars, n_steps_record, n_chains))
history_density <- matrix(NA_real_, n_steps_record, n_chains)

chain_id <- seq_len(n_chains)

for (i in seq_len(steps$total)) {
chain_state <- sampler$step(chain_state, model, rng)
history_pars[, i, ] <- chain_state$pars
history_density[i, ] <- chain_state$density
## TODO: also allow observations here if enabled
progress(i)
progress(chain_id, i)
}

## Pop the parameter names on last
Expand Down
56 changes: 30 additions & 26 deletions R/runner.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@
##' @param progress Optional logical, indicating if we should print a
##' progress bar while running. If `NULL`, we use the value of the
##' option `monty.progress` if set, otherwise we show the progress
##' bar (as it is typically wanted). The progress bar itself
##' responds to cli's options; in particular
##' `cli.progress_show_after` and `cli.progress_clear` will affect
##' your experience. Alternatively, you can provide a string
##' indicating the progress bar type. Options are `fancy`
##' bar (as it is typically wanted). Alternatively, you can provide
##' a string indicating the progress bar type. Options are `fancy`
##' (equivalent to `TRUE`), `none` (equivalent to `FALSE`) and
##' `simple` (a very simple text-mode progress indicator designed
##' play nicely with logging; it does not use special codes to clear
Expand All @@ -25,22 +22,26 @@ monty_runner_serial <- function(progress = NULL) {
run <- function(pars, model, sampler, steps, rng) {
n_chains <- length(rng)
pb <- progress_bar(n_chains, steps$total, progress, show_overall = TRUE)
lapply(
seq_along(rng),
function(i) {
monty_run_chain(pars[, i], model, sampler, steps,
pb(i), rng[[i]])
})
with_progress_fail_on_error(
pb,
lapply(
seq_along(rng),
function(i) {
monty_run_chain(i, pars[, i], model, sampler, steps,
pb$update, rng[[i]])
}))
}

continue <- function(state, model, sampler, steps) {
n_chains <- length(state)
pb <- progress_bar(n_chains, steps$total, progress, show_overall = TRUE)
lapply(
seq_along(state),
function(i) {
monty_continue_chain(state[[i]], model, sampler, steps, pb(i))
})
with_progress_fail_on_error(
pb,
lapply(
seq_along(state),
function(i) {
monty_continue_chain(i, state[[i]], model, sampler, steps, pb$update)
}))
}

monty_runner("Serial",
Expand Down Expand Up @@ -130,10 +131,11 @@ monty_runner_parallel <- function(n_workers) {
args <- list(model = model,
sampler = sampler,
steps = steps,
progress = function(i) NULL)
progress = progress_bar_none()$update)
parallel::clusterMap(
cl,
monty_continue_chain,
seq_len(n_chains),
state,
MoreArgs = args)
}
Expand All @@ -145,14 +147,15 @@ monty_runner_parallel <- function(n_workers) {
}


monty_run_chain_parallel <- function(pars, model, sampler, steps, rng) {
monty_run_chain_parallel <- function(chain_id, pars, model, sampler, steps,
rng) {
rng <- monty_rng$new(rng)
progress <- function(i) NULL
monty_run_chain(pars, model, sampler, steps, progress, rng)
progress <- progress_bar_none()$update
monty_run_chain(chain_id, pars, model, sampler, steps, progress, rng)
}


monty_run_chain <- function(pars, model, sampler, steps,
monty_run_chain <- function(chain_id, pars, model, sampler, steps,
progress, rng) {
r_rng_state <- get_r_rng_state()
chain_state <- sampler$initialise(pars, model, rng)
Expand All @@ -177,24 +180,25 @@ monty_run_chain <- function(pars, model, sampler, steps,
cli::cli_abort("Chain does not have finite starting density")
}

monty_run_chain2(chain_state, model, sampler, steps, progress,
monty_run_chain2(chain_id, chain_state, model, sampler, steps, progress,
rng, r_rng_state)
}


monty_continue_chain <- function(state, model, sampler, steps, progress) {
monty_continue_chain <- function(chain_id, state, model, sampler, steps,
progress) {
r_rng_state <- get_r_rng_state()
rng <- monty_rng$new(seed = state$rng)
sampler$set_internal_state(state$sampler)
if (model$properties$is_stochastic) {
model$rng_state$set(state$model_rng)
}
monty_run_chain2(state$chain, model, sampler, steps, progress,
monty_run_chain2(chain_id, state$chain, model, sampler, steps, progress,
rng, r_rng_state)
}


monty_run_chain2 <- function(chain_state, model, sampler, steps,
monty_run_chain2 <- function(chain_id, chain_state, model, sampler, steps,
progress, rng, r_rng_state) {
initial <- chain_state$pars
n_pars <- length(model$parameters)
Expand All @@ -221,7 +225,7 @@ monty_run_chain2 <- function(chain_state, model, sampler, steps,
}
j <- j + 1L
}
progress(i)
progress(chain_id, i)
}

## Pop the parameter names on last
Expand Down
33 changes: 19 additions & 14 deletions R/sample-manual.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,25 @@ monty_sample_manual_run <- function(chain_id, path, progress = NULL) {
is_continue <- is.list(restart)

pb <- progress_bar(n_chains, steps$total, progress,
show_overall = FALSE, single_chain = TRUE)(chain_id)

if (is_continue) {
state <- restart$state
model <- restart$model
sampler <- restart$sampler
res <- monty_continue_chain(state[[chain_id]], model, sampler, steps, pb)
} else {
pars <- inputs$pars
model <- inputs$model
sampler <- inputs$sampler
rng <- monty_rng$new(seed = inputs$rng_state[[chain_id]])
res <- monty_run_chain(pars[, chain_id], model, sampler, steps, pb, rng)
}
show_overall = FALSE, single_chain = TRUE)

with_progress_fail_on_error(
pb,
if (is_continue) {
state <- restart$state
model <- restart$model
sampler <- restart$sampler
res <- monty_continue_chain(chain_id, state[[chain_id]], model, sampler,
steps, pb$update)
} else {
pars <- inputs$pars
model <- inputs$model
sampler <- inputs$sampler
rng <- monty_rng$new(seed = inputs$rng_state[[chain_id]])
res <- monty_run_chain(chain_id, pars[, chain_id], model, sampler, steps,
pb$update, rng)
}
)

saveRDS(res, path$results)
invisible(path$results)
Expand Down
5 changes: 5 additions & 0 deletions R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,8 @@ callr_safe_result <- function(rs, grace = 2, dt = 0.1) {
last <- function(x) {
x[[length(x)]]
}


is_testing <- function() {
identical(Sys.getenv("TESTTHAT"), "true")
}
Loading

0 comments on commit 14ffdec

Please sign in to comment.