Skip to content

Commit

Permalink
improve error messages for split_multiple_shapes_into_blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Nov 15, 2024
1 parent 4cbf9ef commit 37acc79
Showing 1 changed file with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions bioimageio/core/block_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,25 +317,34 @@ def split_multiple_shapes_into_blocks(
strides: Optional[PerMember[PerAxis[int]]] = None,
broadcast: bool = False,
) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]:
assert not (
missing := [t for t in block_shapes if t not in shapes]
), f"block shape specified for unknown tensors: {missing}"
if unknown_blocks := [t for t in block_shapes if t not in shapes]:
raise ValueError(
f"block shape specified for unknown tensors: {unknown_blocks}."
)

if not block_shapes:
block_shapes = shapes

assert broadcast or not (
missing := [t for t in shapes if t not in block_shapes]
), f"no block shape specified for {missing} (set `broadcast` to True if these tensors should be repeated for each block)"
assert not (
missing := [t for t in halo if t not in block_shapes]
), f"`halo` specified for tensors without block shape: {missing}"
if not broadcast and (
missing_blocks := [t for t in shapes if t not in block_shapes]
):
raise ValueError(
f"no block shape specified for {missing_blocks}."
+ " Set `broadcast` to True if these tensors should be repeated"
+ " as a whole for each block."
)

if extra_halo := [t for t in halo if t not in block_shapes]:
raise ValueError(
f"`halo` specified for tensors without block shape: {extra_halo}."
)

if strides is None:
strides = {}

assert not (
missing := [t for t in strides if t not in block_shapes]
), f"`stride` specified for tensors without block shape: {missing}"
unknown_block := [t for t in strides if t not in block_shapes]
), f"`stride` specified for tensors without block shape: {unknown_block}"

blocks: Dict[MemberId, Iterable[BlockMeta]] = {}
n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {}
Expand All @@ -355,8 +364,9 @@ def split_multiple_shapes_into_blocks(
if len(unique_n_blocks) == 2 and 1 in unique_n_blocks:
if not broadcast:
raise ValueError(
f"Mismatch for total number of blocks due to unsplit (single block) tensors: {n_blocks}."
+ " Set `broadcast` to True if you want to repeat unsplit (single block) tensors."
"Mismatch for total number of blocks due to unsplit (single block)"
+ f" tensors: {n_blocks}. Set `broadcast` to True if you want to"
+ " repeat unsplit (single block) tensors."
)

blocks = {
Expand Down

0 comments on commit 37acc79

Please sign in to comment.