Skip to content

Commit

Permalink
Vectorise gradients too
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Nov 5, 2024
1 parent fa6c5ca commit b81e70c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 6 deletions.
33 changes: 27 additions & 6 deletions R/dsl-generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dsl_generate <- function(dat) {
properties <- monty_model_properties(allow_multiple_parameters = TRUE)
monty_model(
list(parameters = dat$parameters,
density = vectorise_over_parameters(density),
density = density,
gradient = gradient,
domain = domain,
direct_sample = direct_sample),
Expand All @@ -31,7 +31,8 @@ dsl_generate_density <- function(dat, env, meta) {
call("<-", meta[["density"]], quote(numeric())),
exprs,
call("sum", meta[["density"]]))
as_function(alist(x = ), body, env)
vectorise_density_over_parameters(
as_function(alist(x = ), body, env))
}


Expand All @@ -50,7 +51,9 @@ dsl_generate_gradient <- function(dat, env, meta) {
body <- c(call("<-", meta[["data"]], quote(packer$unpack(x))),
unname(eqs),
eq_return)
as_function(alist(x = ), body, env)
vectorise_gradient_over_parameters(
as_function(alist(x = ), body, env),
length(dat$parameters))
}


Expand Down Expand Up @@ -166,8 +169,26 @@ fold_c <- function(x) {
## until the rest of the DSL is written, especially arrays. For
## simple models with scalars we should be able to just pass through
## multiple parameters at once.
vectorise_over_parameters <- function(f) {
function(p) {
if (is.matrix(p)) vnapply(seq_len(ncol(p)), function(i) f(p[, i])) else f(p)
vectorise_density_over_parameters <- function(density) {
function(x) {
if (is.matrix(x)) {
vnapply(seq_len(ncol(x)), function(i) density(x[, i]))
} else {
density(x)
}
}
}

vectorise_gradient_over_parameters <- function(gradient, len) {
function(x) {
if (is.matrix(x)) {
g <- vapply(seq_len(ncol(x)), function(i) gradient(x[, i]), numeric(len))
if (is.null(dim(g))) {
dim(g) <- c(len, ncol(x))
}
g
} else {
gradient(x)
}
}
}
23 changes: 23 additions & 0 deletions tests/testthat/test-dsl.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,27 @@ test_that("can evaluate dsl model densities for multiple parameters", {
x <- matrix(runif(10), 2, 5)
expect_equal(m$density(x),
dnorm(x[1, ], 0, 1, TRUE) + dexp(x[2, ], 2, TRUE))
expect_equal(
m$gradient(x),
apply(x, 2, m$gradient))
expect_equal(
m$gradient(x[, 1, drop = FALSE]),
cbind(m$gradient(x[, 1])))
})


test_that("gradient calculation correct single-parameter model", {
m <- monty_dsl({
a ~ Normal(0, 1)
})
expect_true(m$properties$allow_multiple_parameters)
x <- matrix(runif(5), 1, 5)
expect_equal(m$density(x),
dnorm(x[1, ], 0, 1, TRUE))
expect_equal(
m$gradient(x),
rbind(apply(x, 2, m$gradient)))
expect_equal(
m$gradient(x[, 1, drop = FALSE]),
cbind(m$gradient(x[, 1])))
})

0 comments on commit b81e70c

Please sign in to comment.