Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed data for the dsl #80

Merged
merged 5 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.2.11
Version: 0.2.12
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
10 changes: 8 additions & 2 deletions R/dsl-generate.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
dsl_generate <- function(dat) {
env <- new.env(parent = asNamespace("monty"))
env$packer <- monty_packer(dat$parameters)
env$fixed <- dat$fixed

meta <- list(
pars = quote(pars),
data = quote(data),
density = quote(density))
density = quote(density),
fixed = quote(fixed),
fixed_contents = names(env$fixed))

density <- dsl_generate_density(dat, env, meta)
direct_sample <- dsl_generate_direct_sample(dat, env, meta)
Expand Down Expand Up @@ -110,7 +113,10 @@ dsl_generate_density_rewrite_lookup <- function(expr, dest, meta) {
dest, meta)
as.call(expr)
} else if (is.name(expr)) {
call("[[", meta[[dest]], as.character(expr))
if (as.character(expr) %in% meta$fixed_contents) {
dest <- meta$fixed
}
call("[[", meta[[as.character(dest)]], as.character(expr))
} else {
expr
}
Expand Down
31 changes: 26 additions & 5 deletions R/dsl-parse.R
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
## The default of gradient_required = TRUE here helps with tests
dsl_parse <- function(exprs, gradient_required = TRUE, call = NULL) {
dsl_parse <- function(exprs, gradient_required = TRUE, fixed = NULL,
call = NULL) {
exprs <- lapply(exprs, dsl_parse_expr, call)

dsl_parse_check_duplicates(exprs, call)
dsl_parse_check_usage(exprs, call)
dsl_parse_check_fixed(exprs, fixed, call)
dsl_parse_check_usage(exprs, fixed, call)

name <- vcapply(exprs, "[[", "name")
parameters <- name[vcapply(exprs, "[[", "type") == "stochastic"]

adjoint <- dsl_parse_adjoint(parameters, exprs, gradient_required)

list(parameters = parameters, exprs = exprs, adjoint = adjoint)
list(parameters = parameters, exprs = exprs, adjoint = adjoint, fixed = fixed)
}


Expand Down Expand Up @@ -109,11 +111,28 @@ dsl_parse_check_duplicates <- function(exprs, call) {
}


dsl_parse_check_usage <- function(exprs, call) {
dsl_parse_check_fixed <- function(exprs, fixed, call) {
if (is.null(fixed)) {
return()
}

name <- vcapply(exprs, "[[", "name")
err <- name %in% names(fixed)
if (any(err)) {
eq <- exprs[[which(err)[[1]]]]
dsl_parse_error(
"Value '{eq$name}' in 'fixed' is shadowed by {eq$type}",
"E207", eq$expr, call)
}
}


dsl_parse_check_usage <- function(exprs, fixed, call) {
name <- vcapply(exprs, "[[", "name")
names_fixed <- names(fixed)
for (i in seq_along(exprs)) {
e <- exprs[[i]]
err <- setdiff(e$depends, name[seq_len(i - 1)])
err <- setdiff(e$depends, c(name[seq_len(i - 1)], names_fixed))
if (length(err) > 0) {
## Out of order:
out_of_order <- intersect(name, err)
Expand All @@ -124,6 +143,8 @@ dsl_parse_check_usage <- function(exprs, call) {
## Could also tell the user about variables found in the
## calling env, but that requires detecting and then passing
## through the correct environment.
##
## Could also tell about near misses.
context <- NULL
}
## TODO: It would be nice to indicate that we want to highlight
Expand Down
42 changes: 38 additions & 4 deletions R/dsl.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@
##' then we will error if it is not possible to create a gradient
##' function.
##'
##' @param fixed An optional list of values that can be used within
##' the DSL code. Anything you provide here is available for your
##' calculations. In the interest of future compatibility, we check
##' currently that all elements are scalars. In future this may
##' become more flexible and allow passing environments, etc. Once
##' provided, these values cannot be changed without rebuilding the
##' model; they are fixed data. You might use these for
##' hyperparameters that are fixed across a set of model runs, for
##' example.
##'
##' @return A [monty_model] object derived from the expressions you
##' provide.
##'
Expand All @@ -38,31 +48,33 @@
##'
##' # You can also pass strings
##' monty_dsl("a ~ Normal(0, 1)")
monty_dsl <- function(x, type = NULL, gradient = NULL) {
monty_dsl <- function(x, type = NULL, gradient = NULL, fixed = NULL) {
quo <- rlang::enquo(x)
if (rlang::quo_is_symbol(quo)) {
x <- rlang::eval_tidy(quo)
} else {
x <- rlang::quo_get_expr(quo)
}
call <- environment()
fixed <- check_dsl_fixed(fixed)
exprs <- dsl_preprocess(x, type, call)
dat <- dsl_parse(exprs, gradient, call)
dat <- dsl_parse(exprs, gradient, fixed, call)
dsl_generate(dat)
}



monty_dsl_parse <- function(x, type = NULL, gradient = NULL) {
monty_dsl_parse <- function(x, type = NULL, gradient = NULL, fixed = NULL) {
call <- environment()
quo <- rlang::enquo(x)
if (rlang::quo_is_symbol(quo)) {
x <- rlang::eval_tidy(quo)
} else {
x <- rlang::quo_get_expr(quo)
}
fixed <- check_dsl_fixed(fixed, call)
exprs <- dsl_preprocess(x, type, call)
dsl_parse(exprs, gradient, call)
dsl_parse(exprs, gradient, fixed, call)
}


Expand Down Expand Up @@ -154,3 +166,25 @@ monty_dsl_parse_distribution <- function(expr, name = NULL) {
list(success = TRUE,
value = value)
}


check_dsl_fixed <- function(fixed, call) {
if (is.null(fixed)) {
return(NULL)
}
assert_list(fixed, call = call)
if (length(fixed) == 0) {
return(NULL)
}
assert_named(fixed, unique = TRUE, call = call)
err <- lengths(fixed) != 1
if (any(err)) {
info <- sprintf("'%s' had length %d",
names(fixed)[err], lengths(fixed[err]))
cli::cli_abort(
c("All elements of 'fixed' must currently be scalars",
set_names(info, "x")),
arg = "fixed", call = call)
}
fixed
}
Binary file modified R/sysdata.rda
Binary file not shown.
12 changes: 11 additions & 1 deletion man/monty_dsl.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions tests/testthat/test-dsl-parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,32 @@ test_that("can explain an error", {
mockery::mock_args(mock_explain)[[1]],
list(dsl_errors, "E101", "pretty"))
})


test_that("empty fixed data is null", {
expect_null(check_dsl_fixed(NULL))
expect_null(check_dsl_fixed(list()))
})


test_that("validate fixed data for dsl", {
expect_error(
check_dsl_fixed(c(a = 1, b = 2)),
"Expected 'fixed' to be a list")
expect_error(
check_dsl_fixed(list(a = 1, b = 2, a = 2)),
"'fixed' must have unique names")
expect_error(
check_dsl_fixed(list(a = 1, b = 2:3, c = numeric(10))),
"All elements of 'fixed' must currently be scalars")
expect_equal(
check_dsl_fixed(list(a = 1, b = 2)),
list(a = 1, b = 2))
})


test_that("assignments cannot shadow names of fixed variables", {
expect_error(
dsl_parse(list(quote(a <- 1)), fixed = list(a = 1)),
"Value 'a' in 'fixed' is shadowed by assignment")
})
8 changes: 8 additions & 0 deletions tests/testthat/test-dsl.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,11 @@ test_that("can compute gradients of complicated models", {

expect_equal(m$gradient(p), numDeriv::grad(m$density, p))
})


test_that("can use fixed data in dsl", {
m <- monty_dsl({
a ~ Normal(mu, sd)
}, fixed = list(mu = 1, sd = 2))
expect_equal(m$density(0), dnorm(0, 1, 2, log = TRUE))
})
4 changes: 4 additions & 0 deletions vignettes/dsl-errors.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,7 @@ Variables are used out of order. If you are using odin this is a big departure
# `E206`

Failed to differentiate the model. This error will only be seen where it was not possible to differentiate your model but you requested that a gradient be available. Not all functions supported in the DSL can currently be differentiated by monty; if you think that yours should be, please let us know.

# `E207`

A value in `fixed` is shadowed by an assignment or a relationship. If you pass in fixed data it may not be used on the left hand side of any expression in your DSL code.
52 changes: 52 additions & 0 deletions vignettes/dsl.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,55 @@ The computed properties for the model are:
```{r}
prior$properties
```

# Calculations in the DSL

Sometimes it will be useful to perform calculations in the code; you can do this with assignments. Most trivially, giving names to numbers may help make code more understandable:

```{r}
m <- monty_dsl({
mu <- 10
sd <- 2
a ~ Normal(mu, sd)
})
```

You can also use this to do things like:

```{r}
m <- monty_dsl({
a ~ Normal(0, 1)
b ~ Normal(0, 1)
mu <- (a + b) / 2
c ~ Normal(mu, 1)
})
```

Where `c` is drawn from a normal distribution with a mean that is the average of `a` and `b`.

# Pass in fixed data

You can also pass in a list of data with values that should be available in the DSL code. For example, our first example:

```{r}
prior <- monty_dsl({
alpha ~ Normal(178, 20)
beta ~ Normal(0, 10)
sigma ~ Uniform(0, 50)
})
```

Might be written as

```{r}
fixed <- list(alpha_mean = 170, alpha_sd = 20,
beta_mean = 0, beta_sd = 10,
sigma_max = 50)
prior <- monty_dsl({
alpha ~ Normal(alpha_mean, alpha_sd)
beta ~ Normal(beta_mean, beta_sd)
sigma ~ Uniform(0, sigma_max)
}, fixed = fixed)
```

Values you pass in this way are **fixed** (hence the name!) and cannot be modified after the model object is created.
Loading