diff --git a/tests/conftest.py b/tests/conftest.py index 31929f3e..2d840f59 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -115,6 +115,16 @@ def bioimageio_dummy_model_filepath(data_path, tmpdir): @pytest.fixture def bioimageio_dummy_model_bytes(data_path): rdf_source = data_path / TEST_BIOIMAGEIO_DUMMY / "Dummy.model.yaml" + return _bioimageio_package(rdf_source) + + +@pytest.fixture +def bioimageio_dummy_param_model_bytes(data_path): + rdf_source = data_path / "dummy_param" / "Dummy.model_param.yaml" + return _bioimageio_package(rdf_source) + + +def _bioimageio_package(rdf_source): data = io.BytesIO() export_resource_package(rdf_source, output_path=data) return data diff --git a/tests/test_rpc/test_mp.py b/tests/test_rpc/test_mp.py index 981f4553..ec539cac 100644 --- a/tests/test_rpc/test_mp.py +++ b/tests/test_rpc/test_mp.py @@ -64,7 +64,7 @@ def client(log_queue): p = mp.Process(target=_srv, args=(parent, log_queue)) p.start() - client = create_client(ITestApi, child, timeout=10) + client = create_client(iface_cls=ITestApi, conn=child, timeout=10) yield client @@ -108,7 +108,7 @@ def __getattr__(self, name): p = mp.Process(target=_srv, args=(parent, log_queue)) p.start() - client = create_client(ITestApi, SlowConn(child)) + client = create_client(iface_cls=ITestApi, conn=SlowConn(child)) client.fast_compute(2, 2) @@ -121,7 +121,7 @@ def test_future_timeout(client: ITestApi, log_queue): p = mp.Process(target=_srv, args=(parent, log_queue)) p.start() - client = create_client(ITestApi, child, timeout=0.001) + client = create_client(iface_cls=ITestApi, conn=child, timeout=0.001) with pytest.raises(TimeoutError): client.compute(1, 2) @@ -256,7 +256,7 @@ def _spawn(iface_cls, srv_cls): p = mp.Process(target=_run_srv, args=(srv_cls, parent, log_queue)) p.start() - data["client"] = client = create_client(iface_cls, child) + data["client"] = client = create_client(iface_cls=iface_cls, conn=child) data["process"] = p return client diff --git a/tests/test_server/test_grpc/test_inference_servicer.py b/tests/test_server/test_grpc/test_inference_servicer.py index cbf64477..c45d9f3d 100644 --- a/tests/test_server/test_grpc/test_inference_servicer.py +++ b/tests/test_server/test_grpc/test_inference_servicer.py @@ -156,25 +156,56 @@ def test_call_fails_with_unknown_model_session_id(self, grpc_stub): def test_call_predict(self, grpc_stub, bioimageio_dummy_model_bytes): model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_model_bytes)) - arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y")) + arr = xr.DataArray(np.arange(128 * 128).reshape(1, 1, 128, 128), dims=("b", "c", "x", "y")) expected = arr + 1 - input_tensors = [converters.xarray_to_pb_tensor("input", arr)] + input_spec_id = "input" + output_spec_id = "output" + input_tensors = [converters.xarray_to_pb_tensor(input_spec_id, arr)] res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) grpc_stub.CloseModelSession(model) assert len(res.tensors) == 1 + assert res.tensors[0].specId == output_spec_id assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0])) + def test_call_predict_invalid_shape_explicit(self, grpc_stub, bioimageio_dummy_model_bytes): + model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_model_bytes)) + arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y")) + input_tensors = [converters.xarray_to_pb_tensor("input", arr)] + with pytest.raises(grpc.RpcError): + grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + grpc_stub.CloseModelSession(model) + + @pytest.mark.parametrize("shape", [(1, 1, 64, 32), (1, 1, 32, 64), (1, 1, 64, 32), (0, 1, 64, 64), (1, 0, 64, 64)]) + def test_call_predict_invalid_shape_parameterized(self, grpc_stub, shape, bioimageio_dummy_param_model_bytes): + model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_param_model_bytes)) + arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("b", "c", "x", "y")) + input_tensors = [converters.xarray_to_pb_tensor("param", arr)] + with pytest.raises(grpc.RpcError): + grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + grpc_stub.CloseModelSession(model) + + @pytest.mark.parametrize("shape", [(1, 1, 64, 64), (1, 1, 66, 65), (1, 1, 68, 66), (1, 1, 70, 67)]) + def test_call_predict_valid_shape_parameterized(self, grpc_stub, shape, bioimageio_dummy_param_model_bytes): + model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_param_model_bytes)) + arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("b", "c", "x", "y")) + input_tensors = [converters.xarray_to_pb_tensor("param", arr)] + grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + grpc_stub.CloseModelSession(model) + @pytest.mark.skip def test_call_predict_tf(self, grpc_stub, bioimageio_dummy_tensorflow_model_bytes): model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_tensorflow_model_bytes)) arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y")) expected = arr * -1 - input_tensors = [converters.xarray_to_pb_tensor(arr)] + input_spec_id = "input" + output_spec_id = "output" + input_tensors = [converters.xarray_to_pb_tensor(input_spec_id, arr)] res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) grpc_stub.CloseModelSession(model) assert len(res.tensors) == 1 + assert res.tensors[0].specId == output_spec_id assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0])) diff --git a/tiktorch/rpc/mp.py b/tiktorch/rpc/mp.py index bab3ab1f..a51cf924 100644 --- a/tiktorch/rpc/mp.py +++ b/tiktorch/rpc/mp.py @@ -5,7 +5,7 @@ from functools import wraps from multiprocessing.connection import Connection from threading import Event, Thread -from typing import Any, Optional, Type, TypeVar +from typing import Any, Dict, Optional, Type, TypeVar from uuid import uuid4 from .exceptions import Shutdown @@ -72,9 +72,8 @@ def __call__(self, *args, **kwargs) -> Any: return self._client._invoke(self._method_name, *args, **kwargs) -def create_client(iface_cls: Type[T], conn: Connection, timeout=None) -> T: +def create_client(iface_cls: Type[T], conn: Connection, api_kwargs: Optional[Dict[str, any]] = None, timeout=None) -> T: client = MPClient(iface_cls.__name__, conn, timeout) - get_exposed_methods(iface_cls) def _make_method(method): class MethodWrapper: @@ -98,12 +97,15 @@ def __call__(self, *args, **kwargs) -> Any: return MethodWrapper() class _Client(iface_cls): - pass + def __init__(self, kwargs: Optional[Dict]): + kwargs = kwargs or {} + super().__init__(**kwargs) - for method_name, method in get_exposed_methods(iface_cls).items(): + exposed_methods = get_exposed_methods(iface_cls) + for method_name, method in exposed_methods.items(): setattr(_Client, method_name, _make_method(method)) - return _Client() + return _Client(api_kwargs) class MPClient: @@ -190,7 +192,7 @@ def _shutdown(self, exc): class Message: def __init__(self, id_): - self.id = id + self.id = id_ class Signal: @@ -200,20 +202,19 @@ def __init__(self, payload): class MethodCall(Message): def __init__(self, id_, method_name, args, kwargs): - self.id = id_ + super().__init__(id_) self.method_name = method_name self.args = args self.kwargs = kwargs class Cancellation(Message): - def __init__(self, id_): - self.id = id_ + pass class MethodReturn(Message): def __init__(self, id_, result: Result): - self.id = id_ + super().__init__(id_) self.result = result diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py index db0f2c17..49e670ce 100644 --- a/tiktorch/server/grpc/inference_servicer.py +++ b/tiktorch/server/grpc/inference_servicer.py @@ -36,7 +36,7 @@ def CreateModelSession( lease.terminate() raise - session = self.__session_manager.create_session() + session = self.__session_manager.create_session(client) session.on_close(lease.terminate) session.on_close(client.shutdown) diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index 2147df46..408c94a4 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -1,52 +1,117 @@ import multiprocessing as _mp -import os import pathlib import tempfile import uuid from concurrent.futures import Future from multiprocessing.connection import Connection -from typing import List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Set, Tuple -import numpy +import numpy as np from bioimageio.core import load_resource_description from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline +from bioimageio.core.resource_io import nodes +from bioimageio.core.resource_io.nodes import ParametrizedInputShape from tiktorch import log from tiktorch.rpc import Shutdown from tiktorch.rpc import mp as _mp_rpc from tiktorch.rpc.mp import MPServer +from ...converters import Tensor from .backend import base from .rpc_interface import IRPCModelSession -class ModelSessionProcess(IRPCModelSession): - def __init__(self, model_zip: bytes, devices: List[str]) -> None: - _tmp_file = tempfile.NamedTemporaryFile(suffix=".zip", delete=False) - _tmp_file.write(model_zip) - _tmp_file.close() - model = load_resource_description(pathlib.Path(_tmp_file.name)) - os.unlink(_tmp_file.name) - self._model: PredictionPipeline = create_prediction_pipeline(bioimageio_model=model, devices=devices) +class ModelSessionProcess(IRPCModelSession[PredictionPipeline]): + def __init__(self, model: PredictionPipeline) -> None: + super().__init__(model) self._datasets = {} self._worker = base.SessionBackend(self._model) - def forward(self, input_tensors: numpy.ndarray) -> Future: - res = self._worker.forward(input_tensors) + def forward(self, input_tensors: Set[Tensor]) -> Future: + for tensor in input_tensors: + axes_wih_size = self._get_axes_with_size(tensor.data.dims, tensor.data.shape) + self.check_shape(tensor.spec_id, axes_wih_size) + tensors_data = [tensor.data for tensor in input_tensors] + res = self._worker.forward(tensors_data) return res + def _get_input_spec(self, spec_id: str) -> nodes.InputTensor: + self._check_spec_exists(spec_id) + specs = [spec for spec in self._model.input_specs if spec.name == spec_id] + assert len(specs) == 1, "ids of tensor specs should be unique" + return specs[0] + def create_dataset(self, mean, stddev): id_ = uuid.uuid4().hex self._datasets[id_] = {"mean": mean, "stddev": stddev} return id_ + def check_shape(self, spec_id: str, shape: Dict[str, int]): + spec = self._get_input_spec(spec_id) + if isinstance(spec.shape, list): + self._check_shape_explicit(spec, shape) + elif isinstance(spec.shape, ParametrizedInputShape): + self._check_shape_parameterized(spec, shape) + else: + raise ValueError(f"Unexpected shape {spec.shape}") + + def _check_spec_exists(self, spec_id: str): + spec_names = [spec.name for spec in self._model.input_specs] + if spec_id not in spec_names: + raise ValueError(f"Spec {spec_id} doesn't exist for specs {spec_names}") + + def _check_shape_explicit(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]): + assert self._is_shape_explicit(spec) + reference_shape = {name: size for name, size in zip(spec.axes, spec.shape)} + if reference_shape != tensor_shape: + raise ValueError(f"Incompatible shapes found {tensor_shape}, expected {reference_shape}") + + def _check_shape_parameterized(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]): + assert isinstance(spec.shape, ParametrizedInputShape) + if not self._is_shape(tensor_shape.values()): + raise ValueError(f"Invalid shape's sizes {tensor_shape}") + + min_shape = self._get_axes_with_size(spec.axes, tuple(spec.shape.min)) + step = self._get_axes_with_size(spec.axes, tuple(spec.shape.step)) + assert min_shape.keys() == step.keys() + if tensor_shape.keys() != min_shape.keys(): + raise ValueError(f"Incompatible axes for tensor {tensor_shape} and spec {spec}") + + tensor_shapes_arr = np.array(list(tensor_shape.values())) + min_shape_arr = np.array(list(min_shape.values())) + step_arr = np.array(list(step.values())) + diff = tensor_shapes_arr - min_shape_arr + if any(size < 0 for size in diff): + raise ValueError(f"Tensor shape {tensor_shape} smaller than min shape {min_shape}") + + non_zero_idx = np.nonzero(step_arr) + multipliers = diff[non_zero_idx] / step_arr[non_zero_idx] + multiplier = np.unique(multipliers) + if len(multiplier) == 1 and self._is_natural_number(multiplier[0]): + return + raise ValueError(f"Tensor shape {tensor_shape} not valid for spec {spec}") + + def _is_natural_number(self, n) -> bool: + return np.floor(n) == np.ceil(n) and n >= 0 + + def _is_shape(self, shape: Iterator[int]) -> bool: + return all(self._is_natural_number(dim) for dim in shape) + + def _get_axes_with_size(self, axes: Tuple[str, ...], shape: Tuple[int, ...]) -> Dict[str, int]: + assert len(axes) == len(shape) + return {name: size for name, size in zip(axes, shape)} + + def _is_shape_explicit(self, spec: nodes.InputTensor) -> bool: + return isinstance(spec.shape, list) + def shutdown(self) -> Shutdown: self._worker.shutdown() return Shutdown() def _run_model_session_process( - conn: Connection, model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None + conn: Connection, prediction_pipeline: PredictionPipeline, log_queue: Optional[_mp.Queue] = None ): try: # from: https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667 @@ -60,7 +125,7 @@ def _run_model_session_process( if log_queue: log.configure(log_queue) - session_proc = ModelSessionProcess(model_zip, devices) + session_proc = ModelSessionProcess(prediction_pipeline) srv = MPServer(session_proc, conn) srv.listen() @@ -69,10 +134,26 @@ def start_model_session_process( model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None ) -> Tuple[_mp.Process, IRPCModelSession]: client_conn, server_conn = _mp.Pipe() + prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_zip, devices) proc = _mp.Process( target=_run_model_session_process, name="ModelSessionProcess", - kwargs={"conn": server_conn, "devices": devices, "log_queue": log_queue, "model_zip": model_zip}, + kwargs={ + "conn": server_conn, + "log_queue": log_queue, + "prediction_pipeline": prediction_pipeline, + }, ) proc.start() - return proc, _mp_rpc.create_client(IRPCModelSession, client_conn) + # here create the prediction pipeline, share it to the model session class and the client + return proc, _mp_rpc.create_client( + iface_cls=IRPCModelSession, api_kwargs={"model": prediction_pipeline}, conn=client_conn + ) + + +def _get_prediction_pipeline_from_model_bytes(model_zip: bytes, devices: List[str]) -> PredictionPipeline: + with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as _tmp_file: + _tmp_file.write(model_zip) + temp_file_path = pathlib.Path(_tmp_file.name) + model = load_resource_description(temp_file_path) + return create_prediction_pipeline(bioimageio_model=model, devices=devices) diff --git a/tiktorch/server/session/rpc_interface.py b/tiktorch/server/session/rpc_interface.py index 2e505d18..b75a9a81 100644 --- a/tiktorch/server/session/rpc_interface.py +++ b/tiktorch/server/session/rpc_interface.py @@ -1,11 +1,22 @@ -from typing import List +from typing import Generic, List, Set, TypeVar +from tiktorch.converters import Tensor from tiktorch.rpc import RPCInterface, Shutdown, exposed from tiktorch.tiktypes import TikTensorBatch from tiktorch.types import ModelState +ModelType = TypeVar("ModelType") + + +class IRPCModelSession(RPCInterface, Generic[ModelType]): + def __init__(self, model: ModelType): + super().__init__() + self._model = model + + @property + def model(self): + return self._model -class IRPCModelSession(RPCInterface): @exposed def shutdown(self) -> Shutdown: raise NotImplementedError @@ -43,5 +54,5 @@ def create_dataset_description(self, mean, stddev) -> str: raise NotImplementedError @exposed - def forward(self, input_tensors): + def forward(self, input_tensors: Set[Tensor]): raise NotImplementedError diff --git a/tiktorch/server/session_manager.py b/tiktorch/server/session_manager.py index 37bc07ab..86a8a1ea 100644 --- a/tiktorch/server/session_manager.py +++ b/tiktorch/server/session_manager.py @@ -7,6 +7,8 @@ from typing import Callable, Dict, List, Optional from uuid import uuid4 +from tiktorch.server.session import IRPCModelSession + logger = getLogger(__name__) @@ -24,6 +26,11 @@ def id(self) -> str: """ ... + @property + @abc.abstractmethod + def client(self) -> IRPCModelSession: + ... + @abc.abstractmethod def on_close(self, handler: CloseCallback) -> None: """ @@ -36,9 +43,14 @@ def on_close(self, handler: CloseCallback) -> None: class _Session(ISession): - def __init__(self, id_: str, manager: SessionManager) -> None: + def __init__(self, id_: str, client: IRPCModelSession, manager: SessionManager) -> None: self.__id = id_ self.__manager = manager + self.__client = client + + @property + def client(self) -> IRPCModelSession: + return self.__client @property def id(self) -> str: @@ -53,13 +65,13 @@ class SessionManager: Manages session lifecycle (create/close) """ - def create_session(self) -> ISession: + def create_session(self, client: IRPCModelSession) -> ISession: """ Creates new session with unique id """ with self.__lock: session_id = uuid4().hex - session = _Session(session_id, manager=self) + session = _Session(session_id, client=client, manager=self) self.__session_by_id[session_id] = session logger.info("Created session %s", session.id) return session