Skip to content

Commit

Permalink
Use bioimageio prediction pipeline as model from client and server
Browse files Browse the repository at this point in the history
  • Loading branch information
thodkatz committed Aug 10, 2024
1 parent b1587c7 commit a242db4
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 42 deletions.
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,16 @@ def bioimageio_dummy_model_filepath(data_path, tmpdir):
@pytest.fixture
def bioimageio_dummy_model_bytes(data_path):
rdf_source = data_path / TEST_BIOIMAGEIO_DUMMY / "Dummy.model.yaml"
return _bioimageio_package(rdf_source)


@pytest.fixture
def bioimageio_dummy_param_model_bytes(data_path):
rdf_source = data_path / "dummy_param" / "Dummy.model_param.yaml"
return _bioimageio_package(rdf_source)


def _bioimageio_package(rdf_source):
data = io.BytesIO()
export_resource_package(rdf_source, output_path=data)
return data
Expand Down
8 changes: 4 additions & 4 deletions tests/test_rpc/test_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def client(log_queue):
p = mp.Process(target=_srv, args=(parent, log_queue))
p.start()

client = create_client(ITestApi, child, timeout=10)
client = create_client(iface_cls=ITestApi, conn=child, timeout=10)

yield client

Expand Down Expand Up @@ -108,7 +108,7 @@ def __getattr__(self, name):
p = mp.Process(target=_srv, args=(parent, log_queue))
p.start()

client = create_client(ITestApi, SlowConn(child))
client = create_client(iface_cls=ITestApi, conn=SlowConn(child))

client.fast_compute(2, 2)

Expand All @@ -121,7 +121,7 @@ def test_future_timeout(client: ITestApi, log_queue):
p = mp.Process(target=_srv, args=(parent, log_queue))
p.start()

client = create_client(ITestApi, child, timeout=0.001)
client = create_client(iface_cls=ITestApi, conn=child, timeout=0.001)

with pytest.raises(TimeoutError):
client.compute(1, 2)
Expand Down Expand Up @@ -256,7 +256,7 @@ def _spawn(iface_cls, srv_cls):
p = mp.Process(target=_run_srv, args=(srv_cls, parent, log_queue))
p.start()

data["client"] = client = create_client(iface_cls, child)
data["client"] = client = create_client(iface_cls=iface_cls, conn=child)
data["process"] = p
return client

Expand Down
37 changes: 34 additions & 3 deletions tests/test_server/test_grpc/test_inference_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,25 +156,56 @@ def test_call_fails_with_unknown_model_session_id(self, grpc_stub):

def test_call_predict(self, grpc_stub, bioimageio_dummy_model_bytes):
model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_model_bytes))
arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y"))
arr = xr.DataArray(np.arange(128 * 128).reshape(1, 1, 128, 128), dims=("b", "c", "x", "y"))
expected = arr + 1
input_tensors = [converters.xarray_to_pb_tensor("input", arr)]
input_spec_id = "input"
output_spec_id = "output"
input_tensors = [converters.xarray_to_pb_tensor(input_spec_id, arr)]
res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))

grpc_stub.CloseModelSession(model)

assert len(res.tensors) == 1
assert res.tensors[0].specId == output_spec_id
assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0]))

def test_call_predict_invalid_shape_explicit(self, grpc_stub, bioimageio_dummy_model_bytes):
model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_model_bytes))
arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y"))
input_tensors = [converters.xarray_to_pb_tensor("input", arr)]
with pytest.raises(grpc.RpcError):
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.CloseModelSession(model)

@pytest.mark.parametrize("shape", [(1, 1, 64, 32), (1, 1, 32, 64), (1, 1, 64, 32), (0, 1, 64, 64), (1, 0, 64, 64)])
def test_call_predict_invalid_shape_parameterized(self, grpc_stub, shape, bioimageio_dummy_param_model_bytes):
model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_param_model_bytes))
arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("b", "c", "x", "y"))
input_tensors = [converters.xarray_to_pb_tensor("param", arr)]
with pytest.raises(grpc.RpcError):
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.CloseModelSession(model)

