diff --git a/tiktorch/rpc/mp.py b/tiktorch/rpc/mp.py index 94f3a638..0cd69bf4 100644 --- a/tiktorch/rpc/mp.py +++ b/tiktorch/rpc/mp.py @@ -6,11 +6,12 @@ from functools import wraps from multiprocessing.connection import Connection from threading import Event, Thread -from typing import Any, Generic, List, Optional, Type, TypeVar +from typing import Any, List, Optional, Type, TypeVar from uuid import uuid4 from bioimageio.core.resource_io import nodes +from ..server.session import IRPCModelSession from .exceptions import Shutdown from .interface import get_exposed_methods from .types import RPCFuture, isfutureret @@ -110,8 +111,8 @@ class _Api: @dataclasses.dataclass(frozen=True) -class Client(Generic[T]): - api: T +class BioModelClient: + api: IRPCModelSession input_specs: List[nodes.InputTensor] output_specs: List[nodes.OutputTensor] diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py index 4034098a..f09e0bae 100644 --- a/tiktorch/server/grpc/inference_servicer.py +++ b/tiktorch/server/grpc/inference_servicer.py @@ -7,7 +7,7 @@ from tiktorch.server.data_store import IDataStore from tiktorch.server.device_pool import DeviceStatus, IDevicePool from tiktorch.server.session.process import InputTensorValidator, start_model_session_process -from tiktorch.server.session_manager import ISession, SessionManager +from tiktorch.server.session_manager import Session, SessionManager class InferenceServicer(inference_pb2_grpc.InferenceServicer): @@ -46,7 +46,7 @@ def CreateDatasetDescription( self, request: inference_pb2.CreateDatasetDescriptionRequest, context ) -> inference_pb2.DatasetDescription: session = self._getModelSession(context, request.modelSessionId) - id = session.client.api.create_dataset_description(mean=request.mean, stddev=request.stddev) + id = session.bio_model_client.api.create_dataset_description(mean=request.mean, stddev=request.stddev) return inference_pb2.DatasetDescription(id=id) def CloseModelSession(self, request: inference_pb2.ModelSession, context) -> inference_pb2.Empty: @@ -76,14 +76,14 @@ def ListDevices(self, request: inference_pb2.Empty, context) -> inference_pb2.De def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse: session = self._getModelSession(context, request.modelSessionId) input_sample = Sample.from_pb_tensors(request.tensors) - tensor_validator = InputTensorValidator(session.client.input_specs) + tensor_validator = InputTensorValidator(session.bio_model_client.input_specs) tensor_validator.check_tensors(input_sample) - res = session.client.api.forward(input_sample) - output_tensor_ids = [tensor.name for tensor in session.client.output_specs] + res = session.bio_model_client.api.forward(input_sample) + output_tensor_ids = [tensor.name for tensor in session.bio_model_client.output_specs] output_sample = Sample.from_xr_tensors(output_tensor_ids, res) return inference_pb2.PredictResponse(tensors=output_sample.to_pb_tensors()) - def _getModelSession(self, context, modelSessionId: str) -> ISession: + def _getModelSession(self, context, modelSessionId: str) -> Session: if not modelSessionId: context.abort(grpc.StatusCode.FAILED_PRECONDITION, "model-session-id has not been provided by client") diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index ef17f69e..c9e0186e 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -15,7 +15,7 @@ from tiktorch import log from tiktorch.rpc import Shutdown from tiktorch.rpc import mp as _mp_rpc -from tiktorch.rpc.mp import Client, MPServer +from tiktorch.rpc.mp import BioModelClient, MPServer from ...converters import Sample from .backend import base @@ -150,7 +150,7 @@ def _run_model_session_process( def start_model_session_process( model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None -) -> Tuple[_mp.Process, Client[IRPCModelSession]]: +) -> Tuple[_mp.Process, BioModelClient]: client_conn, server_conn = _mp.Pipe() prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_zip, devices) proc = _mp.Process( @@ -164,7 +164,7 @@ def start_model_session_process( ) proc.start() api = _mp_rpc.create_client_api(iface_cls=IRPCModelSession, conn=client_conn) - return proc, Client( + return proc, BioModelClient( input_specs=prediction_pipeline.input_specs, output_specs=prediction_pipeline.output_specs, api=api ) diff --git a/tiktorch/server/session_manager.py b/tiktorch/server/session_manager.py index b12b2ad0..3807e130 100644 --- a/tiktorch/server/session_manager.py +++ b/tiktorch/server/session_manager.py @@ -1,62 +1,44 @@ from __future__ import annotations -import abc import threading from collections import defaultdict from logging import getLogger from typing import Callable, Dict, List, Optional from uuid import uuid4 -from tiktorch.rpc.mp import Client +from tiktorch.rpc.mp import BioModelClient logger = getLogger(__name__) +CloseCallback = Callable[[], None] + -class ISession(abc.ABC): +class Session: """ session object has unique id Used for resource managent """ - @property - @abc.abstractmethod - def id(self) -> str: - """ - Returns unique id assigned to this session - """ - ... - - @property - @abc.abstractmethod - def client(self) -> Client: - ... - - @abc.abstractmethod - def on_close(self, handler: CloseCallback) -> None: - """ - Register cleanup function to be called when session ends - """ - ... - - -CloseCallback = Callable[[], None] - - -class _Session(ISession): - def __init__(self, id_: str, client: Client, manager: SessionManager) -> None: + def __init__(self, id_: str, bio_model_client: BioModelClient, manager: SessionManager) -> None: self.__id = id_ self.__manager = manager - self.__client = client + self.__bio_model_client = bio_model_client @property - def client(self) -> Client: - return self.__client + def bio_model_client(self) -> BioModelClient: + return self.__bio_model_client @property def id(self) -> str: + """ + Returns unique id assigned to this session + """ return self.__id def on_close(self, handler: CloseCallback) -> None: + """ + Register cleanup function to be called when session ends + """ self.__manager._on_close(self, handler) @@ -65,18 +47,18 @@ class SessionManager: Manages session lifecycle (create/close) """ - def create_session(self, client: Client) -> ISession: + def create_session(self, bio_model_client: BioModelClient) -> Session: """ Creates new session with unique id """ with self.__lock: session_id = uuid4().hex - session = _Session(session_id, client=client, manager=self) + session = Session(session_id, bio_model_client=bio_model_client, manager=self) self.__session_by_id[session_id] = session logger.info("Created session %s", session.id) return session - def get(self, session_id: str) -> Optional[ISession]: + def get(self, session_id: str) -> Optional[Session]: """ Returns existing session with given id if it exists """ @@ -102,10 +84,10 @@ def close_session(self, session_id: str) -> None: def __init__(self) -> None: self.__lock = threading.Lock() - self.__session_by_id: Dict[str, ISession] = {} + self.__session_by_id: Dict[str, Session] = {} self.__close_handlers_by_session_id: Dict[str, List[CloseCallback]] = defaultdict(list) - def _on_close(self, session: ISession, handler: CloseCallback): + def _on_close(self, session: Session, handler: CloseCallback): with self.__lock: logger.debug("Registered close handler %s for session %s", handler, session.id) self.__close_handlers_by_session_id[session.id].append(handler)