From 8ffb6779878c194152cb8554383d60d81a406c19 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 9 Oct 2024 15:49:20 +0100 Subject: [PATCH 1/9] Alternative approach to observation Moving responsibility for this into the model itself --- R/combine.R | 36 +++++++++- R/model.R | 41 +++++++++++- R/runner.R | 7 +- R/sample.R | 6 +- R/sampler-helpers.R | 11 ++- tests/testthat/helper-monty.R | 82 ++++++++++++++--------- tests/testthat/test-sampler-random-walk.R | 22 +----- 7 files changed, 142 insertions(+), 63 deletions(-) diff --git a/R/combine.R b/R/combine.R index de95e59d..793775bb 100644 --- a/R/combine.R +++ b/R/combine.R @@ -93,13 +93,14 @@ monty_model_combine <- function(a, b, properties = NULL, domain <- model_combine_domain(a, b, parameters) density <- model_combine_density(a, b, parameters) - gradient <- model_combine_gradient( a, b, parameters, properties, call) direct_sample <- model_combine_direct_sample( a, b, parameters, properties, name_a, name_b, call) stochastic <- model_combine_stochastic( a, b, properties) + observer <- model_combine_observer( + a, b, parameters, properties, name_a, name_b, call) monty_model( list(model = list(a, b), @@ -109,6 +110,7 @@ monty_model_combine <- function(a, b, properties = NULL, gradient = gradient, get_rng_state = stochastic$get_rng_state, set_rng_state = stochastic$set_rng_state, + observer = observer, direct_sample = direct_sample), properties) } @@ -280,3 +282,35 @@ model_combine_direct_sample <- function(a, b, parameters, properties, model$direct_sample(...)[i] } } + + +model_combine_observer <- function(a, b, parameters, properties, + name_a, name_b, call = NULL) { + if (isFALSE(properties$has_observer)) { + return(NULL) + } + possible <- a$properties$has_observer != b$properties$has_observer + required <- isTRUE(properties$has_observer) + if (!possible && !required) { + return(NULL) + } + if (required && !possible) { + if (a$properties$has_observer) { + hint <- paste("Both models have a 'observer' method so we can't", + "combine them. Set 'has_observer = FALSE' on one", + "of your models and try again") + } else { + hint <- "Neither of your models have 'observer' methods" + } + cli::cli_abort( + c("Can't create a observer from these models", + i = hint), + call = call) + } + + model <- if (a$properties$has_observer) a else b + + function(...) { + model$observer(...) + } +} diff --git a/R/model.R b/R/model.R index acd9663e..949093ba 100644 --- a/R/model.R +++ b/R/model.R @@ -25,6 +25,12 @@ ##' we may also support this in `gradient`). Use `NULL` (the ##' default) to detect this from the model. ##' +##' @param has_observer Logical, indicating if the model has an +##' "observation" function, which we will describe more fully soon. +##' An observer is a function `observe` which takes no arguments and +##' returns arbitrary data about the previously evaluated density. +##' Use `NULL` (the default) to detect this from the model. +##' ##' @param allow_multiple_parameters Logical, indicating if the ##' density calculation can support being passed a matrix of ##' parameters (with each column corresponding to a different @@ -51,13 +57,20 @@ monty_model_properties <- function(has_gradient = NULL, has_direct_sample = NULL, is_stochastic = NULL, has_parameter_groups = NULL, + has_observer = NULL, allow_multiple_parameters = FALSE) { - ## TODO: What name do we want for this property, really? + assert_scalar_logical(has_gradient, allow_null = TRUE) + assert_scalar_logical(has_direct_sample, allow_null = TRUE) + assert_scalar_logical(is_stochastic, allow_null = TRUE) + assert_scalar_logical(has_parameter_groups, allow_null = TRUE) + assert_scalar_logical(has_observer, allow_null = TRUE) assert_scalar_logical(allow_multiple_parameters) + ret <- list(has_gradient = has_gradient, has_direct_sample = has_direct_sample, is_stochastic = is_stochastic, has_parameter_groups = has_parameter_groups, + has_observer = has_observer, ## TODO: I am not convinced on this name allow_multiple_parameters = allow_multiple_parameters) class(ret) <- "monty_model_properties" @@ -184,6 +197,7 @@ monty_model <- function(model, properties = NULL) { properties <- validate_model_properties(properties, call) gradient <- validate_model_gradient(model, properties, call) direct_sample <- validate_model_direct_sample(model, properties, call) + observer <- validate_model_observer(model, properties, call) rng_state <- validate_model_rng_state(model, properties, call) parameter_groups <- validate_model_parameter_groups(model, properties, call) @@ -192,6 +206,7 @@ monty_model <- function(model, properties = NULL) { properties$has_direct_sample <- !is.null(direct_sample) properties$is_stochastic <- !is.null(rng_state$set) properties$has_parameter_groups <- !is.null(parameter_groups) + properties$has_observer <- !is.null(observer) ret <- list(model = model, parameters = parameters, @@ -200,6 +215,7 @@ monty_model <- function(model, properties = NULL) { density = density, gradient = gradient, direct_sample = direct_sample, + observer = observer, rng_state = rng_state, properties = properties) class(ret) <- "monty_model" @@ -438,6 +454,26 @@ validate_model_direct_sample <- function(model, properties, call) { } +validate_model_observer <- function(model, properties, call) { + if (isFALSE(properties$has_observer)) { + return(NULL) + } + value <- model$observer + if (isTRUE(properties$has_observer) && !inherits(value, "monty_observer")) { + cli::cli_abort( + paste("Did not find a 'monty_observer' object 'observer' within", + "your model, but your properties say that it should exist"), + arg = "model", call = call) + } + if (!is.null(value) && !inherits(value, "monty_observer")) { + cli::cli_abort( + "Expected 'model${method_name}' to be a 'monty_observer' if non-NULL", + arg = "model", call = call) + } + value +} + + validate_model_rng_state <- function(model, properties, call) { not_stochastic <- isFALSE(properties$is_stochastic) || ( is.null(properties$is_stochastic) && @@ -582,7 +618,8 @@ print.monty_model <- function(x, ...) { monty_model_properties_str <- function(properties) { c(if (properties$has_gradient) "can compute gradients", if (properties$has_direct_sample) "can be directly sampled from", - if (properties$is_stochastic) "is stochastic") + if (properties$is_stochastic) "is stochastic", + if (properties$has_observer) "has an observer") } diff --git a/R/runner.R b/R/runner.R index 441129b4..78538cf3 100644 --- a/R/runner.R +++ b/R/runner.R @@ -199,8 +199,7 @@ monty_run_chain2 <- function(chain_state, model, sampler, n_steps, progress, rng, r_rng_state) { initial <- chain_state$pars n_pars <- length(model$parameters) - ## has_observer <- model$properties$has_observer - has_observer <- FALSE + has_observer <- model$properties$has_observer history_pars <- matrix(NA_real_, n_pars, n_steps) history_density <- rep(NA_real_, n_steps) @@ -210,7 +209,7 @@ monty_run_chain2 <- function(chain_state, model, sampler, n_steps, chain_state <- sampler$step(chain_state, model, rng) history_pars[, i] <- chain_state$pars history_density[[i]] <- chain_state$density - if (!is.null(chain_state$observation)) { + if (has_observer && !is.null(chain_state$observation)) { history_observation[[i]] <- chain_state$observation } progress(i) @@ -223,7 +222,7 @@ monty_run_chain2 <- function(chain_state, model, sampler, n_steps, details <- sampler$finalise(chain_state, model, rng) if (has_observer) { - ## history_observation <- model$observer$finalise(history_observation) + history_observation <- model$observer$finalise(history_observation) } ## This list will hold things that we'll use internally but not diff --git a/R/sample.R b/R/sample.R index e2f96d04..c38f47ce 100644 --- a/R/sample.R +++ b/R/sample.R @@ -97,9 +97,7 @@ monty_sample <- function(model, sampler, n_steps, initial = NULL, pars <- initial_parameters(initial, model, rng, environment()) res <- runner$run(pars, model, sampler, n_steps, rng) - ## TODO: everywhere we write this combine_chains with observer, - ## respect the property, not the presence of a method. Will be done - ## in the next PR. + observer <- if (model$properties$has_observer) model$observer else NULL samples <- combine_chains(res, model$observer) if (restartable) { samples$restart <- restart_data(res, model, sampler, runner) @@ -153,7 +151,7 @@ monty_sample_continue <- function(samples, n_steps, restartable = FALSE, res <- runner$continue(state, model, sampler, n_steps) - observer <- model$observer # See above about respecting properties + observer <- if (model$properties$has_observer) model$observer else NULL samples <- append_chains(samples, combine_chains(res, observer), observer) if (restartable) { diff --git a/R/sampler-helpers.R b/R/sampler-helpers.R index 0f8b7479..c9a70684 100644 --- a/R/sampler-helpers.R +++ b/R/sampler-helpers.R @@ -1,8 +1,11 @@ initialise_state <- function(pars, model, rng) { initialise_rng_state(model, rng) density <- model$density(pars) - ## TODO: in next PR observe. - observation <- NULL + if (model$properties$has_observer) { + observation <- model$observer$observe() + } else { + observation <- NULL + } list(pars = pars, density = density, observation = observation) } @@ -16,7 +19,9 @@ update_state <- function(state, pars, density, accept, model, rng) { } else { state$pars <- pars state$density <- density - ## TODO: cope with observation + if (model$properties$has_observer) { + state$observation <- model$observer$observe() + } } } state diff --git a/tests/testthat/helper-monty.R b/tests/testthat/helper-monty.R index aaf272a8..1d98e076 100644 --- a/tests/testthat/helper-monty.R +++ b/tests/testthat/helper-monty.R @@ -109,20 +109,6 @@ ex_dust_sir <- function(n_particles = 100, n_threads = 1, trajectories <- NULL - ## In the new dust wrapper we'll need to make this nicer; I think - ## that this is pretty painful atm because we wrap via the particle - ## filter method in mcstate1. This version replicates most of what - ## we need though, which is some subset of the model - details <- function(idx_particle) { - if (save_trajectories) { - traj <- trajectories[, idx_particle, , drop = FALSE] - dim(traj) <- dim(traj)[-2] - } else { - traj <- NULL - } - list(trajectories = traj, state = model$state()[, idx_particle]) - } - density <- function(x) { beta <- x[[1]] gamma <- x[[2]] @@ -163,13 +149,38 @@ ex_dust_sir <- function(n_particles = 100, n_threads = 1, model$rng_state() } + if (save_trajectories) { + observer <- monty_observer( + function() { + ## TODO: It's not really clear to me (Rich) that we want the + ## rng coming in here. In dust2 we'll correctly use the + ## *filter* rng. however, even there we end up with the act + ## of observation changing the behaviour of the sampler and I + ## am not sure that is really desirable. We could probably + ## improve that in dust, but that would require that we do not + ## pass the rng here too. So for now we take the first + ## particle. + ## i <- floor(rng$random_real(1) * model$model$n_particles()) + 1L + i <- 1L + if (save_trajectories) { + traj <- trajectories[, i, , drop = FALSE] + dim(traj) <- dim(traj)[-2] + } else { + traj <- NULL + } + list(trajectories = traj, state = model$state()[, i]) + }) + } else { + observer <- NULL + } + monty_model( list(model = model, - details = details, density = density, direct_sample = direct_sample, parameters = c("beta", "gamma"), domain = cbind(c(0, 0), c(Inf, Inf)), + observer = observer, set_rng_state = set_rng_state, get_rng_state = get_rng_state), monty_model_properties(is_stochastic = !deterministic)) @@ -213,20 +224,6 @@ ex_dust_sir_likelihood <- function(n_particles = 100, n_threads = 1, trajectories <- NULL - ## In the new dust wrapper we'll need to make this nicer; I think - ## that this is pretty painful atm because we wrap via the particle - ## filter method in mcstate1. This version replicates most of what - ## we need though, which is some subset of the model - details <- function(idx_particle) { - if (save_trajectories) { - traj <- trajectories[, idx_particle, , drop = FALSE] - dim(traj) <- dim(traj)[-2] - } else { - traj <- NULL - } - list(trajectories = traj, state = model$state()[, idx_particle]) - } - density <- function(x) { beta <- x[[1]] gamma <- x[[2]] @@ -255,11 +252,36 @@ ex_dust_sir_likelihood <- function(n_particles = 100, n_threads = 1, model$rng_state() } + if (save_trajectories) { + observer <- monty_observer( + function() { + ## TODO: It's not really clear to me (Rich) that we want the + ## rng coming in here. In dust2 we'll correctly use the + ## *filter* rng. however, even there we end up with the act + ## of observation changing the behaviour of the sampler and I + ## am not sure that is really desirable. We could probably + ## improve that in dust, but that would require that we do not + ## pass the rng here too. So for now we take the first + ## particle. + ## i <- floor(rng$random_real(1) * model$model$n_particles()) + 1L + i <- 1L + if (save_trajectories) { + traj <- trajectories[, i, , drop = FALSE] + dim(traj) <- dim(traj)[-2] + } else { + traj <- NULL + } + list(trajectories = traj, state = model$state()[, i]) + }) + } else { + observer <- NULL + } + monty_model( list(model = model, - details = details, density = density, parameters = c("beta", "gamma"), + observer = observer, set_rng_state = set_rng_state, get_rng_state = get_rng_state), monty_model_properties(is_stochastic = !deterministic)) diff --git a/tests/testthat/test-sampler-random-walk.R b/tests/testthat/test-sampler-random-walk.R index 0b1e9681..d43339f7 100644 --- a/tests/testthat/test-sampler-random-walk.R +++ b/tests/testthat/test-sampler-random-walk.R @@ -36,20 +36,13 @@ test_that("can draw samples from a random model", { test_that("can observe a model", { - skip("FIXME: add model-based observer") m <- ex_dust_sir(save_trajectories = TRUE) vcv <- matrix(c(0.0006405, 0.0005628, 0.0005628, 0.0006641), 2, 2) sampler <- monty_sampler_random_walk(vcv = vcv) - observer <- monty_observer( - function(model, rng) { - i <- floor(rng$random_real(1) * model$model$n_particles()) + 1L - model$details(i) - }) - ## This takes quite a while, and that seems mostly to be the time ## taken to call the filter in dust. - res <- monty_sample(m, sampler, 20, n_chains = 3, observer = observer) + res <- monty_sample(m, sampler, 20, n_chains = 3) expect_setequal(names(res), c("pars", "density", "initial", "details", "observations")) expect_equal(names(res$observations), @@ -64,24 +57,15 @@ test_that("can observe a model", { test_that("can continue observed models", { - skip("FIXME: add model-based observer") m <- ex_dust_sir(save_trajectories = TRUE) vcv <- matrix(c(0.0006405, 0.0005628, 0.0005628, 0.0006641), 2, 2) sampler <- monty_sampler_random_walk(vcv = vcv) - observer <- monty_observer( - function(model, rng) { - ## ideally we get a random sample here, but that's not easy with - ## current dust - model$details(4) - }) - set.seed(1) - res1 <- monty_sample(m, sampler, 15, n_chains = 3, observer = observer) + res1 <- monty_sample(m, sampler, 15, n_chains = 3) set.seed(1) - res2a <- monty_sample(m, sampler, 5, n_chains = 3, observer = observer, - restartable = TRUE) + res2a <- monty_sample(m, sampler, 5, n_chains = 3, restartable = TRUE) res2b <- monty_sample_continue(res2a, 10) expect_equal(res1$observations, res2b$observations) From 6f90e2acb1bde7fcc0464552a77e55a20a795df0 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 9 Oct 2024 15:57:55 +0100 Subject: [PATCH 2/9] Bump version --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index aa63f4be..78ce0de9 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: monty Title: Monte Carlo Models -Version: 0.2.10 +Version: 0.2.11 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Wes", "Hinsley", role = "aut"), From fcaad7444ebf6edbbf1051cf4dab841322122b0a Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 9 Oct 2024 16:38:48 +0100 Subject: [PATCH 3/9] Fix bugs found using this from dust --- R/combine.R | 5 +---- R/model.R | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/R/combine.R b/R/combine.R index 793775bb..8d36578a 100644 --- a/R/combine.R +++ b/R/combine.R @@ -309,8 +309,5 @@ model_combine_observer <- function(a, b, parameters, properties, } model <- if (a$properties$has_observer) a else b - - function(...) { - model$observer(...) - } + model$observer } diff --git a/R/model.R b/R/model.R index 949093ba..8921bdfc 100644 --- a/R/model.R +++ b/R/model.R @@ -467,7 +467,7 @@ validate_model_observer <- function(model, properties, call) { } if (!is.null(value) && !inherits(value, "monty_observer")) { cli::cli_abort( - "Expected 'model${method_name}' to be a 'monty_observer' if non-NULL", + "Expected 'model$observer' to be a 'monty_observer' if non-NULL", arg = "model", call = call) } value From c6d29c5eb9da869dfa94b6b1f6faa7d75dfda721 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 9 Oct 2024 17:25:33 +0100 Subject: [PATCH 4/9] Redocument --- man/monty_model_properties.Rd | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/man/monty_model_properties.Rd b/man/monty_model_properties.Rd index a7a02b90..e369b332 100644 --- a/man/monty_model_properties.Rd +++ b/man/monty_model_properties.Rd @@ -9,6 +9,7 @@ monty_model_properties( has_direct_sample = NULL, is_stochastic = NULL, has_parameter_groups = NULL, + has_observer = NULL, allow_multiple_parameters = FALSE ) } @@ -33,6 +34,12 @@ by the presence of a \code{by_group} argument to \code{density} and (later we may also support this in \code{gradient}). Use \code{NULL} (the default) to detect this from the model.} +\item{has_observer}{Logical, indicating if the model has an +"observation" function, which we will describe more fully soon. +An observer is a function \code{observe} which takes no arguments and +returns arbitrary data about the previously evaluated density. +Use \code{NULL} (the default) to detect this from the model.} + \item{allow_multiple_parameters}{Logical, indicating if the density calculation can support being passed a matrix of parameters (with each column corresponding to a different From 40623b9b88231e48cf5cc135c64153088373d38a Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 9 Oct 2024 17:42:08 +0100 Subject: [PATCH 5/9] Expand tests --- R/combine.R | 6 +++--- R/model.R | 2 +- R/observer.R | 1 - tests/testthat/test-combine.R | 40 +++++++++++++++++++++++++++++++++++ tests/testthat/test-model.R | 12 +++++++++++ 5 files changed, 56 insertions(+), 5 deletions(-) diff --git a/R/combine.R b/R/combine.R index 8d36578a..a891dd93 100644 --- a/R/combine.R +++ b/R/combine.R @@ -296,14 +296,14 @@ model_combine_observer <- function(a, b, parameters, properties, } if (required && !possible) { if (a$properties$has_observer) { - hint <- paste("Both models have a 'observer' method so we can't", + hint <- paste("Both models have an 'observer' object so we can't", "combine them. Set 'has_observer = FALSE' on one", "of your models and try again") } else { - hint <- "Neither of your models have 'observer' methods" + hint <- "Neither of your models have 'observer' objects" } cli::cli_abort( - c("Can't create a observer from these models", + c("Can't create an observer from these models", i = hint), call = call) } diff --git a/R/model.R b/R/model.R index 8921bdfc..e19cab21 100644 --- a/R/model.R +++ b/R/model.R @@ -467,7 +467,7 @@ validate_model_observer <- function(model, properties, call) { } if (!is.null(value) && !inherits(value, "monty_observer")) { cli::cli_abort( - "Expected 'model$observer' to be a 'monty_observer' if non-NULL", + "Expected 'model$observer' to be a 'monty_observer' object if non-NULL", arg = "model", call = call) } value diff --git a/R/observer.R b/R/observer.R index a46a04ac..55243f4a 100644 --- a/R/observer.R +++ b/R/observer.R @@ -70,7 +70,6 @@ monty_observer <- function(observe, ##' @export print.monty_observer <- function(x, ...) { cli::cli_h1("") - cli::cli_alert_info("Use {.help monty_sample} to use this observer") cli::cli_alert_info("See {.help monty_observer} for more information") invisible(x) } diff --git a/tests/testthat/test-combine.R b/tests/testthat/test-combine.R index 21d96a04..a1ddc6b1 100644 --- a/tests/testthat/test-combine.R +++ b/tests/testthat/test-combine.R @@ -226,3 +226,43 @@ test_that("Can't force creation of stochastic model from deterministic", { monty_model_combine(a, a, monty_model_properties(is_stochastic = TRUE)), "Can't create stochastic support functions for these models") }) + + +test_that("cominining models with observers is possible", { + a <- monty_model( + list( + parameters = "x", + density = identity, + observer = monty_observer(identity))) + b <- monty_model( + list( + parameters = "x", + density = identity)) + c <- a + b + expect_true(c$properties$has_observer) + expect_identical(c$observer, a$observer) +}) + + +test_that("Can't create observer where both models have them", { + a <- monty_model( + list( + parameters = "x", + density = identity, + observer = monty_observer(identity))) + b <- a + a + expect_false(b$properties$has_observer) + expect_null(b$observer) + + properties <- monty_model_properties(has_observer = TRUE) + err <- expect_error( + monty_model_combine(a, a, properties = properties), + "Can't create an observer from these models") + expect_match(conditionMessage(err), + "Both models have an 'observer' object") + err <- expect_error( + monty_model_combine(b, b, properties = properties), + "Can't create an observer from these models") + expect_match(conditionMessage(err), + "Neither of your models have 'observer' objects") +}) diff --git a/tests/testthat/test-model.R b/tests/testthat/test-model.R index 1c585931..631e9776 100644 --- a/tests/testthat/test-model.R +++ b/tests/testthat/test-model.R @@ -6,6 +6,7 @@ test_that("can create a minimal model", { monty_model_properties(has_gradient = FALSE, has_direct_sample = FALSE, is_stochastic = FALSE, + has_observer = FALSE, has_parameter_groups = FALSE)) expect_equal(m$domain, rbind(a = c(-Inf, Inf))) expect_equal(m$parameters, "a") @@ -20,6 +21,7 @@ test_that("can create a more interesting model", { has_direct_sample = TRUE, is_stochastic = FALSE, has_parameter_groups = FALSE, + has_observer = FALSE, allow_multiple_parameters = TRUE)) expect_equal(m$domain, rbind(gamma = c(0, Inf))) expect_equal(m$parameters, "gamma") @@ -60,6 +62,16 @@ test_that("require direct sample is a function if given", { }) +test_that("require observer is a monty_observer if given", { + expect_error( + monty_model(list(density = identity, + observer = TRUE, + parameters = "a")), + "Expected 'model$observer' to be a 'monty_observer' object if non-NULL", + fixed = TRUE) +}) + + test_that("validate domain", { expect_error( monty_model(list(density = identity, parameters = "a", domain = list())), From b943bd90404f17286e32d8d3b134d7fec8907ead Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 10 Oct 2024 09:19:02 +0100 Subject: [PATCH 6/9] Expand testing --- tests/testthat/test-combine.R | 4 ++++ tests/testthat/test-model.R | 21 +++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/tests/testthat/test-combine.R b/tests/testthat/test-combine.R index a1ddc6b1..93cc91c3 100644 --- a/tests/testthat/test-combine.R +++ b/tests/testthat/test-combine.R @@ -265,4 +265,8 @@ test_that("Can't create observer where both models have them", { "Can't create an observer from these models") expect_match(conditionMessage(err), "Neither of your models have 'observer' objects") + + properties <- monty_model_properties(has_observer = FALSE) + res <- monty_model_combine(a, a, properties = properties) + expect_false(res$properties$has_observer) }) diff --git a/tests/testthat/test-model.R b/tests/testthat/test-model.R index 631e9776..1fe3f3ca 100644 --- a/tests/testthat/test-model.R +++ b/tests/testthat/test-model.R @@ -72,6 +72,27 @@ test_that("require observer is a monty_observer if given", { }) +test_that("observer must exist if required", { + properties <- monty_model_properties(has_observer = TRUE) + expect_error( + monty_model(list(parameters = "a", density = identity), + properties = properties), + "Did not find a 'monty_observer' object 'observer' within your model") +}) + + +test_that("Ignore invalid observer if properties say to ignore it", { + properties <- monty_model_properties(has_observer = FALSE) + res <- monty_model(list(density = identity, + observer = TRUE, + parameters = "a"), + properties = properties) + expect_s3_class(res, "monty_model") + expect_false(res$properties$has_observer) + expect_null(res$observer) +}) + + test_that("validate domain", { expect_error( monty_model(list(density = identity, parameters = "a", domain = list())), From dc3487f7e7d1f41a7ade74f69be9657d75bb4b44 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 10 Oct 2024 09:42:58 +0100 Subject: [PATCH 7/9] Support nested models, sort of --- R/sampler-nested-adaptive.R | 17 ++++++++++++----- R/sampler-nested-random-walk.R | 17 ++++++++++++----- tests/testthat/test-sampler-nested-adaptive.R | 8 +++++--- .../testthat/test-sampler-nested-random-walk.R | 8 +++++--- 4 files changed, 34 insertions(+), 16 deletions(-) diff --git a/R/sampler-nested-adaptive.R b/R/sampler-nested-adaptive.R index c168a43e..5f378861 100644 --- a/R/sampler-nested-adaptive.R +++ b/R/sampler-nested-adaptive.R @@ -146,9 +146,10 @@ monty_sampler_nested_adaptive <- function(initial_vcv, } state <- list(pars = pars, density = c(density)) - ## TODO: we need to fix observation here; it should move into a - ## helper as part of a setup I think; see the sampler-helpers for - ## the single-parameter case. + ## TODO: mrc-5862 + if (model$properties$has_observer) { + state$observation <- m$observer$observe() + } state } @@ -201,7 +202,10 @@ monty_sampler_nested_adaptive <- function(initial_vcv, state$pars <- pars_next state$density <- density_next internal$density_by_group <- density_by_group_next - ## TODO: observe here + ## TODO: mrc-5862 + if (model$properties$has_observer) { + state$observation <- m$observer$observe() + } } } else { accept_prob_base <- NULL @@ -272,7 +276,10 @@ monty_sampler_nested_adaptive <- function(initial_vcv, state$pars <- pars_next state$density <- c(density_next) internal$density_by_group <- density_by_group_next - ## TODO: observe here + ## TODO: mrc-5862 + if (model$properties$has_observer) { + state$observation <- m$observer$observe() + } } if (internal$multiple_parameters) { diff --git a/R/sampler-nested-random-walk.R b/R/sampler-nested-random-walk.R index 77f1fca0..623cf782 100644 --- a/R/sampler-nested-random-walk.R +++ b/R/sampler-nested-random-walk.R @@ -116,9 +116,10 @@ monty_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") { internal$density_by_group <- density_by_group state <- list(pars = pars, density = c(density)) - ## TODO: we need to fix observation here; it should move into a - ## helper as part of a setup I think; see the sampler-helpers for - ## the single-parameter case. + ## TODO: mrc-5862 + if (model$properties$has_observer) { + state$observation <- m$observer$observe() + } state } @@ -163,7 +164,10 @@ monty_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") { state$pars <- pars_next state$density <- density_next internal$density_by_group <- density_by_group_next - ## TODO: observe here + ## TODO: mrc-5862 + if (model$properties$has_observer) { + state$observation <- m$observer$observe() + } } } @@ -227,7 +231,10 @@ monty_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") { state$pars <- pars_next state$density <- c(density_next) internal$density_by_group <- density_by_group_next - ## TODO: observe here + ## TODO: mrc-5862 + if (model$properties$has_observer) { + state$observation <- m$observer$observe() + } } state } diff --git a/tests/testthat/test-sampler-nested-adaptive.R b/tests/testthat/test-sampler-nested-adaptive.R index 02def6e4..d100abb6 100644 --- a/tests/testthat/test-sampler-nested-adaptive.R +++ b/tests/testthat/test-sampler-nested-adaptive.R @@ -120,18 +120,20 @@ test_that("can run a sampler with shared parameters", { test_that("can run an observer during a nested fit", { - skip("FIXME: add model-based observer") set.seed(1) ng <- 5 m <- ex_simple_nested_with_base(ng) s <- monty_sampler_nested_adaptive( list(base = diag(1), groups = rep(list(diag(1)), ng))) counter <- 0 - observer <- monty_observer(function(...) { + ## Directly wire this in for now; we really just need better + ## examples here. + m$observer <- monty_observer(function(...) { counter <<- counter + 1 list(n = counter) }) - res <- monty_sample(m, s, 100, observer = observer) + m$properties$has_observer <- TRUE + res <- monty_sample(m, s, 100) expect_equal( dim(res$observations$n), c(1, 100, 1)) diff --git a/tests/testthat/test-sampler-nested-random-walk.R b/tests/testthat/test-sampler-nested-random-walk.R index 5186e689..6ba86111 100644 --- a/tests/testthat/test-sampler-nested-random-walk.R +++ b/tests/testthat/test-sampler-nested-random-walk.R @@ -186,18 +186,20 @@ test_that("can run a sampler with shared parameters", { test_that("can run an observer during a nested fit", { - skip("FIXME: add model-based observer") set.seed(1) ng <- 5 m <- ex_simple_nested_with_base(ng) s <- monty_sampler_nested_random_walk( list(base = diag(1), groups = rep(list(diag(1)), ng))) counter <- 0 - observer <- monty_observer(function(...) { + ## Directly wire this in for now; we really just need better + ## examples here. + m$observer <- monty_observer(function(...) { counter <<- counter + 1 list(n = counter) }) - res <- monty_sample(m, s, 100, observer = observer) + m$properties$has_observer <- TRUE + res <- monty_sample(m, s, 100) expect_equal( dim(res$observations$n), c(1, 100, 1)) From e7e1ce2bab0d30041344914ff1160cd658864e62 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 10 Oct 2024 09:50:45 +0100 Subject: [PATCH 8/9] Fix lookup --- R/sampler-nested-adaptive.R | 6 +++--- R/sampler-nested-random-walk.R | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/R/sampler-nested-adaptive.R b/R/sampler-nested-adaptive.R index 5f378861..634f4197 100644 --- a/R/sampler-nested-adaptive.R +++ b/R/sampler-nested-adaptive.R @@ -148,7 +148,7 @@ monty_sampler_nested_adaptive <- function(initial_vcv, state <- list(pars = pars, density = c(density)) ## TODO: mrc-5862 if (model$properties$has_observer) { - state$observation <- m$observer$observe() + state$observation <- model$observer$observe() } state } @@ -204,7 +204,7 @@ monty_sampler_nested_adaptive <- function(initial_vcv, internal$density_by_group <- density_by_group_next ## TODO: mrc-5862 if (model$properties$has_observer) { - state$observation <- m$observer$observe() + state$observation <- model$observer$observe() } } } else { @@ -278,7 +278,7 @@ monty_sampler_nested_adaptive <- function(initial_vcv, internal$density_by_group <- density_by_group_next ## TODO: mrc-5862 if (model$properties$has_observer) { - state$observation <- m$observer$observe() + state$observation <- model$observer$observe() } } diff --git a/R/sampler-nested-random-walk.R b/R/sampler-nested-random-walk.R index 623cf782..04447f94 100644 --- a/R/sampler-nested-random-walk.R +++ b/R/sampler-nested-random-walk.R @@ -118,7 +118,7 @@ monty_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") { state <- list(pars = pars, density = c(density)) ## TODO: mrc-5862 if (model$properties$has_observer) { - state$observation <- m$observer$observe() + state$observation <- model$observer$observe() } state } @@ -166,7 +166,7 @@ monty_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") { internal$density_by_group <- density_by_group_next ## TODO: mrc-5862 if (model$properties$has_observer) { - state$observation <- m$observer$observe() + state$observation <- model$observer$observe() } } } @@ -233,7 +233,7 @@ monty_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") { internal$density_by_group <- density_by_group_next ## TODO: mrc-5862 if (model$properties$has_observer) { - state$observation <- m$observer$observe() + state$observation <- model$observer$observe() } } state From 319a828f6a69c594b24aa75c4ade931a9a7997f5 Mon Sep 17 00:00:00 2001 From: Wes Hinsley Date: Fri, 11 Oct 2024 16:07:56 +0100 Subject: [PATCH 9/9] Update tests/testthat/test-combine.R --- tests/testthat/test-combine.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test-combine.R b/tests/testthat/test-combine.R index 93cc91c3..525415c0 100644 --- a/tests/testthat/test-combine.R +++ b/tests/testthat/test-combine.R @@ -228,7 +228,7 @@ test_that("Can't force creation of stochastic model from deterministic", { }) -test_that("cominining models with observers is possible", { +test_that("combining models with observers is possible", { a <- monty_model( list( parameters = "x",