diff --git a/DESCRIPTION b/DESCRIPTION index 78ce0de9..8929557f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: monty Title: Monte Carlo Models -Version: 0.2.11 +Version: 0.2.12 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Wes", "Hinsley", role = "aut"), diff --git a/R/dsl-generate.R b/R/dsl-generate.R index 520291e3..01ac83d7 100644 --- a/R/dsl-generate.R +++ b/R/dsl-generate.R @@ -1,11 +1,14 @@ dsl_generate <- function(dat) { env <- new.env(parent = asNamespace("monty")) env$packer <- monty_packer(dat$parameters) + env$fixed <- dat$fixed meta <- list( pars = quote(pars), data = quote(data), - density = quote(density)) + density = quote(density), + fixed = quote(fixed), + fixed_contents = names(env$fixed)) density <- dsl_generate_density(dat, env, meta) direct_sample <- dsl_generate_direct_sample(dat, env, meta) @@ -110,7 +113,10 @@ dsl_generate_density_rewrite_lookup <- function(expr, dest, meta) { dest, meta) as.call(expr) } else if (is.name(expr)) { - call("[[", meta[[dest]], as.character(expr)) + if (as.character(expr) %in% meta$fixed_contents) { + dest <- meta$fixed + } + call("[[", meta[[as.character(dest)]], as.character(expr)) } else { expr } diff --git a/R/dsl-parse.R b/R/dsl-parse.R index 1e0f1147..6968266d 100644 --- a/R/dsl-parse.R +++ b/R/dsl-parse.R @@ -1,16 +1,18 @@ ## The default of gradient_required = TRUE here helps with tests -dsl_parse <- function(exprs, gradient_required = TRUE, call = NULL) { +dsl_parse <- function(exprs, gradient_required = TRUE, fixed = NULL, + call = NULL) { exprs <- lapply(exprs, dsl_parse_expr, call) dsl_parse_check_duplicates(exprs, call) - dsl_parse_check_usage(exprs, call) + dsl_parse_check_fixed(exprs, fixed, call) + dsl_parse_check_usage(exprs, fixed, call) name <- vcapply(exprs, "[[", "name") parameters <- name[vcapply(exprs, "[[", "type") == "stochastic"] adjoint <- dsl_parse_adjoint(parameters, exprs, gradient_required) - list(parameters = parameters, exprs = exprs, adjoint = adjoint) + list(parameters = parameters, exprs = exprs, adjoint = adjoint, fixed = fixed) } @@ -109,11 +111,28 @@ dsl_parse_check_duplicates <- function(exprs, call) { } -dsl_parse_check_usage <- function(exprs, call) { +dsl_parse_check_fixed <- function(exprs, fixed, call) { + if (is.null(fixed)) { + return() + } + + name <- vcapply(exprs, "[[", "name") + err <- name %in% names(fixed) + if (any(err)) { + eq <- exprs[[which(err)[[1]]]] + dsl_parse_error( + "Value '{eq$name}' in 'fixed' is shadowed by {eq$type}", + "E207", eq$expr, call) + } +} + + +dsl_parse_check_usage <- function(exprs, fixed, call) { name <- vcapply(exprs, "[[", "name") + names_fixed <- names(fixed) for (i in seq_along(exprs)) { e <- exprs[[i]] - err <- setdiff(e$depends, name[seq_len(i - 1)]) + err <- setdiff(e$depends, c(name[seq_len(i - 1)], names_fixed)) if (length(err) > 0) { ## Out of order: out_of_order <- intersect(name, err) @@ -124,6 +143,8 @@ dsl_parse_check_usage <- function(exprs, call) { ## Could also tell the user about variables found in the ## calling env, but that requires detecting and then passing ## through the correct environment. + ## + ## Could also tell about near misses. context <- NULL } ## TODO: It would be nice to indicate that we want to highlight diff --git a/R/dsl.R b/R/dsl.R index 419bd224..f0e67648 100644 --- a/R/dsl.R +++ b/R/dsl.R @@ -23,6 +23,16 @@ ##' then we will error if it is not possible to create a gradient ##' function. ##' +##' @param fixed An optional list of values that can be used within +##' the DSL code. Anything you provide here is available for your +##' calculations. In the interest of future compatibility, we check +##' currently that all elements are scalars. In future this may +##' become more flexible and allow passing environments, etc. Once +##' provided, these values cannot be changed without rebuilding the +##' model; they are fixed data. You might use these for +##' hyperparameters that are fixed across a set of model runs, for +##' example. +##' ##' @return A [monty_model] object derived from the expressions you ##' provide. ##' @@ -38,7 +48,7 @@ ##' ##' # You can also pass strings ##' monty_dsl("a ~ Normal(0, 1)") -monty_dsl <- function(x, type = NULL, gradient = NULL) { +monty_dsl <- function(x, type = NULL, gradient = NULL, fixed = NULL) { quo <- rlang::enquo(x) if (rlang::quo_is_symbol(quo)) { x <- rlang::eval_tidy(quo) @@ -46,14 +56,15 @@ monty_dsl <- function(x, type = NULL, gradient = NULL) { x <- rlang::quo_get_expr(quo) } call <- environment() + fixed <- check_dsl_fixed(fixed) exprs <- dsl_preprocess(x, type, call) - dat <- dsl_parse(exprs, gradient, call) + dat <- dsl_parse(exprs, gradient, fixed, call) dsl_generate(dat) } -monty_dsl_parse <- function(x, type = NULL, gradient = NULL) { +monty_dsl_parse <- function(x, type = NULL, gradient = NULL, fixed = NULL) { call <- environment() quo <- rlang::enquo(x) if (rlang::quo_is_symbol(quo)) { @@ -61,8 +72,9 @@ monty_dsl_parse <- function(x, type = NULL, gradient = NULL) { } else { x <- rlang::quo_get_expr(quo) } + fixed <- check_dsl_fixed(fixed, call) exprs <- dsl_preprocess(x, type, call) - dsl_parse(exprs, gradient, call) + dsl_parse(exprs, gradient, fixed, call) } @@ -154,3 +166,25 @@ monty_dsl_parse_distribution <- function(expr, name = NULL) { list(success = TRUE, value = value) } + + +check_dsl_fixed <- function(fixed, call) { + if (is.null(fixed)) { + return(NULL) + } + assert_list(fixed, call = call) + if (length(fixed) == 0) { + return(NULL) + } + assert_named(fixed, unique = TRUE, call = call) + err <- lengths(fixed) != 1 + if (any(err)) { + info <- sprintf("'%s' had length %d", + names(fixed)[err], lengths(fixed[err])) + cli::cli_abort( + c("All elements of 'fixed' must currently be scalars", + set_names(info, "x")), + arg = "fixed", call = call) + } + fixed +} diff --git a/R/sysdata.rda b/R/sysdata.rda index 4c9f4fd9..6e7a88bd 100644 Binary files a/R/sysdata.rda and b/R/sysdata.rda differ diff --git a/man/monty_dsl.Rd b/man/monty_dsl.Rd index 352799fe..7edd127b 100644 --- a/man/monty_dsl.Rd +++ b/man/monty_dsl.Rd @@ -4,7 +4,7 @@ \alias{monty_dsl} \title{Domain Specific Language for monty} \usage{ -monty_dsl(x, type = NULL, gradient = NULL) +monty_dsl(x, type = NULL, gradient = NULL, fixed = NULL) } \arguments{ \item{x}{The model as an expression. This may be given as an @@ -26,6 +26,16 @@ attempt to construct a gradient function, which prevents a warning being generated if this is not possible. If \code{TRUE}, then we will error if it is not possible to create a gradient function.} + +\item{fixed}{An optional list of values that can be used within +the DSL code. Anything you provide here is available for your +calculations. In the interest of future compatibility, we check +currently that all elements are scalars. In future this may +become more flexible and allow passing environments, etc. Once +provided, these values cannot be changed without rebuilding the +model; they are fixed data. You might use these for +hyperparameters that are fixed across a set of model runs, for +example.} } \value{ A \link{monty_model} object derived from the expressions you diff --git a/tests/testthat/test-dsl-parse.R b/tests/testthat/test-dsl-parse.R index ef1af335..0146eb60 100644 --- a/tests/testthat/test-dsl-parse.R +++ b/tests/testthat/test-dsl-parse.R @@ -210,3 +210,32 @@ test_that("can explain an error", { mockery::mock_args(mock_explain)[[1]], list(dsl_errors, "E101", "pretty")) }) + + +test_that("empty fixed data is null", { + expect_null(check_dsl_fixed(NULL)) + expect_null(check_dsl_fixed(list())) +}) + + +test_that("validate fixed data for dsl", { + expect_error( + check_dsl_fixed(c(a = 1, b = 2)), + "Expected 'fixed' to be a list") + expect_error( + check_dsl_fixed(list(a = 1, b = 2, a = 2)), + "'fixed' must have unique names") + expect_error( + check_dsl_fixed(list(a = 1, b = 2:3, c = numeric(10))), + "All elements of 'fixed' must currently be scalars") + expect_equal( + check_dsl_fixed(list(a = 1, b = 2)), + list(a = 1, b = 2)) +}) + + +test_that("assignments cannot shadow names of fixed variables", { + expect_error( + dsl_parse(list(quote(a <- 1)), fixed = list(a = 1)), + "Value 'a' in 'fixed' is shadowed by assignment") +}) diff --git a/tests/testthat/test-dsl.R b/tests/testthat/test-dsl.R index 99ba74f5..51afcd5b 100644 --- a/tests/testthat/test-dsl.R +++ b/tests/testthat/test-dsl.R @@ -142,3 +142,11 @@ test_that("can compute gradients of complicated models", { expect_equal(m$gradient(p), numDeriv::grad(m$density, p)) }) + + +test_that("can use fixed data in dsl", { + m <- monty_dsl({ + a ~ Normal(mu, sd) + }, fixed = list(mu = 1, sd = 2)) + expect_equal(m$density(0), dnorm(0, 1, 2, log = TRUE)) +}) diff --git a/vignettes/dsl-errors.Rmd b/vignettes/dsl-errors.Rmd index 1f2115bc..c38da621 100644 --- a/vignettes/dsl-errors.Rmd +++ b/vignettes/dsl-errors.Rmd @@ -125,3 +125,7 @@ Variables are used out of order. If you are using odin this is a big departure # `E206` Failed to differentiate the model. This error will only be seen where it was not possible to differentiate your model but you requested that a gradient be available. Not all functions supported in the DSL can currently be differentiated by monty; if you think that yours should be, please let us know. + +# `E207` + +A value in `fixed` is shadowed by an assignment or a relationship. If you pass in fixed data it may not be used on the left hand side of any expression in your DSL code. diff --git a/vignettes/dsl.Rmd b/vignettes/dsl.Rmd index d178ef18..4c77b693 100644 --- a/vignettes/dsl.Rmd +++ b/vignettes/dsl.Rmd @@ -67,3 +67,55 @@ The computed properties for the model are: ```{r} prior$properties ``` + +# Calculations in the DSL + +Sometimes it will be useful to perform calculations in the code; you can do this with assignments. Most trivially, giving names to numbers may help make code more understandable: + +```{r} +m <- monty_dsl({ + mu <- 10 + sd <- 2 + a ~ Normal(mu, sd) +}) +``` + +You can also use this to do things like: + +```{r} +m <- monty_dsl({ + a ~ Normal(0, 1) + b ~ Normal(0, 1) + mu <- (a + b) / 2 + c ~ Normal(mu, 1) +}) +``` + +Where `c` is drawn from a normal distribution with a mean that is the average of `a` and `b`. + +# Pass in fixed data + +You can also pass in a list of data with values that should be available in the DSL code. For example, our first example: + +```{r} +prior <- monty_dsl({ + alpha ~ Normal(178, 20) + beta ~ Normal(0, 10) + sigma ~ Uniform(0, 50) +}) +``` + +Might be written as + +```{r} +fixed <- list(alpha_mean = 170, alpha_sd = 20, + beta_mean = 0, beta_sd = 10, + sigma_max = 50) +prior <- monty_dsl({ + alpha ~ Normal(alpha_mean, alpha_sd) + beta ~ Normal(beta_mean, beta_sd) + sigma ~ Uniform(0, sigma_max) +}, fixed = fixed) +``` + +Values you pass in this way are **fixed** (hence the name!) and cannot be modified after the model object is created.