@pytest.mark.parametrize("shape", [(1, 1, 64, 64), (1, 1, 66, 65), (1, 1, 68, 66), (1, 1, 70, 67)])
def test_call_predict_valid_shape_parameterized(self, grpc_stub, shape, bioimageio_dummy_param_model_bytes):
model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_param_model_bytes))
arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("b", "c", "x", "y"))
input_tensors = [converters.xarray_to_pb_tensor("param", arr)]
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.CloseModelSession(model)

@pytest.mark.skip
def test_call_predict_tf(self, grpc_stub, bioimageio_dummy_tensorflow_model_bytes):
model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_tensorflow_model_bytes))
arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y"))
expected = arr * -1
input_tensors = [converters.xarray_to_pb_tensor(arr)]
input_spec_id = "input"
output_spec_id = "output"
input_tensors = [converters.xarray_to_pb_tensor(input_spec_id, arr)]
res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))

grpc_stub.CloseModelSession(model)

assert len(res.tensors) == 1
assert res.tensors[0].specId == output_spec_id
assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0]))
23 changes: 12 additions & 11 deletions tiktorch/rpc/mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import wraps
from multiprocessing.connection import Connection
from threading import Event, Thread
from typing import Any, Optional, Type, TypeVar
from typing import Any, Dict, Optional, Type, TypeVar
from uuid import uuid4

from .exceptions import Shutdown
Expand Down Expand Up @@ -72,9 +72,8 @@ def __call__(self, *args, **kwargs) -> Any:
return self._client._invoke(self._method_name, *args, **kwargs)


def create_client(iface_cls: Type[T], conn: Connection, timeout=None) -> T:
def create_client(iface_cls: Type[T], conn: Connection, api_kwargs: Optional[Dict[str, any]] = None, timeout=None) -> T:
client = MPClient(iface_cls.__name__, conn, timeout)
get_exposed_methods(iface_cls)

def _make_method(method):
class MethodWrapper:
Expand All @@ -98,12 +97,15 @@ def __call__(self, *args, **kwargs) -> Any:
return MethodWrapper()

class _Client(iface_cls):
pass
def __init__(self, kwargs: Optional[Dict]):
kwargs = kwargs or {}
super().__init__(**kwargs)

for method_name, method in get_exposed_methods(iface_cls).items():
exposed_methods = get_exposed_methods(iface_cls)
for method_name, method in exposed_methods.items():
setattr(_Client, method_name, _make_method(method))

return _Client()
return _Client(api_kwargs)


class MPClient:
Expand Down Expand Up @@ -190,7 +192,7 @@ def _shutdown(self, exc):

class Message:
def __init__(self, id_):
self.id = id
self.id = id_


class Signal:
Expand All @@ -200,20 +202,19 @@ def __init__(self, payload):

class MethodCall(Message):
def __init__(self, id_, method_name, args, kwargs):
self.id = id_
super().__init__(id_)
self.method_name = method_name
self.args = args
self.kwargs = kwargs


class Cancellation(Message):
def __init__(self, id_):
self.id = id_
pass


class MethodReturn(Message):
def __init__(self, id_, result: Result):
self.id = id_
super().__init__(id_)
self.result = result


