Skip to content

Commit

Permalink
Merge pull request #79 from mrc-ide/mrc-5860
Browse files Browse the repository at this point in the history
Light refactor of generation
  • Loading branch information
richfitz authored Oct 10, 2024
2 parents e7847b7 + 0e075d3 commit af75c00
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 39 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.2.10
Version: 0.2.11
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
81 changes: 43 additions & 38 deletions R/dsl-generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@ dsl_generate <- function(dat) {
env <- new.env(parent = asNamespace("monty"))
env$packer <- monty_packer(dat$parameters)

density <- dsl_generate_density(dat, env)
direct_sample <- dsl_generate_direct_sample(dat, env)
gradient <- dsl_generate_gradient(dat, env)
domain <- dsl_generate_domain(dat)
meta <- list(
pars = quote(pars),
data = quote(data),
density = quote(density))

density <- dsl_generate_density(dat, env, meta)
direct_sample <- dsl_generate_direct_sample(dat, env, meta)
gradient <- dsl_generate_gradient(dat, env, meta)
domain <- dsl_generate_domain(dat, meta)
monty_model(
list(parameters = dat$parameters,
density = density,
Expand All @@ -15,104 +20,104 @@ dsl_generate <- function(dat) {
}


dsl_generate_density <- function(dat, env) {
exprs <- lapply(dat$exprs, dsl_generate_density_expr,
quote(pars), quote(density))
body <- c(quote(pars <- packer$unpack(x)),
quote(density <- numeric()),
dsl_generate_density <- function(dat, env, meta) {
exprs <- lapply(dat$exprs, dsl_generate_density_expr, meta)
body <- c(call("<-", meta[["pars"]], quote(packer$unpack(x))),
call("<-", meta[["density"]], quote(numeric())),
exprs,
quote(sum(density)))
call("sum", meta[["density"]]))
as_function(alist(x = ), body, env)
}


dsl_generate_gradient <- function(dat, env) {
dsl_generate_gradient <- function(dat, env, meta) {
if (is.null(dat$adjoint)) {
return(NULL)
}

i_main <- match(dat$adjoint$exprs_main, vcapply(dat$exprs, "[[", "name"))
exprs <- c(dat$exprs[i_main], dat$adjoint$exprs)

eqs <- lapply(exprs, dsl_generate_assignment, quote(data))
eqs <- lapply(exprs, dsl_generate_assignment, "data", meta)
eq_return <- fold_c(
lapply(dat$adjoint$gradient, function(nm) call("[[", quote(data), nm)))
lapply(dat$adjoint$gradient, function(nm) call("[[", meta[["data"]], nm)))

body <- c(quote(data <- packer$unpack(x)),
body <- c(call("<-", meta[["data"]], quote(packer$unpack(x))),
unname(eqs),
eq_return)
as_function(alist(x = ), body, env)
}


dsl_generate_direct_sample <- function(dat, env) {
exprs <- lapply(dat$exprs, dsl_generate_sample_expr,
quote(pars), quote(result))
body <- c(quote(pars <- list()),
dsl_generate_direct_sample <- function(dat, env, meta) {
exprs <- lapply(dat$exprs, dsl_generate_sample_expr, meta)
body <- c(call("<-", meta[["pars"]], quote(list())),
exprs,
quote(unlist(pars[packer$parameters], FALSE, FALSE)))
bquote(unlist(.(meta[["pars"]])[packer$parameters], FALSE, FALSE)))
as_function(alist(rng = ), body, env)
}


dsl_generate_density_expr <- function(expr, env, density) {
dsl_generate_density_expr <- function(expr, meta) {
switch(expr$type,
assignment = dsl_generate_assignment(expr, env),
stochastic = dsl_generate_density_stochastic(expr, env, density),
assignment = dsl_generate_assignment(expr, "pars", meta),
stochastic = dsl_generate_density_stochastic(expr, meta),
cli::cli_abort(paste(
"Unimplemented expression type '{expr$type}';",
"this is a monty bug")))
}


dsl_generate_sample_expr <- function(expr, env, result) {
dsl_generate_sample_expr <- function(expr, meta) {
switch(expr$type,
assignment = dsl_generate_assignment(expr, env),
stochastic = dsl_generate_sample_stochastic(expr, env),
assignment = dsl_generate_assignment(expr, "pars", meta),
stochastic = dsl_generate_sample_stochastic(expr, meta),
cli::cli_abort(paste(
"Unimplemented expression type '{expr$type}';",
"this is a monty bug")))
}


dsl_generate_assignment <- function(expr, env) {
dsl_generate_assignment <- function(expr, dest, meta) {
e <- expr$expr
e[[2]] <- call("[[", env, as.character(e[[2]]))
e[[3]] <- dsl_generate_density_rewrite_lookup(e[[3]], env)
e[[2]] <- call("[[", meta[[dest]], as.character(e[[2]]))
e[[3]] <- dsl_generate_density_rewrite_lookup(e[[3]], dest, meta)
e
}


dsl_generate_density_stochastic <- function(expr, env, density) {
lhs <- bquote(.(density)[[.(expr$name)]])
dsl_generate_density_stochastic <- function(expr, meta) {
lhs <- bquote(.(meta[["density"]])[[.(expr$name)]])
rhs <- rlang::call2(expr$distribution$density,
as.name(expr$name), !!!expr$distribution$args)
rlang::call2("<-", lhs, dsl_generate_density_rewrite_lookup(rhs, env))
rlang::call2("<-", lhs,
dsl_generate_density_rewrite_lookup(rhs, "pars", meta))
}


dsl_generate_sample_stochastic <- function(expr, env) {
lhs <- bquote(.(env)[[.(expr$name)]])
dsl_generate_sample_stochastic <- function(expr, meta) {
lhs <- bquote(.(meta[["pars"]])[[.(expr$name)]])
args <- lapply(expr$distribution$args, dsl_generate_density_rewrite_lookup,
env)
"pars", meta)
rhs <- rlang::call2(expr$distribution$sample, quote(rng), !!!args)
rlang::call2("<-", lhs, rhs)
}


dsl_generate_density_rewrite_lookup <- function(expr, env) {
dsl_generate_density_rewrite_lookup <- function(expr, dest, meta) {
if (is.recursive(expr)) {
expr[-1] <- lapply(expr[-1], dsl_generate_density_rewrite_lookup, env)
expr[-1] <- lapply(expr[-1], dsl_generate_density_rewrite_lookup,
dest, meta)
as.call(expr)
} else if (is.name(expr)) {
call("[[", env, as.character(expr))
call("[[", meta[[dest]], as.character(expr))
} else {
expr
}
}


dsl_generate_domain <- function(dat) {
dsl_generate_domain <- function(dat, meta) {
n <- length(dat$parameters)
domain <- cbind(rep(-Inf, n), rep(Inf, n))
rownames(domain) <- dat$parameters
Expand Down

0 comments on commit af75c00

Please sign in to comment.