Skip to content

Commit

Permalink
Merge pull request #77 from mrc-ide/mrc-5859
Browse files Browse the repository at this point in the history
Add model-based observer
  • Loading branch information
richfitz authored Oct 11, 2024
2 parents 2fa59ec + 319a828 commit db71f46
Show file tree
Hide file tree
Showing 16 changed files with 258 additions and 81 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.2.13
Version: 0.2.14
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
33 changes: 32 additions & 1 deletion R/combine.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,14 @@ monty_model_combine <- function(a, b, properties = NULL,
domain <- model_combine_domain(a, b, parameters)
density <- model_combine_density(a, b, parameters)


gradient <- model_combine_gradient(
a, b, parameters, properties, call)
direct_sample <- model_combine_direct_sample(
a, b, parameters, properties, name_a, name_b, call)
stochastic <- model_combine_stochastic(
a, b, properties)
observer <- model_combine_observer(
a, b, parameters, properties, name_a, name_b, call)

monty_model(
list(model = list(a, b),
Expand All @@ -109,6 +110,7 @@ monty_model_combine <- function(a, b, properties = NULL,
gradient = gradient,
get_rng_state = stochastic$get_rng_state,
set_rng_state = stochastic$set_rng_state,
observer = observer,
direct_sample = direct_sample),
properties)
}
Expand Down Expand Up @@ -280,3 +282,32 @@ model_combine_direct_sample <- function(a, b, parameters, properties,
model$direct_sample(...)[i]
}
}


