Skip to content

Commit

Permalink
add arg max_input_shape to get_axis_sizes()
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Nov 12, 2024
1 parent 81d1236 commit 8504862
Showing 1 changed file with 84 additions and 37 deletions.
121 changes: 84 additions & 37 deletions bioimageio/spec/model/v0_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import numpy as np
from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType]
from loguru import logger
from numpy.typing import NDArray
from pydantic import (
Discriminator,
Expand Down Expand Up @@ -350,10 +351,20 @@ def get_size(
SpaceOutputAxis,
SpaceOutputAxisWithHalo,
],
n: ParameterizedSize_N,
n: ParameterizedSize_N = 0,
ref_size: Optional[int] = None,
):
"""helper method to compute concrete size for a given axis and its reference axis.
If the reference axis is parameterized, `n` is used to compute the concrete size of it, see `ParameterizedSize`.
"""Compute the concrete size for a given axis and its reference axis.

Args:
axis: The axis this `SizeReference` is the size of.
ref_axis: The reference axis to compute the size from.
n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
and no fixed **ref_size** is given,
**n** is used to compute the size of the parameterized **ref_axis**.
ref_size: Overwrite the reference size instead of deriving it from
**ref_axis**
(**ref_axis.scale** is still used; any given **n** is ignored).
"""
assert (
axis.size == self
Expand All @@ -367,22 +378,22 @@ def get_size(
"`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
f" but {axis.unit}!={ref_axis.unit}"
)

if isinstance(ref_axis.size, (int, float)):
ref_size = ref_axis.size
elif isinstance(ref_axis.size, ParameterizedSize):
ref_size = ref_axis.size.get_size(n)
elif isinstance(ref_axis.size, DataDependentSize):
raise ValueError(
"Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
)
elif isinstance(ref_axis.size, SizeReference):
raise ValueError(
"Reference axis referenced in `SizeReference` may not be sized by a"
+ " `SizeReference` itself."
)
else:
assert_never(ref_axis.size)
if ref_size is None:
if isinstance(ref_axis.size, (int, float)):
ref_size = ref_axis.size
elif isinstance(ref_axis.size, ParameterizedSize):
ref_size = ref_axis.size.get_size(n)
elif isinstance(ref_axis.size, DataDependentSize):
raise ValueError(
"Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
)
elif isinstance(ref_axis.size, SizeReference):
raise ValueError(
"Reference axis referenced in `SizeReference` may not be sized by a"
+ " `SizeReference` itself."
)
else:
assert_never(ref_axis.size)

return int(ref_size * ref_axis.scale / axis.scale + self.offset)

Expand Down Expand Up @@ -2474,6 +2485,7 @@ def get_axis_sizes(
self,
ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
batch_size: Optional[int] = None,
*,
max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
) -> _AxisSizes:
"""Determine input and output block shape for scale factors **ns**
Expand Down Expand Up @@ -2514,16 +2526,19 @@ def get_axis_sizes(
def get_axis_size(a: Union[InputAxis, OutputAxis]):
if isinstance(a, BatchAxis):
if (t_descr.id, a.id) in ns:
raise ValueError(
"No size increment factor (n) for batch axis of tensor"
+ f" '{t_descr.id}' expected."
logger.warning(
"Ignoring unexpected size increment factor (n) for batch axis"
+ " of tensor '{}'.",
t_descr.id,
)
return batch_size
elif isinstance(a.size, int):
if (t_descr.id, a.id) in ns:
raise ValueError(
"No size increment factor (n) for fixed size axis"
+ f" '{a.id}' of tensor '{t_descr.id}' expected."
logger.warning(
"Ignoring unexpected size increment factor (n) for fixed size"
+ " axis '{}' of tensor '{}'.",
a.id,
t_descr.id,
)
return a.size
elif isinstance(a.size, ParameterizedSize):
Expand All @@ -2532,39 +2547,71 @@ def get_axis_size(a: Union[InputAxis, OutputAxis]):
"Size increment factor (n) missing for parametrized axis"
+ f" '{a.id}' of tensor '{t_descr.id}'."
)
return a.size.get_size(ns[(t_descr.id, a.id)])
n = ns[(t_descr.id, a.id)]
s_n = a.size.get_size(n)
s_max = max_input_shape.get((t_descr.id, a.id))
if s_max is None:
return s_n

for n_min in range(n):
s = a.size.get_size(n_min)
if s >= s_max:
return s

return s_n # n == 0

elif isinstance(a.size, SizeReference):
if (t_descr.id, a.id) in ns:
raise ValueError(
f"No size increment factor (n) for axis '{a.id}' of tensor"
+ f" '{t_descr.id}' with size reference expected."
logger.warning(
"Ignoring unexpected size increment factor (n) for axis '{}'"
+ " of tensor '{}' with size reference.",
a.id,
t_descr.id,
)
assert not isinstance(a, BatchAxis)
ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
assert not isinstance(ref_axis, BatchAxis)
ref_key = (a.size.tensor_id, a.size.axis_id)
ref_size = inputs.get(ref_key, outputs.get(ref_key))
assert ref_size is not None, ref_key
assert not isinstance(ref_size, _DataDepSize), ref_key
return a.size.get_size(
axis=a,
ref_axis=ref_axis,
n=ns.get((a.size.tensor_id, a.size.axis_id), 0),
ref_size=ref_size,
)
elif isinstance(a.size, DataDependentSize):
if (t_descr.id, a.id) in ns:
raise ValueError(
"No size increment factor (n) for data dependent size axis"
+ f" '{a.id}' of tensor '{t_descr.id}' expected."
logger.warning(
"Ignoring unexpected increment factor (n) for data dependent"
+ " size axis '{}' of tensor '{}'.",
a.id,
t_descr.id,
)
return _DataDepSize(a.size.min, a.size.max)
else:
assert_never(a.size)

# first resolve all , but the `SizeReference` input sizes
for t_descr in self.inputs:
for a in t_descr.axes:
s = get_axis_size(a)
assert not isinstance(s, _DataDepSize)
inputs[t_descr.id, a.id] = s
if not isinstance(a.size, SizeReference):
s = get_axis_size(a)
assert not isinstance(s, _DataDepSize)
inputs[t_descr.id, a.id] = s

# resolve all other input axis sizes
for t_descr in self.inputs:
for a in t_descr.axes:
if isinstance(a.size, SizeReference):
s = get_axis_size(a)
assert not isinstance(s, _DataDepSize)
inputs[t_descr.id, a.id] = s

for t_descr in chain(self.inputs, self.outputs):
# resolve all output axis sizes
for t_descr in self.outputs:
for a in t_descr.axes:
assert not isinstance(a.size, SizeReference)
s = get_axis_size(a)
outputs[t_descr.id, a.id] = s

Expand Down

0 comments on commit 8504862

Please sign in to comment.