diff --git a/tests/test_rpc/test_mp.py b/tests/test_rpc/test_mp.py index ec539cac..2011c54e 100644 --- a/tests/test_rpc/test_mp.py +++ b/tests/test_rpc/test_mp.py @@ -8,7 +8,7 @@ from tiktorch import log from tiktorch.rpc import RPCFuture, RPCInterface, Shutdown, exposed -from tiktorch.rpc.mp import FutureStore, MPServer, create_client +from tiktorch.rpc.mp import FutureStore, MPServer, create_client_api class ITestApi(RPCInterface): @@ -64,7 +64,7 @@ def client(log_queue): p = mp.Process(target=_srv, args=(parent, log_queue)) p.start() - client = create_client(iface_cls=ITestApi, conn=child, timeout=10) + client = create_client_api(iface_cls=ITestApi, conn=child, timeout=10) yield client @@ -108,7 +108,7 @@ def __getattr__(self, name): p = mp.Process(target=_srv, args=(parent, log_queue)) p.start() - client = create_client(iface_cls=ITestApi, conn=SlowConn(child)) + client = create_client_api(iface_cls=ITestApi, conn=SlowConn(child)) client.fast_compute(2, 2) @@ -121,7 +121,7 @@ def test_future_timeout(client: ITestApi, log_queue): p = mp.Process(target=_srv, args=(parent, log_queue)) p.start() - client = create_client(iface_cls=ITestApi, conn=child, timeout=0.001) + client = create_client_api(iface_cls=ITestApi, conn=child, timeout=0.001) with pytest.raises(TimeoutError): client.compute(1, 2) @@ -256,7 +256,7 @@ def _spawn(iface_cls, srv_cls): p = mp.Process(target=_run_srv, args=(srv_cls, parent, log_queue)) p.start() - data["client"] = client = create_client(iface_cls=iface_cls, conn=child) + data["client"] = client = create_client_api(iface_cls=iface_cls, conn=child) data["process"] = p return client diff --git a/tiktorch/rpc/mp.py b/tiktorch/rpc/mp.py index a51cf924..94f3a638 100644 --- a/tiktorch/rpc/mp.py +++ b/tiktorch/rpc/mp.py @@ -1,3 +1,4 @@ +import dataclasses import logging import queue import threading @@ -5,9 +6,11 @@ from functools import wraps from multiprocessing.connection import Connection from threading import Event, Thread -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Any, Generic, List, Optional, Type, TypeVar from uuid import uuid4 +from bioimageio.core.resource_io import nodes + from .exceptions import Shutdown from .interface import get_exposed_methods from .types import RPCFuture, isfutureret @@ -72,7 +75,7 @@ def __call__(self, *args, **kwargs) -> Any: return self._client._invoke(self._method_name, *args, **kwargs) -def create_client(iface_cls: Type[T], conn: Connection, api_kwargs: Optional[Dict[str, any]] = None, timeout=None) -> T: +def create_client_api(iface_cls: Type[T], conn: Connection, timeout=None) -> T: client = MPClient(iface_cls.__name__, conn, timeout) def _make_method(method): @@ -96,16 +99,21 @@ def __call__(self, *args, **kwargs) -> Any: return MethodWrapper() - class _Client(iface_cls): - def __init__(self, kwargs: Optional[Dict]): - kwargs = kwargs or {} - super().__init__(**kwargs) + class _Api: + pass exposed_methods = get_exposed_methods(iface_cls) for method_name, method in exposed_methods.items(): - setattr(_Client, method_name, _make_method(method)) + setattr(_Api, method_name, _make_method(method)) + + return _Api() + - return _Client(api_kwargs) +@dataclasses.dataclass(frozen=True) +class Client(Generic[T]): + api: T + input_specs: List[nodes.InputTensor] + output_specs: List[nodes.OutputTensor] class MPClient: diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py index 9cada0e9..2ac2b632 100644 --- a/tiktorch/server/grpc/inference_servicer.py +++ b/tiktorch/server/grpc/inference_servicer.py @@ -38,7 +38,7 @@ def CreateModelSession( session = self.__session_manager.create_session(client) session.on_close(lease.terminate) - session.on_close(client.shutdown) + session.on_close(client.api.shutdown) return inference_pb2.ModelSession(id=session.id) @@ -46,7 +46,7 @@ def CreateDatasetDescription( self, request: inference_pb2.CreateDatasetDescriptionRequest, context ) -> inference_pb2.DatasetDescription: session = self._getModelSession(context, request.modelSessionId) - id = session.client.create_dataset_description(mean=request.mean, stddev=request.stddev) + id = session.client.api.create_dataset_description(mean=request.mean, stddev=request.stddev) return inference_pb2.DatasetDescription(id=id) def CloseModelSession(self, request: inference_pb2.ModelSession, context) -> inference_pb2.Empty: @@ -76,10 +76,10 @@ def ListDevices(self, request: inference_pb2.Empty, context) -> inference_pb2.De def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse: session = self._getModelSession(context, request.modelSessionId) input_sample = Sample.from_pb_tensors(request.tensors) - tensor_validator = InputTensorValidator(session.client.model) + tensor_validator = InputTensorValidator(session.client.input_specs) tensor_validator.check_tensors(input_sample) - res = session.client.forward(input_sample) - output_tensor_ids = [tensor.name for tensor in session.client.model.output_specs] + res = session.client.api.forward(input_sample) + output_tensor_ids = [tensor.name for tensor in session.client.output_specs] output_sample = Sample.from_raw_data(output_tensor_ids, res) return inference_pb2.PredictResponse(tensors=output_sample.to_pb_tensors()) diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index 91272f83..ef17f69e 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -15,7 +15,7 @@ from tiktorch import log from tiktorch.rpc import Shutdown from tiktorch.rpc import mp as _mp_rpc -from tiktorch.rpc.mp import MPServer +from tiktorch.rpc.mp import Client, MPServer from ...converters import Sample from .backend import base @@ -23,18 +23,18 @@ class InputTensorValidator: - def __init__(self, model: PredictionPipeline): - self._model = model + def __init__(self, input_specs: List[nodes.InputTensor]): + self._input_specs = input_specs def check_tensors(self, sample: Sample): for tensor_id, tensor in sample.tensors.items(): self.check_shape(tensor_id, tensor.dims, tensor.shape) def _get_input_tensors_with_names(self) -> Dict[str, nodes.InputTensor]: - return {tensor.name: tensor for tensor in self._model.input_specs} + return {tensor.name: tensor for tensor in self._input_specs} def check_shape(self, tensor_id: str, axes: Tuple[str, ...], shape: Tuple[int, ...]): - shape = self._get_axes_with_size(axes, shape) + shape = self.get_axes_with_size(axes, shape) spec = self._get_input_spec(tensor_id) if isinstance(spec.shape, list): self._check_shape_explicit(spec, shape) @@ -45,30 +45,30 @@ def check_shape(self, tensor_id: str, axes: Tuple[str, ...], shape: Tuple[int, . def _get_input_spec(self, tensor_id: str) -> nodes.InputTensor: self._check_spec_exists(tensor_id) - specs = [spec for spec in self._model.input_specs if spec.name == tensor_id] + specs = [spec for spec in self._input_specs if spec.name == tensor_id] assert len(specs) == 1, "ids of tensor specs should be unique" return specs[0] def _check_spec_exists(self, tensor_id: str): - spec_names = [spec.name for spec in self._model.input_specs] + spec_names = [spec.name for spec in self._input_specs] if tensor_id not in spec_names: raise ValueError(f"Spec {tensor_id} doesn't exist for specs {spec_names}") def _check_shape_explicit(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]): - assert self._is_shape_explicit(spec) + assert self.is_shape_explicit(spec) reference_shape = {name: size for name, size in zip(spec.axes, spec.shape)} - self._check_same_axes(reference_shape, tensor_shape) + self.check_same_axes(reference_shape, tensor_shape) if reference_shape != tensor_shape: raise ValueError(f"Incompatible shapes found {tensor_shape}, expected {reference_shape}") def _check_shape_parameterized(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]): assert isinstance(spec.shape, ParametrizedInputShape) - if not self._is_shape(tensor_shape.values()): + if not self.is_shape(tensor_shape.values()): raise ValueError(f"Invalid shape's sizes {tensor_shape}") - min_shape = self._get_axes_with_size(spec.axes, tuple(spec.shape.min)) - step = self._get_axes_with_size(spec.axes, tuple(spec.shape.step)) - self._check_same_axes(tensor_shape, min_shape) + min_shape = self.get_axes_with_size(spec.axes, tuple(spec.shape.min)) + step = self.get_axes_with_size(spec.axes, tuple(spec.shape.step)) + self.check_same_axes(tensor_shape, min_shape) tensor_shapes_arr = np.array(list(tensor_shape.values())) min_shape_arr = np.array(list(min_shape.values())) @@ -80,25 +80,30 @@ def _check_shape_parameterized(self, spec: nodes.InputTensor, tensor_shape: Dict non_zero_idx = np.nonzero(step_arr) multipliers = diff[non_zero_idx] / step_arr[non_zero_idx] multiplier = np.unique(multipliers) - if len(multiplier) == 1 and self._is_natural_number(multiplier[0]): + if len(multiplier) == 1 and self.is_natural_number(multiplier[0]): return raise ValueError(f"Tensor shape {tensor_shape} not valid for spec {spec}") - def _check_same_axes(self, source: Dict[str, int], target: Dict[str, int]): + @staticmethod + def check_same_axes(source: Dict[str, int], target: Dict[str, int]): if source.keys() != target.keys(): raise ValueError(f"Incompatible axes for tensor {target} and reference {source}") - def _is_natural_number(self, n) -> bool: + @staticmethod + def is_natural_number(n) -> bool: return n % 1 == 0.0 and n >= 0 - def _is_shape(self, shape: Iterator[int]) -> bool: - return all(self._is_natural_number(dim) for dim in shape) + @staticmethod + def is_shape(shape: Iterator[int]) -> bool: + return all(InputTensorValidator.is_natural_number(dim) for dim in shape) - def _get_axes_with_size(self, axes: Tuple[str, ...], shape: Tuple[int, ...]) -> Dict[str, int]: + @staticmethod + def get_axes_with_size(axes: Tuple[str, ...], shape: Tuple[int, ...]) -> Dict[str, int]: assert len(axes) == len(shape) return {name: size for name, size in zip(axes, shape)} - def _is_shape_explicit(self, spec: nodes.InputTensor) -> bool: + @staticmethod + def is_shape_explicit(spec: nodes.InputTensor) -> bool: return isinstance(spec.shape, list) @@ -107,7 +112,6 @@ def __init__(self, model: PredictionPipeline) -> None: super().__init__(model) self._datasets = {} self._worker = base.SessionBackend(self._model) - self._shape_validator = InputTensorValidator(self._model) def forward(self, sample: Sample) -> Future: tensors_data = [sample.tensors[tensor.name] for tensor in self.model.input_specs] @@ -146,7 +150,7 @@ def _run_model_session_process( def start_model_session_process( model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None -) -> Tuple[_mp.Process, IRPCModelSession]: +) -> Tuple[_mp.Process, Client[IRPCModelSession]]: client_conn, server_conn = _mp.Pipe() prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_zip, devices) proc = _mp.Process( @@ -159,9 +163,9 @@ def start_model_session_process( }, ) proc.start() - # here create the prediction pipeline, share it to the model session class and the client - return proc, _mp_rpc.create_client( - iface_cls=IRPCModelSession, api_kwargs={"model": prediction_pipeline}, conn=client_conn + api = _mp_rpc.create_client_api(iface_cls=IRPCModelSession, conn=client_conn) + return proc, Client( + input_specs=prediction_pipeline.input_specs, output_specs=prediction_pipeline.output_specs, api=api ) diff --git a/tiktorch/server/session_manager.py b/tiktorch/server/session_manager.py index 86a8a1ea..b12b2ad0 100644 --- a/tiktorch/server/session_manager.py +++ b/tiktorch/server/session_manager.py @@ -7,7 +7,7 @@ from typing import Callable, Dict, List, Optional from uuid import uuid4 -from tiktorch.server.session import IRPCModelSession +from tiktorch.rpc.mp import Client logger = getLogger(__name__) @@ -28,7 +28,7 @@ def id(self) -> str: @property @abc.abstractmethod - def client(self) -> IRPCModelSession: + def client(self) -> Client: ... @abc.abstractmethod @@ -43,13 +43,13 @@ def on_close(self, handler: CloseCallback) -> None: class _Session(ISession): - def __init__(self, id_: str, client: IRPCModelSession, manager: SessionManager) -> None: + def __init__(self, id_: str, client: Client, manager: SessionManager) -> None: self.__id = id_ self.__manager = manager self.__client = client @property - def client(self) -> IRPCModelSession: + def client(self) -> Client: return self.__client @property @@ -65,7 +65,7 @@ class SessionManager: Manages session lifecycle (create/close) """ - def create_session(self, client: IRPCModelSession) -> ISession: + def create_session(self, client: Client) -> ISession: """ Creates new session with unique id """