From 37acc798a678409b1b3ce142340ecc0b15442084 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 15 Nov 2024 15:44:50 +0100 Subject: [PATCH] improve error messages for split_multiple_shapes_into_blocks --- bioimageio/core/block_meta.py | 36 ++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/bioimageio/core/block_meta.py b/bioimageio/core/block_meta.py index 22f29ded..ff875404 100644 --- a/bioimageio/core/block_meta.py +++ b/bioimageio/core/block_meta.py @@ -317,25 +317,34 @@ def split_multiple_shapes_into_blocks( strides: Optional[PerMember[PerAxis[int]]] = None, broadcast: bool = False, ) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]: - assert not ( - missing := [t for t in block_shapes if t not in shapes] - ), f"block shape specified for unknown tensors: {missing}" + if unknown_blocks := [t for t in block_shapes if t not in shapes]: + raise ValueError( + f"block shape specified for unknown tensors: {unknown_blocks}." + ) + if not block_shapes: block_shapes = shapes - assert broadcast or not ( - missing := [t for t in shapes if t not in block_shapes] - ), f"no block shape specified for {missing} (set `broadcast` to True if these tensors should be repeated for each block)" - assert not ( - missing := [t for t in halo if t not in block_shapes] - ), f"`halo` specified for tensors without block shape: {missing}" + if not broadcast and ( + missing_blocks := [t for t in shapes if t not in block_shapes] + ): + raise ValueError( + f"no block shape specified for {missing_blocks}." + + " Set `broadcast` to True if these tensors should be repeated" + + " as a whole for each block." + ) + + if extra_halo := [t for t in halo if t not in block_shapes]: + raise ValueError( + f"`halo` specified for tensors without block shape: {extra_halo}." + ) if strides is None: strides = {} assert not ( - missing := [t for t in strides if t not in block_shapes] - ), f"`stride` specified for tensors without block shape: {missing}" + unknown_block := [t for t in strides if t not in block_shapes] + ), f"`stride` specified for tensors without block shape: {unknown_block}" blocks: Dict[MemberId, Iterable[BlockMeta]] = {} n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {} @@ -355,8 +364,9 @@ def split_multiple_shapes_into_blocks( if len(unique_n_blocks) == 2 and 1 in unique_n_blocks: if not broadcast: raise ValueError( - f"Mismatch for total number of blocks due to unsplit (single block) tensors: {n_blocks}." - + " Set `broadcast` to True if you want to repeat unsplit (single block) tensors." + "Mismatch for total number of blocks due to unsplit (single block)" + + f" tensors: {n_blocks}. Set `broadcast` to True if you want to" + + " repeat unsplit (single block) tensors." ) blocks = {