Skip to content

Commit

Permalink
Merge pull request #80 from mrc-ide/mrc-5856-2
Browse files Browse the repository at this point in the history
Fixed data for the dsl
  • Loading branch information
weshinsley authored Oct 10, 2024
2 parents af75c00 + 0be858f commit 6c64b5c
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 13 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.2.11
Version: 0.2.12
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
10 changes: 8 additions & 2 deletions R/dsl-generate.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
dsl_generate <- function(dat) {
env <- new.env(parent = asNamespace("monty"))
env$packer <- monty_packer(dat$parameters)
env$fixed <- dat$fixed

meta <- list(
pars = quote(pars),
data = quote(data),
density = quote(density))
density = quote(density),
fixed = quote(fixed),
fixed_contents = names(env$fixed))

density <- dsl_generate_density(dat, env, meta)
direct_sample <- dsl_generate_direct_sample(dat, env, meta)
Expand Down Expand Up @@ -110,7 +113,10 @@ dsl_generate_density_rewrite_lookup <- function(expr, dest, meta) {
dest, meta)
as.call(expr)
} else if (is.name(expr)) {
call("[[", meta[[dest]], as.character(expr))
if (as.character(expr) %in% meta$fixed_contents) {
dest <- meta$fixed
}
call("[[", meta[[as.character(dest)]], as.character(expr))
} else {
expr
}
Expand Down
31 changes: 26 additions & 5 deletions R/dsl-parse.R
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
## The default of gradient_required = TRUE here helps with tests
dsl_parse <- function(exprs, gradient_required = TRUE, call = NULL) {
dsl_parse <- function(exprs, gradient_required = TRUE, fixed = NULL,
call = NULL) {
exprs <- lapply(exprs, dsl_parse_expr, call)

dsl_parse_check_duplicates(exprs, call)
dsl_parse_check_usage(exprs, call)
dsl_parse_check_fixed(exprs, fixed, call)
dsl_parse_check_usage(exprs, fixed, call)

name <- vcapply(exprs, "[[", "name")
parameters <- name[vcapply(exprs, "[[", "type") == "stochastic"]

adjoint <- dsl_parse_adjoint(parameters, exprs, gradient_required)

list(parameters = parameters, exprs = exprs, adjoint = adjoint)
list(parameters = parameters, exprs = exprs, adjoint = adjoint, fixed = fixed)
}


Expand Down Expand Up @@ -109,11 +111,28 @@ dsl_parse_check_duplicates <- function(exprs, call) {
}


dsl_parse_check_usage <- function(exprs, call) {
dsl_parse_check_fixed <- function(exprs, fixed, call) {
if (is.null(fixed)) {
return()
}

name <- vcapply(exprs, "[[", "name")
err <- name %in% names(fixed)
if (any(err)) {
eq <- exprs[[which(err)[[1]]]]
dsl_parse_error(
"Value '{eq$name}' in 'fixed' is shadowed by {eq$type}",
"E207", eq$expr, call)
}
}


dsl_parse_check_usage <- function(exprs, fixed, call) {
name <- vcapply(exprs, "[[", "name")
names_fixed <- names(fixed)
for (i in seq_along(exprs)) {
e <- exprs[[i]]
err <- setdiff(e$depends, name[seq_len(i - 1)])
err <- setdiff(e$depends, c(name[seq_len(i - 1)], names_fixed))
if (length(err) > 0) {
## Out of order:
out_of_order <- intersect(name, err)
Expand All @@ -124,6 +143,8 @@ dsl_parse_check_usage <- function(exprs, call) {
## Could also tell the user about variables found in the
## calling env, but that requires detecting and then passing
## through the correct environment.
##
## Could also tell about near misses.
context <- NULL
}
## TODO: It would be nice to indicate that we want to highlight
Expand Down
42 changes: 38 additions & 4 deletions R/dsl.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@
##' then we will error if it is not possible to create a gradient
##' function.
##'
##' @param fixed An optional list of values that can be used within
##' the DSL code. Anything you provide here is available for your
##' calculations. In the interest of future compatibility, we check
##' currently that all elements are scalars. In future this may
##' become more flexible and allow passing environments, etc. Once
##' provided, these values cannot be changed without rebuilding the
##' model; they are fixed data. You might use these for
##' hyperparameters that are fixed across a set of model runs, for
##' example.
##'
##' @return A [monty_model] object derived from the expressions you
##' provide.
##'
Expand All @@ -38,31 +48,33 @@
##'
##' # You can also pass strings
##' monty_dsl("a ~ Normal(0, 1)")
monty_dsl <- function(x, type = NULL, gradient = NULL) {
monty_dsl <- function(x, type = NULL, gradient = NULL, fixed = NULL) {
quo <- rlang::enquo(x)
if (rlang::quo_is_symbol(quo)) {
x <- rlang::eval_tidy(quo)
} else {
x <- rlang::quo_get_expr(quo)
}
call <- environment()
fixed <- check_dsl_fixed(fixed)
exprs <- dsl_preprocess(x, type, call)
dat <- dsl_parse(exprs, gradient, call)
dat <- dsl_parse(exprs, gradient, fixed, call)
dsl_generate(dat)
}



monty_dsl_parse <- function(x, type = NULL, gradient = NULL) {
monty_dsl_parse <- function(x, type = NULL, gradient = NULL, fixed = NULL) {
call <- environment()
quo <- rlang::enquo(x)
if (rlang::quo_is_symbol(quo)) {
x <- rlang::eval_tidy(quo)
} else {
x <- rlang::quo_get_expr(quo)
}
fixed <- check_dsl_fixed(fixed, call)
exprs <- dsl_preprocess(x, type, call)
dsl_parse(exprs, gradient, call)
dsl_parse(exprs, gradient, fixed, call)
}


Expand Down Expand Up @@ -154,3 +166,25 @@ monty_dsl_parse_distribution <- function(expr, name = NULL) {
list(success = TRUE,
value = value)
}


check_dsl_fixed <- function(fixed, call) {
if (is.null(fixed)) {
return(NULL)
}
assert_list(fixed, call = call)
if (length(fixed) == 0) {
return(NULL)
}
assert_named(fixed, unique = TRUE, call = call)
err <- lengths(fixed) != 1
if (any(err)) {
info <- sprintf("'%s' had length %d",
names(fixed)[err], lengths(fixed[err]))
cli::cli_abort(
c("All elements of 'fixed' must currently be scalars",
set_names(info, "x")),
arg = "fixed", call = call)
}
fixed
}
Binary file modified R/sysdata.rda
Binary file not shown.
12 changes: 11 additions & 1 deletion man/monty_dsl.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions tests/testthat/test-dsl-parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,32 @@ test_that("can explain an error", {
mockery::mock_args(mock_explain)[[1]],
list(dsl_errors, "E101", "pretty"))
})


test_that("empty fixed data is null", {
expect_null(check_dsl_fixed(NULL))
expect_null(check_dsl_fixed(list()))
})


test_that("validate fixed data for dsl", {
expect_error(
check_dsl_fixed(c(a = 1, b = 2)),
"Expected 'fixed' to be a list")
expect_error(
check_dsl_fixed(list(a = 1, b = 2, a = 2)),
"'fixed' must have unique names")
expect_error(
check_dsl_fixed(list(a = 1, b = 2:3, c = numeric(10))),
"All elements of 'fixed' must currently be scalars")
expect_equal(
check_dsl_fixed(list(a = 1, b = 2)),
list(a = 1, b = 2))
})


test_that("assignments cannot shadow names of fixed variables", {
expect_error(
dsl_parse(list(quote(a <- 1)), fixed = list(a = 1)),
"Value 'a' in 'fixed' is shadowed by assignment")
})
8 changes: 8 additions & 0 deletions tests/testthat/test-dsl.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,11 @@ test_that("can compute gradients of complicated models", {

expect_equal(m$gradient(p), numDeriv::grad(m$density, p))
})


test_that("can use fixed data in dsl", {
m <- monty_dsl({
a ~ Normal(mu, sd)
}, fixed = list(mu = 1, sd = 2))
expect_equal(m$density(0), dnorm(0, 1, 2, log = TRUE))
})
4 changes: 4 additions & 0 deletions vignettes/dsl-errors.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,7 @@ Variables are used out of order. If you are using odin this is a big departure
# `E206`

Failed to differentiate the model. This error will only be seen where it was not possible to differentiate your model but you requested that a gradient be available. Not all functions supported in the DSL can currently be differentiated by monty; if you think that yours should be, please let us know.

# `E207`

A value in `fixed` is shadowed by an assignment or a relationship. If you pass in fixed data it may not be used on the left hand side of any expression in your DSL code.
52 changes: 52 additions & 0 deletions vignettes/dsl.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,55 @@ The computed properties for the model are:
```{r}
prior$properties
```

# Calculations in the DSL

Sometimes it will be useful to perform calculations in the code; you can do this with assignments. Most trivially, giving names to numbers may help make code more understandable:

```{r}
m <- monty_dsl({
mu <- 10
sd <- 2
a ~ Normal(mu, sd)
})
```

You can also use this to do things like:

```{r}
m <- monty_dsl({
a ~ Normal(0, 1)
b ~ Normal(0, 1)
mu <- (a + b) / 2
c ~ Normal(mu, 1)
})
```

Where `c` is drawn from a normal distribution with a mean that is the average of `a` and `b`.

# Pass in fixed data

You can also pass in a list of data with values that should be available in the DSL code. For example, our first example:

```{r}
prior <- monty_dsl({
alpha ~ Normal(178, 20)
beta ~ Normal(0, 10)
sigma ~ Uniform(0, 50)
})
```

Might be written as

```{r}
fixed <- list(alpha_mean = 170, alpha_sd = 20,
beta_mean = 0, beta_sd = 10,
sigma_max = 50)
prior <- monty_dsl({
alpha ~ Normal(alpha_mean, alpha_sd)
beta ~ Normal(beta_mean, beta_sd)
sigma ~ Uniform(0, sigma_max)
}, fixed = fixed)
```

Values you pass in this way are **fixed** (hence the name!) and cannot be modified after the model object is created.

0 comments on commit 6c64b5c

Please sign in to comment.