Expand Down
2 changes: 1 addition & 1 deletion tiktorch/server/grpc/inference_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def CreateModelSession(
lease.terminate()
raise

session = self.__session_manager.create_session()
session = self.__session_manager.create_session(client)
session.on_close(lease.terminate)
session.on_close(client.shutdown)

Expand Down
115 changes: 98 additions & 17 deletions tiktorch/server/session/process.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,117 @@
import multiprocessing as _mp
import os
import pathlib
import tempfile
import uuid
from concurrent.futures import Future
from multiprocessing.connection import Connection
from typing import List, Optional, Tuple
from typing import Dict, Iterator, List, Optional, Set, Tuple

import numpy
import numpy as np
from bioimageio.core import load_resource_description
from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline
from bioimageio.core.resource_io import nodes
from bioimageio.core.resource_io.nodes import ParametrizedInputShape

from tiktorch import log
from tiktorch.rpc import Shutdown
from tiktorch.rpc import mp as _mp_rpc
from tiktorch.rpc.mp import MPServer

from ...converters import Tensor
from .backend import base
from .rpc_interface import IRPCModelSession


class ModelSessionProcess(IRPCModelSession):
def __init__(self, model_zip: bytes, devices: List[str]) -> None:
_tmp_file = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
_tmp_file.write(model_zip)
_tmp_file.close()
model = load_resource_description(pathlib.Path(_tmp_file.name))
os.unlink(_tmp_file.name)
self._model: PredictionPipeline = create_prediction_pipeline(bioimageio_model=model, devices=devices)
class ModelSessionProcess(IRPCModelSession[PredictionPipeline]):
def __init__(self, model: PredictionPipeline) -> None:
super().__init__(model)
self._datasets = {}
self._worker = base.SessionBackend(self._model)

def forward(self, input_tensors: numpy.ndarray) -> Future:
res = self._worker.forward(input_tensors)
def forward(self, input_tensors: Set[Tensor]) -> Future:
for tensor in input_tensors:
axes_wih_size = self._get_axes_with_size(tensor.data.dims, tensor.data.shape)
self.check_shape(tensor.spec_id, axes_wih_size)
tensors_data = [tensor.data for tensor in input_tensors]
res = self._worker.forward(tensors_data)
return res

def _get_input_spec(self, spec_id: str) -> nodes.InputTensor:
self._check_spec_exists(spec_id)
specs = [spec for spec in self._model.input_specs if spec.name == spec_id]
assert len(specs) == 1, "ids of tensor specs should be unique"
return specs[0]

def create_dataset(self, mean, stddev):
id_ = uuid.uuid4().hex
self._datasets[id_] = {"mean": mean, "stddev": stddev}
return id_

def check_shape(self, spec_id: str, shape: Dict[str, int]):
spec = self._get_input_spec(spec_id)
if isinstance(spec.shape, list):
self._check_shape_explicit(spec, shape)
elif isinstance(spec.shape, ParametrizedInputShape):
self._check_shape_parameterized(spec, shape)
else:
raise ValueError(f"Unexpected shape {spec.shape}")

def _check_spec_exists(self, spec_id: str):
spec_names = [spec.name for spec in self._model.input_specs]
if spec_id not in spec_names:
raise ValueError(f"Spec {spec_id} doesn't exist for specs {spec_names}")

def _check_shape_explicit(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]):
assert self._is_shape_explicit(spec)
reference_shape = {name: size for name, size in zip(spec.axes, spec.shape)}
if reference_shape != tensor_shape:
raise ValueError(f"Incompatible shapes found {tensor_shape}, expected {reference_shape}")

def _check_shape_parameterized(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]):
assert isinstance(spec.shape, ParametrizedInputShape)
if not self._is_shape(tensor_shape.values()):
raise ValueError(f"Invalid shape's sizes {tensor_shape}")

min_shape = self._get_axes_with_size(spec.axes, tuple(spec.shape.min))
step = self._get_axes_with_size(spec.axes, tuple(spec.shape.step))
assert min_shape.keys() == step.keys()
if tensor_shape.keys() != min_shape.keys():
raise ValueError(f"Incompatible axes for tensor {tensor_shape} and spec {spec}")

tensor_shapes_arr = np.array(list(tensor_shape.values()))
min_shape_arr = np.array(list(min_shape.values()))
step_arr = np.array(list(step.values()))
diff = tensor_shapes_arr - min_shape_arr
if any(size < 0 for size in diff):
raise ValueError(f"Tensor shape {tensor_shape} smaller than min shape {min_shape}")

