Skip to content

Commit

Permalink
Allow process for grouped packers
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Nov 5, 2024
1 parent 8c96fc0 commit 0e3018b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 9 deletions.
20 changes: 18 additions & 2 deletions R/packer-grouped.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@
##' @param shared Names of the elements in `scalar` and `array` that
##' are shared among all groups.
##'
##' @param process An arbitrary R function that will be passed the
##' final assembled list **for each region**; it may create any
##' *additional* entries, which will be concatenated onto the
##' original list. If you use this you should take care not to
##' return any values with the same names as entries listed in
##' `scalar`, `array` or `fixed`, as this is an error (this is so
##' that `pack()` is not broken). We will likely play around with
##' this process in future in order to get automatic differentiation
##' to work.
##'
##' @inheritParams monty_packer
##'
##' @return An object of class `monty_packer_grouped`, which has the
Expand Down Expand Up @@ -96,9 +106,12 @@ monty_packer_grouped <- function(groups, scalar = NULL, array = NULL,
cli::cli_abort("Expected at least two groups", arg = "groups")
}
if (!is.null(process)) {
cli::cli_abort("'process' is not yet compatible with grouped packers")
if (!is.function(process)) {
cli::cli_abort(
"Expected a function for 'process'",
arg = "process")
}
}

if (!is.null(fixed)) {
i <- names(fixed) %in% groups
fixed <- list(shared = if (!all(i)) fixed[!i],
Expand Down Expand Up @@ -186,6 +199,9 @@ monty_packer_grouped <- function(groups, scalar = NULL, array = NULL,
if (!is.null(fixed$varied)) {
ret[[i]][names(fixed$varied[[i]])] <- fixed$varied[[i]]
}
if (!is.null(process)) {
ret[[i]] <- unpack_vector_process(ret[[i]], process)
}
}
names(ret) <- groups
ret
Expand Down
16 changes: 16 additions & 0 deletions R/packer.R
Original file line number Diff line number Diff line change
Expand Up @@ -651,3 +651,19 @@ pack_check_dimensions <- function(p, shape, fixed, process,

ret
}


unpack_vector_process <- function(x, process) {
extra <- process(x)
err <- intersect(names(extra), names(x))
if (length(err) > 0) {
cli::cli_abort(
c("'process()' is trying to overwrite entries in your list",
i = paste("The 'process()' function should only create elements",
"that are not already present in 'scalar', 'array'",
"or 'fixed', as this lets us reverse the transformation",
"process"),
x = "{?Entry/Entries} already present: {squote(err)}"))
}
c(x, extra)
}
18 changes: 11 additions & 7 deletions tests/testthat/test-packer-grouped.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,6 @@ test_that("prevent varied names in fixed clashing with elements in packer", {
})


test_that("can't use 'progress' with a grouped packer", {
expect_error(
monty_packer_grouped(c("x", "y"), c("a", "b"), process = identity),
"process' is not yet compatible with grouped packers")
})


test_that("grouped packers require at least two groups", {
expect_error(
monty_packer_grouped("x", c("a", "b")),
Expand Down Expand Up @@ -237,3 +230,14 @@ test_that("Can print a grouped packer", {
"Packing 4 values: 'x<a>', 'y<a>', 'x<b>', and 'y<b>",
fixed = TRUE, all = FALSE)
})


test_that("can used process with grouped packer", {
process <- function(res) {
list(z = res$x + res$y)
}
p <- monty_packer_grouped(c("a", "b"), c("x", "y"), process = process)
expect_equal(p$unpack(1:4),
list(a = list(x = 1, y = 2, z = 3),
b = list(x = 3, y = 4, z = 7)))
})

0 comments on commit 0e3018b

Please sign in to comment.