diff --git a/DESCRIPTION b/DESCRIPTION index e929dbbd..9e5a4eec 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: monty Title: Monte Carlo Models -Version: 0.2.28 +Version: 0.2.29 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Wes", "Hinsley", role = "aut"), diff --git a/R/progress.R b/R/progress.R index 0249f8a9..5577db75 100644 --- a/R/progress.R +++ b/R/progress.R @@ -11,7 +11,7 @@ progress_bar <- function(n_chains, n_steps, progress, show_overall, show_progress_bar <- function(progress, call = NULL) { if (is.null(progress)) { - progress <- getOption("monty.progress", TRUE) + progress <- getOption("monty.progress", !is_testing()) } ## Errors here are not great if we get this from an option, probably ## needs its own error path. @@ -33,48 +33,64 @@ show_progress_bar <- function(progress, call = NULL) { ## ## These are currently not tuneable from user-facing code. progress_bar_simple <- function(n_steps, every_s = 1, min_updates = 20) { - function(chain_index) { - env <- new.env(parent = emptyenv()) - env$t_next <- Sys.time() - freq <- ceiling(n_steps / min_updates) - function(at) { - now <- Sys.time() - show_progress <- at == n_steps || at %% freq == 0 || now > env$t_next - if (show_progress) { - env$t_next <- now + every_s - message(sprintf("MONTY-PROGRESS: chain: %s, step: %s", - chain_index, at)) - } + env <- new.env(parent = emptyenv()) + env$t_next <- Sys.time() + freq <- ceiling(n_steps / min_updates) + update <- function(chain_id, at) { + now <- Sys.time() + show_progress <- at == n_steps || at %% freq == 0 || now > env$t_next + if (show_progress) { + env$t_next <- now + every_s + message(sprintf("MONTY-PROGRESS: chain: %s, step: %s", + chain_id, at)) } } + list(update = update, fail = fail_no_action) } progress_bar_fancy <- function(n_chains, n_steps, show_overall, single_chain = FALSE) { + ## We're expecting to take a while, so we show immediately, if enabled: + oo <- options(cli.progress_show_after = 0) + on.exit(options(oo)) + e <- new.env() e$n <- rep(0, n_chains) overall <- progress_overall(n_chains, n_steps, show_overall, single_chain) fmt <- paste("Sampling {overall(e$n)} {cli::pb_bar} |", "{cli::pb_percent} ETA: {cli::pb_eta}") + fmt_done <- paste( + "{cli::col_green(cli::symbol$tick)} Sampled {cli::pb_total} steps", + "across {n_chains} chains in {cli::pb_elapsed}") + fmt_failed <- paste( + "{cli::col_red(cli::symbol$cross)} Sampling stopped at {cli::pb_current}", + "step{?s} after {cli::pb_elapsed}") n_steps_total <- if (single_chain) n_steps else n_chains * n_steps id <- cli::cli_progress_bar( total = n_steps_total, format = fmt, + format_done = fmt_done, + format_failed = fmt_failed, + clear = FALSE, .auto_close = FALSE) - function(chain_index) { - function(at) { - ## Avoid writing into a closed progress bar, it will cause an - ## error. We do this by checking to see if progress has changed - ## from last time we tried updating. - changed <- any(e$n[chain_index] != at, na.rm = TRUE) - if (changed) { - e$n[chain_index] <- at - cli::cli_progress_update(id = id, set = sum(e$n)) - } + update <- function(chain_id, at) { + ## Avoid writing into a closed progress bar, it will cause an + ## error. We do this by checking to see if progress has changed + ## from last time we tried updating. + changed <- any(e$n[chain_id] != at, na.rm = TRUE) + if (changed) { + e$n[chain_id] <- at + cli::cli_progress_update(id = id, set = sum(e$n)) } } + + fail <- function() { + cli::cli_progress_done(id, result = "failed") + } + + list(update = update, fail = fail) } @@ -93,10 +109,13 @@ parse_progress_bar_simple <- function(txt) { ## Dummy version that can be used where no progress bar is wanted. progress_bar_none <- function(...) { - function(chain_index) { - function(at) { - } + update <- function(chain_id, at) { } + list(update = update, fail = fail_no_action) +} + + +fail_no_action <- function() { } @@ -122,3 +141,17 @@ progress_overall <- function(n_chains, n_steps, show_overall, single_chain) { paste0(c("[", ret, "]"), collapse = "") } } + + +## Sets up some erorr handlers that close out the progress bar where +## we exit uncleanly; this works for everything other than closing out +## progress bars that were exited from a browser call, which we can't +## really help with. If we don't do this then an ugly partial +## progress bar will be printed a the completion of every subsequent +## progress bar, because we use .auto_close = FALSE. +with_progress_fail_on_error <- function(progress, code) { + withCallingHandlers( + code, + error = function(e) progress$fail(), + interrupt = function(e) progress$fail()) +} diff --git a/R/runner-callr.R b/R/runner-callr.R index 0ce2bb56..43b61b71 100644 --- a/R/runner-callr.R +++ b/R/runner-callr.R @@ -71,7 +71,7 @@ monty_runner_callr <- function(n_workers, progress = NULL) { } } } - env$progress(env$n_steps_progress) + env$progress$update(env$n_steps_progress) } all(env$status == "done") @@ -87,12 +87,14 @@ monty_runner_callr <- function(n_workers, progress = NULL) { env$result_path <- rep(NA_character_, n_chains) env$n_steps <- steps$total env$n_steps_progress <- rep(0, n_chains) - env$progress <- pb(seq_len(n_chains)) + env$progress <- pb for (session_id in seq_len(n_workers)) { launch(session_id) } - while (!step()) { - } + with_progress_fail_on_error( + pb, + while (!step()) { + }) res <- lapply(env$result_path, readRDS) unlink(env$path, recursive = TRUE) diff --git a/R/runner-simultaneous.R b/R/runner-simultaneous.R index 282c88a0..5fd8b6ea 100644 --- a/R/runner-simultaneous.R +++ b/R/runner-simultaneous.R @@ -36,7 +36,6 @@ monty_runner_simultaneous <- function(progress = NULL) { validate_suitable(model) n_chains <- length(rng) pb <- progress_bar(n_chains, steps$total, progress, show_overall = FALSE) - progress <- pb(seq_len(n_chains)) rng_state <- lapply(rng, function(r) r$state()) ## TODO: get the rng state back into 'rng' here, or (better) look ## at if we should just be using seed instead here perhaps? @@ -44,17 +43,20 @@ monty_runner_simultaneous <- function(progress = NULL) { ## > for (i in seq_len(n_chains)) { ## > rng[[i]]$set_state(rng_state[, i]) # not supported! ## > } - monty_run_chains_simultaneous(pars, model, sampler, - steps, progress, rng_state) + with_progress_fail_on_error( + pb, + monty_run_chains_simultaneous(pars, model, sampler, steps, pb$update, + rng_state)) } continue <- function(state, model, sampler, steps) { validate_suitable(model) n_chains <- length(state) pb <- progress_bar(n_chains, steps$total, progress, show_overall = FALSE) - progress <- pb(seq_len(n_chains)) - monty_continue_chains_simultaneous(state, model, sampler, - steps, progress) + with_progress_fail_on_error( + pb, + monty_continue_chains_simultaneous(state, model, sampler, steps, + pb$update)) } monty_runner("Simultaneous", @@ -126,12 +128,14 @@ monty_run_chains_simultaneous2 <- function(chain_state, model, sampler, history_pars <- array(NA_real_, c(n_pars, n_steps_record, n_chains)) history_density <- matrix(NA_real_, n_steps_record, n_chains) + chain_id <- seq_len(n_chains) + for (i in seq_len(steps$total)) { chain_state <- sampler$step(chain_state, model, rng) history_pars[, i, ] <- chain_state$pars history_density[i, ] <- chain_state$density ## TODO: also allow observations here if enabled - progress(i) + progress(chain_id, i) } ## Pop the parameter names on last diff --git a/R/runner.R b/R/runner.R index b52469d6..6cc9f0d8 100644 --- a/R/runner.R +++ b/R/runner.R @@ -7,11 +7,8 @@ ##' @param progress Optional logical, indicating if we should print a ##' progress bar while running. If `NULL`, we use the value of the ##' option `monty.progress` if set, otherwise we show the progress -##' bar (as it is typically wanted). The progress bar itself -##' responds to cli's options; in particular -##' `cli.progress_show_after` and `cli.progress_clear` will affect -##' your experience. Alternatively, you can provide a string -##' indicating the progress bar type. Options are `fancy` +##' bar (as it is typically wanted). Alternatively, you can provide +##' a string indicating the progress bar type. Options are `fancy` ##' (equivalent to `TRUE`), `none` (equivalent to `FALSE`) and ##' `simple` (a very simple text-mode progress indicator designed ##' play nicely with logging; it does not use special codes to clear @@ -25,22 +22,26 @@ monty_runner_serial <- function(progress = NULL) { run <- function(pars, model, sampler, steps, rng) { n_chains <- length(rng) pb <- progress_bar(n_chains, steps$total, progress, show_overall = TRUE) - lapply( - seq_along(rng), - function(i) { - monty_run_chain(pars[, i], model, sampler, steps, - pb(i), rng[[i]]) - }) + with_progress_fail_on_error( + pb, + lapply( + seq_along(rng), + function(i) { + monty_run_chain(i, pars[, i], model, sampler, steps, + pb$update, rng[[i]]) + })) } continue <- function(state, model, sampler, steps) { n_chains <- length(state) pb <- progress_bar(n_chains, steps$total, progress, show_overall = TRUE) - lapply( - seq_along(state), - function(i) { - monty_continue_chain(state[[i]], model, sampler, steps, pb(i)) - }) + with_progress_fail_on_error( + pb, + lapply( + seq_along(state), + function(i) { + monty_continue_chain(i, state[[i]], model, sampler, steps, pb$update) + })) } monty_runner("Serial", @@ -130,10 +131,11 @@ monty_runner_parallel <- function(n_workers) { args <- list(model = model, sampler = sampler, steps = steps, - progress = function(i) NULL) + progress = progress_bar_none()$update) parallel::clusterMap( cl, monty_continue_chain, + seq_len(n_chains), state, MoreArgs = args) } @@ -145,14 +147,15 @@ monty_runner_parallel <- function(n_workers) { } -monty_run_chain_parallel <- function(pars, model, sampler, steps, rng) { +monty_run_chain_parallel <- function(chain_id, pars, model, sampler, steps, + rng) { rng <- monty_rng$new(rng) - progress <- function(i) NULL - monty_run_chain(pars, model, sampler, steps, progress, rng) + progress <- progress_bar_none()$update + monty_run_chain(chain_id, pars, model, sampler, steps, progress, rng) } -monty_run_chain <- function(pars, model, sampler, steps, +monty_run_chain <- function(chain_id, pars, model, sampler, steps, progress, rng) { r_rng_state <- get_r_rng_state() chain_state <- sampler$initialise(pars, model, rng) @@ -177,24 +180,25 @@ monty_run_chain <- function(pars, model, sampler, steps, cli::cli_abort("Chain does not have finite starting density") } - monty_run_chain2(chain_state, model, sampler, steps, progress, + monty_run_chain2(chain_id, chain_state, model, sampler, steps, progress, rng, r_rng_state) } -monty_continue_chain <- function(state, model, sampler, steps, progress) { +monty_continue_chain <- function(chain_id, state, model, sampler, steps, + progress) { r_rng_state <- get_r_rng_state() rng <- monty_rng$new(seed = state$rng) sampler$set_internal_state(state$sampler) if (model$properties$is_stochastic) { model$rng_state$set(state$model_rng) } - monty_run_chain2(state$chain, model, sampler, steps, progress, + monty_run_chain2(chain_id, state$chain, model, sampler, steps, progress, rng, r_rng_state) } -monty_run_chain2 <- function(chain_state, model, sampler, steps, +monty_run_chain2 <- function(chain_id, chain_state, model, sampler, steps, progress, rng, r_rng_state) { initial <- chain_state$pars n_pars <- length(model$parameters) @@ -221,7 +225,7 @@ monty_run_chain2 <- function(chain_state, model, sampler, steps, } j <- j + 1L } - progress(i) + progress(chain_id, i) } ## Pop the parameter names on last diff --git a/R/sample-manual.R b/R/sample-manual.R index 7f9d36e0..2281293c 100644 --- a/R/sample-manual.R +++ b/R/sample-manual.R @@ -127,20 +127,25 @@ monty_sample_manual_run <- function(chain_id, path, progress = NULL) { is_continue <- is.list(restart) pb <- progress_bar(n_chains, steps$total, progress, - show_overall = FALSE, single_chain = TRUE)(chain_id) - - if (is_continue) { - state <- restart$state - model <- restart$model - sampler <- restart$sampler - res <- monty_continue_chain(state[[chain_id]], model, sampler, steps, pb) - } else { - pars <- inputs$pars - model <- inputs$model - sampler <- inputs$sampler - rng <- monty_rng$new(seed = inputs$rng_state[[chain_id]]) - res <- monty_run_chain(pars[, chain_id], model, sampler, steps, pb, rng) - } + show_overall = FALSE, single_chain = TRUE) + + with_progress_fail_on_error( + pb, + if (is_continue) { + state <- restart$state + model <- restart$model + sampler <- restart$sampler + res <- monty_continue_chain(chain_id, state[[chain_id]], model, sampler, + steps, pb$update) + } else { + pars <- inputs$pars + model <- inputs$model + sampler <- inputs$sampler + rng <- monty_rng$new(seed = inputs$rng_state[[chain_id]]) + res <- monty_run_chain(chain_id, pars[, chain_id], model, sampler, steps, + pb$update, rng) + } + ) saveRDS(res, path$results) invisible(path$results) diff --git a/R/util.R b/R/util.R index 03f080a7..415e4ea1 100644 --- a/R/util.R +++ b/R/util.R @@ -219,3 +219,8 @@ callr_safe_result <- function(rs, grace = 2, dt = 0.1) { last <- function(x) { x[[length(x)]] } + + +is_testing <- function() { + identical(Sys.getenv("TESTTHAT"), "true") +} diff --git a/man/monty_runner_callr.Rd b/man/monty_runner_callr.Rd index 3db9ae8a..c6e96f83 100644 --- a/man/monty_runner_callr.Rd +++ b/man/monty_runner_callr.Rd @@ -18,11 +18,8 @@ will likely be no faster than 4).} \item{progress}{Optional logical, indicating if we should print a progress bar while running. If \code{NULL}, we use the value of the option \code{monty.progress} if set, otherwise we show the progress -bar (as it is typically wanted). The progress bar itself -responds to cli's options; in particular -\code{cli.progress_show_after} and \code{cli.progress_clear} will affect -your experience. Alternatively, you can provide a string -indicating the progress bar type. Options are \code{fancy} +bar (as it is typically wanted). Alternatively, you can provide +a string indicating the progress bar type. Options are \code{fancy} (equivalent to \code{TRUE}), \code{none} (equivalent to \code{FALSE}) and \code{simple} (a very simple text-mode progress indicator designed play nicely with logging; it does not use special codes to clear diff --git a/man/monty_runner_serial.Rd b/man/monty_runner_serial.Rd index 370784aa..dbfc72e5 100644 --- a/man/monty_runner_serial.Rd +++ b/man/monty_runner_serial.Rd @@ -10,11 +10,8 @@ monty_runner_serial(progress = NULL) \item{progress}{Optional logical, indicating if we should print a progress bar while running. If \code{NULL}, we use the value of the option \code{monty.progress} if set, otherwise we show the progress -bar (as it is typically wanted). The progress bar itself -responds to cli's options; in particular -\code{cli.progress_show_after} and \code{cli.progress_clear} will affect -your experience. Alternatively, you can provide a string -indicating the progress bar type. Options are \code{fancy} +bar (as it is typically wanted). Alternatively, you can provide +a string indicating the progress bar type. Options are \code{fancy} (equivalent to \code{TRUE}), \code{none} (equivalent to \code{FALSE}) and \code{simple} (a very simple text-mode progress indicator designed play nicely with logging; it does not use special codes to clear diff --git a/man/monty_runner_simultaneous.Rd b/man/monty_runner_simultaneous.Rd index d2954e48..de13d71e 100644 --- a/man/monty_runner_simultaneous.Rd +++ b/man/monty_runner_simultaneous.Rd @@ -10,11 +10,8 @@ monty_runner_simultaneous(progress = NULL) \item{progress}{Optional logical, indicating if we should print a progress bar while running. If \code{NULL}, we use the value of the option \code{monty.progress} if set, otherwise we show the progress -bar (as it is typically wanted). The progress bar itself -responds to cli's options; in particular -\code{cli.progress_show_after} and \code{cli.progress_clear} will affect -your experience. Alternatively, you can provide a string -indicating the progress bar type. Options are \code{fancy} +bar (as it is typically wanted). Alternatively, you can provide +a string indicating the progress bar type. Options are \code{fancy} (equivalent to \code{TRUE}), \code{none} (equivalent to \code{FALSE}) and \code{simple} (a very simple text-mode progress indicator designed play nicely with logging; it does not use special codes to clear diff --git a/man/monty_sample_manual_run.Rd b/man/monty_sample_manual_run.Rd index 1a6895f2..c1486321 100644 --- a/man/monty_sample_manual_run.Rd +++ b/man/monty_sample_manual_run.Rd @@ -18,11 +18,8 @@ provide an integer that does not correspond to a chain in 1 to \item{progress}{Optional logical, indicating if we should print a progress bar while running. If \code{NULL}, we use the value of the option \code{monty.progress} if set, otherwise we show the progress -bar (as it is typically wanted). The progress bar itself -responds to cli's options; in particular -\code{cli.progress_show_after} and \code{cli.progress_clear} will affect -your experience. Alternatively, you can provide a string -indicating the progress bar type. Options are \code{fancy} +bar (as it is typically wanted). Alternatively, you can provide +a string indicating the progress bar type. Options are \code{fancy} (equivalent to \code{TRUE}), \code{none} (equivalent to \code{FALSE}) and \code{simple} (a very simple text-mode progress indicator designed play nicely with logging; it does not use special codes to clear diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R deleted file mode 100644 index 02afbcdd..00000000 --- a/tests/testthat/setup.R +++ /dev/null @@ -1,4 +0,0 @@ -withr::local_options( - monty.progress = FALSE, - .local_envir = teardown_env() -) diff --git a/tests/testthat/test-progress.R b/tests/testthat/test-progress.R index a1fa5190..1d943159 100644 --- a/tests/testthat/test-progress.R +++ b/tests/testthat/test-progress.R @@ -1,4 +1,5 @@ test_that("can select sensible values for progress", { + withr::local_envvar(TESTTHAT = FALSE) withr::with_options(list(monty.progress = TRUE), { expect_equal(show_progress_bar(FALSE), "none") expect_equal(show_progress_bar(TRUE), "fancy") @@ -23,8 +24,8 @@ test_that("can select sensible values for progress", { test_that("null progress bar does nothing", { - p <- progress_bar(10, 10, FALSE)(1) - expect_silent(p(1)) + p <- progress_bar(10, 10, FALSE) + expect_silent(p$update(1, 1)) }) @@ -67,14 +68,13 @@ test_that("overall progress is empty if disabled", { test_that("can format fancy", { - f <- progress_bar_fancy(4, 100, TRUE) - g <- f(1) + f <- progress_bar_fancy(4, 100, TRUE)$update id <- environment(f)$id e <- environment(f)$e mock_update <- mockery::mock() - mockery::stub(g, "cli::cli_progress_update", mock_update) - g(5) + mockery::stub(f, "cli::cli_progress_update", mock_update) + f(1, 5) mockery::expect_called(mock_update, 1) expect_equal(mockery::mock_args(mock_update)[[1]], list(id = id, set = 5)) @@ -102,14 +102,13 @@ test_that("can create pb", { test_that("can create a simple progress bar", { - pb <- progress_bar_simple(104, 5) - p <- pb(1) - expect_message(p(10), "MONTY-PROGRESS: chain: 1, step: 10") - expect_no_message(p(11)) - expect_message(p(36), "MONTY-PROGRESS: chain: 1, step: 36") - expect_message(p(102), "MONTY-PROGRESS: chain: 1, step: 102") - expect_no_message(p(103)) - expect_message(p(104), "MONTY-PROGRESS: chain: 1, step: 104") + p <- progress_bar_simple(104, 5)$update + expect_message(p(1, 10), "MONTY-PROGRESS: chain: 1, step: 10") + expect_no_message(p(1, 11)) + expect_message(p(1, 36), "MONTY-PROGRESS: chain: 1, step: 36") + expect_message(p(1, 102), "MONTY-PROGRESS: chain: 1, step: 102") + expect_no_message(p(1, 103)) + expect_message(p(1, 104), "MONTY-PROGRESS: chain: 1, step: 104") })