non_zero_idx = np.nonzero(step_arr)
multipliers = diff[non_zero_idx] / step_arr[non_zero_idx]
multiplier = np.unique(multipliers)
if len(multiplier) == 1 and self._is_natural_number(multiplier[0]):
return
raise ValueError(f"Tensor shape {tensor_shape} not valid for spec {spec}")

def _is_natural_number(self, n) -> bool:
return np.floor(n) == np.ceil(n) and n >= 0

def _is_shape(self, shape: Iterator[int]) -> bool:
return all(self._is_natural_number(dim) for dim in shape)

def _get_axes_with_size(self, axes: Tuple[str, ...], shape: Tuple[int, ...]) -> Dict[str, int]:
assert len(axes) == len(shape)
return {name: size for name, size in zip(axes, shape)}

def _is_shape_explicit(self, spec: nodes.InputTensor) -> bool:
return isinstance(spec.shape, list)

def shutdown(self) -> Shutdown:
self._worker.shutdown()
return Shutdown()


def _run_model_session_process(
conn: Connection, model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None
conn: Connection, prediction_pipeline: PredictionPipeline, log_queue: Optional[_mp.Queue] = None
):
try:
# from: https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667
Expand All @@ -60,7 +125,7 @@ def _run_model_session_process(
if log_queue:
log.configure(log_queue)

session_proc = ModelSessionProcess(model_zip, devices)
session_proc = ModelSessionProcess(prediction_pipeline)
srv = MPServer(session_proc, conn)
srv.listen()

Expand All @@ -69,10 +134,26 @@ def start_model_session_process(
model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None
) -> Tuple[_mp.Process, IRPCModelSession]:
client_conn, server_conn = _mp.Pipe()
prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_zip, devices)
proc = _mp.Process(
target=_run_model_session_process,
name="ModelSessionProcess",
kwargs={"conn": server_conn, "devices": devices, "log_queue": log_queue, "model_zip": model_zip},
kwargs={
"conn": server_conn,
"log_queue": log_queue,
"prediction_pipeline": prediction_pipeline,
},
)
proc.start()
return proc, _mp_rpc.create_client(IRPCModelSession, client_conn)
# here create the prediction pipeline, share it to the model session class and the client
return proc, _mp_rpc.create_client(
iface_cls=IRPCModelSession, api_kwargs={"model": prediction_pipeline}, conn=client_conn
)


def _get_prediction_pipeline_from_model_bytes(model_zip: bytes, devices: List[str]) -> PredictionPipeline:
with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as _tmp_file:
_tmp_file.write(model_zip)
temp_file_path = pathlib.Path(_tmp_file.name)
model = load_resource_description(temp_file_path)
return create_prediction_pipeline(bioimageio_model=model, devices=devices)
17 changes: 14 additions & 3 deletions tiktorch/server/session/rpc_interface.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from typing import List
from typing import Generic, List, Set, TypeVar

from tiktorch.converters import Tensor
from tiktorch.rpc import RPCInterface, Shutdown, exposed
from tiktorch.tiktypes import TikTensorBatch
from tiktorch.types import ModelState

ModelType = TypeVar("ModelType")


class IRPCModelSession(RPCInterface, Generic[ModelType]):
def __init__(self, model: ModelType):
super().__init__()
self._model = model

@property
def model(self):
return self._model

class IRPCModelSession(RPCInterface):
@exposed
def shutdown(self) -> Shutdown:
raise NotImplementedError
Expand Down Expand Up @@ -43,5 +54,5 @@ def create_dataset_description(self, mean, stddev) -> str:
raise NotImplementedError

@exposed
def forward(self, input_tensors):
def forward(self, input_tensors: Set[Tensor]):
raise NotImplementedError
Loading

0 comments on commit a242db4

Please sign in to comment.