From 75ee23aa8b52a63dc8602bf34574e3b5bbe4c307 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 4 Nov 2024 17:10:05 +0000 Subject: [PATCH] Support domain expansion for grouped packers --- R/domain.R | 29 ++++++++++++++++++++++------- tests/testthat/test-domain.R | 20 ++++++++++++++++++++ 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/R/domain.R b/R/domain.R index 4a018060..24577545 100644 --- a/R/domain.R +++ b/R/domain.R @@ -45,7 +45,6 @@ monty_domain_expand <- function(domain, packer) { } ## Below here needs some actual work for grouped packers - assert_is(packer, "monty_packer") nms <- rownames(domain) if (is.null(nms)) { cli::cli_abort("Expected 'domain' to have row names", arg = "domain") @@ -59,7 +58,14 @@ monty_domain_expand <- function(domain, packer) { nms_full <- packer$names() nms_map <- packer$unpack(nms_full) - nms_logical <- names(nms_map) + + is_grouped <- inherits(packer, "monty_packer_grouped") + + if (is_grouped) { + nms_logical <- unique(unlist(lapply(nms_map, names), FALSE, FALSE)) + } else { + nms_logical <- names(nms_map) + } i <- nms %in% nms_logical & !(nms %in% intersect(nms_logical, nms_full)) err <- !(i | nms %in% nms_full) @@ -71,11 +77,20 @@ monty_domain_expand <- function(domain, packer) { if (any(i)) { nms_expand <- nms[i] - extra <- unname(domain)[ - rep(which(i), lengths(nms_map[nms_expand])), , drop = FALSE] - rownames(extra) <- unlist(nms_map[nms_expand], FALSE, FALSE) - j <- !(rownames(extra) %in% rownames(domain)) - domain <- rbind(extra[j, , drop = FALSE], + if (is_grouped) { + j <- unlist(lapply(unname(nms_map), function(el) { + nms_el <- intersect(nms, names(el)) + set_names(rep(match(nms_el, nms), lengths(el[nms_el])), + unlist(el[nms_el])) + })) + } else { + j <- set_names(rep(which(i), lengths(nms_map[nms_expand])), + unlist(nms_map[nms_expand], FALSE, FALSE)) + } + extra <- unname(domain)[j, , drop = FALSE] + rownames(extra) <- names(j) + keep <- !(rownames(extra) %in% rownames(domain)) + domain <- rbind(extra[keep, , drop = FALSE], domain[!i, , drop = FALSE]) } diff --git a/tests/testthat/test-domain.R b/tests/testthat/test-domain.R index 6499ac7a..4ebe106f 100644 --- a/tests/testthat/test-domain.R +++ b/tests/testthat/test-domain.R @@ -54,3 +54,23 @@ test_that("can expand parameters", { monty_domain_expand(rbind(x = 0:1, b = 2:3, a = 4:5), packer), rbind(a = 4:5, b = 2:3, "x[1]" = 0:1, "x[2]" = 0:1, "x[3]" = 0:1)) }) + + +test_that("Expand domain for grouped packer", { + packer <- monty_packer_grouped(c("x", "y"), c("a", "b")) + expect_equal( + monty_domain_expand(rbind(a = 0:1), packer), + rbind("a" = 0:1, + "a" = 0:1)) +}) + + +test_that("Expand domain for grouped packer with arrays", { + packer <- monty_packer_grouped(c("x", "y"), c("a", "b"), list(c = 3, d = 2)) + expected <- cbind(c(0, 0, 0, 0, 2, 0), c(1, 1, 1, 1, 3, 1)) + rownames(expected) <- + c("c[1]", "c[2]", "c[3]", "c[1]", "c[2]", "c[3]") + expect_equal( + monty_domain_expand(rbind(c = 0:1, "c[2]" = 2:3), packer), + expected) +})