diff --git a/proto/inference.proto b/proto/inference.proto index c6e85629..39845dc5 100644 --- a/proto/inference.proto +++ b/proto/inference.proto @@ -86,14 +86,6 @@ message OutputShape { message ModelSession { string id = 1; - string name = 2; - repeated string inputAxes = 3; - repeated string outputAxes = 4; - bool hasTraining = 5; - repeated InputShape inputShapes = 6; - repeated OutputShape outputShapes = 7; - repeated string inputNames = 8; - repeated string outputNames = 9; } message LogEntry { @@ -128,7 +120,8 @@ message NamedFloat { message Tensor { bytes buffer = 1; string dtype = 2; - repeated NamedInt shape = 3; + string tensorId = 3; + repeated NamedInt shape = 4; } message PredictRequest { diff --git a/tests/conftest.py b/tests/conftest.py index 31929f3e..1118935a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,9 @@ TEST_DATA = "data" TEST_BIOIMAGEIO_ZIPFOLDER = "unet2d" TEST_BIOIMAGEIO_ONNX = "unet2d_onnx" -TEST_BIOIMAGEIO_DUMMY = "dummy" +TEST_BIOIMAGEIO_DUMMY_EXPLICIT = "dummy" +TEST_BIOIMAGEIO_DUMMY_EXPLICIT_RDF = f"{TEST_BIOIMAGEIO_DUMMY_EXPLICIT}/Dummy.model.yaml" +TEST_BIOIMAGEIO_DUMMY_PARAM_RDF = "dummy_param/Dummy.model_param.yaml" TEST_BIOIMAGEIO_TENSORFLOW_DUMMY = "dummy_tensorflow" TEST_BIOIMAGEIO_TORCHSCRIPT = "unet2d_torchscript" @@ -98,7 +100,7 @@ def bioimageio_model_zipfile(bioimageio_model_bytes): @pytest.fixture def bioimageio_dummy_model_filepath(data_path, tmpdir): - bioimageio_net_dir = Path(data_path) / TEST_BIOIMAGEIO_DUMMY + bioimageio_net_dir = Path(data_path) / TEST_BIOIMAGEIO_DUMMY_EXPLICIT path = tmpdir / "dummy_model.zip" with ZipFile(path, mode="w") as zip_model: @@ -113,8 +115,24 @@ def bioimageio_dummy_model_filepath(data_path, tmpdir): @pytest.fixture -def bioimageio_dummy_model_bytes(data_path): - rdf_source = data_path / TEST_BIOIMAGEIO_DUMMY / "Dummy.model.yaml" +def bioimageio_dummy_explicit_model_bytes(data_path): + rdf_source = data_path / TEST_BIOIMAGEIO_DUMMY_EXPLICIT_RDF + return _bioimageio_package(rdf_source) + + +@pytest.fixture +def bioimageio_dummy_param_model_bytes(data_path): + rdf_source = data_path / TEST_BIOIMAGEIO_DUMMY_PARAM_RDF + return _bioimageio_package(rdf_source) + + +@pytest.fixture(params=[(TEST_BIOIMAGEIO_DUMMY_PARAM_RDF, "param"), (TEST_BIOIMAGEIO_DUMMY_EXPLICIT_RDF, "input")]) +def bioimageio_dummy_model(request, data_path): + path, tensor_id = request.param + yield _bioimageio_package(data_path / path), tensor_id + + +def _bioimageio_package(rdf_source): data = io.BytesIO() export_resource_package(rdf_source, output_path=data) return data diff --git a/tests/data/dummy/Dummy.model.yaml b/tests/data/dummy/Dummy.model.yaml index a27a1d4d..1aa8fe90 100644 --- a/tests/data/dummy/Dummy.model.yaml +++ b/tests/data/dummy/Dummy.model.yaml @@ -41,15 +41,16 @@ inputs: data_range: [-inf, inf] shape: [1, 1, 128, 128] + outputs: - name: output axes: bcyx data_type: float32 data_range: [0, 1] shape: - reference_tensor: input # FIXME(m-novikov) ignoring for now + reference_tensor: input scale: [1, 1, 1, 1] offset: [0, 0, 0, 0] - halo: [0, 0, 32, 32] # Should be moved to outputs + halo: [0, 0, 32, 32] type: model diff --git a/tests/data/dummy_param/Dummy.model_param.yaml b/tests/data/dummy_param/Dummy.model_param.yaml new file mode 100644 index 00000000..87ddd885 --- /dev/null +++ b/tests/data/dummy_param/Dummy.model_param.yaml @@ -0,0 +1,57 @@ +format_version: 0.3.3 +language: python +framework: pytorch + +name: UNet2DNucleiBroad +description: A 2d U-Net pretrained on broad nucleus dataset. +cite: + - text: "Ronneberger, Olaf et al. U-net: Convolutional networks for biomedical image segmentation. MICCAI 2015." + doi: https://doi.org/10.1007/978-3-319-24574-4_28 +authors: + - name: "ilastik-team" + affiliation: "EMBL Heidelberg" + +documentation: dummy.md +tags: [pytorch, nucleus-segmentation] +license: MIT +git_repo: https://github.com/ilastik/tiktorch +covers: [] + +source: dummy.py::Dummy +sha256: 00ffb1647cf7ec524892206dce6258d9da498fe040c62838f31b501a09bfd573 +timestamp: 2019-12-11T12:22:32Z # ISO 8601 + +test_inputs: [dummy_in.npy] +test_outputs: [dummy_out.npy] + +weights: + pytorch_state_dict: + source: ./weights + sha256: 518cb80bad2eb3ec3dfbe6bab74920951391ce8fb24e15cf59b9b9f052a575a6 + authors: + - name: "ilastik-team" + affiliation: "EMBL Heidelberg" + + +# TODO double check inputs/outputs +inputs: + - name: param + axes: bcyx + data_type: float32 + data_range: [-inf, inf] + shape: + min: [1, 1, 64, 64] + step: [0, 0, 2, 1] + +outputs: + - name: output + axes: bcyx + data_type: float32 + data_range: [0, 1] + shape: + reference_tensor: param + scale: [1, 1, 1, 1] + offset: [0, 0, 0, 0] + halo: [0, 0, 8, 8] + +type: model diff --git a/tests/data/dummy_param/dummy.md b/tests/data/dummy_param/dummy.md new file mode 100644 index 00000000..e69de29b diff --git a/tests/data/dummy_param/dummy.py b/tests/data/dummy_param/dummy.py new file mode 100644 index 00000000..195e98e3 --- /dev/null +++ b/tests/data/dummy_param/dummy.py @@ -0,0 +1,7 @@ +from torch import nn + + +class Dummy(nn.Module): + def forward(self, input): + x = input + return x + 1 diff --git a/tests/data/dummy_param/dummy_in.npy b/tests/data/dummy_param/dummy_in.npy new file mode 100644 index 00000000..96a78a7b Binary files /dev/null and b/tests/data/dummy_param/dummy_in.npy differ diff --git a/tests/data/dummy_param/dummy_out.npy b/tests/data/dummy_param/dummy_out.npy new file mode 100644 index 00000000..56f76ca7 Binary files /dev/null and b/tests/data/dummy_param/dummy_out.npy differ diff --git a/tests/data/dummy_param/environment.yaml b/tests/data/dummy_param/environment.yaml new file mode 100644 index 00000000..e69de29b diff --git a/tests/data/dummy_param/weights b/tests/data/dummy_param/weights new file mode 100644 index 00000000..da14f342 Binary files /dev/null and b/tests/data/dummy_param/weights differ diff --git a/tests/test_converters.py b/tests/test_converters.py index c775ede3..be268e42 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -7,6 +7,7 @@ NamedExplicitOutputShape, NamedImplicitOutputShape, NamedParametrizedShape, + Sample, input_shape_to_pb_input_shape, numpy_to_pb_tensor, output_shape_to_pb_output_shape, @@ -27,6 +28,16 @@ def _numpy_to_pb_tensor(arr): return parsed +def to_pb_tensor(tensor_id: str, arr: xr.DataArray): + """ + Makes sure that tensor was serialized/deserialized + """ + tensor = xarray_to_pb_tensor(tensor_id, arr) + parsed = inference_pb2.Tensor() + parsed.ParseFromString(tensor.SerializeToString()) + return parsed + + class TestNumpyToPBTensor: def test_should_serialize_to_tensor_type(self): arr = np.arange(9) @@ -97,18 +108,9 @@ def test_should_same_data(self, shape): class TestXarrayToPBTensor: - def to_pb_tensor(self, arr): - """ - Makes sure that tensor was serialized/deserialized - """ - tensor = xarray_to_pb_tensor(arr) - parsed = inference_pb2.Tensor() - parsed.ParseFromString(tensor.SerializeToString()) - return parsed - def test_should_serialize_to_tensor_type(self): xarr = xr.DataArray(np.arange(8).reshape((2, 4)), dims=("x", "y")) - pb_tensor = self.to_pb_tensor(xarr) + pb_tensor = to_pb_tensor("input0", xarr) assert isinstance(pb_tensor, inference_pb2.Tensor) assert len(pb_tensor.shape) == 2 dim1 = pb_tensor.shape[0] @@ -123,28 +125,19 @@ def test_should_serialize_to_tensor_type(self): @pytest.mark.parametrize("shape", [(3, 3), (1,), (1, 1), (18, 20, 1)]) def test_should_have_shape(self, shape): arr = xr.DataArray(np.zeros(shape)) - tensor = self.to_pb_tensor(arr) + tensor = to_pb_tensor("input0", arr) assert tensor.shape assert list(shape) == [dim.size for dim in tensor.shape] def test_should_have_serialized_bytes(self): arr = xr.DataArray(np.arange(9, dtype=np.uint8)) expected = bytes(arr.data) - tensor = self.to_pb_tensor(arr) + tensor = to_pb_tensor("input0", arr) assert expected == tensor.buffer class TestPBTensorToXarray: - def to_pb_tensor(self, arr): - """ - Makes sure that tensor was serialized/deserialized - """ - tensor = xarray_to_pb_tensor(arr) - parsed = inference_pb2.Tensor() - parsed.ParseFromString(tensor.SerializeToString()) - return parsed - def test_should_raise_on_empty_dtype(self): tensor = inference_pb2.Tensor(dtype="", shape=[inference_pb2.NamedInt(size=1), inference_pb2.NamedInt(size=2)]) with pytest.raises(ValueError): @@ -155,33 +148,32 @@ def test_should_raise_on_empty_shape(self): with pytest.raises(ValueError): pb_tensor_to_xarray(tensor) - def test_should_return_ndarray(self): + def test_should_return_xarray(self): arr = xr.DataArray(np.arange(9)) - parsed = self.to_pb_tensor(arr) - result_arr = pb_tensor_to_xarray(parsed) - - assert isinstance(result_arr, xr.DataArray) + parsed = to_pb_tensor("input0", arr) + result_tensor = pb_tensor_to_xarray(parsed) + assert isinstance(result_tensor, xr.DataArray) @pytest.mark.parametrize("np_dtype,dtype_str", [(np.int64, "int64"), (np.uint8, "uint8"), (np.float32, "float32")]) def test_should_have_same_dtype(self, np_dtype, dtype_str): arr = xr.DataArray(np.arange(9, dtype=np_dtype)) - tensor = self.to_pb_tensor(arr) - result_arr = pb_tensor_to_xarray(tensor) + pb_tensor = to_pb_tensor("input0", arr) + result_arr = pb_tensor_to_xarray(pb_tensor) assert arr.dtype == result_arr.dtype @pytest.mark.parametrize("shape", [(3, 3), (1,), (1, 1), (18, 20, 1)]) def test_should_same_shape(self, shape): arr = xr.DataArray(np.zeros(shape)) - tensor = self.to_pb_tensor(arr) - result_arr = pb_tensor_to_xarray(tensor) + pb_tensor = to_pb_tensor("input0", arr) + result_arr = pb_tensor_to_xarray(pb_tensor) assert arr.shape == result_arr.shape @pytest.mark.parametrize("shape", [(3, 3), (1,), (1, 1), (18, 20, 1)]) def test_should_same_data(self, shape): arr = xr.DataArray(np.random.random(shape)) - tensor = self.to_pb_tensor(arr) - result_arr = pb_tensor_to_xarray(tensor) + pb_tensor = to_pb_tensor("input0", arr) + result_arr = pb_tensor_to_xarray(pb_tensor) assert_array_equal(arr, result_arr) @@ -276,3 +268,64 @@ def test_parametrized_input_shape(self, min_shape, axes, step): assert [(d.name, d.size) for d in pb_shape.stepShape.namedInts] == [ (name, size) for name, size in zip(axes, step) ] + + +class TestSample: + def test_create_sample_from_pb_tensors(self): + arr_1 = np.arange(32 * 32, dtype=np.int64).reshape(32, 32) + tensor_1 = inference_pb2.Tensor( + dtype="int64", + tensorId="input1", + buffer=bytes(arr_1), + shape=[inference_pb2.NamedInt(name="x", size=32), inference_pb2.NamedInt(name="y", size=32)], + ) + + arr_2 = np.arange(64 * 64, dtype=int).reshape(64, 64) + tensor_2 = inference_pb2.Tensor( + dtype="int64", + tensorId="input2", + buffer=bytes(arr_2), + shape=[inference_pb2.NamedInt(name="x", size=64), inference_pb2.NamedInt(name="y", size=64)], + ) + + sample = Sample.from_pb_tensors([tensor_1, tensor_2]) + assert len(sample.tensors) == 2 + assert sample.tensors["input1"].equals(xr.DataArray(arr_1, dims=["x", "y"])) + assert sample.tensors["input2"].equals(xr.DataArray(arr_2, dims=["x", "y"])) + + def test_create_sample_from_raw_data(self): + arr_1 = np.arange(32 * 32, dtype=np.int64).reshape(32, 32) + tensor_1 = xr.DataArray(arr_1, dims=["x", "y"]) + arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64) + tensor_2 = xr.DataArray(arr_2, dims=["x", "y"]) + tensors_ids = ["input1", "input2"] + actual_sample = Sample.from_xr_tensors(tensors_ids, [tensor_1, tensor_2]) + + expected_dict = {tensors_ids[0]: tensor_1, tensors_ids[1]: tensor_2} + expected_sample = Sample(expected_dict) + assert actual_sample == expected_sample + + def test_sample_to_pb_tensors(self): + arr_1 = np.arange(32 * 32, dtype=np.int64).reshape(32, 32) + tensor_1 = xr.DataArray(arr_1, dims=["x", "y"]) + arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64) + tensor_2 = xr.DataArray(arr_2, dims=["x", "y"]) + tensors_ids = ["input1", "input2"] + sample = Sample.from_xr_tensors(tensors_ids, [tensor_1, tensor_2]) + + pb_tensor_1 = inference_pb2.Tensor( + dtype="int64", + tensorId="input1", + buffer=bytes(arr_1), + shape=[inference_pb2.NamedInt(name="x", size=32), inference_pb2.NamedInt(name="y", size=32)], + ) + pb_tensor_2 = inference_pb2.Tensor( + dtype="int64", + tensorId="input2", + buffer=bytes(arr_2), + shape=[inference_pb2.NamedInt(name="x", size=64), inference_pb2.NamedInt(name="y", size=64)], + ) + expected_tensors = [pb_tensor_1, pb_tensor_2] + + actual_tensors = sample.to_pb_tensors() + assert expected_tensors == actual_tensors diff --git a/tests/test_rpc/test_mp.py b/tests/test_rpc/test_mp.py index 981f4553..2011c54e 100644 --- a/tests/test_rpc/test_mp.py +++ b/tests/test_rpc/test_mp.py @@ -8,7 +8,7 @@ from tiktorch import log from tiktorch.rpc import RPCFuture, RPCInterface, Shutdown, exposed -from tiktorch.rpc.mp import FutureStore, MPServer, create_client +from tiktorch.rpc.mp import FutureStore, MPServer, create_client_api class ITestApi(RPCInterface): @@ -64,7 +64,7 @@ def client(log_queue): p = mp.Process(target=_srv, args=(parent, log_queue)) p.start() - client = create_client(ITestApi, child, timeout=10) + client = create_client_api(iface_cls=ITestApi, conn=child, timeout=10) yield client @@ -108,7 +108,7 @@ def __getattr__(self, name): p = mp.Process(target=_srv, args=(parent, log_queue)) p.start() - client = create_client(ITestApi, SlowConn(child)) + client = create_client_api(iface_cls=ITestApi, conn=SlowConn(child)) client.fast_compute(2, 2) @@ -121,7 +121,7 @@ def test_future_timeout(client: ITestApi, log_queue): p = mp.Process(target=_srv, args=(parent, log_queue)) p.start() - client = create_client(ITestApi, child, timeout=0.001) + client = create_client_api(iface_cls=ITestApi, conn=child, timeout=0.001) with pytest.raises(TimeoutError): client.compute(1, 2) @@ -256,7 +256,7 @@ def _spawn(iface_cls, srv_cls): p = mp.Process(target=_run_srv, args=(srv_cls, parent, log_queue)) p.start() - data["client"] = client = create_client(iface_cls, child) + data["client"] = client = create_client_api(iface_cls=iface_cls, conn=child) data["process"] = p return client diff --git a/tests/test_server/test_grpc/test_inference_servicer.py b/tests/test_server/test_grpc/test_inference_servicer.py index 8a5188af..864c4d6a 100644 --- a/tests/test_server/test_grpc/test_inference_servicer.py +++ b/tests/test_server/test_grpc/test_inference_servicer.py @@ -47,12 +47,10 @@ def method_requiring_session(self, request, grpc_stub): def test_model_session_creation(self, grpc_stub, bioimageio_model_bytes): model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_model_bytes)) assert model.id - assert hasattr(model, "outputShapes") - assert hasattr(model, "inputShapes") grpc_stub.CloseModelSession(model) - def test_model_session_creation_using_upload_id(self, grpc_stub, data_store, bioimageio_dummy_model_bytes): - id_ = data_store.put(bioimageio_dummy_model_bytes.getvalue()) + def test_model_session_creation_using_upload_id(self, grpc_stub, data_store, bioimageio_dummy_explicit_model_bytes): + id_ = data_store.put(bioimageio_dummy_explicit_model_bytes.getvalue()) rq = inference_pb2.CreateModelSessionRequest(model_uri=f"upload://{id_}", deviceIds=["cpu"]) model = grpc_stub.CreateModelSession(rq) @@ -156,27 +154,81 @@ def test_call_fails_with_unknown_model_session_id(self, grpc_stub): assert grpc.StatusCode.FAILED_PRECONDITION == e.value.code() assert "model-session with id myid1 doesn't exist" in e.value.details() - def test_call_predict(self, grpc_stub, bioimageio_dummy_model_bytes): - model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_model_bytes)) - arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y")) + def test_call_predict_valid_explicit(self, grpc_stub, bioimageio_dummy_explicit_model_bytes): + model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_explicit_model_bytes)) + arr = xr.DataArray(np.arange(128 * 128).reshape(1, 1, 128, 128), dims=("b", "c", "x", "y")) expected = arr + 1 - input_tensors = [converters.xarray_to_pb_tensor(arr)] + input_tensor_id = "input" + output_tensor_id = "output" + input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)] res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) grpc_stub.CloseModelSession(model) assert len(res.tensors) == 1 + assert res.tensors[0].tensorId == output_tensor_id assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0])) + def test_call_predict_invalid_shape_explicit(self, grpc_stub, bioimageio_dummy_explicit_model_bytes): + model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_explicit_model_bytes)) + arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y")) + input_tensors = [converters.xarray_to_pb_tensor("input", arr)] + with pytest.raises(grpc.RpcError): + grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + grpc_stub.CloseModelSession(model) + + @pytest.mark.parametrize( + "shape", + [(1, 1, 64, 32), (1, 1, 32, 64), (1, 1, 64, 32), (0, 1, 64, 64), (1, 0, 64, 64)], + ) + def test_call_predict_invalid_shape_parameterized(self, grpc_stub, shape, bioimageio_dummy_param_model_bytes): + model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_param_model_bytes)) + arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("b", "c", "x", "y")) + input_tensors = [converters.xarray_to_pb_tensor("param", arr)] + with pytest.raises(grpc.RpcError): + grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + grpc_stub.CloseModelSession(model) + + def test_call_predict_invalid_tensor_ids(self, grpc_stub, bioimageio_dummy_model): + model_bytes, _ = bioimageio_dummy_model + model = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) + arr = xr.DataArray(np.arange(32 * 32).reshape(32, 32), dims=("x", "y")) + input_tensors = [converters.xarray_to_pb_tensor("invalidTensorName", arr)] + with pytest.raises(grpc.RpcError) as error: + grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + assert error.value.details().startswith("Exception calling application: Spec invalidTensorName doesn't exist") + grpc_stub.CloseModelSession(model) + + def test_call_predict_invalid_axes(self, grpc_stub, bioimageio_dummy_model): + model_bytes, tensor_id = bioimageio_dummy_model + model = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) + arr = xr.DataArray(np.arange(32 * 32).reshape(32, 32), dims=("invalidAxis", "y")) + input_tensors = [converters.xarray_to_pb_tensor(tensor_id, arr)] + with pytest.raises(grpc.RpcError) as error: + grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + assert error.value.details().startswith("Exception calling application: Incompatible axes") + grpc_stub.CloseModelSession(model) + + @pytest.mark.parametrize("shape", [(1, 1, 64, 64), (1, 1, 66, 65), (1, 1, 68, 66), (1, 1, 70, 67)]) + def test_call_predict_valid_shape_parameterized(self, grpc_stub, shape, bioimageio_dummy_param_model_bytes): + model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_param_model_bytes)) + arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("b", "c", "x", "y")) + input_tensors = [converters.xarray_to_pb_tensor("param", arr)] + grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + grpc_stub.CloseModelSession(model) + @pytest.mark.skip def test_call_predict_tf(self, grpc_stub, bioimageio_dummy_tensorflow_model_bytes): model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_tensorflow_model_bytes)) arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y")) expected = arr * -1 - input_tensors = [converters.xarray_to_pb_tensor(arr)] + input_tensor_id = "input" + output_tensor_id = "output" + input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)] res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) grpc_stub.CloseModelSession(model) assert len(res.tensors) == 1 + assert res.tensors[0].tensorId == output_tensor_id assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0])) diff --git a/tests/test_server/test_modelinfo.py b/tests/test_server/test_modelinfo.py deleted file mode 100644 index 4fbb72c5..00000000 --- a/tests/test_server/test_modelinfo.py +++ /dev/null @@ -1,153 +0,0 @@ -import random -from unittest import mock - -import pytest -from bioimageio.spec.shared.raw_nodes import ImplicitOutputShape, ParametrizedInputShape -from marshmallow import missing - -from tiktorch.converters import NamedExplicitOutputShape, NamedImplicitOutputShape, NamedParametrizedShape -from tiktorch.server.session.process import ModelInfo - - -@pytest.fixture -def implicit_output_spec(): - """output spec with ImplicitOutputShape""" - shape = ImplicitOutputShape( - reference_tensor="blah", - scale=[1.0] + [float(random.randint(0, 2**32)) for _ in range(4)], - offset=[0.0] + [float(random.randint(0, 2**32)) for _ in range(4)], - ) - output_spec = mock.Mock(axes=("x", "y"), shape=shape, halo=[5, 12]) - output_spec.name = "implicit_out" - return output_spec - - -@pytest.fixture -def parametrized_input_spec(): - shape = ParametrizedInputShape( - min=[random.randint(0, 2**32) for _ in range(5)], step=[float(random.randint(0, 2**32)) for _ in range(5)] - ) - input_spec = mock.Mock(axes=("b", "x", "y", "z", "c"), shape=shape) - input_spec.name = "param_in" - return input_spec - - -@pytest.fixture -def explicit_input_spec(): - input_shape = [random.randint(0, 2**32) for _ in range(3)] - input_spec = mock.Mock(axes=("b", "x", "y"), shape=input_shape) - input_spec.name = "explicit_in" - return input_spec - - -@pytest.fixture -def explicit_output_spec(): - output_shape = [random.randint(0, 2**32) for _ in range(3)] - halo = [0] + [random.randint(0, 2**32) for _ in range(2)] - output_spec = mock.Mock(axes=("b", "x", "y"), shape=output_shape, halo=halo) - output_spec.name = "explicit_out" - return output_spec - - -def test_model_info_explicit_shapes(explicit_input_spec, explicit_output_spec): - prediction_pipeline = mock.Mock(input_specs=[explicit_input_spec], output_specs=[explicit_output_spec], name="bleh") - - model_info = ModelInfo.from_prediction_pipeline(prediction_pipeline) - - assert model_info.input_axes == ["".join(explicit_input_spec.axes)] - assert model_info.output_axes == ["".join(explicit_output_spec.axes)] - assert len(model_info.input_shapes) == 1 - assert len(model_info.output_shapes) == 1 - assert isinstance(model_info.input_shapes[0], list) - assert model_info.input_shapes[0] == [(ax, s) for ax, s in zip(explicit_input_spec.axes, explicit_input_spec.shape)] - assert isinstance(model_info.output_shapes[0], NamedExplicitOutputShape) - assert model_info.output_shapes[0].shape == [ - (ax, s) for ax, s in zip(explicit_output_spec.axes, explicit_output_spec.shape) - ] - assert model_info.output_shapes[0].halo == [ - (ax, s) for ax, s in zip(explicit_output_spec.axes, explicit_output_spec.halo) - ] - assert model_info.input_names == ["explicit_in"] - assert model_info.output_names == ["explicit_out"] - - -def test_model_info_explicit_shapes_missing_halo(explicit_input_spec, explicit_output_spec): - explicit_output_spec.halo = missing - - prediction_pipeline = mock.Mock(input_specs=[explicit_input_spec], output_specs=[explicit_output_spec], name="bleh") - - model_info = ModelInfo.from_prediction_pipeline(prediction_pipeline) - - assert model_info.input_axes == ["".join(explicit_input_spec.axes)] - assert model_info.output_axes == ["".join(explicit_output_spec.axes)] - assert len(model_info.input_shapes) == 1 - assert len(model_info.output_shapes) == 1 - assert isinstance(model_info.input_shapes[0], list) - assert model_info.input_shapes[0] == [(ax, s) for ax, s in zip(explicit_input_spec.axes, explicit_input_spec.shape)] - assert isinstance(model_info.output_shapes[0], NamedExplicitOutputShape) - assert model_info.output_shapes[0].shape == [ - (ax, s) for ax, s in zip(explicit_output_spec.axes, explicit_output_spec.shape) - ] - assert model_info.output_shapes[0].halo == [(ax, s) for ax, s in zip(explicit_output_spec.axes, [0, 0, 0])] - - -def test_model_info_implicit_shapes(parametrized_input_spec, implicit_output_spec): - prediction_pipeline = mock.Mock( - input_specs=[parametrized_input_spec], output_specs=[implicit_output_spec], name="bleh" - ) - - model_info = ModelInfo.from_prediction_pipeline(prediction_pipeline) - assert model_info.input_axes == ["".join(parametrized_input_spec.axes)] - assert model_info.output_axes == ["".join(implicit_output_spec.axes)] - assert len(model_info.input_shapes) == 1 - assert len(model_info.output_shapes) == 1 - assert isinstance(model_info.input_shapes[0], NamedParametrizedShape) - assert model_info.input_shapes[0].min_shape == [ - (ax, s) for ax, s in zip(parametrized_input_spec.axes, parametrized_input_spec.shape.min) - ] - assert model_info.input_shapes[0].step_shape == [ - (ax, s) for ax, s in zip(parametrized_input_spec.axes, parametrized_input_spec.shape.step) - ] - assert isinstance(model_info.output_shapes[0], NamedImplicitOutputShape) - assert model_info.output_shapes[0].offset == [ - (ax, s) for ax, s in zip(implicit_output_spec.axes, implicit_output_spec.shape.offset) - ] - assert model_info.output_shapes[0].scale == [ - (ax, s) for ax, s in zip(implicit_output_spec.axes, implicit_output_spec.shape.scale) - ] - assert model_info.output_shapes[0].halo == [ - (ax, s) for ax, s in zip(implicit_output_spec.axes, implicit_output_spec.halo) - ] - assert model_info.output_shapes[0].reference_tensor == implicit_output_spec.shape.reference_tensor - - assert model_info.input_names == ["param_in"] - assert model_info.output_names == ["implicit_out"] - - -def test_model_info_implicit_shapes_missing_halo(parametrized_input_spec, implicit_output_spec): - implicit_output_spec.halo = missing - prediction_pipeline = mock.Mock( - input_specs=[parametrized_input_spec], output_specs=[implicit_output_spec], name="bleh" - ) - - model_info = ModelInfo.from_prediction_pipeline(prediction_pipeline) - assert model_info.input_axes == ["".join(parametrized_input_spec.axes)] - assert model_info.output_axes == ["".join(implicit_output_spec.axes)] - assert len(model_info.input_shapes) == 1 - assert len(model_info.output_shapes) == 1 - assert isinstance(model_info.input_shapes[0], NamedParametrizedShape) - assert model_info.input_shapes[0].min_shape == [ - (ax, s) for ax, s in zip(parametrized_input_spec.axes, parametrized_input_spec.shape.min) - ] - assert model_info.input_shapes[0].step_shape == [ - (ax, s) for ax, s in zip(parametrized_input_spec.axes, parametrized_input_spec.shape.step) - ] - assert isinstance(model_info.output_shapes[0], NamedImplicitOutputShape) - assert model_info.output_shapes[0].offset == [ - (ax, s) for ax, s in zip(implicit_output_spec.axes, implicit_output_spec.shape.offset) - ] - assert model_info.output_shapes[0].scale == [ - (ax, s) for ax, s in zip(implicit_output_spec.axes, implicit_output_spec.shape.scale) - ] - assert model_info.output_shapes[0].halo == [(ax, s) for ax, s in zip(implicit_output_spec.axes, [0, 0, 0, 0, 0])] - assert model_info.output_shapes[0].reference_tensor == implicit_output_spec.shape.reference_tensor diff --git a/tiktorch/converters.py b/tiktorch/converters.py index 4093bec3..b7fee3fe 100644 --- a/tiktorch/converters.py +++ b/tiktorch/converters.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import dataclasses -from typing import List, Tuple, Union +from typing import Dict, List, Tuple, Union import numpy as np import xarray as xr @@ -33,6 +35,23 @@ class NamedImplicitOutputShape: halo: NamedShape +@dataclasses.dataclass(frozen=True) +class Sample: + tensors: Dict[str, xr.DataArray] + + @classmethod + def from_pb_tensors(cls, pb_tensors: List[inference_pb2.Tensor]) -> Sample: + return Sample({tensor.tensorId: pb_tensor_to_xarray(tensor) for tensor in pb_tensors}) + + @classmethod + def from_xr_tensors(cls, tensor_ids: List[str], tensors_data: List[xr.DataArray]) -> Sample: + assert len(tensor_ids) == len(tensors_data) + return Sample({tensor_id: tensor_data for tensor_id, tensor_data in zip(tensor_ids, tensors_data)}) + + def to_pb_tensors(self) -> List[inference_pb2.Tensor]: + return [xarray_to_pb_tensor(tensor_id, res_tensor) for tensor_id, res_tensor in self.tensors.items()] + + def numpy_to_pb_tensor(array: np.ndarray, axistags=None) -> inference_pb2.Tensor: if axistags: shape = [inference_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, axistags)] @@ -41,9 +60,9 @@ def numpy_to_pb_tensor(array: np.ndarray, axistags=None) -> inference_pb2.Tensor return inference_pb2.Tensor(dtype=str(array.dtype), shape=shape, buffer=bytes(array)) -def xarray_to_pb_tensor(array: xr.DataArray) -> inference_pb2.Tensor: +def xarray_to_pb_tensor(tensor_id: str, array: xr.DataArray) -> inference_pb2.Tensor: shape = [inference_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, array.dims)] - return inference_pb2.Tensor(dtype=str(array.dtype), shape=shape, buffer=bytes(array.data)) + return inference_pb2.Tensor(tensorId=tensor_id, dtype=str(array.dtype), shape=shape, buffer=bytes(array.data)) def name_int_tuples_to_pb_NamedInts(name_int_tuples) -> inference_pb2.NamedInts: diff --git a/tiktorch/proto/inference_pb2.py b/tiktorch/proto/inference_pb2.py index e098a4df..dc5c6c9c 100644 --- a/tiktorch/proto/inference_pb2.py +++ b/tiktorch/proto/inference_pb2.py @@ -20,7 +20,7 @@ syntax='proto3', serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_pb=b'\n\x0finference.proto\"Y\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06status\x18\x02 \x01(\x0e\x32\x0e.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"W\n\x1f\x43reateDatasetDescriptionRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"i\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12\x1b\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\x9d\x01\n\nInputShape\x12(\n\tshapeType\x18\x01 \x01(\x0e\x32\x15.InputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x1d\n\tstepShape\x18\x04 \x01(\x0b\x32\n.NamedInts\"+\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x10\n\x0cPARAMETRIZED\x10\x01\"\xea\x01\n\x0bOutputShape\x12)\n\tshapeType\x18\x01 \x01(\x0e\x32\x16.OutputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x18\n\x04halo\x18\x03 \x01(\x0b\x32\n.NamedInts\x12\x17\n\x0freferenceTensor\x18\x04 \x01(\t\x12\x1b\n\x05scale\x18\x05 \x01(\x0b\x32\x0c.NamedFloats\x12\x1c\n\x06offset\x18\x06 \x01(\x0b\x32\x0c.NamedFloats\"\'\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x0c\n\x08IMPLICIT\x10\x01\"\xd3\x01\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tinputAxes\x18\x03 \x03(\t\x12\x12\n\noutputAxes\x18\x04 \x03(\t\x12\x13\n\x0bhasTraining\x18\x05 \x01(\x08\x12 \n\x0binputShapes\x18\x06 \x03(\x0b\x32\x0b.InputShape\x12\"\n\x0coutputShapes\x18\x07 \x03(\x0b\x32\x0c.OutputShape\x12\x12\n\ninputNames\x18\x08 \x03(\t\x12\x13\n\x0boutputNames\x18\t \x03(\t\"\x9e\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12\x1e\n\x05level\x18\x02 \x01(\x0e\x32\x0f.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\"#\n\x07\x44\x65vices\x12\x18\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x07.Device\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"A\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x18\n\x05shape\x18\x03 \x03(\x0b\x32\t.NamedInt\"U\n\x0ePredictRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x11\n\tdatasetId\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x03 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"\x07\n\x05\x45mpty\"\x1e\n\tModelInfo\x12\x11\n\tdeviceIds\x18\x01 \x03(\t\"^\n CreateModelSessionChunkedRequest\x12\x1a\n\x04info\x18\x01 \x01(\x0b\x32\n.ModelInfoH\x00\x12\x16\n\x05\x63hunk\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x42\x06\n\x04\x64\x61ta2\xc6\x02\n\tInference\x12\x41\n\x12\x43reateModelSession\x12\x1a.CreateModelSessionRequest\x1a\r.ModelSession\"\x00\x12,\n\x11\x43loseModelSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12S\n\x18\x43reateDatasetDescription\x12 .CreateDatasetDescriptionRequest\x1a\x13.DatasetDescription\"\x00\x12 \n\x07GetLogs\x12\x06.Empty\x1a\t.LogEntry\"\x00\x30\x01\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x32G\n\rFlightControl\x12\x18\n\x04Ping\x12\x06.Empty\x1a\x06.Empty\"\x00\x12\x1c\n\x08Shutdown\x12\x06.Empty\x1a\x06.Empty\"\x00\x62\x06proto3' + serialized_pb=b'\n\x0finference.proto\"Y\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06status\x18\x02 \x01(\x0e\x32\x0e.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"W\n\x1f\x43reateDatasetDescriptionRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"i\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12\x1b\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\x9d\x01\n\nInputShape\x12(\n\tshapeType\x18\x01 \x01(\x0e\x32\x15.InputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x1d\n\tstepShape\x18\x04 \x01(\x0b\x32\n.NamedInts\"+\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x10\n\x0cPARAMETRIZED\x10\x01\"\xea\x01\n\x0bOutputShape\x12)\n\tshapeType\x18\x01 \x01(\x0e\x32\x16.OutputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x18\n\x04halo\x18\x03 \x01(\x0b\x32\n.NamedInts\x12\x17\n\x0freferenceTensor\x18\x04 \x01(\t\x12\x1b\n\x05scale\x18\x05 \x01(\x0b\x32\x0c.NamedFloats\x12\x1c\n\x06offset\x18\x06 \x01(\x0b\x32\x0c.NamedFloats\"\'\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x0c\n\x08IMPLICIT\x10\x01\"\x1a\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\"\x9e\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12\x1e\n\x05level\x18\x02 \x01(\x0e\x32\x0f.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\"#\n\x07\x44\x65vices\x12\x18\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x07.Device\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"S\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x10\n\x08tensorId\x18\x03 \x01(\t\x12\x18\n\x05shape\x18\x04 \x03(\x0b\x32\t.NamedInt\"U\n\x0ePredictRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x11\n\tdatasetId\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x03 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"\x07\n\x05\x45mpty\"\x1e\n\tModelInfo\x12\x11\n\tdeviceIds\x18\x01 \x03(\t\"^\n CreateModelSessionChunkedRequest\x12\x1a\n\x04info\x18\x01 \x01(\x0b\x32\n.ModelInfoH\x00\x12\x16\n\x05\x63hunk\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x42\x06\n\x04\x64\x61ta2\xc6\x02\n\tInference\x12\x41\n\x12\x43reateModelSession\x12\x1a.CreateModelSessionRequest\x1a\r.ModelSession\"\x00\x12,\n\x11\x43loseModelSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12S\n\x18\x43reateDatasetDescription\x12 .CreateDatasetDescriptionRequest\x1a\x13.DatasetDescription\"\x00\x12 \n\x07GetLogs\x12\x06.Empty\x1a\t.LogEntry\"\x00\x30\x01\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x32G\n\rFlightControl\x12\x18\n\x04Ping\x12\x06.Empty\x1a\x06.Empty\"\x00\x12\x1c\n\x08Shutdown\x12\x06.Empty\x1a\x06.Empty\"\x00\x62\x06proto3' ) @@ -140,8 +140,8 @@ ], containing_type=None, serialized_options=None, - serialized_start=1165, - serialized_end=1243, + serialized_start=979, + serialized_end=1057, ) _sym_db.RegisterEnumDescriptor(_LOGENTRY_LEVEL) @@ -548,62 +548,6 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='name', full_name='ModelSession.name', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='inputAxes', full_name='ModelSession.inputAxes', index=2, - number=3, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='outputAxes', full_name='ModelSession.outputAxes', index=3, - number=4, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='hasTraining', full_name='ModelSession.hasTraining', index=4, - number=5, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='inputShapes', full_name='ModelSession.inputShapes', index=5, - number=6, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='outputShapes', full_name='ModelSession.outputShapes', index=6, - number=7, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='inputNames', full_name='ModelSession.inputNames', index=7, - number=8, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='outputNames', full_name='ModelSession.outputNames', index=8, - number=9, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ], extensions=[ ], @@ -616,8 +560,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=871, - serialized_end=1082, + serialized_start=870, + serialized_end=896, ) @@ -663,8 +607,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1085, - serialized_end=1243, + serialized_start=899, + serialized_end=1057, ) @@ -695,8 +639,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1245, - serialized_end=1280, + serialized_start=1059, + serialized_end=1094, ) @@ -734,8 +678,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1282, - serialized_end=1320, + serialized_start=1096, + serialized_end=1134, ) @@ -773,8 +717,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1322, - serialized_end=1362, + serialized_start=1136, + serialized_end=1176, ) @@ -801,8 +745,15 @@ is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), _descriptor.FieldDescriptor( - name='shape', full_name='Tensor.shape', index=2, - number=3, type=11, cpp_type=10, label=3, + name='tensorId', full_name='Tensor.tensorId', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='shape', full_name='Tensor.shape', index=3, + number=4, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, @@ -819,8 +770,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1364, - serialized_end=1429, + serialized_start=1178, + serialized_end=1261, ) @@ -865,8 +816,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1431, - serialized_end=1516, + serialized_start=1263, + serialized_end=1348, ) @@ -897,8 +848,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1518, - serialized_end=1561, + serialized_start=1350, + serialized_end=1393, ) @@ -922,8 +873,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1563, - serialized_end=1570, + serialized_start=1395, + serialized_end=1402, ) @@ -954,8 +905,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1572, - serialized_end=1602, + serialized_start=1404, + serialized_end=1434, ) @@ -998,8 +949,8 @@ create_key=_descriptor._internal_create_key, fields=[]), ], - serialized_start=1604, - serialized_end=1698, + serialized_start=1436, + serialized_end=1530, ) _DEVICE.fields_by_name['status'].enum_type = _DEVICE_STATUS @@ -1023,8 +974,6 @@ _OUTPUTSHAPE.fields_by_name['scale'].message_type = _NAMEDFLOATS _OUTPUTSHAPE.fields_by_name['offset'].message_type = _NAMEDFLOATS _OUTPUTSHAPE_SHAPETYPE.containing_type = _OUTPUTSHAPE -_MODELSESSION.fields_by_name['inputShapes'].message_type = _INPUTSHAPE -_MODELSESSION.fields_by_name['outputShapes'].message_type = _OUTPUTSHAPE _LOGENTRY.fields_by_name['level'].enum_type = _LOGENTRY_LEVEL _LOGENTRY_LEVEL.containing_type = _LOGENTRY _DEVICES.fields_by_name['devices'].message_type = _DEVICE @@ -1210,8 +1159,8 @@ index=0, serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_start=1701, - serialized_end=2027, + serialized_start=1533, + serialized_end=1859, methods=[ _descriptor.MethodDescriptor( name='CreateModelSession', @@ -1286,8 +1235,8 @@ index=1, serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_start=2029, - serialized_end=2100, + serialized_start=1861, + serialized_end=1932, methods=[ _descriptor.MethodDescriptor( name='Ping', diff --git a/tiktorch/rpc/mp.py b/tiktorch/rpc/mp.py index bab3ab1f..0cd69bf4 100644 --- a/tiktorch/rpc/mp.py +++ b/tiktorch/rpc/mp.py @@ -1,3 +1,4 @@ +import dataclasses import logging import queue import threading @@ -5,9 +6,12 @@ from functools import wraps from multiprocessing.connection import Connection from threading import Event, Thread -from typing import Any, Optional, Type, TypeVar +from typing import Any, List, Optional, Type, TypeVar from uuid import uuid4 +from bioimageio.core.resource_io import nodes + +from ..server.session import IRPCModelSession from .exceptions import Shutdown from .interface import get_exposed_methods from .types import RPCFuture, isfutureret @@ -72,9 +76,8 @@ def __call__(self, *args, **kwargs) -> Any: return self._client._invoke(self._method_name, *args, **kwargs) -def create_client(iface_cls: Type[T], conn: Connection, timeout=None) -> T: +def create_client_api(iface_cls: Type[T], conn: Connection, timeout=None) -> T: client = MPClient(iface_cls.__name__, conn, timeout) - get_exposed_methods(iface_cls) def _make_method(method): class MethodWrapper: @@ -97,13 +100,21 @@ def __call__(self, *args, **kwargs) -> Any: return MethodWrapper() - class _Client(iface_cls): + class _Api: pass - for method_name, method in get_exposed_methods(iface_cls).items(): - setattr(_Client, method_name, _make_method(method)) + exposed_methods = get_exposed_methods(iface_cls) + for method_name, method in exposed_methods.items(): + setattr(_Api, method_name, _make_method(method)) + + return _Api() + - return _Client() +@dataclasses.dataclass(frozen=True) +class BioModelClient: + api: IRPCModelSession + input_specs: List[nodes.InputTensor] + output_specs: List[nodes.OutputTensor] class MPClient: @@ -190,7 +201,7 @@ def _shutdown(self, exc): class Message: def __init__(self, id_): - self.id = id + self.id = id_ class Signal: @@ -200,20 +211,19 @@ def __init__(self, payload): class MethodCall(Message): def __init__(self, id_, method_name, args, kwargs): - self.id = id_ + super().__init__(id_) self.method_name = method_name self.args = args self.kwargs = kwargs class Cancellation(Message): - def __init__(self, id_): - self.id = id_ + pass class MethodReturn(Message): def __init__(self, id_, result: Result): - self.id = id_ + super().__init__(id_) self.result = result diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py index 5e188787..f09e0bae 100644 --- a/tiktorch/server/grpc/inference_servicer.py +++ b/tiktorch/server/grpc/inference_servicer.py @@ -2,12 +2,12 @@ import grpc -from tiktorch import converters +from tiktorch.converters import Sample from tiktorch.proto import inference_pb2, inference_pb2_grpc from tiktorch.server.data_store import IDataStore from tiktorch.server.device_pool import DeviceStatus, IDevicePool -from tiktorch.server.session.process import start_model_session_process -from tiktorch.server.session_manager import ISession, SessionManager +from tiktorch.server.session.process import InputTensorValidator, start_model_session_process +from tiktorch.server.session_manager import Session, SessionManager class InferenceServicer(inference_pb2_grpc.InferenceServicer): @@ -36,37 +36,17 @@ def CreateModelSession( lease.terminate() raise - session = self.__session_manager.create_session() + session = self.__session_manager.create_session(client) session.on_close(lease.terminate) - session.on_close(client.shutdown) - session.client = client + session.on_close(client.api.shutdown) - try: - model_info = session.client.get_model_info() - except Exception: - lease.terminate() - raise - - pb_input_shapes = [converters.input_shape_to_pb_input_shape(shape) for shape in model_info.input_shapes] - pb_output_shapes = [converters.output_shape_to_pb_output_shape(shape) for shape in model_info.output_shapes] - - return inference_pb2.ModelSession( - id=session.id, - name=model_info.name, - inputAxes=model_info.input_axes, - outputAxes=model_info.output_axes, - inputShapes=pb_input_shapes, - hasTraining=False, - outputShapes=pb_output_shapes, - inputNames=model_info.input_names, - outputNames=model_info.output_names, - ) + return inference_pb2.ModelSession(id=session.id) def CreateDatasetDescription( self, request: inference_pb2.CreateDatasetDescriptionRequest, context ) -> inference_pb2.DatasetDescription: session = self._getModelSession(context, request.modelSessionId) - id = session.client.create_dataset_description(mean=request.mean, stddev=request.stddev) + id = session.bio_model_client.api.create_dataset_description(mean=request.mean, stddev=request.stddev) return inference_pb2.DatasetDescription(id=id) def CloseModelSession(self, request: inference_pb2.ModelSession, context) -> inference_pb2.Empty: @@ -95,12 +75,15 @@ def ListDevices(self, request: inference_pb2.Empty, context) -> inference_pb2.De def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse: session = self._getModelSession(context, request.modelSessionId) - arrs = [converters.pb_tensor_to_xarray(tensor) for tensor in request.tensors] - res = session.client.forward(arrs) - pb_tensors = [converters.xarray_to_pb_tensor(res_tensor) for res_tensor in res] - return inference_pb2.PredictResponse(tensors=pb_tensors) - - def _getModelSession(self, context, modelSessionId: str) -> ISession: + input_sample = Sample.from_pb_tensors(request.tensors) + tensor_validator = InputTensorValidator(session.bio_model_client.input_specs) + tensor_validator.check_tensors(input_sample) + res = session.bio_model_client.api.forward(input_sample) + output_tensor_ids = [tensor.name for tensor in session.bio_model_client.output_specs] + output_sample = Sample.from_xr_tensors(output_tensor_ids, res) + return inference_pb2.PredictResponse(tensors=output_sample.to_pb_tensors()) + + def _getModelSession(self, context, modelSessionId: str) -> Session: if not modelSessionId: context.abort(grpc.StatusCode.FAILED_PRECONDITION, "model-session-id has not been provided by client") diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index 91b00f83..c9e0186e 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -1,106 +1,121 @@ -import dataclasses import multiprocessing as _mp -import os import pathlib import tempfile import uuid from concurrent.futures import Future from multiprocessing.connection import Connection -from typing import List, Optional, Tuple, Union +from typing import Dict, Iterator, List, Optional, Tuple -import numpy +import numpy as np from bioimageio.core import load_resource_description from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline -from bioimageio.spec.shared.raw_nodes import ImplicitOutputShape, ParametrizedInputShape -from marshmallow import missing +from bioimageio.core.resource_io import nodes +from bioimageio.core.resource_io.nodes import ParametrizedInputShape from tiktorch import log -from tiktorch.converters import NamedExplicitOutputShape, NamedImplicitOutputShape, NamedParametrizedShape, NamedShape from tiktorch.rpc import Shutdown from tiktorch.rpc import mp as _mp_rpc -from tiktorch.rpc.mp import MPServer +from tiktorch.rpc.mp import BioModelClient, MPServer +from ...converters import Sample from .backend import base from .rpc_interface import IRPCModelSession -@dataclasses.dataclass -class ModelInfo: - """Intermediate representation of bioimageio neural network model - - TODO (k-dominik): ModelInfo only used in inference_servicer to convert to - protobuf modelinfo. - - """ - - name: str - input_axes: List[str] # one per input - output_axes: List[str] # one per output - input_shapes: List[Union[NamedShape, NamedParametrizedShape]] # per input multiple shapes - output_shapes: List[Union[NamedExplicitOutputShape, NamedImplicitOutputShape]] - input_names: List[str] # one per input - output_names: List[str] # one per output - - @classmethod - def from_prediction_pipeline(cls, prediction_pipeline: PredictionPipeline) -> "ModelInfo": - input_shapes = [] - for input_spec in prediction_pipeline.input_specs: - if isinstance(input_spec.shape, ParametrizedInputShape): - input_shapes.append( - NamedParametrizedShape( - min_shape=list(map(tuple, zip(input_spec.axes, input_spec.shape.min))), - step_shape=list(map(tuple, zip(input_spec.axes, input_spec.shape.step))), - ) - ) - else: - input_shapes.append(list(map(tuple, zip(input_spec.axes, input_spec.shape)))) - - output_shapes = [] - for output_spec in prediction_pipeline.output_specs: - # halo is not required by spec. We could alternatively make it optional in the - # respective grpc message types and handle missing values in ilastik - halo = [0 for _ in output_spec.axes] if output_spec.halo == missing else output_spec.halo - if isinstance(output_spec.shape, ImplicitOutputShape): - output_shapes.append( - NamedImplicitOutputShape( - reference_tensor=output_spec.shape.reference_tensor, - scale=list(map(tuple, zip(output_spec.axes, output_spec.shape.scale))), - offset=list(map(tuple, zip(output_spec.axes, output_spec.shape.offset))), - halo=list(map(tuple, zip(output_spec.axes, halo))), - ) - ) - else: # isinstance(output_spec.shape, ExplicitShape): - output_shapes.append( - NamedExplicitOutputShape( - shape=list(map(tuple, zip(output_spec.axes, output_spec.shape))), - halo=list(map(tuple, zip(output_spec.axes, halo))), - ) - ) - - return cls( - name=prediction_pipeline.name, - input_axes=["".join(input_spec.axes) for input_spec in prediction_pipeline.input_specs], - output_axes=["".join(output_spec.axes) for output_spec in prediction_pipeline.output_specs], - input_shapes=input_shapes, - output_shapes=output_shapes, - input_names=[input_spec.name for input_spec in prediction_pipeline.input_specs], - output_names=[output_spec.name for output_spec in prediction_pipeline.output_specs], - ) - - -class ModelSessionProcess(IRPCModelSession): - def __init__(self, model_zip: bytes, devices: List[str]) -> None: - _tmp_file = tempfile.NamedTemporaryFile(suffix=".zip", delete=False) - _tmp_file.write(model_zip) - _tmp_file.close() - model = load_resource_description(pathlib.Path(_tmp_file.name)) - os.unlink(_tmp_file.name) - self._model: PredictionPipeline = create_prediction_pipeline(bioimageio_model=model, devices=devices) +class InputTensorValidator: + def __init__(self, input_specs: List[nodes.InputTensor]): + self._input_specs = input_specs + + def check_tensors(self, sample: Sample): + for tensor_id, tensor in sample.tensors.items(): + self.check_shape(tensor_id, tensor.dims, tensor.shape) + + def _get_input_tensors_with_names(self) -> Dict[str, nodes.InputTensor]: + return {tensor.name: tensor for tensor in self._input_specs} + + def check_shape(self, tensor_id: str, axes: Tuple[str, ...], shape: Tuple[int, ...]): + shape = self.get_axes_with_size(axes, shape) + spec = self._get_input_spec(tensor_id) + if isinstance(spec.shape, list): + self._check_shape_explicit(spec, shape) + elif isinstance(spec.shape, ParametrizedInputShape): + self._check_shape_parameterized(spec, shape) + else: + raise ValueError(f"Unexpected shape {spec.shape}") + + def _get_input_spec(self, tensor_id: str) -> nodes.InputTensor: + self._check_spec_exists(tensor_id) + specs = [spec for spec in self._input_specs if spec.name == tensor_id] + assert len(specs) == 1, "ids of tensor specs should be unique" + return specs[0] + + def _check_spec_exists(self, tensor_id: str): + spec_names = [spec.name for spec in self._input_specs] + if tensor_id not in spec_names: + raise ValueError(f"Spec {tensor_id} doesn't exist for specs {spec_names}") + + def _check_shape_explicit(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]): + assert self.is_shape_explicit(spec) + reference_shape = {name: size for name, size in zip(spec.axes, spec.shape)} + self.check_same_axes(reference_shape, tensor_shape) + if reference_shape != tensor_shape: + raise ValueError(f"Incompatible shapes found {tensor_shape}, expected {reference_shape}") + + def _check_shape_parameterized(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]): + assert isinstance(spec.shape, ParametrizedInputShape) + if not self.is_shape(tensor_shape.values()): + raise ValueError(f"Invalid shape's sizes {tensor_shape}") + + min_shape = self.get_axes_with_size(spec.axes, tuple(spec.shape.min)) + step = self.get_axes_with_size(spec.axes, tuple(spec.shape.step)) + self.check_same_axes(tensor_shape, min_shape) + + tensor_shapes_arr = np.array(list(tensor_shape.values())) + min_shape_arr = np.array(list(min_shape.values())) + step_arr = np.array(list(step.values())) + diff = tensor_shapes_arr - min_shape_arr + if any(size < 0 for size in diff): + raise ValueError(f"Tensor shape {tensor_shape} smaller than min shape {min_shape}") + + non_zero_idx = np.nonzero(step_arr) + multipliers = diff[non_zero_idx] / step_arr[non_zero_idx] + multiplier = np.unique(multipliers) + if len(multiplier) == 1 and self.is_natural_number(multiplier[0]): + return + raise ValueError(f"Tensor shape {tensor_shape} not valid for spec {spec}") + + @staticmethod + def check_same_axes(source: Dict[str, int], target: Dict[str, int]): + if source.keys() != target.keys(): + raise ValueError(f"Incompatible axes for tensor {target} and reference {source}") + + @staticmethod + def is_natural_number(n) -> bool: + return n % 1 == 0.0 and n >= 0 + + @staticmethod + def is_shape(shape: Iterator[int]) -> bool: + return all(InputTensorValidator.is_natural_number(dim) for dim in shape) + + @staticmethod + def get_axes_with_size(axes: Tuple[str, ...], shape: Tuple[int, ...]) -> Dict[str, int]: + assert len(axes) == len(shape) + return {name: size for name, size in zip(axes, shape)} + + @staticmethod + def is_shape_explicit(spec: nodes.InputTensor) -> bool: + return isinstance(spec.shape, list) + + +class ModelSessionProcess(IRPCModelSession[PredictionPipeline]): + def __init__(self, model: PredictionPipeline) -> None: + super().__init__(model) self._datasets = {} self._worker = base.SessionBackend(self._model) - def forward(self, input_tensors: numpy.ndarray) -> Future: - res = self._worker.forward(input_tensors) + def forward(self, sample: Sample) -> Future: + tensors_data = [sample.tensors[tensor.name] for tensor in self.model.input_specs] + res = self._worker.forward(tensors_data) return res def create_dataset(self, mean, stddev): @@ -108,16 +123,13 @@ def create_dataset(self, mean, stddev): self._datasets[id_] = {"mean": mean, "stddev": stddev} return id_ - def get_model_info(self) -> ModelInfo: - return ModelInfo.from_prediction_pipeline(self._model) - def shutdown(self) -> Shutdown: self._worker.shutdown() return Shutdown() def _run_model_session_process( - conn: Connection, model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None + conn: Connection, prediction_pipeline: PredictionPipeline, log_queue: Optional[_mp.Queue] = None ): try: # from: https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667 @@ -131,19 +143,35 @@ def _run_model_session_process( if log_queue: log.configure(log_queue) - session_proc = ModelSessionProcess(model_zip, devices) + session_proc = ModelSessionProcess(prediction_pipeline) srv = MPServer(session_proc, conn) srv.listen() def start_model_session_process( model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None -) -> Tuple[_mp.Process, IRPCModelSession]: +) -> Tuple[_mp.Process, BioModelClient]: client_conn, server_conn = _mp.Pipe() + prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_zip, devices) proc = _mp.Process( target=_run_model_session_process, name="ModelSessionProcess", - kwargs={"conn": server_conn, "devices": devices, "log_queue": log_queue, "model_zip": model_zip}, + kwargs={ + "conn": server_conn, + "log_queue": log_queue, + "prediction_pipeline": prediction_pipeline, + }, ) proc.start() - return proc, _mp_rpc.create_client(IRPCModelSession, client_conn) + api = _mp_rpc.create_client_api(iface_cls=IRPCModelSession, conn=client_conn) + return proc, BioModelClient( + input_specs=prediction_pipeline.input_specs, output_specs=prediction_pipeline.output_specs, api=api + ) + + +def _get_prediction_pipeline_from_model_bytes(model_zip: bytes, devices: List[str]) -> PredictionPipeline: + with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as _tmp_file: + _tmp_file.write(model_zip) + temp_file_path = pathlib.Path(_tmp_file.name) + model = load_resource_description(temp_file_path) + return create_prediction_pipeline(bioimageio_model=model, devices=devices) diff --git a/tiktorch/server/session/rpc_interface.py b/tiktorch/server/session/rpc_interface.py index a01dd02a..bb414138 100644 --- a/tiktorch/server/session/rpc_interface.py +++ b/tiktorch/server/session/rpc_interface.py @@ -1,11 +1,22 @@ -from typing import List +from typing import Generic, List, TypeVar +from tiktorch.converters import Sample from tiktorch.rpc import RPCInterface, Shutdown, exposed from tiktorch.tiktypes import TikTensorBatch from tiktorch.types import ModelState +ModelType = TypeVar("ModelType") + + +class IRPCModelSession(RPCInterface, Generic[ModelType]): + def __init__(self, model: ModelType): + super().__init__() + self._model = model + + @property + def model(self): + return self._model -class IRPCModelSession(RPCInterface): @exposed def shutdown(self) -> Shutdown: raise NotImplementedError @@ -43,9 +54,5 @@ def create_dataset_description(self, mean, stddev) -> str: raise NotImplementedError @exposed - def forward(self, input_tensors): - raise NotImplementedError - - @exposed - def get_model_info(self): + def forward(self, input_tensors: Sample): raise NotImplementedError diff --git a/tiktorch/server/session_manager.py b/tiktorch/server/session_manager.py index 37bc07ab..3807e130 100644 --- a/tiktorch/server/session_manager.py +++ b/tiktorch/server/session_manager.py @@ -1,50 +1,44 @@ from __future__ import annotations -import abc import threading from collections import defaultdict from logging import getLogger from typing import Callable, Dict, List, Optional from uuid import uuid4 +from tiktorch.rpc.mp import BioModelClient + logger = getLogger(__name__) +CloseCallback = Callable[[], None] + -class ISession(abc.ABC): +class Session: """ session object has unique id Used for resource managent """ + def __init__(self, id_: str, bio_model_client: BioModelClient, manager: SessionManager) -> None: + self.__id = id_ + self.__manager = manager + self.__bio_model_client = bio_model_client + + @property + def bio_model_client(self) -> BioModelClient: + return self.__bio_model_client + @property - @abc.abstractmethod def id(self) -> str: """ Returns unique id assigned to this session """ - ... + return self.__id - @abc.abstractmethod def on_close(self, handler: CloseCallback) -> None: """ Register cleanup function to be called when session ends """ - ... - - -CloseCallback = Callable[[], None] - - -class _Session(ISession): - def __init__(self, id_: str, manager: SessionManager) -> None: - self.__id = id_ - self.__manager = manager - - @property - def id(self) -> str: - return self.__id - - def on_close(self, handler: CloseCallback) -> None: self.__manager._on_close(self, handler) @@ -53,18 +47,18 @@ class SessionManager: Manages session lifecycle (create/close) """ - def create_session(self) -> ISession: + def create_session(self, bio_model_client: BioModelClient) -> Session: """ Creates new session with unique id """ with self.__lock: session_id = uuid4().hex - session = _Session(session_id, manager=self) + session = Session(session_id, bio_model_client=bio_model_client, manager=self) self.__session_by_id[session_id] = session logger.info("Created session %s", session.id) return session - def get(self, session_id: str) -> Optional[ISession]: + def get(self, session_id: str) -> Optional[Session]: """ Returns existing session with given id if it exists """ @@ -90,10 +84,10 @@ def close_session(self, session_id: str) -> None: def __init__(self) -> None: self.__lock = threading.Lock() - self.__session_by_id: Dict[str, ISession] = {} + self.__session_by_id: Dict[str, Session] = {} self.__close_handlers_by_session_id: Dict[str, List[CloseCallback]] = defaultdict(list) - def _on_close(self, session: ISession, handler: CloseCallback): + def _on_close(self, session: Session, handler: CloseCallback): with self.__lock: logger.debug("Registered close handler %s for session %s", handler, session.id) self.__close_handlers_by_session_id[session.id].append(handler)