Skip to content

Commit

Permalink
Fix gradient and expand testing
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Nov 5, 2024
1 parent b81e70c commit 6244a2c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
20 changes: 19 additions & 1 deletion R/combine.R
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,30 @@ model_combine_gradient <- function(a, b, parameters, properties, call = NULL) {
n_pars <- length(parameters)
i_a <- match(a$parameters, parameters)
i_b <- match(b$parameters, parameters)
function(x, ...) {

gradient_vector <- function(x, ...) {
ret <- numeric(n_pars)
ret[i_a] <- ret[i_a] + a$gradient(x[i_a], ...)
ret[i_b] <- ret[i_b] + b$gradient(x[i_b], ...)
ret
}

if (properties$allow_multiple_parameters) {
function(x, ...) {
if (is.matrix(x)) {
ret <- matrix(0, n_pars, ncol(x))
ret[i_a, ] <-
ret[i_a, , drop = FALSE] + a$gradient(x[i_a, , drop = FALSE], ...)
ret[i_b, ] <-
ret[i_b, , drop = FALSE] + b$gradient(x[i_b, , drop = FALSE], ...)
ret
} else {
gradient_vector(x, ...)
}
}
} else {
gradient_vector
}
}


Expand Down
10 changes: 10 additions & 0 deletions tests/testthat/test-combine.R
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,16 @@ test_that("can combine models that allow multiple parameters", {
expect_equal(
m$density(x),
m1$density(x[1, , drop = FALSE]) + m2$density(x))
expect_equal(
m$density(x[, 1]),
m1$density(x[1, 1]) + m2$density(x[, 1]))

expect_equal(
m$gradient(x),
rbind(m1$gradient(x[1, , drop = FALSE]), 0) + m2$gradient(x))
expect_equal(
m$gradient(x[, 1]),
c(m1$gradient(x[1, 1]), 0) + m2$gradient(x[, 1]))
})


Expand Down

0 comments on commit 6244a2c

Please sign in to comment.