Skip to content

Commit

Permalink
✅ Add test to cli.
Browse files Browse the repository at this point in the history
  • Loading branch information
shaneahmed committed Mar 7, 2025
1 parent 7eed649 commit 3dff881
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 15 deletions.
19 changes: 19 additions & 0 deletions tests/engines/test_patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@

import numpy as np
import torch
import yaml
import zarr
from click.testing import CliRunner

from tests.conftest import timed
from tiatoolbox import cli, logger, rcParam
from tiatoolbox.models import IOPatchPredictorConfig
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.architecture.vanilla import CNNModel
from tiatoolbox.models.engine.patch_predictor import PatchPredictor
from tiatoolbox.utils import env_detection as toolbox_env
Expand Down Expand Up @@ -642,6 +644,17 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -
dir_path_masks = tmp_path.joinpath("new_copies_masks")
dir_path_masks.mkdir()

config = {
"input_resolutions": [{"units": "mpp", "resolution": 0.5}],
"patch_input_shape": [224, 224],
}

with Path.open(tmp_path.joinpath("config.yaml"), "w") as fptr:
yaml.dump(config, fptr)

model = "alexnet-kather100k"
weights = fetch_pretrained_weights(model)

try:
dir_path.joinpath("1_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs)
dir_path.joinpath("2_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs)
Expand Down Expand Up @@ -671,6 +684,12 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -
str(False),
"--masks",
str(dir_path_masks),
"--model",
model,
"--weights",
str(weights),
"--yaml-config-path",
tmp_path / "config.yaml",
"--output-path",
str(tmp_path / "output"),
"--output-type",
Expand Down
8 changes: 4 additions & 4 deletions tiatoolbox/cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,17 +619,17 @@ def prepare_model_cli(
tiatoolbox_cli = TIAToolboxCLI()


def prepare_ioconfig_seg(
segment_config_class: type[IOConfigABC],
def prepare_ioconfig(
config_class: type[IOConfigABC],
pretrained_weights: str | Path | None,
yaml_config_path: str | Path,
) -> IOConfigABC | None:
"""Prepare ioconfig for segmentation."""
"""Prepare ioconfig for CLI."""
import yaml

if pretrained_weights is not None:
with Path(yaml_config_path).open() as registry_handle:
ioconfig = yaml.safe_load(registry_handle)
return segment_config_class(**ioconfig)
return config_class(**ioconfig)

return None
4 changes: 2 additions & 2 deletions tiatoolbox/cli/nucleus_instance_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
cli_pretrained_weights,
cli_verbose,
cli_yaml_config_path,
prepare_ioconfig_seg,
prepare_ioconfig,
prepare_model_cli,
tiatoolbox_cli,
)
Expand Down Expand Up @@ -77,7 +77,7 @@ def nucleus_instance_segment(
file_types=file_types,
)

ioconfig = prepare_ioconfig_seg(
ioconfig = prepare_ioconfig(

Check warning on line 80 in tiatoolbox/cli/nucleus_instance_segment.py

View check run for this annotation

Codecov / codecov/patch

tiatoolbox/cli/nucleus_instance_segment.py#L80

Added line #L80 was not covered by tests
IOInstanceSegmentorConfig,
pretrained_weights,
yaml_config_path,
Expand Down
19 changes: 12 additions & 7 deletions tiatoolbox/cli/patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
cli_output_path,
cli_output_type,
cli_patch_mode,
cli_resolution,
cli_return_labels,
cli_return_probabilities,
cli_units,
cli_verbose,
cli_weights,
cli_yaml_config_path,
prepare_ioconfig,
prepare_model_cli,
tiatoolbox_cli,
)
Expand All @@ -37,8 +37,7 @@
@cli_weights()
@cli_device(default="cpu")
@cli_batch_size(default=1)
@cli_resolution(default=0.5)
@cli_units(default="mpp")
@cli_yaml_config_path()
@cli_masks(default=None)
@cli_num_loader_workers(default=0)
@cli_output_type(
Expand All @@ -56,8 +55,7 @@ def patch_predictor(
masks: str | None,
output_path: str,
batch_size: int,
resolution: float,
units: str,
yaml_config_path: str,
num_loader_workers: int,
device: str,
output_type: str,
Expand All @@ -68,6 +66,7 @@ def patch_predictor(
verbose: bool,
) -> None:
"""Process an image/directory of input images with a patch classification CNN."""
from tiatoolbox.models.engine.io_config import IOPatchPredictorConfig
from tiatoolbox.models.engine.patch_predictor import PatchPredictor

files_all, masks_all, output_path = prepare_model_cli(
Expand All @@ -85,11 +84,17 @@ def patch_predictor(
verbose=verbose,
)

ioconfig = prepare_ioconfig(
IOPatchPredictorConfig,
pretrained_weights=weights,
yaml_config_path=yaml_config_path,
)

_ = predictor.run(
images=files_all,
masks=masks_all,
patch_mode=patch_mode,
input_resolutions=[{"units": units, "resolution": resolution}],
ioconfig=ioconfig,
device=device,
save_dir=output_path,
output_type=output_type,
Expand Down
4 changes: 2 additions & 2 deletions tiatoolbox/cli/semantic_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
cli_pretrained_weights,
cli_verbose,
cli_yaml_config_path,
prepare_ioconfig_seg,
prepare_ioconfig,
prepare_model_cli,
tiatoolbox_cli,
)
Expand Down Expand Up @@ -71,7 +71,7 @@ def semantic_segment(
file_types=file_types,
)

ioconfig = prepare_ioconfig_seg(
ioconfig = prepare_ioconfig(

Check warning on line 74 in tiatoolbox/cli/semantic_segment.py

View check run for this annotation

Codecov / codecov/patch

tiatoolbox/cli/semantic_segment.py#L74

Added line #L74 was not covered by tests
IOSegmentorConfig,
pretrained_weights,
yaml_config_path,
Expand Down

0 comments on commit 3dff881

Please sign in to comment.