Skip to content

Commit

Permalink
Merge pull request #654 from bioimage-io/get_axis_size
Browse files Browse the repository at this point in the history
add arg max_input_shape to v0_5.Model.get_axis_sizes()
  • Loading branch information
FynnBe authored Nov 13, 2024
2 parents e77a24d + b77efd1 commit 5f44013
Show file tree
Hide file tree
Showing 19 changed files with 338 additions and 337 deletions.
22 changes: 16 additions & 6 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ jobs:
include:
- python-version: "3.12"
is-dev-version: true
run_expensive_tests: true
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
Expand All @@ -29,12 +30,13 @@ jobs:
- name: Get Date
id: get-date
run: |
echo "week=$(/bin/date -u "+%Y-%U")" >> $GITHUB_OUTPUT
echo "date=$(date +'%Y-%b')"
echo "date=$(date +'%Y-%b')" >> $GITHUB_OUTPUT
shell: bash
- uses: actions/cache@v3
- uses: actions/cache/restore@v4
with:
path: tests/cache
key: ${{ runner.os }}-${{ matrix.python-version }}-${{ steps.get-date.outputs.week }}-${{ hashFiles('**/lockfiles') }}
path: bioimageio_cache
key: "py${{ matrix.python-version }}-${{ steps.get-date.outputs.date }}"
- name: Check autogenerated imports
run: python scripts/generate_version_submodule_imports.py check
- run: black --check .
Expand All @@ -49,14 +51,22 @@ jobs:
- run: pyright -p pyproject.toml --pythonversion ${{ matrix.python-version }}
if: matrix.is-dev-version
- run: pytest
env:
BIOIMAGEIO_CACHE_PATH: bioimageio_cache
SKIP_EXPENSIVE_TESTS: ${{ matrix.run_expensive_tests && 'false' || 'true' }}
- uses: actions/cache/save@v4
# explicit restore/save instead of cache action to cache even if coverage fails
with:
path: bioimageio_cache
key: "py${{ matrix.python-version }}-${{ steps.get-date.outputs.date }}"
- if: matrix.is-dev-version && github.event_name == 'pull_request'
uses: orgoro/[email protected]
with:
coverageFile: coverage.xml
token: ${{ secrets.GITHUB_TOKEN }}
thresholdAll: 0.75
thresholdAll: 0.7
thresholdNew: 0.9
thresholdModified: 0.85
thresholdModified: 0.6
- if: matrix.is-dev-version
run: |
pip install genbadge[coverage]
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ coverage.xml
dist/
docs/
output/
tests/cache
tests/generated_json_schemas
tmp/
user_docs/
scripts/pdoc/original.py
scripts/pdoc/patched.py
bioimageio_cache/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ To keep the bioimageio.spec Python package version in sync with the (model) desc
* fix summary formatting
* improve logged origin for logged messages
* make the `model.v0_5.ModelDescr.training_data` field a `left_to_right` Union to avoid warnings
* the deprecated `version_number` is no longer appended to the `id`, but instead set as `version` if no `version` is specified.

#### bioimageio.spec 0.5.3.3

