Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow process for grouped packers #100

Merged
merged 4 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.2.32
Version: 0.2.33
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
20 changes: 18 additions & 2 deletions R/packer-grouped.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@
##' @param shared Names of the elements in `scalar` and `array` that
##' are shared among all groups.
##'
##' @param process An arbitrary R function that will be passed the
##' final assembled list **for each region**; it may create any
richfitz marked this conversation as resolved.
Show resolved Hide resolved
##' *additional* entries, which will be concatenated onto the
##' original list. If you use this you should take care not to
##' return any values with the same names as entries listed in
##' `scalar`, `array` or `fixed`, as this is an error (this is so
##' that `pack()` is not broken). We will likely play around with
##' this process in future in order to get automatic differentiation
##' to work.
##'
##' @inheritParams monty_packer
##'
##' @return An object of class `monty_packer_grouped`, which has the
Expand Down Expand Up @@ -96,9 +106,12 @@ monty_packer_grouped <- function(groups, scalar = NULL, array = NULL,
cli::cli_abort("Expected at least two groups", arg = "groups")
}
if (!is.null(process)) {
cli::cli_abort("'process' is not yet compatible with grouped packers")
if (!is.function(process)) {
cli::cli_abort(
"Expected a function for 'process'",
arg = "process")
}
}

if (!is.null(fixed)) {
i <- names(fixed) %in% groups
fixed <- list(shared = if (!all(i)) fixed[!i],
Expand Down Expand Up @@ -186,6 +199,9 @@ monty_packer_grouped <- function(groups, scalar = NULL, array = NULL,
if (!is.null(fixed$varied)) {
ret[[i]][names(fixed$varied[[i]])] <- fixed$varied[[i]]
}
if (!is.null(process)) {
ret[[i]] <- unpack_vector_process(ret[[i]], process)
}
}
names(ret) <- groups
ret
Expand Down
34 changes: 17 additions & 17 deletions R/packer.R
Original file line number Diff line number Diff line change
Expand Up @@ -505,23 +505,7 @@ unpack_vector <- function(x, nms, len, idx, shape, fixed, process) {
res <- c(res, fixed)
}
if (!is.null(process)) {
extra <- process(res)
err <- intersect(names(extra), names(res))
if (length(err) > 0) {
cli::cli_abort(
c("'process()' is trying to overwrite entries in your list",
i = paste("The 'process()' function should only create elements",
"that are not already present in 'scalar', 'array'",
"or 'fixed', as this lets us reverse the transformation",
"process"),
x = "{?Entry/Entries} already present: {squote(err)}"))
}
## TODO: check names?
##
## TODO: this fails the multi-region use - but I think that we
## might want a different interface there anyway as we'll
## struggle to hold all the options here.
res <- c(res, extra)
res <- unpack_vector_process(res, process)
}
res
}
Expand Down Expand Up @@ -651,3 +635,19 @@ pack_check_dimensions <- function(p, shape, fixed, process,

ret
}


unpack_vector_process <- function(x, process) {
extra <- process(x)
err <- intersect(names(extra), names(x))
if (length(err) > 0) {
cli::cli_abort(
c("'process()' is trying to overwrite entries in your list",
i = paste("The 'process()' function should only create elements",
"that are not already present in 'scalar', 'array'",
"or 'fixed', as this lets us reverse the transformation",
"process"),
x = "{?Entry/Entries} already present: {squote(err)}"))
}
c(x, extra)
}
25 changes: 18 additions & 7 deletions tests/testthat/test-packer-grouped.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,6 @@ test_that("prevent varied names in fixed clashing with elements in packer", {
})


test_that("can't use 'progress' with a grouped packer", {
expect_error(
monty_packer_grouped(c("x", "y"), c("a", "b"), process = identity),
"process' is not yet compatible with grouped packers")
})


test_that("grouped packers require at least two groups", {
expect_error(
monty_packer_grouped("x", c("a", "b")),
Expand Down Expand Up @@ -237,3 +230,21 @@ test_that("Can print a grouped packer", {
"Packing 4 values: 'x<a>', 'y<a>', 'x<b>', and 'y<b>",
fixed = TRUE, all = FALSE)
})


richfitz marked this conversation as resolved.
Show resolved Hide resolved
test_that("can used process with grouped packer", {
richfitz marked this conversation as resolved.
Show resolved Hide resolved
process <- function(res) {
list(z = res$x + res$y)
}
p <- monty_packer_grouped(c("a", "b"), c("x", "y"), process = process)
expect_equal(p$unpack(1:4),
list(a = list(x = 1, y = 2, z = 3),
b = list(x = 3, y = 4, z = 7)))
})


test_that("can used process with grouped packer", {
richfitz marked this conversation as resolved.
Show resolved Hide resolved
expect_error(
monty_packer_grouped(c("a", "b"), c("x", "y"), process = TRUE),
"Expected a function for 'process'")
})
Loading