From ad5b34d8cc6cc488b00fb88f00e7a3974af710e9 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Tue, 3 Sep 2024 15:12:10 +0200 Subject: [PATCH] Add tests for v4 models and parameterized weights - Weights are parameterized for pytorch and torchscript workflows --- tests/conftest.py | 346 ++++++++++++++---- .../test_grpc/test_inference_servicer.py | 15 +- tiktorch/server/session/process.py | 11 +- 3 files changed, 287 insertions(+), 85 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 613cab3e..7d401a3e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import faulthandler import io import logging.handlers @@ -6,6 +8,8 @@ import sys import tempfile import threading +from datetime import datetime +from enum import Enum from os import getenv from pathlib import Path from random import randint @@ -17,6 +21,7 @@ import xarray as xr from bioimageio.core import AxisId from bioimageio.spec import save_bioimageio_package_to_stream +from bioimageio.spec.model import v0_4 from bioimageio.spec.model.v0_5 import ( ArchitectureFromLibraryDescr, Author, @@ -26,22 +31,31 @@ Doi, FileDescr, HttpUrl, + Identifier, InputAxis, InputTensorDescr, LicenseId, ModelDescr, + OutputAxis, OutputTensorDescr, ParameterizedSize, PytorchStateDictWeightsDescr, SizeReference, SpaceInputAxis, SpaceOutputAxis, + TensorId, + TorchscriptWeightsDescr, Version, WeightsDescr, ) from torch import nn +class WeightsFormat(Enum): + PYTORCH = ("pytorch",) + TORCHSCRIPT = "torchscript" + + @pytest.fixture def srv_port(): return getenv("TEST_TIKTORCH_PORT", randint(5500, 8000)) @@ -88,51 +102,38 @@ def assert_threads_cleanup(): pytest.fail("Threads still running:\n\t%s" % "\n\t".join(running_threads)) -@pytest.fixture -def bioimage_model_explicit_siso() -> Tuple[io.BytesIO, xr.DataArray]: - test_tensor = np.arange(1 * 2 * 10 * 10, dtype="float32").reshape(1, 2, 10, 10) - model_descr, expected_output = _bioimage_model_siso( - [ - BatchAxis(), - ChannelAxis(channel_names=["channel1", "channel2"]), - SpaceInputAxis(id="x", size=10), - SpaceInputAxis(id="y", size=10), - ], - test_tensor, - ) - model_bytes = io.BytesIO() - save_bioimageio_package_to_stream(model_descr, output_stream=model_bytes) - return model_bytes, expected_output - - -@pytest.fixture -def bioimage_model_param_siso() -> Tuple[io.BytesIO, xr.DataArray]: - test_tensor = np.arange(1 * 2 * 10 * 20, dtype="float32").reshape(1, 2, 10, 20) - model_descr, expected_output = _bioimage_model_siso( - [ - BatchAxis(), - ChannelAxis(channel_names=["channel1", "channel2"]), - SpaceInputAxis(id="x", size=ParameterizedSize(min=10, step=2)), - SpaceInputAxis(id="y", size=ParameterizedSize(min=20, step=3)), - ], - test_tensor, - ) - model_bytes = io.BytesIO() - save_bioimageio_package_to_stream(model_descr, output_stream=model_bytes) - return model_bytes, expected_output - - -def _bioimage_model_siso(input_axes: List[InputAxis], test_tensor: np.array) -> Tuple[ModelDescr, xr.DataArray]: - """ - Mocked bioimageio prediction pipeline with single input single output - """ - with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as test_tensor_file: - np.save(test_tensor_file.name, test_tensor) - - input_tensor = InputTensorDescr( - id="input", axes=input_axes, description="", test_tensor=FileDescr(source=Path(test_tensor_file.name)) - ) - return _bioimage_model([input_tensor]) +@pytest.fixture(params=[WeightsFormat.PYTORCH, WeightsFormat.TORCHSCRIPT]) +def bioimage_model_explicit_siso(request) -> Tuple[io.BytesIO, xr.DataArray]: + input_axes = [ + BatchAxis(), + ChannelAxis(channel_names=[Identifier("channel1"), Identifier("channel2")]), + SpaceInputAxis(id=AxisId("x"), size=10), + SpaceInputAxis(id=AxisId("y"), size=10), + ] + input_test_tensor = np.arange(1 * 2 * 10 * 10, dtype="float32").reshape(1, 2, 10, 10) + if request.param == WeightsFormat.PYTORCH: + return _bioimage_model_dummy_v5_siso_pytorch(input_axes, input_test_tensor) + elif request.param == WeightsFormat.TORCHSCRIPT: + return _bioimage_model_dummy_v5_siso_torchscript(input_axes, input_test_tensor) + else: + raise NotImplementedError(f"{request.param}") + + +@pytest.fixture(params=[WeightsFormat.PYTORCH, WeightsFormat.TORCHSCRIPT]) +def bioimage_model_param_siso(request) -> Tuple[io.BytesIO, xr.DataArray]: + input_test_tensor = np.arange(1 * 2 * 10 * 20, dtype="float32").reshape(1, 2, 10, 20) + input_axes = [ + BatchAxis(), + ChannelAxis(channel_names=[Identifier("channel1"), Identifier("channel2")]), + SpaceInputAxis(id=AxisId("x"), size=ParameterizedSize(min=10, step=2)), + SpaceInputAxis(id=AxisId("y"), size=ParameterizedSize(min=20, step=3)), + ] + if request.param == WeightsFormat.PYTORCH: + return _bioimage_model_dummy_v5_siso_pytorch(input_axes, input_test_tensor) + elif request.param == WeightsFormat.TORCHSCRIPT: + return _bioimage_model_dummy_v5_siso_torchscript(input_axes, input_test_tensor) + else: + raise NotImplementedError(f"{request.param}") @pytest.fixture @@ -152,22 +153,22 @@ def bioimage_model_miso() -> Tuple[io.BytesIO, xr.DataArray]: np.save(test_tensor3_file.name, test_tensor3) input1 = InputTensorDescr( - id="input1", + id=TensorId("input1"), axes=[ BatchAxis(), - ChannelAxis(channel_names=["channel1", "channel2"]), + ChannelAxis(channel_names=[Identifier("channel1"), Identifier("channel2")]), SpaceInputAxis(id=AxisId("x"), size=10), - SpaceInputAxis(id=AxisId("y"), size=SizeReference(tensor_id="input3", axis_id="y")), + SpaceInputAxis(id=AxisId("y"), size=SizeReference(tensor_id=TensorId("input3"), axis_id=AxisId("y"))), ], description="", test_tensor=FileDescr(source=Path(test_tensor1_file.name)), ) input2 = InputTensorDescr( - id="input2", + id=TensorId("input2"), axes=[ BatchAxis(), - ChannelAxis(channel_names=["channel1", "channel2"]), + ChannelAxis(channel_names=[Identifier("channel1"), Identifier("channel2")]), SpaceInputAxis(id=AxisId("x"), size=ParameterizedSize(min=10, step=2)), SpaceInputAxis(id=AxisId("y"), size=ParameterizedSize(min=10, step=5)), ], @@ -176,46 +177,158 @@ def bioimage_model_miso() -> Tuple[io.BytesIO, xr.DataArray]: ) input3 = InputTensorDescr( - id="input3", + id=TensorId("input3"), axes=[ BatchAxis(), - ChannelAxis(channel_names=["channel1", "channel2"]), - SpaceInputAxis(id="x", size=SizeReference(tensor_id="input2", axis_id="x")), - SpaceInputAxis(id="y", size=10), + ChannelAxis(channel_names=[Identifier("channel1"), Identifier("channel2")]), + SpaceInputAxis(id=AxisId("x"), size=SizeReference(tensor_id=TensorId("input2"), axis_id=AxisId("x"))), + SpaceInputAxis(id=AxisId("y"), size=10), ], description="", test_tensor=FileDescr(source=Path(test_tensor1_file.name)), ) - model_descr, expected_output = _bioimage_model([input1, input2, input3]) - model_bytes = io.BytesIO() - save_bioimageio_package_to_stream(model_descr, output_stream=model_bytes) + dummy_model = _DummyNetwork() + expected_output = _dummy_network_output + with tempfile.NamedTemporaryFile(suffix=".pts", delete=False) as weights_file: + torch.save(dummy_model.state_dict(), weights_file.name) + weights = WeightsDescr( + pytorch_state_dict=PytorchStateDictWeightsDescr( + source=Path(weights_file.name), + architecture=ArchitectureFromLibraryDescr( + import_from="tests.conftest", + callable=Identifier(f"{_DummyNetwork.__name__}"), + ), + pytorch_version=Version("1.1.1"), + ) + ) + + output_test_tensor = np.arange(1 * 2 * 10 * 10, dtype="float32").reshape(1, 2, 10, 10) + output_axes = [ + BatchAxis(), + ChannelAxis(channel_names=[Identifier("channel1"), Identifier("channel2")]), + SpaceOutputAxis(id=AxisId("x"), size=10), + SpaceOutputAxis(id=AxisId("y"), size=10), + ] + + with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as output_test_tensor_file: + np.save(output_test_tensor_file.name, output_test_tensor) + + output_tensor = OutputTensorDescr( + id=TensorId("output"), + axes=output_axes, + description="", + test_tensor=FileDescr(source=Path(output_test_tensor_file.name)), + ) + + model_bytes = _bioimage_model_v5(weights=weights, inputs=[input1, input2, input3], outputs=[output_tensor]) return model_bytes, expected_output -def _bioimage_model(inputs: List[InputTensorDescr]) -> Tuple[ModelDescr, xr.DataArray]: - test_tensor = np.arange(1 * 2 * 10 * 10, dtype="float32").reshape(1, 2, 10, 10) - with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as test_tensor_file: - np.save(test_tensor_file.name, test_tensor) +def _bioimage_model_dummy_v5_siso_torchscript( + input_axes: List[InputAxis], input_test_tensor: np.ndarray +) -> Tuple[io.BytesIO, xr.DataArray]: + dummy_model = _DummyNetwork() + expected_output = _dummy_network_output + traced_model = torch.jit.trace(dummy_model, example_inputs=torch.from_numpy(input_test_tensor)) + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as model_file: + traced_model.save(model_file.name) + weights = WeightsDescr( + torchscript=TorchscriptWeightsDescr(source=Path(model_file.name), pytorch_version=Version("1.1.1")) + ) + + output_test_tensor = np.arange(1 * 2 * 10 * 10, dtype="float32").reshape(1, 2, 10, 10) + output_axes = [ + BatchAxis(), + ChannelAxis(channel_names=[Identifier("channel1"), Identifier("channel2")]), + SpaceOutputAxis(id=AxisId("x"), size=10), + SpaceOutputAxis(id=AxisId("y"), size=10), + ] + + return ( + _bioimage_model_v5_siso( + weights=weights, + input_axes=input_axes, + output_axes=output_axes, + input_test_tensor=input_test_tensor, + output_test_tensor=output_test_tensor, + ), + expected_output, + ) + +def _bioimage_model_dummy_v5_siso_pytorch( + input_axes: List[InputAxis], input_test_tensor: np.ndarray +) -> Tuple[io.BytesIO, xr.DataArray]: dummy_model = _DummyNetwork() + expected_output = _dummy_network_output with tempfile.NamedTemporaryFile(suffix=".pts", delete=False) as weights_file: torch.save(dummy_model.state_dict(), weights_file.name) + weights = WeightsDescr( + pytorch_state_dict=PytorchStateDictWeightsDescr( + source=Path(weights_file.name), + architecture=ArchitectureFromLibraryDescr( + import_from="tests.conftest", + callable=Identifier(f"{_DummyNetwork.__name__}"), + ), + pytorch_version=Version("1.1.1"), + ) + ) + + output_test_tensor = np.arange(1 * 2 * 10 * 10, dtype="float32").reshape(1, 2, 10, 10) + output_axes = [ + BatchAxis(), + ChannelAxis(channel_names=[Identifier("channel1"), Identifier("channel2")]), + SpaceOutputAxis(id=AxisId("x"), size=10), + SpaceOutputAxis(id=AxisId("y"), size=10), + ] + + return ( + _bioimage_model_v5_siso( + weights=weights, + input_axes=input_axes, + output_axes=output_axes, + input_test_tensor=input_test_tensor, + output_test_tensor=output_test_tensor, + ), + expected_output, + ) + + +def _bioimage_model_v5_siso( + weights: WeightsDescr, + input_axes: List[InputAxis], + output_axes: List[OutputAxis], + input_test_tensor: np.ndarray, + output_test_tensor: np.ndarray, +) -> io.BytesIO: + with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as input_test_tensor_file: + np.save(input_test_tensor_file.name, input_test_tensor) + + input_tensor = InputTensorDescr( + id=TensorId("input"), + axes=input_axes, + description="", + test_tensor=FileDescr(source=Path(input_test_tensor_file.name)), + ) + + with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as output_test_tensor_file: + np.save(output_test_tensor_file.name, output_test_tensor) output_tensor = OutputTensorDescr( - id="output", - axes=[ - BatchAxis(), - ChannelAxis(channel_names=["channel1", "channel2"]), - SpaceOutputAxis(id=AxisId("x"), size=10), - SpaceOutputAxis(id=AxisId("y"), size=10), - ], + id=TensorId("output"), + axes=output_axes, description="", - test_tensor=FileDescr(source=Path(test_tensor_file.name)), + test_tensor=FileDescr(source=Path(output_test_tensor_file.name)), ) + return _bioimage_model_v5(weights=weights, inputs=[input_tensor], outputs=[output_tensor]) + +def _bioimage_model_v5( + weights: WeightsDescr, inputs: List[InputTensorDescr], outputs: List[OutputTensorDescr] +) -> io.BytesIO: mocked_descr = ModelDescr( - name="mocked model", + name="mocked v5 model", description="A test model for demonstration purposes only", authors=[Author(name="me", affiliation="my institute", github_user="bioimageiobot")], # change github_user to your GitHub account name @@ -224,18 +337,89 @@ def _bioimage_model(inputs: List[InputTensorDescr]) -> Tuple[ModelDescr, xr.Data documentation=HttpUrl("https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/README.md"), git_repo=HttpUrl("https://github.com/bioimage-io/spec-bioimage-io"), inputs=inputs, + outputs=outputs, + weights=weights, + ) + model_bytes = io.BytesIO() + save_bioimageio_package_to_stream(mocked_descr, output_stream=model_bytes) + return model_bytes + + +@pytest.fixture(params=[WeightsFormat.PYTORCH, WeightsFormat.TORCHSCRIPT]) +def bioimage_model_v4(request) -> Tuple[io.BytesIO, xr.DataArray]: + if request.param == WeightsFormat.PYTORCH: + return _bioimage_model_dummy_v4_siso_pytorch() + elif request.param == WeightsFormat.TORCHSCRIPT: + return _bioimage_model_dummy_v4_siso_torchscript() + else: + raise NotImplementedError(f"{request.param}") + + +def _bioimage_model_dummy_v4_siso_pytorch() -> Tuple[io.BytesIO, xr.DataArray]: + dummy_model = _DummyNetwork() + dummy_model_expected_output = _dummy_network_output + input_test_tensor = np.arange(1 * 2 * 10 * 10, dtype="float32").reshape(1, 2, 10, 10) + output_test_tensor = np.arange(1 * 2 * 10 * 10, dtype="float32").reshape(1, 2, 10, 10) + traced_model = torch.jit.trace(dummy_model, example_inputs=torch.from_numpy(input_test_tensor)) + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as weights_file: + traced_model.save(weights_file.name) + weights = v0_4.WeightsDescr(torchscript=v0_4.TorchscriptWeightsDescr(source=Path(weights_file.name))) + model_bytes = _bioimage_model_v4_siso( + weights=weights, input_test_tensor=input_test_tensor, output_test_tensor=output_test_tensor + ) + return model_bytes, dummy_model_expected_output + + +def _bioimage_model_dummy_v4_siso_torchscript() -> Tuple[io.BytesIO, xr.DataArray]: + dummy_model = _DummyNetwork() + dummy_model_expected_output = _dummy_network_output + input_test_tensor = np.arange(1 * 2 * 10 * 10, dtype="float32").reshape(1, 2, 10, 10) + output_test_tensor = np.arange(1 * 2 * 10 * 10, dtype="float32").reshape(1, 2, 10, 10) + traced_model = torch.jit.trace(dummy_model, example_inputs=torch.from_numpy(input_test_tensor)) + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as model_file: + traced_model.save(model_file.name) + weights = v0_4.WeightsDescr(torchscript=v0_4.TorchscriptWeightsDescr(source=Path(model_file.name))) + model_bytes = _bioimage_model_v4_siso( + weights=weights, input_test_tensor=input_test_tensor, output_test_tensor=output_test_tensor + ) + return model_bytes, dummy_model_expected_output + + +def _bioimage_model_v4_siso( + weights: v0_4.WeightsDescr, input_test_tensor: np.ndarray, output_test_tensor: np.ndarray +) -> io.BytesIO: + input_tensor = v0_4.InputTensorDescr( + name=v0_4.TensorName("input"), description="", axes="bcxy", shape=input_test_tensor.shape, data_type="float32" + ) + + output_tensor = v0_4.OutputTensorDescr( + name=v0_4.TensorName("output"), description="", axes="bcxy", shape=output_test_tensor.shape, data_type="float32" + ) + + with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as output_test_tensor_file: + np.save(output_test_tensor_file.name, output_test_tensor) + + with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as input_test_tensor_file: + np.save(input_test_tensor_file.name, input_test_tensor) + + model_descr = v0_4.ModelDescr( + name="mocked v4 model", + authors=[v0_4.Author(name="me")], + cite=[v0_4.CiteEntry(text="for model training see my paper", url=HttpUrl("https://doi.org/10.1234something"))], + description="", + inputs=[input_tensor], outputs=[output_tensor], - weights=WeightsDescr( - pytorch_state_dict=PytorchStateDictWeightsDescr( - source=weights_file.name, - architecture=ArchitectureFromLibraryDescr( - import_from="tests.conftest", callable=_DummyNetwork.__name__ - ), - pytorch_version=Version("1.1.1"), - ) - ), + documentation=HttpUrl("https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/README.md"), + license="MIT", + test_inputs=[Path(input_test_tensor_file.name)], + test_outputs=[Path(output_test_tensor_file.name)], + timestamp=v0_4.Datetime(root=datetime.now()), + weights=weights, ) - return mocked_descr, _dummy_network_output + + model_bytes = io.BytesIO() + save_bioimageio_package_to_stream(model_descr, output_stream=model_bytes) + return model_bytes _dummy_network_output = xr.DataArray(np.arange(2 * 10 * 10).reshape(1, 2, 10, 10), dims=["batch", "channel", "x", "y"]) @@ -243,4 +427,4 @@ def _bioimage_model(inputs: List[InputTensorDescr]) -> Tuple[ModelDescr, xr.Data class _DummyNetwork(nn.Module): def forward(self, *args): - return _dummy_network_output + return torch.from_numpy(_dummy_network_output.values) diff --git a/tests/test_server/test_grpc/test_inference_servicer.py b/tests/test_server/test_grpc/test_inference_servicer.py index a7974d48..65ce7bf4 100644 --- a/tests/test_server/test_grpc/test_inference_servicer.py +++ b/tests/test_server/test_grpc/test_inference_servicer.py @@ -171,6 +171,19 @@ def test_call_predict_valid_explicit(self, grpc_stub, bioimage_model_explicit_si assert pb_tensor.tensorId == "output" assert_array_equal(pb_tensor_to_xarray(res.tensors[0]), expected_output) + def test_call_predict_valid_explicit_v4(self, grpc_stub, bioimage_model_v4): + model_bytes, expected_output = bioimage_model_v4 + model = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) + arr = xr.DataArray(np.arange(2 * 10 * 10).reshape(1, 2, 10, 10), dims=("batch", "channel", "x", "y")) + input_tensor_id = "input" + input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)] + res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + grpc_stub.CloseModelSession(model) + assert len(res.tensors) == 1 + pb_tensor = res.tensors[0] + assert pb_tensor.tensorId == "output" + assert_array_equal(pb_tensor_to_xarray(res.tensors[0]), expected_output) + def test_call_predict_invalid_shape_explicit(self, grpc_stub, bioimage_model_explicit_siso): model_bytes, expected_output = bioimage_model_explicit_siso model = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) @@ -239,7 +252,7 @@ def test_call_predict_invalid_tensor_ids(self, grpc_stub, bioimage_model_explici input_tensors = [converters.xarray_to_pb_tensor("invalidTensorName", arr)] with pytest.raises(grpc.RpcError) as error: grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) - assert error.value.details().startswith("Exception calling application: Spec invalidTensorName doesn't exist") + assert error.value.details().startswith("Exception calling application: Spec 'invalidTensorName' doesn't exist") grpc_stub.CloseModelSession(model) @pytest.mark.parametrize( diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index 48e13be6..07aa28b8 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -7,7 +7,7 @@ from typing import List, Optional, Tuple, Union from bioimageio.core import PredictionPipeline, Tensor, create_prediction_pipeline -from bioimageio.spec import load_description +from bioimageio.spec import InvalidDescr, load_description from bioimageio.spec.model import v0_5 from bioimageio.spec.model.v0_5 import BatchAxis @@ -75,7 +75,9 @@ def _realize_size_reference(self, size: v0_5.SizeReference) -> Union[int, v0_5.P def _get_spec(self, tensor_id: str) -> v0_5.InputTensorDescr: specs = [spec for spec in self._specs if tensor_id == spec.id] if len(specs) == 0: - raise ValueError(f"Spec {tensor_id} doesn't exist for specs {[spec.id for spec in self._specs]}") + raise ValueError( + f"Spec '{tensor_id}' doesn't exist for specs {','.join([spec.id for spec in self._specs])}" + ) assert len(specs) == 1, "ids of tensor specs should be unique" return specs[0] @@ -148,4 +150,7 @@ def _get_model_descr_from_model_bytes(model_bytes: bytes) -> v0_5.ModelDescr: with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as _tmp_file: _tmp_file.write(model_bytes) temp_file_path = pathlib.Path(_tmp_file.name) - return load_description(temp_file_path) + model_descr = load_description(temp_file_path, format_version="latest") + if isinstance(model_descr, InvalidDescr): + raise ValueError(f"Failed to load valid model descriptor {model_descr.validation_summary}") + return model_descr