Skip to content

Commit

Permalink
Merge pull request #97 from mrc-ide/mrc-5918
Browse files Browse the repository at this point in the history
Support domain expansion for grouped packers
  • Loading branch information
richfitz authored Nov 6, 2024
2 parents 14ffdec + 09f63f8 commit c0a23ee
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 8 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.2.29
Version: 0.2.30
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
29 changes: 22 additions & 7 deletions R/domain.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ monty_domain_expand <- function(domain, packer) {
}

## Below here needs some actual work for grouped packers
assert_is(packer, "monty_packer")
nms <- rownames(domain)
if (is.null(nms)) {
cli::cli_abort("Expected 'domain' to have row names", arg = "domain")
Expand All @@ -59,7 +58,14 @@ monty_domain_expand <- function(domain, packer) {

nms_full <- packer$names()
nms_map <- packer$unpack(nms_full)
nms_logical <- names(nms_map)

is_grouped <- inherits(packer, "monty_packer_grouped")

if (is_grouped) {
nms_logical <- unique(unlist(lapply(nms_map, names), FALSE, FALSE))
} else {
nms_logical <- names(nms_map)
}

i <- nms %in% nms_logical & !(nms %in% intersect(nms_logical, nms_full))
err <- !(i | nms %in% nms_full)
Expand All @@ -71,11 +77,20 @@ monty_domain_expand <- function(domain, packer) {

if (any(i)) {
nms_expand <- nms[i]
extra <- unname(domain)[
rep(which(i), lengths(nms_map[nms_expand])), , drop = FALSE]
rownames(extra) <- unlist(nms_map[nms_expand], FALSE, FALSE)
j <- !(rownames(extra) %in% rownames(domain))
domain <- rbind(extra[j, , drop = FALSE],
if (is_grouped) {
j <- unlist(lapply(unname(nms_map), function(el) {
nms_el <- intersect(nms, names(el))
set_names(rep(match(nms_el, nms), lengths(el[nms_el])),
unlist(el[nms_el]))
}))
} else {
j <- set_names(rep(which(i), lengths(nms_map[nms_expand])),
unlist(nms_map[nms_expand], FALSE, FALSE))
}
extra <- unname(domain)[j, , drop = FALSE]
rownames(extra) <- names(j)
keep <- !(rownames(extra) %in% rownames(domain))
domain <- rbind(extra[keep, , drop = FALSE],
domain[!i, , drop = FALSE])
}

Expand Down
20 changes: 20 additions & 0 deletions tests/testthat/test-domain.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,23 @@ test_that("can expand parameters", {
monty_domain_expand(rbind(x = 0:1, b = 2:3, a = 4:5), packer),
rbind(a = 4:5, b = 2:3, "x[1]" = 0:1, "x[2]" = 0:1, "x[3]" = 0:1))
})


test_that("Expand domain for grouped packer", {
packer <- monty_packer_grouped(c("x", "y"), c("a", "b"))
expect_equal(
monty_domain_expand(rbind(a = 0:1), packer),
rbind("a<x>" = 0:1,
"a<y>" = 0:1))
})


test_that("Expand domain for grouped packer with arrays", {
packer <- monty_packer_grouped(c("x", "y"), c("a", "b"), list(c = 3, d = 2))
expected <- cbind(c(0, 0, 0, 0, 2, 0), c(1, 1, 1, 1, 3, 1))
rownames(expected) <-
c("c[1]<x>", "c[2]<x>", "c[3]<x>", "c[1]<y>", "c[2]<y>", "c[3]<y>")
expect_equal(
monty_domain_expand(rbind(c = 0:1, "c[2]<y>" = 2:3), packer),
expected)
})

0 comments on commit c0a23ee

Please sign in to comment.