From 0e3018ba8b2b40d8033f3190c28b4c8fe106565a Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 5 Nov 2024 14:24:10 +0000 Subject: [PATCH] Allow process for grouped packers --- R/packer-grouped.R | 20 ++++++++++++++++++-- R/packer.R | 16 ++++++++++++++++ tests/testthat/test-packer-grouped.R | 18 +++++++++++------- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/R/packer-grouped.R b/R/packer-grouped.R index 2806764c..d1708b9e 100644 --- a/R/packer-grouped.R +++ b/R/packer-grouped.R @@ -49,6 +49,16 @@ ##' @param shared Names of the elements in `scalar` and `array` that ##' are shared among all groups. ##' +##' @param process An arbitrary R function that will be passed the +##' final assembled list **for each region**; it may create any +##' *additional* entries, which will be concatenated onto the +##' original list. If you use this you should take care not to +##' return any values with the same names as entries listed in +##' `scalar`, `array` or `fixed`, as this is an error (this is so +##' that `pack()` is not broken). We will likely play around with +##' this process in future in order to get automatic differentiation +##' to work. +##' ##' @inheritParams monty_packer ##' ##' @return An object of class `monty_packer_grouped`, which has the @@ -96,9 +106,12 @@ monty_packer_grouped <- function(groups, scalar = NULL, array = NULL, cli::cli_abort("Expected at least two groups", arg = "groups") } if (!is.null(process)) { - cli::cli_abort("'process' is not yet compatible with grouped packers") + if (!is.function(process)) { + cli::cli_abort( + "Expected a function for 'process'", + arg = "process") + } } - if (!is.null(fixed)) { i <- names(fixed) %in% groups fixed <- list(shared = if (!all(i)) fixed[!i], @@ -186,6 +199,9 @@ monty_packer_grouped <- function(groups, scalar = NULL, array = NULL, if (!is.null(fixed$varied)) { ret[[i]][names(fixed$varied[[i]])] <- fixed$varied[[i]] } + if (!is.null(process)) { + ret[[i]] <- unpack_vector_process(ret[[i]], process) + } } names(ret) <- groups ret diff --git a/R/packer.R b/R/packer.R index 6eebca6f..9e584eca 100644 --- a/R/packer.R +++ b/R/packer.R @@ -651,3 +651,19 @@ pack_check_dimensions <- function(p, shape, fixed, process, ret } + + +unpack_vector_process <- function(x, process) { + extra <- process(x) + err <- intersect(names(extra), names(x)) + if (length(err) > 0) { + cli::cli_abort( + c("'process()' is trying to overwrite entries in your list", + i = paste("The 'process()' function should only create elements", + "that are not already present in 'scalar', 'array'", + "or 'fixed', as this lets us reverse the transformation", + "process"), + x = "{?Entry/Entries} already present: {squote(err)}")) + } + c(x, extra) +} diff --git a/tests/testthat/test-packer-grouped.R b/tests/testthat/test-packer-grouped.R index f62a9f7d..a1ef64a6 100644 --- a/tests/testthat/test-packer-grouped.R +++ b/tests/testthat/test-packer-grouped.R @@ -155,13 +155,6 @@ test_that("prevent varied names in fixed clashing with elements in packer", { }) -test_that("can't use 'progress' with a grouped packer", { - expect_error( - monty_packer_grouped(c("x", "y"), c("a", "b"), process = identity), - "process' is not yet compatible with grouped packers") -}) - - test_that("grouped packers require at least two groups", { expect_error( monty_packer_grouped("x", c("a", "b")), @@ -237,3 +230,14 @@ test_that("Can print a grouped packer", { "Packing 4 values: 'x', 'y', 'x', and 'y", fixed = TRUE, all = FALSE) }) + + +test_that("can used process with grouped packer", { + process <- function(res) { + list(z = res$x + res$y) + } + p <- monty_packer_grouped(c("a", "b"), c("x", "y"), process = process) + expect_equal(p$unpack(1:4), + list(a = list(x = 1, y = 2, z = 3), + b = list(x = 3, y = 4, z = 7))) +})