Skip to content

Commit

Permalink
Simplify session manager to work only with sessions associated with a…
Browse files Browse the repository at this point in the history
… bio model client
  • Loading branch information
thodkatz committed Aug 15, 2024
1 parent fae2741 commit 85a4c8c
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 49 deletions.
7 changes: 4 additions & 3 deletions tiktorch/rpc/mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from functools import wraps
from multiprocessing.connection import Connection
from threading import Event, Thread
from typing import Any, Generic, List, Optional, Type, TypeVar
from typing import Any, List, Optional, Type, TypeVar
from uuid import uuid4

from bioimageio.core.resource_io import nodes

from ..server.session import IRPCModelSession
from .exceptions import Shutdown
from .interface import get_exposed_methods
from .types import RPCFuture, isfutureret
Expand Down Expand Up @@ -110,8 +111,8 @@ class _Api:


@dataclasses.dataclass(frozen=True)
class Client(Generic[T]):
api: T
class BioModelClient:
api: IRPCModelSession
input_specs: List[nodes.InputTensor]
output_specs: List[nodes.OutputTensor]

Expand Down
12 changes: 6 additions & 6 deletions tiktorch/server/grpc/inference_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tiktorch.server.data_store import IDataStore
from tiktorch.server.device_pool import DeviceStatus, IDevicePool
from tiktorch.server.session.process import InputTensorValidator, start_model_session_process
from tiktorch.server.session_manager import ISession, SessionManager
from tiktorch.server.session_manager import Session, SessionManager


class InferenceServicer(inference_pb2_grpc.InferenceServicer):
Expand Down Expand Up @@ -46,7 +46,7 @@ def CreateDatasetDescription(
self, request: inference_pb2.CreateDatasetDescriptionRequest, context
) -> inference_pb2.DatasetDescription:
session = self._getModelSession(context, request.modelSessionId)
id = session.client.api.create_dataset_description(mean=request.mean, stddev=request.stddev)
id = session.bio_model_client.api.create_dataset_description(mean=request.mean, stddev=request.stddev)
return inference_pb2.DatasetDescription(id=id)

def CloseModelSession(self, request: inference_pb2.ModelSession, context) -> inference_pb2.Empty:
Expand Down Expand Up @@ -76,14 +76,14 @@ def ListDevices(self, request: inference_pb2.Empty, context) -> inference_pb2.De
def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse:
session = self._getModelSession(context, request.modelSessionId)
input_sample = Sample.from_pb_tensors(request.tensors)
tensor_validator = InputTensorValidator(session.client.input_specs)
tensor_validator = InputTensorValidator(session.bio_model_client.input_specs)
tensor_validator.check_tensors(input_sample)
res = session.client.api.forward(input_sample)
output_tensor_ids = [tensor.name for tensor in session.client.output_specs]
res = session.bio_model_client.api.forward(input_sample)
output_tensor_ids = [tensor.name for tensor in session.bio_model_client.output_specs]
output_sample = Sample.from_xr_tensors(output_tensor_ids, res)
return inference_pb2.PredictResponse(tensors=output_sample.to_pb_tensors())

def _getModelSession(self, context, modelSessionId: str) -> ISession:
def _getModelSession(self, context, modelSessionId: str) -> Session:
if not modelSessionId:
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "model-session-id has not been provided by client")

Expand Down
6 changes: 3 additions & 3 deletions tiktorch/server/session/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tiktorch import log
from tiktorch.rpc import Shutdown
from tiktorch.rpc import mp as _mp_rpc
from tiktorch.rpc.mp import Client, MPServer
from tiktorch.rpc.mp import BioModelClient, MPServer

from ...converters import Sample
from .backend import base
Expand Down Expand Up @@ -150,7 +150,7 @@ def _run_model_session_process(

def start_model_session_process(
model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None
) -> Tuple[_mp.Process, Client[IRPCModelSession]]:
) -> Tuple[_mp.Process, BioModelClient]:
client_conn, server_conn = _mp.Pipe()
prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_zip, devices)
proc = _mp.Process(
Expand All @@ -164,7 +164,7 @@ def start_model_session_process(
)
proc.start()
api = _mp_rpc.create_client_api(iface_cls=IRPCModelSession, conn=client_conn)
return proc, Client(
return proc, BioModelClient(
input_specs=prediction_pipeline.input_specs, output_specs=prediction_pipeline.output_specs, api=api
)

Expand Down
56 changes: 19 additions & 37 deletions tiktorch/server/session_manager.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,44 @@
from __future__ import annotations

import abc
import threading
from collections import defaultdict
from logging import getLogger
from typing import Callable, Dict, List, Optional
from uuid import uuid4

from tiktorch.rpc.mp import Client
from tiktorch.rpc.mp import BioModelClient

logger = getLogger(__name__)

CloseCallback = Callable[[], None]


class ISession(abc.ABC):
class Session:
"""
session object has unique id
Used for resource managent
"""

@property
@abc.abstractmethod
def id(self) -> str:
"""
Returns unique id assigned to this session
"""
...

@property
@abc.abstractmethod
def client(self) -> Client:
...

@abc.abstractmethod
def on_close(self, handler: CloseCallback) -> None:
"""
Register cleanup function to be called when session ends
"""
...


CloseCallback = Callable[[], None]


class _Session(ISession):
def __init__(self, id_: str, client: Client, manager: SessionManager) -> None:
def __init__(self, id_: str, bio_model_client: BioModelClient, manager: SessionManager) -> None:
self.__id = id_
self.__manager = manager
self.__client = client
self.__bio_model_client = bio_model_client

@property
def client(self) -> Client:
return self.__client
def bio_model_client(self) -> BioModelClient:
return self.__bio_model_client

@property
def id(self) -> str:
"""
Returns unique id assigned to this session
"""
return self.__id

def on_close(self, handler: CloseCallback) -> None:
"""
Register cleanup function to be called when session ends
"""
self.__manager._on_close(self, handler)


Expand All @@ -65,18 +47,18 @@ class SessionManager:
Manages session lifecycle (create/close)
"""

def create_session(self, client: Client) -> ISession:
def create_session(self, bio_model_client: BioModelClient) -> Session:
"""
Creates new session with unique id
"""
with self.__lock:
session_id = uuid4().hex
session = _Session(session_id, client=client, manager=self)
session = Session(session_id, bio_model_client=bio_model_client, manager=self)
self.__session_by_id[session_id] = session
logger.info("Created session %s", session.id)
return session

def get(self, session_id: str) -> Optional[ISession]:
def get(self, session_id: str) -> Optional[Session]:
"""
Returns existing session with given id if it exists
"""
Expand All @@ -102,10 +84,10 @@ def close_session(self, session_id: str) -> None:

def __init__(self) -> None:
self.__lock = threading.Lock()
self.__session_by_id: Dict[str, ISession] = {}
self.__session_by_id: Dict[str, Session] = {}
self.__close_handlers_by_session_id: Dict[str, List[CloseCallback]] = defaultdict(list)

def _on_close(self, session: ISession, handler: CloseCallback):
def _on_close(self, session: Session, handler: CloseCallback):
with self.__lock:
logger.debug("Registered close handler %s for session %s", handler, session.id)
self.__close_handlers_by_session_id[session.id].append(handler)

0 comments on commit 85a4c8c

Please sign in to comment.