Skip to content

Commit

Permalink
♻️ Use input_resolutions instead of resolution
Browse files Browse the repository at this point in the history
- Use `input_resolutions` instead of resolution to make engines outputs compatible with ioconfig.
- Uses input resolution as a list of dictionaries on units and resolution.
  • Loading branch information
shaneahmed committed Mar 7, 2025
1 parent a643ea6 commit 7eed649
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 106 deletions.
18 changes: 6 additions & 12 deletions tests/engines/test_engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,17 +521,15 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)

kwargs = {
"patch_input_shape": [512, 512],
"resolution": 1.75,
"units": "mpp",
"input_resolutions": [{"units": "mpp", "resolution": 1.75}],
}
with caplog.at_level(logging.WARNING):
eng.run(
np.zeros((10, 224, 224, 3)),
patch_mode=True,
save_dir=tmp_path / "dump",
patch_input_shape=kwargs["patch_input_shape"],
input_resolutions=kwargs["resolution"],
units=kwargs["units"],
input_resolutions=kwargs["input_resolutions"],
)
assert "provide a valid ModelIOConfigABC" in caplog.text
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
Expand Down Expand Up @@ -570,8 +568,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
patch_input_shape=(300, 300),
stride_shape=(300, 300),
input_resolutions=1.99,
units="baseline",
input_resolutions=[{"units": "baseline", "resolution": 1.99}],
patch_mode=True,
save_dir=f"{tmp_path}/dump",
)
Expand All @@ -586,7 +583,6 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
patch_input_shape=(300, 300),
stride_shape=(300, 300),
input_resolutions=None,
units=None,
patch_mode=True,
save_dir=f"{tmp_path}/dump",
)
Expand All @@ -599,8 +595,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
ioconfig=None,
patch_input_shape=(300, 300),
stride_shape=(300, 300),
input_resolutions=1.99,
units="baseline",
input_resolutions=[{"units": "baseline", "resolution": 1.99}],
)

assert _ioconfig.patch_input_shape == (300, 300)
Expand All @@ -614,12 +609,11 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
with pytest.raises(
ValueError,
match=r".*Must provide either `ioconfig` or "
r"`patch_input_shape`, `resolution`, and `units`*",
r"`patch_input_shape` and `input_resolutions`*",
):
eng._update_ioconfig(
ioconfig=None,
patch_input_shape=_kwargs["patch_input_shape"],
stride_shape=(1, 1),
input_resolutions=_kwargs["resolution"],
units=_kwargs["units"],
input_resolutions=_kwargs["input_resolutions"],
)
16 changes: 6 additions & 10 deletions tests/engines/test_patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
predictor = PatchPredictor(model=model, weights=None)
kwargs = {
"patch_input_shape": [512, 512],
"resolution": 1.75,
"units": "mpp",
"input_resolutions": [{"units": "mpp", "resolution": 1.75}],
}

# test providing config / full input info for default models without weights
Expand Down Expand Up @@ -134,7 +133,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:

predictor.run(
images=[mini_wsi_svs],
input_resolutions=1.99,
input_resolutions=[{"units": "mpp", "resolution": 1.99}],
patch_mode=False,
save_dir=f"{tmp_path}/dump",
)
Expand All @@ -143,7 +142,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:

predictor.run(
images=[mini_wsi_svs],
units="baseline",
input_resolutions=[{"units": "baseline", "resolution": 1.0}],
patch_mode=False,
save_dir=f"{tmp_path}/dump",
)
Expand All @@ -152,8 +151,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:

predictor.run(
images=[mini_wsi_svs],
units="level",
input_resolutions=0,
input_resolutions=[{"units": "level", "resolution": 0}],
patch_mode=False,
save_dir=f"{tmp_path}/dump",
)
Expand All @@ -163,8 +161,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:

predictor.run(
images=[mini_wsi_svs],
units="power",
input_resolutions=20,
input_resolutions=[{"units": "power", "resolution": 20}],
patch_mode=False,
save_dir=f"{tmp_path}/dump",
)
Expand Down Expand Up @@ -262,8 +259,7 @@ def test_wsi_predictor_api(
kwargs = {
"patch_input_shape": patch_size,
"stride_shape": patch_size,
"resolution": 1.0,
"units": "baseline",
"input_resolutions": [{"units": "baseline", "resolution": 1.0}],
"save_dir": save_dir,
}
# ! add this test back once the read at `baseline` is fixed
Expand Down
3 changes: 1 addition & 2 deletions tiatoolbox/cli/patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ def patch_predictor(
images=files_all,
masks=masks_all,
patch_mode=patch_mode,
resolution=resolution,
units=units,
input_resolutions=[{"units": units, "resolution": resolution}],
device=device,
save_dir=output_path,
output_type=output_type,
Expand Down
10 changes: 7 additions & 3 deletions tiatoolbox/models/architecture/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING

import numpy as np
import torch
from torch import nn

from tiatoolbox import logger

if TYPE_CHECKING: # pragma: no cover
from tiatoolbox.models.models_abc import ModelABC


def is_torch_compile_compatible() -> bool:
"""Check if the current GPU is compatible with torch-compile.
Expand Down Expand Up @@ -45,10 +49,10 @@ def is_torch_compile_compatible() -> bool:


def compile_model(
model: nn.Module | None = None,
model: nn.Module | ModelABC | None = None,
*,
mode: str = "default",
) -> nn.Module:
) -> torch.nn.Module | ModelABC:
"""A decorator to compile a model using torch-compile.
Args:
Expand All @@ -67,7 +71,7 @@ def compile_model(
CUDA graphs
Returns:
torch.nn.Module:
torch.nn.Module or ModelABC:
Compiled model.
"""
Expand Down
83 changes: 34 additions & 49 deletions tiatoolbox/models/engine/engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,23 +132,21 @@ class EngineABCRunParams(TypedDict, total=False):
Shape of patches input to the model as tuple of height and width (HW).
Patches are requested at read resolution, not with respect to level 0,
and must be positive.
input_resolutions (Resolution):
Resolution used for reading the image. Please see
input_resolutions (list(dict(Units, Resolution))):
List of Python dictionaries with units and resolution for each
input head for model inference for reading the image. Supported
units are `level`, `power` and `mpp`. Keys should be "units" and
"resolution" e.g., [{"units": "mpp", "resolution": 0.25}]. Please see
:class:`WSIReader` for details.
scale_factor (tuple[float, float]):
The scale factor to use when loading the
annotations. All coordinates will be multiplied by this factor to allow
conversion of annotations saved at non-baseline resolution to baseline.
Should be model_mpp/slide_mpp.
The scale factor to use when loading the annotations. All coordinates
will be multiplied by this factor to allow conversion of annotations
saved at non-baseline resolution to baseline. Should be model_mpp/slide_mpp.
stride_shape (tuple):
Stride used during WSI processing. Stride is
at requested read resolution, not with respect to
level 0, and must be positive. If not provided,
`stride_shape=patch_input_shape`.
units (Units):
Units of resolution used for reading the image. Choose
from either `level`, `power` or `mpp`. Please see
:class:`WSIReader` for details.
verbose (bool):
Whether to output logging information.
Expand All @@ -164,11 +162,10 @@ class EngineABCRunParams(TypedDict, total=False):
num_post_proc_workers: int
output_file: str
patch_input_shape: IntPair
input_resolutions: Resolution
input_resolutions: list[dict[Units, Resolution]]
return_labels: bool
scale_factor: tuple[float, float]
stride_shape: IntPair
units: Units
verbose: bool


Expand Down Expand Up @@ -242,13 +239,12 @@ class EngineABC(ABC): # noqa: B024
Runtime ioconfig.
return_labels (bool):
Whether to return the labels with the predictions.
input_resolutions (Resolution):
Resolution used for reading the image. Please see
:obj:`WSIReader` for details.
units (Units):
Units of resolution used for reading the image. Choose
from either `level`, `power` or `mpp`. Please see
:obj:`WSIReader` for details.
input_resolutions (list(dict(Units, Resolution))):
List of Python dictionaries with units and resolution for each
input head for model inference for reading the image. Supported
units are `level`, `power` and `mpp`. Keys should be "units" and
"resolution" e.g., [{"units": "mpp", "resolution": 0.25}]. Please see
:class:`WSIReader` for details.
patch_input_shape (tuple):
Shape of patches input to the model as tupled of HW. Patches are at
requested read resolution, not with respect to level 0,
Expand Down Expand Up @@ -283,20 +279,14 @@ class EngineABC(ABC): # noqa: B024
Number of workers to postprocess the results of the model.
return_labels (bool):
Whether to return the output labels. Default value is False.
input_resolutions (Resolution):
Resolution used for reading the image. Please see
:class:`WSIReader` for details.
When `patch_mode` is True, the input image patches are expected to be at
the correct resolution and units. When `patch_mode` is False, the patches
are extracted at the requested resolution and units. Default value is 1.0.
units (Units):
Units of resolution used for reading the image. Choose
from either `baseline`, `level`, `power` or `mpp`. Please see
:class:`WSIReader` for details.
When `patch_mode` is True, the input image patches are expected to be at
the correct resolution and units. When `patch_mode` is False, the patches
are extracted at the requested resolution and units.
Default value is `baseline`.
input_resolutions (list(dict(Units, Resolution))):
List of Python dictionaries with units and resolution for each
input head for model inference for reading the image. Supported
units are `level`, `power` and `mpp`. When `patch_mode` is `True`,
the input image patches are expected to be at the correct resolution and
units. When `patch_mode` is `False`, the patches are extracted at the
requested resolution and units. Default value is [{"units": "baseline",
"resolution": 1.0}].
verbose (bool):
Whether to output logging information. Default value is False.
Expand Down Expand Up @@ -371,10 +361,9 @@ def __init__(
self.num_loader_workers = num_loader_workers
self.num_post_proc_workers = num_post_proc_workers
self.patch_input_shape: IntPair | None = None
self.input_resolutions: Resolution | None = None
self.input_resolutions: list[dict[Units, Resolution]] | None = None
self.return_labels: bool = False
self.stride_shape: IntPair | None = None
self.units: Units | None = None
self.verbose = verbose

@staticmethod
Expand Down Expand Up @@ -791,8 +780,7 @@ def _update_ioconfig(
ioconfig: ModelIOConfigABC,
patch_input_shape: IntPair,
stride_shape: IntPair,
input_resolutions: Resolution,
units: Units,
input_resolutions: list[dict[Units, Resolution]],
) -> ModelIOConfigABC:
"""Update IOConfig.
Expand All @@ -808,11 +796,12 @@ def _update_ioconfig(
at requested read resolution, not with respect to
level 0, and must be positive. If not provided,
`stride_shape=patch_input_shape`.
input_resolutions (Resolution):
Resolution used for reading the image. Please see
:obj:`WSIReader` for details.
units (Units):
Units of resolution used for reading the image.
input_resolutions (list(dict(Units, Resolution))):
List of Python dictionaries with units and resolution for each
input head for model inference for reading the image. Supported
units are `level`, `power` and `mpp`. Keys should be "units" and
"resolution" e.g., [{"units": "mpp", "resolution": 0.25}]. Please see
:class:`WSIReader` for details.
Returns:
Updated Patch Predictor IO configuration.
Expand All @@ -821,15 +810,14 @@ def _update_ioconfig(
config_flag = (
patch_input_shape is None,
input_resolutions is None,
units is None,
)
if isinstance(ioconfig, ModelIOConfigABC):
return ioconfig

if self.ioconfig is None and any(config_flag):
msg = (
"Must provide either "
"`ioconfig` or `patch_input_shape`, `resolution`, and `units`."
"`ioconfig` or `patch_input_shape` and `input_resolutions`."
)
raise ValueError(
msg,
Expand All @@ -846,14 +834,12 @@ def _update_ioconfig(
if stride_shape is not None:
ioconfig.stride_shape = stride_shape
if input_resolutions is not None:
ioconfig.input_resolutions[0]["resolution"] = input_resolutions
if units is not None:
ioconfig.input_resolutions[0]["units"] = units
ioconfig.input_resolutions = input_resolutions

return ioconfig

return ModelIOConfigABC(
input_resolutions=[{"resolution": input_resolutions, "units": units}],
input_resolutions=input_resolutions,
patch_input_shape=patch_input_shape,
stride_shape=stride_shape,
output_resolutions=[],
Expand Down Expand Up @@ -956,7 +942,6 @@ def _update_run_params(
self.patch_input_shape,
self.stride_shape,
self.input_resolutions,
self.units,
)

return prepare_engines_save_dir(
Expand Down
8 changes: 5 additions & 3 deletions tiatoolbox/models/engine/io_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np

if TYPE_CHECKING: # pragma: no cover
from tiatoolbox.type_hints import Units
from tiatoolbox.type_hints import Resolution, Units


@dataclass
Expand Down Expand Up @@ -107,7 +107,9 @@ def _validate(self: ModelIOConfigABC) -> None:
raise ValueError(msg)

@staticmethod
def scale_to_highest(resolutions: list[dict], units: Units) -> np.array:
def scale_to_highest(
resolutions: list[dict[Units, Resolution]], units: Units
) -> np.array:
"""Get the scaling factor from input resolutions.
This will convert resolutions to a scaling factor with respect to
Expand All @@ -117,7 +119,7 @@ def scale_to_highest(resolutions: list[dict], units: Units) -> np.array:
and will be scaled for low resolution requirements using interpolation.
Args:
resolutions (list):
resolutions (list(dict(Units, Resolution))):
A list of resolutions where one is defined as
`{'resolution': value, 'unit': value}`
units (Units):
Expand Down
Loading

0 comments on commit 7eed649

Please sign in to comment.