Skip to content

Commit

Permalink
Merge branch 'main' into mrc-5859
Browse files Browse the repository at this point in the history
  • Loading branch information
weshinsley authored Oct 11, 2024
2 parents e7e1ce2 + 2fa59ec commit 6d1ece8
Show file tree
Hide file tree
Showing 13 changed files with 321 additions and 51 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.2.11
Version: 0.2.14
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
89 changes: 50 additions & 39 deletions R/dsl-generate.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
dsl_generate <- function(dat) {
env <- new.env(parent = asNamespace("monty"))
env$packer <- monty_packer(dat$parameters)

density <- dsl_generate_density(dat, env)
direct_sample <- dsl_generate_direct_sample(dat, env)
gradient <- dsl_generate_gradient(dat, env)
domain <- dsl_generate_domain(dat)
env$fixed <- dat$fixed

meta <- list(
pars = quote(pars),
data = quote(data),
density = quote(density),
fixed = quote(fixed),
fixed_contents = names(env$fixed))

density <- dsl_generate_density(dat, env, meta)
direct_sample <- dsl_generate_direct_sample(dat, env, meta)
gradient <- dsl_generate_gradient(dat, env, meta)
domain <- dsl_generate_domain(dat, meta)
monty_model(
list(parameters = dat$parameters,
density = density,
Expand All @@ -15,104 +23,107 @@ dsl_generate <- function(dat) {
}


dsl_generate_density <- function(dat, env) {
exprs <- lapply(dat$exprs, dsl_generate_density_expr,
quote(pars), quote(density))
body <- c(quote(pars <- packer$unpack(x)),
quote(density <- numeric()),
dsl_generate_density <- function(dat, env, meta) {
exprs <- lapply(dat$exprs, dsl_generate_density_expr, meta)
body <- c(call("<-", meta[["pars"]], quote(packer$unpack(x))),
call("<-", meta[["density"]], quote(numeric())),
exprs,
quote(sum(density)))
call("sum", meta[["density"]]))
as_function(alist(x = ), body, env)
}


dsl_generate_gradient <- function(dat, env) {
dsl_generate_gradient <- function(dat, env, meta) {
if (is.null(dat$adjoint)) {
return(NULL)
}

i_main <- match(dat$adjoint$exprs_main, vcapply(dat$exprs, "[[", "name"))
exprs <- c(dat$exprs[i_main], dat$adjoint$exprs)

eqs <- lapply(exprs, dsl_generate_assignment, quote(data))
eqs <- lapply(exprs, dsl_generate_assignment, "data", meta)
eq_return <- fold_c(
lapply(dat$adjoint$gradient, function(nm) call("[[", quote(data), nm)))
lapply(dat$adjoint$gradient, function(nm) call("[[", meta[["data"]], nm)))

body <- c(quote(data <- packer$unpack(x)),
body <- c(call("<-", meta[["data"]], quote(packer$unpack(x))),
unname(eqs),
eq_return)
as_function(alist(x = ), body, env)
}


dsl_generate_direct_sample <- function(dat, env) {
exprs <- lapply(dat$exprs, dsl_generate_sample_expr,
quote(pars), quote(result))
body <- c(quote(pars <- list()),
dsl_generate_direct_sample <- function(dat, env, meta) {
exprs <- lapply(dat$exprs, dsl_generate_sample_expr, meta)
body <- c(call("<-", meta[["pars"]], quote(list())),
exprs,
quote(unlist(pars[packer$parameters], FALSE, FALSE)))
bquote(unlist(.(meta[["pars"]])[packer$parameters], FALSE, FALSE)))
as_function(alist(rng = ), body, env)
}


dsl_generate_density_expr <- function(expr, env, density) {
dsl_generate_density_expr <- function(expr, meta) {
switch(expr$type,
assignment = dsl_generate_assignment(expr, env),
stochastic = dsl_generate_density_stochastic(expr, env, density),
assignment = dsl_generate_assignment(expr, "pars", meta),
stochastic = dsl_generate_density_stochastic(expr, meta),
cli::cli_abort(paste(
"Unimplemented expression type '{expr$type}';",
"this is a monty bug")))
}


dsl_generate_sample_expr <- function(expr, env, result) {
dsl_generate_sample_expr <- function(expr, meta) {
switch(expr$type,
assignment = dsl_generate_assignment(expr, env),
stochastic = dsl_generate_sample_stochastic(expr, env),
assignment = dsl_generate_assignment(expr, "pars", meta),
stochastic = dsl_generate_sample_stochastic(expr, meta),
cli::cli_abort(paste(
"Unimplemented expression type '{expr$type}';",
"this is a monty bug")))
}


dsl_generate_assignment <- function(expr, env) {
dsl_generate_assignment <- function(expr, dest, meta) {
e <- expr$expr
e[[2]] <- call("[[", env, as.character(e[[2]]))
e[[3]] <- dsl_generate_density_rewrite_lookup(e[[3]], env)
e[[2]] <- call("[[", meta[[dest]], as.character(e[[2]]))
e[[3]] <- dsl_generate_density_rewrite_lookup(e[[3]], dest, meta)
e
}


dsl_generate_density_stochastic <- function(expr, env, density) {
lhs <- bquote(.(density)[[.(expr$name)]])
dsl_generate_density_stochastic <- function(expr, meta) {
lhs <- bquote(.(meta[["density"]])[[.(expr$name)]])
rhs <- rlang::call2(expr$distribution$density,
as.name(expr$name), !!!expr$distribution$args)
rlang::call2("<-", lhs, dsl_generate_density_rewrite_lookup(rhs, env))
rlang::call2("<-", lhs,
dsl_generate_density_rewrite_lookup(rhs, "pars", meta))
}


dsl_generate_sample_stochastic <- function(expr, env) {
lhs <- bquote(.(env)[[.(expr$name)]])
dsl_generate_sample_stochastic <- function(expr, meta) {
lhs <- bquote(.(meta[["pars"]])[[.(expr$name)]])
args <- lapply(expr$distribution$args, dsl_generate_density_rewrite_lookup,
env)
"pars", meta)
rhs <- rlang::call2(expr$distribution$sample, quote(rng), !!!args)
rlang::call2("<-", lhs, rhs)
}


dsl_generate_density_rewrite_lookup <- function(expr, env) {
dsl_generate_density_rewrite_lookup <- function(expr, dest, meta) {
if (is.recursive(expr)) {
expr[-1] <- lapply(expr[-1], dsl_generate_density_rewrite_lookup, env)
expr[-1] <- lapply(expr[-1], dsl_generate_density_rewrite_lookup,
dest, meta)
as.call(expr)
} else if (is.name(expr)) {
call("[[", env, as.character(expr))
if (as.character(expr) %in% meta$fixed_contents) {
dest <- meta$fixed
}
call("[[", meta[[as.character(dest)]], as.character(expr))
} else {
expr
}
}


dsl_generate_domain <- function(dat) {
dsl_generate_domain <- function(dat, meta) {
n <- length(dat$parameters)
domain <- cbind(rep(-Inf, n), rep(Inf, n))
rownames(domain) <- dat$parameters
Expand Down
31 changes: 26 additions & 5 deletions R/dsl-parse.R
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
## The default of gradient_required = TRUE here helps with tests
dsl_parse <- function(exprs, gradient_required = TRUE, call = NULL) {
dsl_parse <- function(exprs, gradient_required = TRUE, fixed = NULL,
call = NULL) {
exprs <- lapply(exprs, dsl_parse_expr, call)

dsl_parse_check_duplicates(exprs, call)
dsl_parse_check_usage(exprs, call)
dsl_parse_check_fixed(exprs, fixed, call)
dsl_parse_check_usage(exprs, fixed, call)

name <- vcapply(exprs, "[[", "name")
parameters <- name[vcapply(exprs, "[[", "type") == "stochastic"]

adjoint <- dsl_parse_adjoint(parameters, exprs, gradient_required)

list(parameters = parameters, exprs = exprs, adjoint = adjoint)
list(parameters = parameters, exprs = exprs, adjoint = adjoint, fixed = fixed)
}


Expand Down Expand Up @@ -109,11 +111,28 @@ dsl_parse_check_duplicates <- function(exprs, call) {
}


dsl_parse_check_usage <- function(exprs, call) {
dsl_parse_check_fixed <- function(exprs, fixed, call) {
if (is.null(fixed)) {
return()
}

name <- vcapply(exprs, "[[", "name")
err <- name %in% names(fixed)
if (any(err)) {
eq <- exprs[[which(err)[[1]]]]
dsl_parse_error(
"Value '{eq$name}' in 'fixed' is shadowed by {eq$type}",
"E207", eq$expr, call)
}
}


dsl_parse_check_usage <- function(exprs, fixed, call) {
name <- vcapply(exprs, "[[", "name")
names_fixed <- names(fixed)
for (i in seq_along(exprs)) {
e <- exprs[[i]]
err <- setdiff(e$depends, name[seq_len(i - 1)])
err <- setdiff(e$depends, c(name[seq_len(i - 1)], names_fixed))
if (length(err) > 0) {
## Out of order:
out_of_order <- intersect(name, err)
Expand All @@ -124,6 +143,8 @@ dsl_parse_check_usage <- function(exprs, call) {
## Could also tell the user about variables found in the
## calling env, but that requires detecting and then passing
## through the correct environment.
##
## Could also tell about near misses.
context <- NULL
}
## TODO: It would be nice to indicate that we want to highlight
Expand Down
42 changes: 38 additions & 4 deletions R/dsl.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@
##' then we will error if it is not possible to create a gradient
##' function.
##'
##' @param fixed An optional list of values that can be used within
##' the DSL code. Anything you provide here is available for your
##' calculations. In the interest of future compatibility, we check
##' currently that all elements are scalars. In future this may
##' become more flexible and allow passing environments, etc. Once
##' provided, these values cannot be changed without rebuilding the
##' model; they are fixed data. You might use these for
##' hyperparameters that are fixed across a set of model runs, for
##' example.
##'
##' @return A [monty_model] object derived from the expressions you
##' provide.
##'
Expand All @@ -38,31 +48,33 @@
##'
##' # You can also pass strings
##' monty_dsl("a ~ Normal(0, 1)")
monty_dsl <- function(x, type = NULL, gradient = NULL) {
monty_dsl <- function(x, type = NULL, gradient = NULL, fixed = NULL) {
quo <- rlang::enquo(x)
if (rlang::quo_is_symbol(quo)) {
x <- rlang::eval_tidy(quo)
} else {
x <- rlang::quo_get_expr(quo)
}
call <- environment()
fixed <- check_dsl_fixed(fixed)
exprs <- dsl_preprocess(x, type, call)
dat <- dsl_parse(exprs, gradient, call)
dat <- dsl_parse(exprs, gradient, fixed, call)
dsl_generate(dat)
}



monty_dsl_parse <- function(x, type = NULL, gradient = NULL) {
monty_dsl_parse <- function(x, type = NULL, gradient = NULL, fixed = NULL) {
call <- environment()
quo <- rlang::enquo(x)
if (rlang::quo_is_symbol(quo)) {
x <- rlang::eval_tidy(quo)
} else {
x <- rlang::quo_get_expr(quo)
}
fixed <- check_dsl_fixed(fixed, call)
exprs <- dsl_preprocess(x, type, call)
dsl_parse(exprs, gradient, call)
dsl_parse(exprs, gradient, fixed, call)
}


Expand Down Expand Up @@ -154,3 +166,25 @@ monty_dsl_parse_distribution <- function(expr, name = NULL) {
list(success = TRUE,
value = value)
}


check_dsl_fixed <- function(fixed, call) {
if (is.null(fixed)) {
return(NULL)
}
assert_list(fixed, call = call)
if (length(fixed) == 0) {
return(NULL)
}
assert_named(fixed, unique = TRUE, call = call)
err <- lengths(fixed) != 1
if (any(err)) {
info <- sprintf("'%s' had length %d",
names(fixed)[err], lengths(fixed[err]))
cli::cli_abort(
c("All elements of 'fixed' must currently be scalars",
set_names(info, "x")),
arg = "fixed", call = call)
}
fixed
}
37 changes: 36 additions & 1 deletion R/packer.R
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@
##' element has the name of a value in `parameters` and each value
##' has the indices within an unstructured vector where these values
##' can be found.
##' * `subset`: an experimental interface which can be used to subset a
##' packer to a packer for a subset of contents. Documentation will
##' be provided once the interface settles.
##'
##' @export
##'
Expand Down Expand Up @@ -349,10 +352,42 @@ monty_packer <- function(scalar = NULL, array = NULL, fixed = NULL,
ret
}

subset <- function(keep) {
## TODO: later we will allow passing integer indexes here. This
## is complicated because we should probably retain structure for
## any compartments that are entirely captured (contiguously and
## in order) and convert everything else into scalars. Or perhaps
## if we take a slice out of a matrix we keep it as an array.
## Lots of decisions to make, so do it later.
if (is.character(keep)) {
if (anyDuplicated(keep)) {
dups <- unique(keep[duplicated(keep)])
cli::cli_abort("Duplicated name{?s} in 'keep': {squote(dups)}")
}
i <- match(keep, names(idx))
if (anyNA(i)) {
cli::cli_abort("Unknown name{?s} in 'keep': {squote(keep[is.na(i)])}")
}
index <- unlist(idx[i], FALSE, FALSE)
## Convert all scalars into array spec for now; this allows
## reordering.
shape2 <- c(set_names(rep(list(integer()), length(scalar)), scalar),
shape)
scalar_keep <- NULL
array_keep <- shape2[keep]
} else {
cli::cli_abort(
"Invalid input for 'keep'; this must currently be a character vector")
}
list(index = index,
packer = monty_packer(scalar_keep, array_keep))
}

ret <- list(parameters = parameters,
unpack = unpack,
pack = pack,
index = function() idx)
index = function() idx,
subset = subset)
class(ret) <- "monty_packer"
ret
}
Expand Down
Binary file modified R/sysdata.rda
Binary file not shown.
12 changes: 11 additions & 1 deletion man/monty_dsl.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 6d1ece8

Please sign in to comment.