Expand Down
3 changes: 2 additions & 1 deletion bioimageio/spec/application/v0_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class ApplicationDescr(GenericDescrBase, title="bioimage.io application specific
type: Literal["application"] = "application"

id: Optional[ApplicationId] = None
"""Model zoo (bioimage.io) wide, unique identifier (assigned by bioimage.io)"""
"""bioimage.io-wide unique resource identifier
assigned by bioimage.io; version **un**specific."""

source: Annotated[
Optional[ImportantFileSource],
Expand Down
3 changes: 2 additions & 1 deletion bioimageio/spec/application/v0_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class ApplicationDescr(GenericDescrBase, title="bioimage.io application specific
type: Literal["application"] = "application"

id: Optional[ApplicationId] = None
"""Model zoo (bioimage.io) wide, unique identifier (assigned by bioimage.io)"""
"""bioimage.io-wide unique resource identifier
assigned by bioimage.io; version **un**specific."""

parent: Optional[ApplicationId] = None
"""The description from which this one is derived"""
Expand Down
3 changes: 2 additions & 1 deletion bioimageio/spec/dataset/v0_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class DatasetDescr(GenericDescrBase, title="bioimage.io dataset specification"):
type: Literal["dataset"] = "dataset"

id: Optional[DatasetId] = None
"""Model zoo (bioimage.io) wide, unique identifier (assigned by bioimage.io)"""
"""bioimage.io-wide unique resource identifier
assigned by bioimage.io; version **un**specific."""

source: Optional[HttpUrl] = None
""""URL to the source of the dataset."""
Expand Down
3 changes: 2 additions & 1 deletion bioimageio/spec/dataset/v0_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class DatasetDescr(GenericDescrBase, title="bioimage.io dataset specification"):
type: Literal["dataset"] = "dataset"

id: Optional[DatasetId] = None
"""Model zoo (bioimage.io) wide, unique identifier (assigned by bioimage.io)"""
"""bioimage.io-wide unique resource identifier
assigned by bioimage.io; version **un**specific."""

parent: Optional[DatasetId] = None
"""The description from which this one is derived"""
Expand Down
3 changes: 2 additions & 1 deletion bioimageio/spec/generic/v0_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ class GenericDescr(
"""The resource type assigns a broad category to the resource."""

id: Optional[ResourceId] = None
"""Model zoo (bioimage.io) wide, unique identifier (assigned by bioimage.io)"""
"""bioimage.io-wide unique resource identifier
assigned by bioimage.io; version **un**specific."""

source: Optional[HttpUrl] = None
"""The primary source of the resource"""
Expand Down
14 changes: 9 additions & 5 deletions bioimageio/spec/generic/v0_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,8 @@ def _remove_version_number( # pyright: ignore[reportUnknownParameterType]
):
if isinstance(value, dict):
vn: Any = value.pop("version_number", None)
if vn is not None and "id" in value:
value["id"] = f"{value['id']}/{vn}"
if vn is not None and value.get("version") is None:
value["version"] = vn

return value # pyright: ignore[reportUnknownVariableType]

Expand Down Expand Up @@ -420,7 +420,8 @@ class GenericDescr(
"""The resource type assigns a broad category to the resource."""

id: Optional[ResourceId] = None
"""Model zoo (bioimage.io) wide, unique identifier (assigned by bioimage.io)"""
"""bioimage.io-wide unique resource identifier
assigned by bioimage.io; version **un**specific."""

parent: Optional[ResourceId] = None
"""The description from which this one is derived"""
Expand Down Expand Up @@ -448,7 +449,10 @@ def _remove_version_number( # pyright: ignore[reportUnknownParameterType]
):
if isinstance(value, dict):
vn: Any = value.pop("version_number", None)
if vn is not None and "id" in value:
value["id"] = f"{value['id']}/{vn}"
if vn is not None and value.get("version") is None:
value["version"] = vn

return value # pyright: ignore[reportUnknownVariableType]

version: Optional[Version] = None
"""The version of the linked resource following SemVer 2.0."""
3 changes: 2 additions & 1 deletion bioimageio/spec/model/v0_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,8 @@ class ModelDescr(GenericModelDescrBase, title="bioimage.io model specification")
"""Specialized resource type 'model'"""

id: Optional[ModelId] = None
"""Model zoo (bioimage.io) wide, unique identifier (assigned by bioimage.io)"""
"""bioimage.io-wide unique resource identifier
assigned by bioimage.io; version **un**specific."""

authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory
List[Author]
Expand Down
151 changes: 112 additions & 39 deletions bioimageio/spec/model/v0_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import numpy as np
from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
from imageio.v3 import imread, imwrite # pyright: ignore[reportUnknownVariableType]
from loguru import logger
from numpy.typing import NDArray
from pydantic import (
Discriminator,
Expand Down Expand Up @@ -350,10 +351,20 @@ def get_size(
SpaceOutputAxis,
SpaceOutputAxisWithHalo,
],
n: ParameterizedSize_N,
n: ParameterizedSize_N = 0,
ref_size: Optional[int] = None,
):
"""helper method to compute concrete size for a given axis and its reference axis.
If the reference axis is parameterized, `n` is used to compute the concrete size of it, see `ParameterizedSize`.
"""Compute the concrete size for a given axis and its reference axis.
Args:
axis: The axis this `SizeReference` is the size of.
ref_axis: The reference axis to compute the size from.
n: If the **ref_axis** is parameterized (of type `ParameterizedSize`)
and no fixed **ref_size** is given,
**n** is used to compute the size of the parameterized **ref_axis**.
ref_size: Overwrite the reference size instead of deriving it from
**ref_axis**
(**ref_axis.scale** is still used; any given **n** is ignored).
"""
assert (
axis.size == self
Expand All @@ -367,22 +378,22 @@ def get_size(
"`SizeReference` requires `axis` and `ref_axis` to have the same `unit`,"
f" but {axis.unit}!={ref_axis.unit}"
)

if isinstance(ref_axis.size, (int, float)):
ref_size = ref_axis.size
elif isinstance(ref_axis.size, ParameterizedSize):
ref_size = ref_axis.size.get_size(n)
elif isinstance(ref_axis.size, DataDependentSize):
raise ValueError(
"Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
)
elif isinstance(ref_axis.size, SizeReference):
raise ValueError(
"Reference axis referenced in `SizeReference` may not be sized by a"
+ " `SizeReference` itself."
)
else:
assert_never(ref_axis.size)
if ref_size is None:
if isinstance(ref_axis.size, (int, float)):
ref_size = ref_axis.size
elif isinstance(ref_axis.size, ParameterizedSize):
ref_size = ref_axis.size.get_size(n)
elif isinstance(ref_axis.size, DataDependentSize):
raise ValueError(
"Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
)
elif isinstance(ref_axis.size, SizeReference):
raise ValueError(
"Reference axis referenced in `SizeReference` may not be sized by a"
+ " `SizeReference` itself."
)
else:
assert_never(ref_axis.size)

return int(ref_size * ref_axis.scale / axis.scale + self.offset)

Expand Down Expand Up @@ -2032,7 +2043,8 @@ class ModelDescr(GenericModelDescrBase, title="bioimage.io model specification")
"""Specialized resource type 'model'"""

id: Optional[ModelId] = None
"""Model zoo (bioimage.io) wide, unique identifier (assigned by bioimage.io)"""
"""bioimage.io-wide unique resource identifier
assigned by bioimage.io; version **un**specific."""

authors: NotEmpty[List[Author]]
"""The authors are the creators of the model RDF and the primary points of contact."""
Expand Down Expand Up @@ -2471,8 +2483,40 @@ def get_tensor_sizes(
)

def get_axis_sizes(
self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
self,
ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
batch_size: Optional[int] = None,
*,
max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
) -> _AxisSizes:
"""Determine input and output block shape for scale factors **ns**
of parameterized input sizes.
Args:
ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
that is parameterized as `size = min + n * step`.
batch_size: The desired size of the batch dimension.
If given **batch_size** overwrites any batch size present in
**max_input_shape**. Default 1.
max_input_shape: Limits the derived block shapes.
Each axis for which the input size, parameterized by `n`, is larger
than **max_input_shape** is set to the minimal value `n_min` for which
this is still true.
Use this for small input samples or large values of **ns**.
Or simply whenever you know the full input shape.
Returns:
Resolved axis sizes for model inputs and outputs.
"""
max_input_shape = max_input_shape or {}
if batch_size is None:
for (_t_id, a_id), s in max_input_shape.items():
if a_id == BATCH_AXIS_ID:
batch_size = s
break
else:
batch_size = 1

all_axes = {
t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
}
Expand All @@ -2483,16 +2527,19 @@ def get_axis_sizes(
def get_axis_size(a: Union[InputAxis, OutputAxis]):
if isinstance(a, BatchAxis):
if (t_descr.id, a.id) in ns:
raise ValueError(
"No size increment factor (n) for batch axis of tensor"
+ f" '{t_descr.id}' expected."
logger.warning(
"Ignoring unexpected size increment factor (n) for batch axis"
+ " of tensor '{}'.",
t_descr.id,
)
return batch_size
elif isinstance(a.size, int):
if (t_descr.id, a.id) in ns:
raise ValueError(
"No size increment factor (n) for fixed size axis"
+ f" '{a.id}' of tensor '{t_descr.id}' expected."
logger.warning(
"Ignoring unexpected size increment factor (n) for fixed size"
+ " axis '{}' of tensor '{}'.",
a.id,
t_descr.id,
)
return a.size
elif isinstance(a.size, ParameterizedSize):
Expand All @@ -2501,39 +2548,65 @@ def get_axis_size(a: Union[InputAxis, OutputAxis]):
"Size increment factor (n) missing for parametrized axis"
+ f" '{a.id}' of tensor '{t_descr.id}'."
)
return a.size.get_size(ns[(t_descr.id, a.id)])
n = ns[(t_descr.id, a.id)]
s_max = max_input_shape.get((t_descr.id, a.id))
if s_max is not None:
n = min(n, a.size.get_n(s_max))

return a.size.get_size(n)

elif isinstance(a.size, SizeReference):
if (t_descr.id, a.id) in ns:
raise ValueError(
f"No size increment factor (n) for axis '{a.id}' of tensor"
+ f" '{t_descr.id}' with size reference expected."
logger.warning(
"Ignoring unexpected size increment factor (n) for axis '{}'"
+ " of tensor '{}' with size reference.",
a.id,
t_descr.id,
)
assert not isinstance(a, BatchAxis)
ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
assert not isinstance(ref_axis, BatchAxis)
ref_key = (a.size.tensor_id, a.size.axis_id)
ref_size = inputs.get(ref_key, outputs.get(ref_key))
assert ref_size is not None, ref_key
assert not isinstance(ref_size, _DataDepSize), ref_key
return a.size.get_size(
axis=a,
ref_axis=ref_axis,
n=ns.get((a.size.tensor_id, a.size.axis_id), 0),
ref_size=ref_size,
)
elif isinstance(a.size, DataDependentSize):
if (t_descr.id, a.id) in ns:
raise ValueError(
"No size increment factor (n) for data dependent size axis"
+ f" '{a.id}' of tensor '{t_descr.id}' expected."
logger.warning(
"Ignoring unexpected increment factor (n) for data dependent"
+ " size axis '{}' of tensor '{}'.",
a.id,
t_descr.id,
)
return _DataDepSize(a.size.min, a.size.max)
else:
assert_never(a.size)

# first resolve all , but the `SizeReference` input sizes
for t_descr in self.inputs:
for a in t_descr.axes:
s = get_axis_size(a)
assert not isinstance(s, _DataDepSize)
inputs[t_descr.id, a.id] = s
if not isinstance(a.size, SizeReference):
s = get_axis_size(a)
assert not isinstance(s, _DataDepSize)
inputs[t_descr.id, a.id] = s

# resolve all other input axis sizes
for t_descr in self.inputs:
for a in t_descr.axes:
if isinstance(a.size, SizeReference):
s = get_axis_size(a)
assert not isinstance(s, _DataDepSize)
inputs[t_descr.id, a.id] = s

for t_descr in chain(self.inputs, self.outputs):
# resolve all output axis sizes
for t_descr in self.outputs:
for a in t_descr.axes:
assert not isinstance(a.size, ParameterizedSize)
s = get_axis_size(a)
outputs[t_descr.id, a.id] = s

Expand Down
Loading

0 comments on commit 5f44013

Please sign in to comment.