model_combine_observer <- function(a, b, parameters, properties,
name_a, name_b, call = NULL) {
if (isFALSE(properties$has_observer)) {
return(NULL)
}
possible <- a$properties$has_observer != b$properties$has_observer
required <- isTRUE(properties$has_observer)
if (!possible && !required) {
return(NULL)
}
if (required && !possible) {
if (a$properties$has_observer) {
hint <- paste("Both models have an 'observer' object so we can't",
"combine them. Set 'has_observer = FALSE' on one",
"of your models and try again")
} else {
hint <- "Neither of your models have 'observer' objects"
}
cli::cli_abort(
c("Can't create an observer from these models",
i = hint),
call = call)
}

model <- if (a$properties$has_observer) a else b
model$observer
}
41 changes: 39 additions & 2 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
##' we may also support this in `gradient`). Use `NULL` (the
##' default) to detect this from the model.
##'
##' @param has_observer Logical, indicating if the model has an
##' "observation" function, which we will describe more fully soon.
##' An observer is a function `observe` which takes no arguments and
##' returns arbitrary data about the previously evaluated density.
##' Use `NULL` (the default) to detect this from the model.
##'
##' @param allow_multiple_parameters Logical, indicating if the
##' density calculation can support being passed a matrix of
##' parameters (with each column corresponding to a different
Expand All @@ -51,13 +57,20 @@ monty_model_properties <- function(has_gradient = NULL,
has_direct_sample = NULL,
is_stochastic = NULL,
has_parameter_groups = NULL,
has_observer = NULL,
allow_multiple_parameters = FALSE) {
## TODO: What name do we want for this property, really?
assert_scalar_logical(has_gradient, allow_null = TRUE)
assert_scalar_logical(has_direct_sample, allow_null = TRUE)
assert_scalar_logical(is_stochastic, allow_null = TRUE)
assert_scalar_logical(has_parameter_groups, allow_null = TRUE)
assert_scalar_logical(has_observer, allow_null = TRUE)
assert_scalar_logical(allow_multiple_parameters)

ret <- list(has_gradient = has_gradient,
has_direct_sample = has_direct_sample,
is_stochastic = is_stochastic,
has_parameter_groups = has_parameter_groups,
has_observer = has_observer,
## TODO: I am not convinced on this name
allow_multiple_parameters = allow_multiple_parameters)
class(ret) <- "monty_model_properties"
Expand Down Expand Up @@ -184,6 +197,7 @@ monty_model <- function(model, properties = NULL) {
properties <- validate_model_properties(properties, call)
gradient <- validate_model_gradient(model, properties, call)
direct_sample <- validate_model_direct_sample(model, properties, call)
observer <- validate_model_observer(model, properties, call)
rng_state <- validate_model_rng_state(model, properties, call)
parameter_groups <- validate_model_parameter_groups(model, properties, call)

Expand All @@ -192,6 +206,7 @@ monty_model <- function(model, properties = NULL) {
properties$has_direct_sample <- !is.null(direct_sample)
properties$is_stochastic <- !is.null(rng_state$set)
properties$has_parameter_groups <- !is.null(parameter_groups)
properties$has_observer <- !is.null(observer)

ret <- list(model = model,
parameters = parameters,
Expand All @@ -200,6 +215,7 @@ monty_model <- function(model, properties = NULL) {
density = density,
gradient = gradient,
direct_sample = direct_sample,
observer = observer,
rng_state = rng_state,
properties = properties)
class(ret) <- "monty_model"
Expand Down Expand Up @@ -438,6 +454,26 @@ validate_model_direct_sample <- function(model, properties, call) {
}


validate_model_observer <- function(model, properties, call) {
if (isFALSE(properties$has_observer)) {
return(NULL)
}
value <- model$observer
if (isTRUE(properties$has_observer) && !inherits(value, "monty_observer")) {
cli::cli_abort(
paste("Did not find a 'monty_observer' object 'observer' within",
"your model, but your properties say that it should exist"),
arg = "model", call = call)
}
if (!is.null(value) && !inherits(value, "monty_observer")) {
cli::cli_abort(
"Expected 'model$observer' to be a 'monty_observer' object if non-NULL",
arg = "model", call = call)
}
value
}


validate_model_rng_state <- function(model, properties, call) {
not_stochastic <- isFALSE(properties$is_stochastic) || (
is.null(properties$is_stochastic) &&
Expand Down Expand Up @@ -582,7 +618,8 @@ print.monty_model <- function(x, ...) {
monty_model_properties_str <- function(properties) {
c(if (properties$has_gradient) "can compute gradients",
if (properties$has_direct_sample) "can be directly sampled from",
if (properties$is_stochastic) "is stochastic")
if (properties$is_stochastic) "is stochastic",
if (properties$has_observer) "has an observer")
}


Expand Down
1 change: 0 additions & 1 deletion R/observer.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ monty_observer <- function(observe,
##' @export
print.monty_observer <- function(x, ...) {
cli::cli_h1("<monty_observer>")
cli::cli_alert_info("Use {.help monty_sample} to use this observer")
cli::cli_alert_info("See {.help monty_observer} for more information")
invisible(x)
}
Expand Down
7 changes: 3 additions & 4 deletions R/runner.R
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ monty_run_chain2 <- function(chain_state, model, sampler, n_steps,
progress, rng, r_rng_state) {
initial <- chain_state$pars
n_pars <- length(model$parameters)
## has_observer <- model$properties$has_observer
has_observer <- FALSE
has_observer <- model$properties$has_observer

history_pars <- matrix(NA_real_, n_pars, n_steps)
history_density <- rep(NA_real_, n_steps)
Expand All @@ -210,7 +209,7 @@ monty_run_chain2 <- function(chain_state, model, sampler, n_steps,
chain_state <- sampler$step(chain_state, model, rng)
history_pars[, i] <- chain_state$pars
history_density[[i]] <- chain_state$density
if (!is.null(chain_state$observation)) {
if (has_observer && !is.null(chain_state$observation)) {
history_observation[[i]] <- chain_state$observation
}
progress(i)
Expand All @@ -223,7 +222,7 @@ monty_run_chain2 <- function(chain_state, model, sampler, n_steps,
details <- sampler$finalise(chain_state, model, rng)

if (has_observer) {
## history_observation <- model$observer$finalise(history_observation)
history_observation <- model$observer$finalise(history_observation)
}

## This list will hold things that we'll use internally but not
Expand Down
6 changes: 2 additions & 4 deletions R/sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ monty_sample <- function(model, sampler, n_steps, initial = NULL,
pars <- initial_parameters(initial, model, rng, environment())
res <- runner$run(pars, model, sampler, n_steps, rng)

## TODO: everywhere we write this combine_chains with observer,
## respect the property, not the presence of a method. Will be done
## in the next PR.
observer <- if (model$properties$has_observer) model$observer else NULL
samples <- combine_chains(res, model$observer)
if (restartable) {
samples$restart <- restart_data(res, model, sampler, runner)
Expand Down Expand Up @@ -153,7 +151,7 @@ monty_sample_continue <- function(samples, n_steps, restartable = FALSE,

res <- runner$continue(state, model, sampler, n_steps)

observer <- model$observer # See above about respecting properties
observer <- if (model$properties$has_observer) model$observer else NULL
samples <- append_chains(samples, combine_chains(res, observer), observer)

if (restartable) {
Expand Down
11 changes: 8 additions & 3 deletions R/sampler-helpers.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
initialise_state <- function(pars, model, rng) {
initialise_rng_state(model, rng)
density <- model$density(pars)
## TODO: in next PR observe.
observation <- NULL
if (model$properties$has_observer) {
observation <- model$observer$observe()
} else {
observation <- NULL
}
list(pars = pars, density = density, observation = observation)
}

Expand All @@ -16,7 +19,9 @@ update_state <- function(state, pars, density, accept, model, rng) {
} else {
state$pars <- pars
state$density <- density
## TODO: cope with observation
if (model$properties$has_observer) {
state$observation <- model$observer$observe()
}
}
}
state
Expand Down
17 changes: 12 additions & 5 deletions R/sampler-nested-adaptive.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,10 @@ monty_sampler_nested_adaptive <- function(initial_vcv,
}

state <- list(pars = pars, density = c(density))
## TODO: we need to fix observation here; it should move into a
## helper as part of a setup I think; see the sampler-helpers for
## the single-parameter case.
## TODO: mrc-5862
if (model$properties$has_observer) {
state$observation <- model$observer$observe()
}
state
}

Expand Down Expand Up @@ -201,7 +202,10 @@ monty_sampler_nested_adaptive <- function(initial_vcv,
state$pars <- pars_next
state$density <- density_next
internal$density_by_group <- density_by_group_next
## TODO: observe here
## TODO: mrc-5862
if (model$properties$has_observer) {
state$observation <- model$observer$observe()
}
}
} else {
accept_prob_base <- NULL
Expand Down Expand Up @@ -272,7 +276,10 @@ monty_sampler_nested_adaptive <- function(initial_vcv,
state$pars <- pars_next
state$density <- c(density_next)
internal$density_by_group <- density_by_group_next
## TODO: observe here
## TODO: mrc-5862
if (model$properties$has_observer) {
state$observation <- model$observer$observe()
}
}

if (internal$multiple_parameters) {
Expand Down
17 changes: 12 additions & 5 deletions R/sampler-nested-random-walk.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ monty_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") {

internal$density_by_group <- density_by_group
state <- list(pars = pars, density = c(density))
## TODO: we need to fix observation here; it should move into a
## helper as part of a setup I think; see the sampler-helpers for
## the single-parameter case.
## TODO: mrc-5862
if (model$properties$has_observer) {
state$observation <- model$observer$observe()
}
state
}

Expand Down Expand Up @@ -163,7 +164,10 @@ monty_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") {
state$pars <- pars_next
state$density <- density_next
internal$density_by_group <- density_by_group_next
## TODO: observe here
## TODO: mrc-5862
if (model$properties$has_observer) {
state$observation <- model$observer$observe()
}
}
}

Expand Down Expand Up @@ -227,7 +231,10 @@ monty_sampler_nested_random_walk <- function(vcv, boundaries = "reflect") {
state$pars <- pars_next
state$density <- c(density_next)
internal$density_by_group <- density_by_group_next
## TODO: observe here
## TODO: mrc-5862
if (model$properties$has_observer) {
state$observation <- model$observer$observe()
}
}
state
}
Expand Down
7 changes: 7 additions & 0 deletions man/monty_model_properties.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit db71f46

Please sign in to comment.