Skip to content

Commit

Permalink
Merge pull request #66 from mrc-ide/mrc-5640
Browse files Browse the repository at this point in the history
Add simple function interface
  • Loading branch information
richfitz authored Aug 30, 2024
2 parents a5d6414 + 5c77b26 commit 8ae84a2
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 44 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.2.1
Version: 0.2.2
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export(monty_model)
export(monty_model_combine)
export(monty_model_density)
export(monty_model_direct_sample)
export(monty_model_function)
export(monty_model_gradient)
export(monty_model_properties)
export(monty_observer)
Expand Down
56 changes: 56 additions & 0 deletions R/model-function.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
##' Create a [monty_model] from a function that computes density.
##' This allows use of any R function as a simple monty model. If you
##' need advanced model features, then this interface may not suit you
##' and you may prefer to use [monty_model] directly.
##'
##' This interface will expand in future versions of monty to support
##' gradients, stochastic models, parameter groups and simultaneous
##' calculation of density.
##'
##' @title Create `monty_model` from a function computing density
##'
##' @param density A function to compute log density. It can take any
##' number of parameters
##'
##' @param packer Optionally, a [monty_packer] object to control how
##' your function parameters are packed into a numeric vector. You
##' can typically omit this if all the arguments to your functions
##' are present in your numeric vector and if they are all scalars.
##'
##' @param fixed Optionally, a named list of fixed values to
##' substitute into the call to `density`. This cannot be used in
##' conjunction with `packer` (you should use the `fixed` argument
##' to `monty_packer` instead).
##'
##' @return A [monty_model] object that computes log density with the
##' provided `density` function, given a numeric vector argument
##' representing all parameters.
##'
##' @export
monty_model_function <- function(density, packer = NULL, fixed = NULL) {
if (!is.function(density)) {
cli::cli_abort("Expected 'density' to be a function", arg = "density")
}

if (!is.null(fixed)) {
assert_named(fixed, unique = TRUE)
assert_list(fixed, call = call)
}

if (is.null(packer)) {
packer <- monty_packer(
setdiff(names(formals(density)), names(fixed)),
fixed = fixed)
} else {
assert_is(packer, "monty_packer")
if (!is.null(fixed)) {
cli::cli_abort("Can't provide both 'packer' and 'fixed'", arg = "fixed")
}
}

monty_model(
list(parameters = packer$parameters,
density = function(x) {
rlang::inject(density(!!!packer$unpack(x)))
}))
}
7 changes: 4 additions & 3 deletions R/util_assert.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
assert_is <- function(x, what, name = deparse(substitute(x)), call = NULL) {
assert_is <- function(x, what, name = deparse(substitute(x)),
call = parent.frame()) {
if (!inherits(x, what)) {
cli::cli_abort("Expected '{name}' to be a '{what}' object",
arg = name, call = call)
Expand Down Expand Up @@ -63,7 +64,7 @@ assert_scalar_character <- function(x, name = deparse(substitute(x)),


assert_named <- function(x, unique = FALSE, name = deparse(substitute(x)),
arg = name, call = NULL) {
arg = name, call = parent.frame()) {
if (is.null(names(x))) {
cli::cli_abort("'{name}' must be named", call = call, arg = arg)
}
Expand All @@ -79,7 +80,7 @@ assert_named <- function(x, unique = FALSE, name = deparse(substitute(x)),


assert_list <- function(x, name = deparse(substitute(x)), arg = name,
call = NULL) {
call = parent.frame()) {
if (!is.list(x)) {
cli::cli_abort("Expected '{name}' to be a list",
arg = arg, call = call)
Expand Down
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ reference:
contents:
- monty_model
- monty_model_combine
- monty_model_function
- monty_model_properties

- subtitle: Functions for working with models
Expand Down
38 changes: 38 additions & 0 deletions man/monty_model_function.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 54 additions & 0 deletions tests/testthat/test-model-function.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
test_that("can create model from function", {
fn <- function(a, b) {
dnorm(0, a, b)
}
m <- monty_model_function(fn)
expect_s3_class(m, "monty_model")
expect_equal(m$parameters, c("a", "b"))
expect_equal(monty_model_density(m, c(1, 2)),
dnorm(0, 1, 2))
})


test_that("density must be a function", {
expect_error(monty_model_function(NULL),
"Expected 'density' to be a function")
})


test_that("can provide a custom packer", {
p <- monty_packer(c("a", "b"), fixed = list(x = 10))
fn <- function(a, b, x) {
dnorm(x, b, a)
}
m <- monty_model_function(fn, p)
expect_equal(m$parameters, c("a", "b"))
expect_equal(monty_model_density(m, c(1, 2)),
dnorm(10, 2, 1))
})


test_that("packer must be a monty_packer if provided", {
fn <- function(a, b) {
dnorm(0, a, b)
}
expect_no_error(monty_model_function(fn, NULL))
expect_error(
monty_model_function(fn, TRUE),
"Expected 'packer' to be a 'monty_packer' object")
})


test_that("can fix some data", {
p <- monty_packer(c("a", "b"))
fn <- function(a, b, x) {
dnorm(x, b, a)
}
m <- monty_model_function(fn, fixed = list(x = 10))
expect_equal(m$parameters, c("a", "b"))
expect_equal(monty_model_density(m, c(1, 2)),
dnorm(10, 2, 1))
expect_error(
monty_model_function(fn, p, fixed = list(x = 10)),
"Can't provide both 'packer' and 'fixed'")
})
66 changes: 26 additions & 40 deletions vignettes/monty.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ knitr::opts_chunk$set(
)
```

This vignette is an introduction to some of the ideas in `monty`. As the package is not ready for users to use it, this is more a place for us to jot down ideas about what it will do.
This vignette is an introduction to some of the ideas in `monty`. Please note that the interface is not yet stable and some function names, arguments and ideas will change before the first generally usable version.

```{r setup}
library(monty)
Expand Down Expand Up @@ -55,53 +55,35 @@ plot(height ~ weight, data)
A simple likelihood, following the model formulation in "Statistical Rethinking" chapter 3; height is modelled as normally distributed departures from a linear relationship with weight.

```{r}
likelihood <- monty_model(
list(
parameters = c("a", "b", "sigma"),
density = function(x) {
a <- x[[1]]
b <- x[[2]]
sigma <- x[[3]]
mu <- a + b * data$weight
sum(dnorm(data$height, mu, sigma, log = TRUE))
}))
fn <- function(a, b, sigma, data) {
mu <- a + b * data$weight
sum(dnorm(data$height, mu, sigma, log = TRUE))
}
```

The prior we'll make much nicer to work with in the future, but here we construct the density by hand as a sum of normally distributed priors on `a` and `b`, and a weak uniform prior on `sigma`. We provide a `direct_sample` function here so that we can draw samples from the prior distribution directly.
We can wrap this density function in a `monty_model`. The `data` argument is "fixed" - it's not part of the statistical model, so we'll pass that in as the `fixed` argument:

```{r}
prior <- local({
a_mu <- 178
a_sd <- 100
b_mu <- 0
b_sd <- 10
sigma_min <- 0
sigma_max <- 50
monty_model(
list(
parameters = c("a", "b", "sigma"),
density = function(x) {
a <- x[[1]]
b <- x[[2]]
sigma <- x[[3]]
dnorm(a, a_mu, a_sd, log = TRUE) +
dnorm(b, b_mu, b_sd, log = TRUE) +
dunif(sigma, sigma_min, sigma_max, log = TRUE)
},
direct_sample = function(rng) {
c(rng$normal(1, a_mu, a_sd),
rng$normal(1, b_mu, b_sd),
rng$uniform(1, sigma_min, sigma_max))
},
domain = rbind(c(-Inf, Inf), c(-Inf, Inf), c(0, Inf))
))
likelihood <- monty_model_function(fn, fixed = list(data = data))
likelihood
```

The prior we'll make much nicer to work with in the future, but here we construct the density by hand as a sum of normally distributed priors on `a` and `b`, and a weak uniform prior on `sigma`.

```{r}
prior <- monty_dsl({
a ~ Normal(178, 100)
b ~ Normal(0, 10)
sigma ~ Uniform(0, 50)
})
prior
```

The posterior distribution is the combination of these two models (indicated with a `+` because we're adding on a log-scale, or because we are using `prior` *and* `posterior`; you can use `monty_model_combine()` if you prefer).

```{r}
posterior <- likelihood + prior
posterior
```

Constructing a sensible initial variance-covariance matrix is a bit of a trick, and using an adaptive sampler reduces the pain here. These values are chosen to be reasonable starting points.
Expand All @@ -116,14 +98,14 @@ sampler <- monty_sampler_random_walk(vcv = vcv)
Now run the sampler. We've started from a good starting point to make this simple sampler converge quickly:

```{r}
samples <- monty_sample(posterior, sampler, 5000, initial = c(114, 0.9, 3),
samples <- monty_sample(posterior, sampler, 2000, initial = c(114, 0.9, 3),
n_chains = 4)
```

We don't yet have tools for working with the samples objects, but we can see the density over time easily enough:
We don't aim to directly provide tools for visualising and working with samples, as this is well trodden ground in other packages. However, we can directly plot density over time:

```{r}
matplot(t(samples$density), type = "l", lty = 1,
matplot(samples$density, type = "l", lty = 1,
xlab = "log posterior density", ylab = "sample", col = "#00000055")
```

Expand All @@ -138,3 +120,7 @@ abline(v = 0.9, col = "red")
plot(density(samples$pars["sigma", , ]), main = "sigma")
abline(v = 3, col = "red")
```

If you have `coda` installed you can convert these samples into a `coda` `mcmc.list` using `coda::as.mcmc.list()`, and if you have `posterior` installed you can convert into a `draws_df` using `posterior::as_draws_df()`, from which you can probably use your favourite plotting tools.

See `vignette("samplers")` for more information.

0 comments on commit 8ae84a2

Please sign in to comment.