From 8a772975e462ac4cc5761675e07711d1fc9b9d81 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Sat, 10 Aug 2024 19:25:47 +0200 Subject: [PATCH 1/8] Remove ModelInfo --- proto/inference.proto | 8 - .../test_grpc/test_inference_servicer.py | 2 - tests/test_server/test_modelinfo.py | 153 ------------------ tiktorch/proto/inference_pb2.py | 116 ++++--------- tiktorch/server/grpc/inference_servicer.py | 22 +-- tiktorch/server/session/process.py | 73 +-------- tiktorch/server/session/rpc_interface.py | 4 - 7 files changed, 31 insertions(+), 347 deletions(-) delete mode 100644 tests/test_server/test_modelinfo.py diff --git a/proto/inference.proto b/proto/inference.proto index c6e85629..2c5c55cb 100644 --- a/proto/inference.proto +++ b/proto/inference.proto @@ -86,14 +86,6 @@ message OutputShape { message ModelSession { string id = 1; - string name = 2; - repeated string inputAxes = 3; - repeated string outputAxes = 4; - bool hasTraining = 5; - repeated InputShape inputShapes = 6; - repeated OutputShape outputShapes = 7; - repeated string inputNames = 8; - repeated string outputNames = 9; } message LogEntry { diff --git a/tests/test_server/test_grpc/test_inference_servicer.py b/tests/test_server/test_grpc/test_inference_servicer.py index 8a5188af..dbfb2c52 100644 --- a/tests/test_server/test_grpc/test_inference_servicer.py +++ b/tests/test_server/test_grpc/test_inference_servicer.py @@ -47,8 +47,6 @@ def method_requiring_session(self, request, grpc_stub): def test_model_session_creation(self, grpc_stub, bioimageio_model_bytes): model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_model_bytes)) assert model.id - assert hasattr(model, "outputShapes") - assert hasattr(model, "inputShapes") grpc_stub.CloseModelSession(model) def test_model_session_creation_using_upload_id(self, grpc_stub, data_store, bioimageio_dummy_model_bytes): diff --git a/tests/test_server/test_modelinfo.py b/tests/test_server/test_modelinfo.py deleted file mode 100644 index 4fbb72c5..00000000 --- a/tests/test_server/test_modelinfo.py +++ /dev/null @@ -1,153 +0,0 @@ -import random -from unittest import mock - -import pytest -from bioimageio.spec.shared.raw_nodes import ImplicitOutputShape, ParametrizedInputShape -from marshmallow import missing - -from tiktorch.converters import NamedExplicitOutputShape, NamedImplicitOutputShape, NamedParametrizedShape -from tiktorch.server.session.process import ModelInfo - - -@pytest.fixture -def implicit_output_spec(): - """output spec with ImplicitOutputShape""" - shape = ImplicitOutputShape( - reference_tensor="blah", - scale=[1.0] + [float(random.randint(0, 2**32)) for _ in range(4)], - offset=[0.0] + [float(random.randint(0, 2**32)) for _ in range(4)], - ) - output_spec = mock.Mock(axes=("x", "y"), shape=shape, halo=[5, 12]) - output_spec.name = "implicit_out" - return output_spec - - -@pytest.fixture -def parametrized_input_spec(): - shape = ParametrizedInputShape( - min=[random.randint(0, 2**32) for _ in range(5)], step=[float(random.randint(0, 2**32)) for _ in range(5)] - ) - input_spec = mock.Mock(axes=("b", "x", "y", "z", "c"), shape=shape) - input_spec.name = "param_in" - return input_spec - - -@pytest.fixture -def explicit_input_spec(): - input_shape = [random.randint(0, 2**32) for _ in range(3)] - input_spec = mock.Mock(axes=("b", "x", "y"), shape=input_shape) - input_spec.name = "explicit_in" - return input_spec - - -@pytest.fixture -def explicit_output_spec(): - output_shape = [random.randint(0, 2**32) for _ in range(3)] - halo = [0] + [random.randint(0, 2**32) for _ in range(2)] - output_spec = mock.Mock(axes=("b", "x", "y"), shape=output_shape, halo=halo) - output_spec.name = "explicit_out" - return output_spec - - -def test_model_info_explicit_shapes(explicit_input_spec, explicit_output_spec): - prediction_pipeline = mock.Mock(input_specs=[explicit_input_spec], output_specs=[explicit_output_spec], name="bleh") - - model_info = ModelInfo.from_prediction_pipeline(prediction_pipeline) - - assert model_info.input_axes == ["".join(explicit_input_spec.axes)] - assert model_info.output_axes == ["".join(explicit_output_spec.axes)] - assert len(model_info.input_shapes) == 1 - assert len(model_info.output_shapes) == 1 - assert isinstance(model_info.input_shapes[0], list) - assert model_info.input_shapes[0] == [(ax, s) for ax, s in zip(explicit_input_spec.axes, explicit_input_spec.shape)] - assert isinstance(model_info.output_shapes[0], NamedExplicitOutputShape) - assert model_info.output_shapes[0].shape == [ - (ax, s) for ax, s in zip(explicit_output_spec.axes, explicit_output_spec.shape) - ] - assert model_info.output_shapes[0].halo == [ - (ax, s) for ax, s in zip(explicit_output_spec.axes, explicit_output_spec.halo) - ] - assert model_info.input_names == ["explicit_in"] - assert model_info.output_names == ["explicit_out"] - - -def test_model_info_explicit_shapes_missing_halo(explicit_input_spec, explicit_output_spec): - explicit_output_spec.halo = missing - - prediction_pipeline = mock.Mock(input_specs=[explicit_input_spec], output_specs=[explicit_output_spec], name="bleh") - - model_info = ModelInfo.from_prediction_pipeline(prediction_pipeline) - - assert model_info.input_axes == ["".join(explicit_input_spec.axes)] - assert model_info.output_axes == ["".join(explicit_output_spec.axes)] - assert len(model_info.input_shapes) == 1 - assert len(model_info.output_shapes) == 1 - assert isinstance(model_info.input_shapes[0], list) - assert model_info.input_shapes[0] == [(ax, s) for ax, s in zip(explicit_input_spec.axes, explicit_input_spec.shape)] - assert isinstance(model_info.output_shapes[0], NamedExplicitOutputShape) - assert model_info.output_shapes[0].shape == [ - (ax, s) for ax, s in zip(explicit_output_spec.axes, explicit_output_spec.shape) - ] - assert model_info.output_shapes[0].halo == [(ax, s) for ax, s in zip(explicit_output_spec.axes, [0, 0, 0])] - - -def test_model_info_implicit_shapes(parametrized_input_spec, implicit_output_spec): - prediction_pipeline = mock.Mock( - input_specs=[parametrized_input_spec], output_specs=[implicit_output_spec], name="bleh" - ) - - model_info = ModelInfo.from_prediction_pipeline(prediction_pipeline) - assert model_info.input_axes == ["".join(parametrized_input_spec.axes)] - assert model_info.output_axes == ["".join(implicit_output_spec.axes)] - assert len(model_info.input_shapes) == 1 - assert len(model_info.output_shapes) == 1 - assert isinstance(model_info.input_shapes[0], NamedParametrizedShape) - assert model_info.input_shapes[0].min_shape == [ - (ax, s) for ax, s in zip(parametrized_input_spec.axes, parametrized_input_spec.shape.min) - ] - assert model_info.input_shapes[0].step_shape == [ - (ax, s) for ax, s in zip(parametrized_input_spec.axes, parametrized_input_spec.shape.step) - ] - assert isinstance(model_info.output_shapes[0], NamedImplicitOutputShape) - assert model_info.output_shapes[0].offset == [ - (ax, s) for ax, s in zip(implicit_output_spec.axes, implicit_output_spec.shape.offset) - ] - assert model_info.output_shapes[0].scale == [ - (ax, s) for ax, s in zip(implicit_output_spec.axes, implicit_output_spec.shape.scale) - ] - assert model_info.output_shapes[0].halo == [ - (ax, s) for ax, s in zip(implicit_output_spec.axes, implicit_output_spec.halo) - ] - assert model_info.output_shapes[0].reference_tensor == implicit_output_spec.shape.reference_tensor - - assert model_info.input_names == ["param_in"] - assert model_info.output_names == ["implicit_out"] - - -def test_model_info_implicit_shapes_missing_halo(parametrized_input_spec, implicit_output_spec): - implicit_output_spec.halo = missing - prediction_pipeline = mock.Mock( - input_specs=[parametrized_input_spec], output_specs=[implicit_output_spec], name="bleh" - ) - - model_info = ModelInfo.from_prediction_pipeline(prediction_pipeline) - assert model_info.input_axes == ["".join(parametrized_input_spec.axes)] - assert model_info.output_axes == ["".join(implicit_output_spec.axes)] - assert len(model_info.input_shapes) == 1 - assert len(model_info.output_shapes) == 1 - assert isinstance(model_info.input_shapes[0], NamedParametrizedShape) - assert model_info.input_shapes[0].min_shape == [ - (ax, s) for ax, s in zip(parametrized_input_spec.axes, parametrized_input_spec.shape.min) - ] - assert model_info.input_shapes[0].step_shape == [ - (ax, s) for ax, s in zip(parametrized_input_spec.axes, parametrized_input_spec.shape.step) - ] - assert isinstance(model_info.output_shapes[0], NamedImplicitOutputShape) - assert model_info.output_shapes[0].offset == [ - (ax, s) for ax, s in zip(implicit_output_spec.axes, implicit_output_spec.shape.offset) - ] - assert model_info.output_shapes[0].scale == [ - (ax, s) for ax, s in zip(implicit_output_spec.axes, implicit_output_spec.shape.scale) - ] - assert model_info.output_shapes[0].halo == [(ax, s) for ax, s in zip(implicit_output_spec.axes, [0, 0, 0, 0, 0])] - assert model_info.output_shapes[0].reference_tensor == implicit_output_spec.shape.reference_tensor diff --git a/tiktorch/proto/inference_pb2.py b/tiktorch/proto/inference_pb2.py index e098a4df..8ac8db2e 100644 --- a/tiktorch/proto/inference_pb2.py +++ b/tiktorch/proto/inference_pb2.py @@ -20,7 +20,7 @@ syntax='proto3', serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_pb=b'\n\x0finference.proto\"Y\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06status\x18\x02 \x01(\x0e\x32\x0e.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"W\n\x1f\x43reateDatasetDescriptionRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"i\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12\x1b\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\x9d\x01\n\nInputShape\x12(\n\tshapeType\x18\x01 \x01(\x0e\x32\x15.InputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x1d\n\tstepShape\x18\x04 \x01(\x0b\x32\n.NamedInts\"+\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x10\n\x0cPARAMETRIZED\x10\x01\"\xea\x01\n\x0bOutputShape\x12)\n\tshapeType\x18\x01 \x01(\x0e\x32\x16.OutputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x18\n\x04halo\x18\x03 \x01(\x0b\x32\n.NamedInts\x12\x17\n\x0freferenceTensor\x18\x04 \x01(\t\x12\x1b\n\x05scale\x18\x05 \x01(\x0b\x32\x0c.NamedFloats\x12\x1c\n\x06offset\x18\x06 \x01(\x0b\x32\x0c.NamedFloats\"\'\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x0c\n\x08IMPLICIT\x10\x01\"\xd3\x01\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tinputAxes\x18\x03 \x03(\t\x12\x12\n\noutputAxes\x18\x04 \x03(\t\x12\x13\n\x0bhasTraining\x18\x05 \x01(\x08\x12 \n\x0binputShapes\x18\x06 \x03(\x0b\x32\x0b.InputShape\x12\"\n\x0coutputShapes\x18\x07 \x03(\x0b\x32\x0c.OutputShape\x12\x12\n\ninputNames\x18\x08 \x03(\t\x12\x13\n\x0boutputNames\x18\t \x03(\t\"\x9e\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12\x1e\n\x05level\x18\x02 \x01(\x0e\x32\x0f.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\"#\n\x07\x44\x65vices\x12\x18\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x07.Device\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"A\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x18\n\x05shape\x18\x03 \x03(\x0b\x32\t.NamedInt\"U\n\x0ePredictRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x11\n\tdatasetId\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x03 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"\x07\n\x05\x45mpty\"\x1e\n\tModelInfo\x12\x11\n\tdeviceIds\x18\x01 \x03(\t\"^\n CreateModelSessionChunkedRequest\x12\x1a\n\x04info\x18\x01 \x01(\x0b\x32\n.ModelInfoH\x00\x12\x16\n\x05\x63hunk\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x42\x06\n\x04\x64\x61ta2\xc6\x02\n\tInference\x12\x41\n\x12\x43reateModelSession\x12\x1a.CreateModelSessionRequest\x1a\r.ModelSession\"\x00\x12,\n\x11\x43loseModelSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12S\n\x18\x43reateDatasetDescription\x12 .CreateDatasetDescriptionRequest\x1a\x13.DatasetDescription\"\x00\x12 \n\x07GetLogs\x12\x06.Empty\x1a\t.LogEntry\"\x00\x30\x01\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x32G\n\rFlightControl\x12\x18\n\x04Ping\x12\x06.Empty\x1a\x06.Empty\"\x00\x12\x1c\n\x08Shutdown\x12\x06.Empty\x1a\x06.Empty\"\x00\x62\x06proto3' + serialized_pb=b'\n\x0finference.proto\"Y\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06status\x18\x02 \x01(\x0e\x32\x0e.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"W\n\x1f\x43reateDatasetDescriptionRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"i\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12\x1b\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\x9d\x01\n\nInputShape\x12(\n\tshapeType\x18\x01 \x01(\x0e\x32\x15.InputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x1d\n\tstepShape\x18\x04 \x01(\x0b\x32\n.NamedInts\"+\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x10\n\x0cPARAMETRIZED\x10\x01\"\xea\x01\n\x0bOutputShape\x12)\n\tshapeType\x18\x01 \x01(\x0e\x32\x16.OutputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x18\n\x04halo\x18\x03 \x01(\x0b\x32\n.NamedInts\x12\x17\n\x0freferenceTensor\x18\x04 \x01(\t\x12\x1b\n\x05scale\x18\x05 \x01(\x0b\x32\x0c.NamedFloats\x12\x1c\n\x06offset\x18\x06 \x01(\x0b\x32\x0c.NamedFloats\"\'\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x0c\n\x08IMPLICIT\x10\x01\"\x1a\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\"\x9e\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12\x1e\n\x05level\x18\x02 \x01(\x0e\x32\x0f.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\"#\n\x07\x44\x65vices\x12\x18\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x07.Device\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"A\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x18\n\x05shape\x18\x03 \x03(\x0b\x32\t.NamedInt\"U\n\x0ePredictRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x11\n\tdatasetId\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x03 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"\x07\n\x05\x45mpty\"\x1e\n\tModelInfo\x12\x11\n\tdeviceIds\x18\x01 \x03(\t\"^\n CreateModelSessionChunkedRequest\x12\x1a\n\x04info\x18\x01 \x01(\x0b\x32\n.ModelInfoH\x00\x12\x16\n\x05\x63hunk\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x42\x06\n\x04\x64\x61ta2\xc6\x02\n\tInference\x12\x41\n\x12\x43reateModelSession\x12\x1a.CreateModelSessionRequest\x1a\r.ModelSession\"\x00\x12,\n\x11\x43loseModelSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12S\n\x18\x43reateDatasetDescription\x12 .CreateDatasetDescriptionRequest\x1a\x13.DatasetDescription\"\x00\x12 \n\x07GetLogs\x12\x06.Empty\x1a\t.LogEntry\"\x00\x30\x01\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x32G\n\rFlightControl\x12\x18\n\x04Ping\x12\x06.Empty\x1a\x06.Empty\"\x00\x12\x1c\n\x08Shutdown\x12\x06.Empty\x1a\x06.Empty\"\x00\x62\x06proto3' ) @@ -140,8 +140,8 @@ ], containing_type=None, serialized_options=None, - serialized_start=1165, - serialized_end=1243, + serialized_start=979, + serialized_end=1057, ) _sym_db.RegisterEnumDescriptor(_LOGENTRY_LEVEL) @@ -548,62 +548,6 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='name', full_name='ModelSession.name', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='inputAxes', full_name='ModelSession.inputAxes', index=2, - number=3, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='outputAxes', full_name='ModelSession.outputAxes', index=3, - number=4, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='hasTraining', full_name='ModelSession.hasTraining', index=4, - number=5, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='inputShapes', full_name='ModelSession.inputShapes', index=5, - number=6, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='outputShapes', full_name='ModelSession.outputShapes', index=6, - number=7, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='inputNames', full_name='ModelSession.inputNames', index=7, - number=8, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='outputNames', full_name='ModelSession.outputNames', index=8, - number=9, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ], extensions=[ ], @@ -616,8 +560,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=871, - serialized_end=1082, + serialized_start=870, + serialized_end=896, ) @@ -663,8 +607,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1085, - serialized_end=1243, + serialized_start=899, + serialized_end=1057, ) @@ -695,8 +639,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1245, - serialized_end=1280, + serialized_start=1059, + serialized_end=1094, ) @@ -734,8 +678,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1282, - serialized_end=1320, + serialized_start=1096, + serialized_end=1134, ) @@ -773,8 +717,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1322, - serialized_end=1362, + serialized_start=1136, + serialized_end=1176, ) @@ -819,8 +763,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1364, - serialized_end=1429, + serialized_start=1178, + serialized_end=1243, ) @@ -865,8 +809,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1431, - serialized_end=1516, + serialized_start=1245, + serialized_end=1330, ) @@ -897,8 +841,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1518, - serialized_end=1561, + serialized_start=1332, + serialized_end=1375, ) @@ -922,8 +866,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1563, - serialized_end=1570, + serialized_start=1377, + serialized_end=1384, ) @@ -954,8 +898,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1572, - serialized_end=1602, + serialized_start=1386, + serialized_end=1416, ) @@ -998,8 +942,8 @@ create_key=_descriptor._internal_create_key, fields=[]), ], - serialized_start=1604, - serialized_end=1698, + serialized_start=1418, + serialized_end=1512, ) _DEVICE.fields_by_name['status'].enum_type = _DEVICE_STATUS @@ -1023,8 +967,6 @@ _OUTPUTSHAPE.fields_by_name['scale'].message_type = _NAMEDFLOATS _OUTPUTSHAPE.fields_by_name['offset'].message_type = _NAMEDFLOATS _OUTPUTSHAPE_SHAPETYPE.containing_type = _OUTPUTSHAPE -_MODELSESSION.fields_by_name['inputShapes'].message_type = _INPUTSHAPE -_MODELSESSION.fields_by_name['outputShapes'].message_type = _OUTPUTSHAPE _LOGENTRY.fields_by_name['level'].enum_type = _LOGENTRY_LEVEL _LOGENTRY_LEVEL.containing_type = _LOGENTRY _DEVICES.fields_by_name['devices'].message_type = _DEVICE @@ -1210,8 +1152,8 @@ index=0, serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_start=1701, - serialized_end=2027, + serialized_start=1515, + serialized_end=1841, methods=[ _descriptor.MethodDescriptor( name='CreateModelSession', @@ -1286,8 +1228,8 @@ index=1, serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_start=2029, - serialized_end=2100, + serialized_start=1843, + serialized_end=1914, methods=[ _descriptor.MethodDescriptor( name='Ping', diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py index 5e188787..3a119549 100644 --- a/tiktorch/server/grpc/inference_servicer.py +++ b/tiktorch/server/grpc/inference_servicer.py @@ -39,28 +39,8 @@ def CreateModelSession( session = self.__session_manager.create_session() session.on_close(lease.terminate) session.on_close(client.shutdown) - session.client = client - try: - model_info = session.client.get_model_info() - except Exception: - lease.terminate() - raise - - pb_input_shapes = [converters.input_shape_to_pb_input_shape(shape) for shape in model_info.input_shapes] - pb_output_shapes = [converters.output_shape_to_pb_output_shape(shape) for shape in model_info.output_shapes] - - return inference_pb2.ModelSession( - id=session.id, - name=model_info.name, - inputAxes=model_info.input_axes, - outputAxes=model_info.output_axes, - inputShapes=pb_input_shapes, - hasTraining=False, - outputShapes=pb_output_shapes, - inputNames=model_info.input_names, - outputNames=model_info.output_names, - ) + return inference_pb2.ModelSession(id=session.id) def CreateDatasetDescription( self, request: inference_pb2.CreateDatasetDescriptionRequest, context diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index 91b00f83..2147df46 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -1,4 +1,3 @@ -import dataclasses import multiprocessing as _mp import os import pathlib @@ -6,16 +5,13 @@ import uuid from concurrent.futures import Future from multiprocessing.connection import Connection -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import numpy from bioimageio.core import load_resource_description from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline -from bioimageio.spec.shared.raw_nodes import ImplicitOutputShape, ParametrizedInputShape -from marshmallow import missing from tiktorch import log -from tiktorch.converters import NamedExplicitOutputShape, NamedImplicitOutputShape, NamedParametrizedShape, NamedShape from tiktorch.rpc import Shutdown from tiktorch.rpc import mp as _mp_rpc from tiktorch.rpc.mp import MPServer @@ -24,70 +20,6 @@ from .rpc_interface import IRPCModelSession -@dataclasses.dataclass -class ModelInfo: - """Intermediate representation of bioimageio neural network model - - TODO (k-dominik): ModelInfo only used in inference_servicer to convert to - protobuf modelinfo. - - """ - - name: str - input_axes: List[str] # one per input - output_axes: List[str] # one per output - input_shapes: List[Union[NamedShape, NamedParametrizedShape]] # per input multiple shapes - output_shapes: List[Union[NamedExplicitOutputShape, NamedImplicitOutputShape]] - input_names: List[str] # one per input - output_names: List[str] # one per output - - @classmethod - def from_prediction_pipeline(cls, prediction_pipeline: PredictionPipeline) -> "ModelInfo": - input_shapes = [] - for input_spec in prediction_pipeline.input_specs: - if isinstance(input_spec.shape, ParametrizedInputShape): - input_shapes.append( - NamedParametrizedShape( - min_shape=list(map(tuple, zip(input_spec.axes, input_spec.shape.min))), - step_shape=list(map(tuple, zip(input_spec.axes, input_spec.shape.step))), - ) - ) - else: - input_shapes.append(list(map(tuple, zip(input_spec.axes, input_spec.shape)))) - - output_shapes = [] - for output_spec in prediction_pipeline.output_specs: - # halo is not required by spec. We could alternatively make it optional in the - # respective grpc message types and handle missing values in ilastik - halo = [0 for _ in output_spec.axes] if output_spec.halo == missing else output_spec.halo - if isinstance(output_spec.shape, ImplicitOutputShape): - output_shapes.append( - NamedImplicitOutputShape( - reference_tensor=output_spec.shape.reference_tensor, - scale=list(map(tuple, zip(output_spec.axes, output_spec.shape.scale))), - offset=list(map(tuple, zip(output_spec.axes, output_spec.shape.offset))), - halo=list(map(tuple, zip(output_spec.axes, halo))), - ) - ) - else: # isinstance(output_spec.shape, ExplicitShape): - output_shapes.append( - NamedExplicitOutputShape( - shape=list(map(tuple, zip(output_spec.axes, output_spec.shape))), - halo=list(map(tuple, zip(output_spec.axes, halo))), - ) - ) - - return cls( - name=prediction_pipeline.name, - input_axes=["".join(input_spec.axes) for input_spec in prediction_pipeline.input_specs], - output_axes=["".join(output_spec.axes) for output_spec in prediction_pipeline.output_specs], - input_shapes=input_shapes, - output_shapes=output_shapes, - input_names=[input_spec.name for input_spec in prediction_pipeline.input_specs], - output_names=[output_spec.name for output_spec in prediction_pipeline.output_specs], - ) - - class ModelSessionProcess(IRPCModelSession): def __init__(self, model_zip: bytes, devices: List[str]) -> None: _tmp_file = tempfile.NamedTemporaryFile(suffix=".zip", delete=False) @@ -108,9 +40,6 @@ def create_dataset(self, mean, stddev): self._datasets[id_] = {"mean": mean, "stddev": stddev} return id_ - def get_model_info(self) -> ModelInfo: - return ModelInfo.from_prediction_pipeline(self._model) - def shutdown(self) -> Shutdown: self._worker.shutdown() return Shutdown() diff --git a/tiktorch/server/session/rpc_interface.py b/tiktorch/server/session/rpc_interface.py index a01dd02a..2e505d18 100644 --- a/tiktorch/server/session/rpc_interface.py +++ b/tiktorch/server/session/rpc_interface.py @@ -45,7 +45,3 @@ def create_dataset_description(self, mean, stddev) -> str: @exposed def forward(self, input_tensors): raise NotImplementedError - - @exposed - def get_model_info(self): - raise NotImplementedError From 2034a29f698c31b8cec9ae522138721e2b3d5224 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Sat, 10 Aug 2024 19:37:20 +0200 Subject: [PATCH 2/8] Add the concept of Tensor to encapsulate DataArray and spec id - It is useful for tensors to contain the information of the spec id, that corresponds to the bioimage tensor spec, to be aware of valid shapes for both inputs and output --- proto/inference.proto | 3 +- tests/test_converters.py | 68 ++++++++----------- .../test_grpc/test_inference_servicer.py | 2 +- tiktorch/converters.py | 27 ++++++-- tiktorch/proto/inference_pb2.py | 43 +++++++----- tiktorch/server/grpc/inference_servicer.py | 10 ++- 6 files changed, 88 insertions(+), 65 deletions(-) diff --git a/proto/inference.proto b/proto/inference.proto index 2c5c55cb..1959374d 100644 --- a/proto/inference.proto +++ b/proto/inference.proto @@ -120,7 +120,8 @@ message NamedFloat { message Tensor { bytes buffer = 1; string dtype = 2; - repeated NamedInt shape = 3; + string specId = 3; + repeated NamedInt shape = 4; } message PredictRequest { diff --git a/tests/test_converters.py b/tests/test_converters.py index c775ede3..33586a7b 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -7,11 +7,12 @@ NamedExplicitOutputShape, NamedImplicitOutputShape, NamedParametrizedShape, + Tensor, input_shape_to_pb_input_shape, numpy_to_pb_tensor, output_shape_to_pb_output_shape, pb_tensor_to_numpy, - pb_tensor_to_xarray, + pb_tensor_to_tensor, xarray_to_pb_tensor, ) from tiktorch.proto import inference_pb2 @@ -27,6 +28,16 @@ def _numpy_to_pb_tensor(arr): return parsed +def to_pb_tensor(spec_id: str, arr: xr.DataArray): + """ + Makes sure that tensor was serialized/deserialized + """ + tensor = xarray_to_pb_tensor(spec_id, arr) + parsed = inference_pb2.Tensor() + parsed.ParseFromString(tensor.SerializeToString()) + return parsed + + class TestNumpyToPBTensor: def test_should_serialize_to_tensor_type(self): arr = np.arange(9) @@ -97,18 +108,9 @@ def test_should_same_data(self, shape): class TestXarrayToPBTensor: - def to_pb_tensor(self, arr): - """ - Makes sure that tensor was serialized/deserialized - """ - tensor = xarray_to_pb_tensor(arr) - parsed = inference_pb2.Tensor() - parsed.ParseFromString(tensor.SerializeToString()) - return parsed - def test_should_serialize_to_tensor_type(self): xarr = xr.DataArray(np.arange(8).reshape((2, 4)), dims=("x", "y")) - pb_tensor = self.to_pb_tensor(xarr) + pb_tensor = to_pb_tensor("input0", xarr) assert isinstance(pb_tensor, inference_pb2.Tensor) assert len(pb_tensor.shape) == 2 dim1 = pb_tensor.shape[0] @@ -123,66 +125,56 @@ def test_should_serialize_to_tensor_type(self): @pytest.mark.parametrize("shape", [(3, 3), (1,), (1, 1), (18, 20, 1)]) def test_should_have_shape(self, shape): arr = xr.DataArray(np.zeros(shape)) - tensor = self.to_pb_tensor(arr) + tensor = to_pb_tensor("input0", arr) assert tensor.shape assert list(shape) == [dim.size for dim in tensor.shape] def test_should_have_serialized_bytes(self): arr = xr.DataArray(np.arange(9, dtype=np.uint8)) expected = bytes(arr.data) - tensor = self.to_pb_tensor(arr) + tensor = to_pb_tensor("input0", arr) assert expected == tensor.buffer class TestPBTensorToXarray: - def to_pb_tensor(self, arr): - """ - Makes sure that tensor was serialized/deserialized - """ - tensor = xarray_to_pb_tensor(arr) - parsed = inference_pb2.Tensor() - parsed.ParseFromString(tensor.SerializeToString()) - return parsed - def test_should_raise_on_empty_dtype(self): tensor = inference_pb2.Tensor(dtype="", shape=[inference_pb2.NamedInt(size=1), inference_pb2.NamedInt(size=2)]) with pytest.raises(ValueError): - pb_tensor_to_xarray(tensor) + pb_tensor_to_tensor(tensor) def test_should_raise_on_empty_shape(self): tensor = inference_pb2.Tensor(dtype="int64", shape=[]) with pytest.raises(ValueError): - pb_tensor_to_xarray(tensor) + pb_tensor_to_tensor(tensor) - def test_should_return_ndarray(self): + def test_should_return_tensor(self): arr = xr.DataArray(np.arange(9)) - parsed = self.to_pb_tensor(arr) - result_arr = pb_tensor_to_xarray(parsed) - - assert isinstance(result_arr, xr.DataArray) + parsed = to_pb_tensor("input0", arr) + result_tensor = pb_tensor_to_tensor(parsed) + assert isinstance(result_tensor, Tensor) @pytest.mark.parametrize("np_dtype,dtype_str", [(np.int64, "int64"), (np.uint8, "uint8"), (np.float32, "float32")]) def test_should_have_same_dtype(self, np_dtype, dtype_str): arr = xr.DataArray(np.arange(9, dtype=np_dtype)) - tensor = self.to_pb_tensor(arr) - result_arr = pb_tensor_to_xarray(tensor) + pb_tensor = to_pb_tensor("input0", arr) + tensor = pb_tensor_to_tensor(pb_tensor) - assert arr.dtype == result_arr.dtype + assert arr.dtype == tensor.data.dtype @pytest.mark.parametrize("shape", [(3, 3), (1,), (1, 1), (18, 20, 1)]) def test_should_same_shape(self, shape): arr = xr.DataArray(np.zeros(shape)) - tensor = self.to_pb_tensor(arr) - result_arr = pb_tensor_to_xarray(tensor) - assert arr.shape == result_arr.shape + pb_tensor = to_pb_tensor("input0", arr) + tensor = pb_tensor_to_tensor(pb_tensor) + assert arr.shape == tensor.data.shape @pytest.mark.parametrize("shape", [(3, 3), (1,), (1, 1), (18, 20, 1)]) def test_should_same_data(self, shape): arr = xr.DataArray(np.random.random(shape)) - tensor = self.to_pb_tensor(arr) - result_arr = pb_tensor_to_xarray(tensor) - assert_array_equal(arr, result_arr) + pb_tensor = to_pb_tensor("input0", arr) + tensor = pb_tensor_to_tensor(pb_tensor) + assert_array_equal(arr, tensor.data) class TestShapeConversions: diff --git a/tests/test_server/test_grpc/test_inference_servicer.py b/tests/test_server/test_grpc/test_inference_servicer.py index dbfb2c52..cbf64477 100644 --- a/tests/test_server/test_grpc/test_inference_servicer.py +++ b/tests/test_server/test_grpc/test_inference_servicer.py @@ -158,7 +158,7 @@ def test_call_predict(self, grpc_stub, bioimageio_dummy_model_bytes): model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_model_bytes)) arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y")) expected = arr + 1 - input_tensors = [converters.xarray_to_pb_tensor(arr)] + input_tensors = [converters.xarray_to_pb_tensor("input", arr)] res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) grpc_stub.CloseModelSession(model) diff --git a/tiktorch/converters.py b/tiktorch/converters.py index 4093bec3..7fc69ba0 100644 --- a/tiktorch/converters.py +++ b/tiktorch/converters.py @@ -33,6 +33,25 @@ class NamedImplicitOutputShape: halo: NamedShape +@dataclasses.dataclass(frozen=True) +class Tensor: + spec_id: str + data: xr.DataArray + + def __hash__(self): + return hash(self.spec_id) + + def __eq__(self, other): + if isinstance(other, Tensor): + return self.spec_id == other.spec_id + return False + + def equals(self, other): + if not isinstance(other, Tensor): + return False + return self.__dict__.items() == other.__dict__.items() + + def numpy_to_pb_tensor(array: np.ndarray, axistags=None) -> inference_pb2.Tensor: if axistags: shape = [inference_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, axistags)] @@ -41,9 +60,9 @@ def numpy_to_pb_tensor(array: np.ndarray, axistags=None) -> inference_pb2.Tensor return inference_pb2.Tensor(dtype=str(array.dtype), shape=shape, buffer=bytes(array)) -def xarray_to_pb_tensor(array: xr.DataArray) -> inference_pb2.Tensor: +def xarray_to_pb_tensor(spec_id: str, array: xr.DataArray) -> inference_pb2.Tensor: shape = [inference_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, array.dims)] - return inference_pb2.Tensor(dtype=str(array.dtype), shape=shape, buffer=bytes(array.data)) + return inference_pb2.Tensor(specId=spec_id, dtype=str(array.dtype), shape=shape, buffer=bytes(array.data)) def name_int_tuples_to_pb_NamedInts(name_int_tuples) -> inference_pb2.NamedInts: @@ -93,7 +112,7 @@ def output_shape_to_pb_output_shape( raise TypeError(f"Conversion not supported for type {type(output_shape)}") -def pb_tensor_to_xarray(tensor: inference_pb2.Tensor) -> inference_pb2.Tensor: +def pb_tensor_to_tensor(tensor: inference_pb2.Tensor) -> inference_pb2.Tensor: if not tensor.dtype: raise ValueError("Tensor dtype is not specified") @@ -102,7 +121,7 @@ def pb_tensor_to_xarray(tensor: inference_pb2.Tensor) -> inference_pb2.Tensor: data = np.frombuffer(tensor.buffer, dtype=tensor.dtype).reshape(*[dim.size for dim in tensor.shape]) - return xr.DataArray(data, dims=[d.name for d in tensor.shape]) + return Tensor(spec_id=tensor.specId, data=xr.DataArray(data, dims=[d.name for d in tensor.shape])) def pb_tensor_to_numpy(tensor: inference_pb2.Tensor) -> np.ndarray: diff --git a/tiktorch/proto/inference_pb2.py b/tiktorch/proto/inference_pb2.py index 8ac8db2e..4f682116 100644 --- a/tiktorch/proto/inference_pb2.py +++ b/tiktorch/proto/inference_pb2.py @@ -20,7 +20,7 @@ syntax='proto3', serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_pb=b'\n\x0finference.proto\"Y\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06status\x18\x02 \x01(\x0e\x32\x0e.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"W\n\x1f\x43reateDatasetDescriptionRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"i\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12\x1b\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\x9d\x01\n\nInputShape\x12(\n\tshapeType\x18\x01 \x01(\x0e\x32\x15.InputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x1d\n\tstepShape\x18\x04 \x01(\x0b\x32\n.NamedInts\"+\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x10\n\x0cPARAMETRIZED\x10\x01\"\xea\x01\n\x0bOutputShape\x12)\n\tshapeType\x18\x01 \x01(\x0e\x32\x16.OutputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x18\n\x04halo\x18\x03 \x01(\x0b\x32\n.NamedInts\x12\x17\n\x0freferenceTensor\x18\x04 \x01(\t\x12\x1b\n\x05scale\x18\x05 \x01(\x0b\x32\x0c.NamedFloats\x12\x1c\n\x06offset\x18\x06 \x01(\x0b\x32\x0c.NamedFloats\"\'\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x0c\n\x08IMPLICIT\x10\x01\"\x1a\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\"\x9e\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12\x1e\n\x05level\x18\x02 \x01(\x0e\x32\x0f.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\"#\n\x07\x44\x65vices\x12\x18\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x07.Device\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"A\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x18\n\x05shape\x18\x03 \x03(\x0b\x32\t.NamedInt\"U\n\x0ePredictRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x11\n\tdatasetId\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x03 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"\x07\n\x05\x45mpty\"\x1e\n\tModelInfo\x12\x11\n\tdeviceIds\x18\x01 \x03(\t\"^\n CreateModelSessionChunkedRequest\x12\x1a\n\x04info\x18\x01 \x01(\x0b\x32\n.ModelInfoH\x00\x12\x16\n\x05\x63hunk\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x42\x06\n\x04\x64\x61ta2\xc6\x02\n\tInference\x12\x41\n\x12\x43reateModelSession\x12\x1a.CreateModelSessionRequest\x1a\r.ModelSession\"\x00\x12,\n\x11\x43loseModelSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12S\n\x18\x43reateDatasetDescription\x12 .CreateDatasetDescriptionRequest\x1a\x13.DatasetDescription\"\x00\x12 \n\x07GetLogs\x12\x06.Empty\x1a\t.LogEntry\"\x00\x30\x01\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x32G\n\rFlightControl\x12\x18\n\x04Ping\x12\x06.Empty\x1a\x06.Empty\"\x00\x12\x1c\n\x08Shutdown\x12\x06.Empty\x1a\x06.Empty\"\x00\x62\x06proto3' + serialized_pb=b'\n\x0finference.proto\"Y\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06status\x18\x02 \x01(\x0e\x32\x0e.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"W\n\x1f\x43reateDatasetDescriptionRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"i\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12\x1b\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\x9d\x01\n\nInputShape\x12(\n\tshapeType\x18\x01 \x01(\x0e\x32\x15.InputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x1d\n\tstepShape\x18\x04 \x01(\x0b\x32\n.NamedInts\"+\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x10\n\x0cPARAMETRIZED\x10\x01\"\xea\x01\n\x0bOutputShape\x12)\n\tshapeType\x18\x01 \x01(\x0e\x32\x16.OutputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x18\n\x04halo\x18\x03 \x01(\x0b\x32\n.NamedInts\x12\x17\n\x0freferenceTensor\x18\x04 \x01(\t\x12\x1b\n\x05scale\x18\x05 \x01(\x0b\x32\x0c.NamedFloats\x12\x1c\n\x06offset\x18\x06 \x01(\x0b\x32\x0c.NamedFloats\"\'\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x0c\n\x08IMPLICIT\x10\x01\"\x1a\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\"\x9e\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12\x1e\n\x05level\x18\x02 \x01(\x0e\x32\x0f.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\"#\n\x07\x44\x65vices\x12\x18\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x07.Device\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"Q\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x0e\n\x06specId\x18\x03 \x01(\t\x12\x18\n\x05shape\x18\x04 \x03(\x0b\x32\t.NamedInt\"U\n\x0ePredictRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x11\n\tdatasetId\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x03 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"\x07\n\x05\x45mpty\"\x1e\n\tModelInfo\x12\x11\n\tdeviceIds\x18\x01 \x03(\t\"^\n CreateModelSessionChunkedRequest\x12\x1a\n\x04info\x18\x01 \x01(\x0b\x32\n.ModelInfoH\x00\x12\x16\n\x05\x63hunk\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x42\x06\n\x04\x64\x61ta2\xc6\x02\n\tInference\x12\x41\n\x12\x43reateModelSession\x12\x1a.CreateModelSessionRequest\x1a\r.ModelSession\"\x00\x12,\n\x11\x43loseModelSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12S\n\x18\x43reateDatasetDescription\x12 .CreateDatasetDescriptionRequest\x1a\x13.DatasetDescription\"\x00\x12 \n\x07GetLogs\x12\x06.Empty\x1a\t.LogEntry\"\x00\x30\x01\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x32G\n\rFlightControl\x12\x18\n\x04Ping\x12\x06.Empty\x1a\x06.Empty\"\x00\x12\x1c\n\x08Shutdown\x12\x06.Empty\x1a\x06.Empty\"\x00\x62\x06proto3' ) @@ -745,8 +745,15 @@ is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), _descriptor.FieldDescriptor( - name='shape', full_name='Tensor.shape', index=2, - number=3, type=11, cpp_type=10, label=3, + name='specId', full_name='Tensor.specId', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='shape', full_name='Tensor.shape', index=3, + number=4, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, @@ -764,7 +771,7 @@ oneofs=[ ], serialized_start=1178, - serialized_end=1243, + serialized_end=1259, ) @@ -809,8 +816,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1245, - serialized_end=1330, + serialized_start=1261, + serialized_end=1346, ) @@ -841,8 +848,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1332, - serialized_end=1375, + serialized_start=1348, + serialized_end=1391, ) @@ -866,8 +873,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1377, - serialized_end=1384, + serialized_start=1393, + serialized_end=1400, ) @@ -898,8 +905,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1386, - serialized_end=1416, + serialized_start=1402, + serialized_end=1432, ) @@ -942,8 +949,8 @@ create_key=_descriptor._internal_create_key, fields=[]), ], - serialized_start=1418, - serialized_end=1512, + serialized_start=1434, + serialized_end=1528, ) _DEVICE.fields_by_name['status'].enum_type = _DEVICE_STATUS @@ -1152,8 +1159,8 @@ index=0, serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_start=1515, - serialized_end=1841, + serialized_start=1531, + serialized_end=1857, methods=[ _descriptor.MethodDescriptor( name='CreateModelSession', @@ -1228,8 +1235,8 @@ index=1, serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_start=1843, - serialized_end=1914, + serialized_start=1859, + serialized_end=1930, methods=[ _descriptor.MethodDescriptor( name='Ping', diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py index 3a119549..db0f2c17 100644 --- a/tiktorch/server/grpc/inference_servicer.py +++ b/tiktorch/server/grpc/inference_servicer.py @@ -75,9 +75,13 @@ def ListDevices(self, request: inference_pb2.Empty, context) -> inference_pb2.De def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse: session = self._getModelSession(context, request.modelSessionId) - arrs = [converters.pb_tensor_to_xarray(tensor) for tensor in request.tensors] - res = session.client.forward(arrs) - pb_tensors = [converters.xarray_to_pb_tensor(res_tensor) for res_tensor in res] + tensors = set([converters.pb_tensor_to_tensor(tensor) for tensor in request.tensors]) + res = session.client.forward(tensors) + output_spec_ids = [spec.name for spec in session.client.model.output_specs] + assert len(output_spec_ids) == len(res) + pb_tensors = [ + converters.xarray_to_pb_tensor(spec_id, res_tensor) for spec_id, res_tensor in zip(output_spec_ids, res) + ] return inference_pb2.PredictResponse(tensors=pb_tensors) def _getModelSession(self, context, modelSessionId: str) -> ISession: From b1587c756c54f818b2b477b63dbb2bf6ffd87918 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Sat, 10 Aug 2024 23:41:37 +0200 Subject: [PATCH 3/8] Add test data for parameterized input --- tests/data/dummy/Dummy.model.yaml | 5 +- tests/data/dummy_param/Dummy.model_param.yaml | 57 ++++++++++++++++++ tests/data/dummy_param/dummy.md | 0 tests/data/dummy_param/dummy.py | 7 +++ tests/data/dummy_param/dummy_in.npy | Bin 0 -> 65664 bytes tests/data/dummy_param/dummy_out.npy | Bin 0 -> 65664 bytes tests/data/dummy_param/environment.yaml | 0 tests/data/dummy_param/weights | Bin 0 -> 232 bytes 8 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 tests/data/dummy_param/Dummy.model_param.yaml create mode 100644 tests/data/dummy_param/dummy.md create mode 100644 tests/data/dummy_param/dummy.py create mode 100644 tests/data/dummy_param/dummy_in.npy create mode 100644 tests/data/dummy_param/dummy_out.npy create mode 100644 tests/data/dummy_param/environment.yaml create mode 100644 tests/data/dummy_param/weights diff --git a/tests/data/dummy/Dummy.model.yaml b/tests/data/dummy/Dummy.model.yaml index a27a1d4d..1aa8fe90 100644 --- a/tests/data/dummy/Dummy.model.yaml +++ b/tests/data/dummy/Dummy.model.yaml @@ -41,15 +41,16 @@ inputs: data_range: [-inf, inf] shape: [1, 1, 128, 128] + outputs: - name: output axes: bcyx data_type: float32 data_range: [0, 1] shape: - reference_tensor: input # FIXME(m-novikov) ignoring for now + reference_tensor: input scale: [1, 1, 1, 1] offset: [0, 0, 0, 0] - halo: [0, 0, 32, 32] # Should be moved to outputs + halo: [0, 0, 32, 32] type: model diff --git a/tests/data/dummy_param/Dummy.model_param.yaml b/tests/data/dummy_param/Dummy.model_param.yaml new file mode 100644 index 00000000..87ddd885 --- /dev/null +++ b/tests/data/dummy_param/Dummy.model_param.yaml @@ -0,0 +1,57 @@ +format_version: 0.3.3 +language: python +framework: pytorch + +name: UNet2DNucleiBroad +description: A 2d U-Net pretrained on broad nucleus dataset. +cite: + - text: "Ronneberger, Olaf et al. U-net: Convolutional networks for biomedical image segmentation. MICCAI 2015." + doi: https://doi.org/10.1007/978-3-319-24574-4_28 +authors: + - name: "ilastik-team" + affiliation: "EMBL Heidelberg" + +documentation: dummy.md +tags: [pytorch, nucleus-segmentation] +license: MIT +git_repo: https://github.com/ilastik/tiktorch +covers: [] + +source: dummy.py::Dummy +sha256: 00ffb1647cf7ec524892206dce6258d9da498fe040c62838f31b501a09bfd573 +timestamp: 2019-12-11T12:22:32Z # ISO 8601 + +test_inputs: [dummy_in.npy] +test_outputs: [dummy_out.npy] + +weights: + pytorch_state_dict: + source: ./weights + sha256: 518cb80bad2eb3ec3dfbe6bab74920951391ce8fb24e15cf59b9b9f052a575a6 + authors: + - name: "ilastik-team" + affiliation: "EMBL Heidelberg" + + +# TODO double check inputs/outputs +inputs: + - name: param + axes: bcyx + data_type: float32 + data_range: [-inf, inf] + shape: + min: [1, 1, 64, 64] + step: [0, 0, 2, 1] + +outputs: + - name: output + axes: bcyx + data_type: float32 + data_range: [0, 1] + shape: + reference_tensor: param + scale: [1, 1, 1, 1] + offset: [0, 0, 0, 0] + halo: [0, 0, 8, 8] + +type: model diff --git a/tests/data/dummy_param/dummy.md b/tests/data/dummy_param/dummy.md new file mode 100644 index 00000000..e69de29b diff --git a/tests/data/dummy_param/dummy.py b/tests/data/dummy_param/dummy.py new file mode 100644 index 00000000..195e98e3 --- /dev/null +++ b/tests/data/dummy_param/dummy.py @@ -0,0 +1,7 @@ +from torch import nn + + +class Dummy(nn.Module): + def forward(self, input): + x = input + return x + 1 diff --git a/tests/data/dummy_param/dummy_in.npy b/tests/data/dummy_param/dummy_in.npy new file mode 100644 index 0000000000000000000000000000000000000000..96a78a7b87dbeef0611f2cf9c97f5b0f72cf8396 GIT binary patch literal 65664 zcmeIuF$%&!6a>)NdW!8ic9?hT-PR-do;;I-_!YTjq~sd zf7UR|-Yeq=3^2d|0}L?000Rs#zyJdbFu(u<3^2d|0}L?000Rs#zyJdbFu(u<3^2d| z0}L?000Rs#zyJdbFu(u<3^2d|0}L?000Rs#zyJdbFu(u<3^2d|0}L?000Rs#zyJdb zFu(u<3^2d|0}L?000Rs#zyJdbFu(u<3^2d|0}L?000Rs#zyJdbFu(u<3^2d|0}L?0 z00Rs#zyJdbFu(u<3^2d|0}L?000Rs#zyJdbFu(u<3^2d|0}L?000Rs#zyJdbFu(u< z3^2d|0}L?000Rs#zyJdbFu(u<3^2d|0}L?000Rs#zyJdbFu(u<3^2d|0}L?000Rs# zzyJdbFu(u<3^2d|0}L?000Rs#zyJdbFu(u<3^2d|0}L?000Rs#zyJdbFu(u<3^2d| P0}L?000Rs#zyJeZi2B~R literal 0 HcmV?d00001 diff --git a/tests/data/dummy_param/environment.yaml b/tests/data/dummy_param/environment.yaml new file mode 100644 index 00000000..e69de29b diff --git a/tests/data/dummy_param/weights b/tests/data/dummy_param/weights new file mode 100644 index 0000000000000000000000000000000000000000..da14f34253e8e85c58d498b94b7ed3b933de0dc6 GIT binary patch literal 232 zcmXwz!D<3Q42EZCXI)$ndhM-8z4#R3Wf7&PTx8kq1|1n^%uEz4EO;zFLm!~8)<-Dx z(0J%2B;TL>e{T9-x!=#_&&%O!^E|)l)pss+AT@<2rPL_~4qb4~1!JA5ye?% zhYkAPwx09{R08>W!Y0{wOq35~rcek>`w*FmT0<_A^-QDMfO&WXzoW7?=d)yX663IA zNlDc@_875W*p4ewvscOn(lC~r=7`+(Ew_~KAA$jaii&|oP)hgzXjx||r8hC&lA8Yk D#+O8X literal 0 HcmV?d00001 From 127d1621bb9e5fba23edaa46a1cff06d3894fd47 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Sat, 10 Aug 2024 23:41:58 +0200 Subject: [PATCH 4/8] Use bioimageio prediction pipeline as model from client and server --- tests/conftest.py | 10 ++ tests/test_rpc/test_mp.py | 8 +- .../test_grpc/test_inference_servicer.py | 37 +++++- tiktorch/rpc/mp.py | 23 ++-- tiktorch/server/grpc/inference_servicer.py | 6 +- tiktorch/server/session/process.py | 124 +++++++++++++++--- tiktorch/server/session/rpc_interface.py | 17 ++- tiktorch/server/session_manager.py | 18 ++- 8 files changed, 200 insertions(+), 43 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 31929f3e..2d840f59 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -115,6 +115,16 @@ def bioimageio_dummy_model_filepath(data_path, tmpdir): @pytest.fixture def bioimageio_dummy_model_bytes(data_path): rdf_source = data_path / TEST_BIOIMAGEIO_DUMMY / "Dummy.model.yaml" + return _bioimageio_package(rdf_source) + + +@pytest.fixture +def bioimageio_dummy_param_model_bytes(data_path): + rdf_source = data_path / "dummy_param" / "Dummy.model_param.yaml" + return _bioimageio_package(rdf_source) + + +def _bioimageio_package(rdf_source): data = io.BytesIO() export_resource_package(rdf_source, output_path=data) return data diff --git a/tests/test_rpc/test_mp.py b/tests/test_rpc/test_mp.py index 981f4553..ec539cac 100644 --- a/tests/test_rpc/test_mp.py +++ b/tests/test_rpc/test_mp.py @@ -64,7 +64,7 @@ def client(log_queue): p = mp.Process(target=_srv, args=(parent, log_queue)) p.start() - client = create_client(ITestApi, child, timeout=10) + client = create_client(iface_cls=ITestApi, conn=child, timeout=10) yield client @@ -108,7 +108,7 @@ def __getattr__(self, name): p = mp.Process(target=_srv, args=(parent, log_queue)) p.start() - client = create_client(ITestApi, SlowConn(child)) + client = create_client(iface_cls=ITestApi, conn=SlowConn(child)) client.fast_compute(2, 2) @@ -121,7 +121,7 @@ def test_future_timeout(client: ITestApi, log_queue): p = mp.Process(target=_srv, args=(parent, log_queue)) p.start() - client = create_client(ITestApi, child, timeout=0.001) + client = create_client(iface_cls=ITestApi, conn=child, timeout=0.001) with pytest.raises(TimeoutError): client.compute(1, 2) @@ -256,7 +256,7 @@ def _spawn(iface_cls, srv_cls): p = mp.Process(target=_run_srv, args=(srv_cls, parent, log_queue)) p.start() - data["client"] = client = create_client(iface_cls, child) + data["client"] = client = create_client(iface_cls=iface_cls, conn=child) data["process"] = p return client diff --git a/tests/test_server/test_grpc/test_inference_servicer.py b/tests/test_server/test_grpc/test_inference_servicer.py index cbf64477..c45d9f3d 100644 --- a/tests/test_server/test_grpc/test_inference_servicer.py +++ b/tests/test_server/test_grpc/test_inference_servicer.py @@ -156,25 +156,56 @@ def test_call_fails_with_unknown_model_session_id(self, grpc_stub): def test_call_predict(self, grpc_stub, bioimageio_dummy_model_bytes): model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_model_bytes)) - arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y")) + arr = xr.DataArray(np.arange(128 * 128).reshape(1, 1, 128, 128), dims=("b", "c", "x", "y")) expected = arr + 1 - input_tensors = [converters.xarray_to_pb_tensor("input", arr)] + input_spec_id = "input" + output_spec_id = "output" + input_tensors = [converters.xarray_to_pb_tensor(input_spec_id, arr)] res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) grpc_stub.CloseModelSession(model) assert len(res.tensors) == 1 + assert res.tensors[0].specId == output_spec_id assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0])) + def test_call_predict_invalid_shape_explicit(self, grpc_stub, bioimageio_dummy_model_bytes): + model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_model_bytes)) + arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y")) + input_tensors = [converters.xarray_to_pb_tensor("input", arr)] + with pytest.raises(grpc.RpcError): + grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + grpc_stub.CloseModelSession(model) + + @pytest.mark.parametrize("shape", [(1, 1, 64, 32), (1, 1, 32, 64), (1, 1, 64, 32), (0, 1, 64, 64), (1, 0, 64, 64)]) + def test_call_predict_invalid_shape_parameterized(self, grpc_stub, shape, bioimageio_dummy_param_model_bytes): + model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_param_model_bytes)) + arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("b", "c", "x", "y")) + input_tensors = [converters.xarray_to_pb_tensor("param", arr)] + with pytest.raises(grpc.RpcError): + grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + grpc_stub.CloseModelSession(model) + + @pytest.mark.parametrize("shape", [(1, 1, 64, 64), (1, 1, 66, 65), (1, 1, 68, 66), (1, 1, 70, 67)]) + def test_call_predict_valid_shape_parameterized(self, grpc_stub, shape, bioimageio_dummy_param_model_bytes): + model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_param_model_bytes)) + arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("b", "c", "x", "y")) + input_tensors = [converters.xarray_to_pb_tensor("param", arr)] + grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + grpc_stub.CloseModelSession(model) + @pytest.mark.skip def test_call_predict_tf(self, grpc_stub, bioimageio_dummy_tensorflow_model_bytes): model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_tensorflow_model_bytes)) arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y")) expected = arr * -1 - input_tensors = [converters.xarray_to_pb_tensor(arr)] + input_spec_id = "input" + output_spec_id = "output" + input_tensors = [converters.xarray_to_pb_tensor(input_spec_id, arr)] res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) grpc_stub.CloseModelSession(model) assert len(res.tensors) == 1 + assert res.tensors[0].specId == output_spec_id assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0])) diff --git a/tiktorch/rpc/mp.py b/tiktorch/rpc/mp.py index bab3ab1f..a51cf924 100644 --- a/tiktorch/rpc/mp.py +++ b/tiktorch/rpc/mp.py @@ -5,7 +5,7 @@ from functools import wraps from multiprocessing.connection import Connection from threading import Event, Thread -from typing import Any, Optional, Type, TypeVar +from typing import Any, Dict, Optional, Type, TypeVar from uuid import uuid4 from .exceptions import Shutdown @@ -72,9 +72,8 @@ def __call__(self, *args, **kwargs) -> Any: return self._client._invoke(self._method_name, *args, **kwargs) -def create_client(iface_cls: Type[T], conn: Connection, timeout=None) -> T: +def create_client(iface_cls: Type[T], conn: Connection, api_kwargs: Optional[Dict[str, any]] = None, timeout=None) -> T: client = MPClient(iface_cls.__name__, conn, timeout) - get_exposed_methods(iface_cls) def _make_method(method): class MethodWrapper: @@ -98,12 +97,15 @@ def __call__(self, *args, **kwargs) -> Any: return MethodWrapper() class _Client(iface_cls): - pass + def __init__(self, kwargs: Optional[Dict]): + kwargs = kwargs or {} + super().__init__(**kwargs) - for method_name, method in get_exposed_methods(iface_cls).items(): + exposed_methods = get_exposed_methods(iface_cls) + for method_name, method in exposed_methods.items(): setattr(_Client, method_name, _make_method(method)) - return _Client() + return _Client(api_kwargs) class MPClient: @@ -190,7 +192,7 @@ def _shutdown(self, exc): class Message: def __init__(self, id_): - self.id = id + self.id = id_ class Signal: @@ -200,20 +202,19 @@ def __init__(self, payload): class MethodCall(Message): def __init__(self, id_, method_name, args, kwargs): - self.id = id_ + super().__init__(id_) self.method_name = method_name self.args = args self.kwargs = kwargs class Cancellation(Message): - def __init__(self, id_): - self.id = id_ + pass class MethodReturn(Message): def __init__(self, id_, result: Result): - self.id = id_ + super().__init__(id_) self.result = result diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py index db0f2c17..5e61b46b 100644 --- a/tiktorch/server/grpc/inference_servicer.py +++ b/tiktorch/server/grpc/inference_servicer.py @@ -6,7 +6,7 @@ from tiktorch.proto import inference_pb2, inference_pb2_grpc from tiktorch.server.data_store import IDataStore from tiktorch.server.device_pool import DeviceStatus, IDevicePool -from tiktorch.server.session.process import start_model_session_process +from tiktorch.server.session.process import ShapeValidator, start_model_session_process from tiktorch.server.session_manager import ISession, SessionManager @@ -36,7 +36,7 @@ def CreateModelSession( lease.terminate() raise - session = self.__session_manager.create_session() + session = self.__session_manager.create_session(client) session.on_close(lease.terminate) session.on_close(client.shutdown) @@ -76,6 +76,8 @@ def ListDevices(self, request: inference_pb2.Empty, context) -> inference_pb2.De def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse: session = self._getModelSession(context, request.modelSessionId) tensors = set([converters.pb_tensor_to_tensor(tensor) for tensor in request.tensors]) + shape_validator = ShapeValidator(session.client.model) + shape_validator.check_tensors(tensors) res = session.client.forward(tensors) output_spec_ids = [spec.name for spec in session.client.model.output_specs] assert len(output_spec_ids) == len(res) diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index 2147df46..72e08b14 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -1,38 +1,112 @@ import multiprocessing as _mp -import os import pathlib import tempfile import uuid from concurrent.futures import Future from multiprocessing.connection import Connection -from typing import List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Set, Tuple -import numpy +import numpy as np from bioimageio.core import load_resource_description from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline +from bioimageio.core.resource_io import nodes +from bioimageio.core.resource_io.nodes import ParametrizedInputShape from tiktorch import log from tiktorch.rpc import Shutdown from tiktorch.rpc import mp as _mp_rpc from tiktorch.rpc.mp import MPServer +from ...converters import Tensor from .backend import base from .rpc_interface import IRPCModelSession -class ModelSessionProcess(IRPCModelSession): - def __init__(self, model_zip: bytes, devices: List[str]) -> None: - _tmp_file = tempfile.NamedTemporaryFile(suffix=".zip", delete=False) - _tmp_file.write(model_zip) - _tmp_file.close() - model = load_resource_description(pathlib.Path(_tmp_file.name)) - os.unlink(_tmp_file.name) - self._model: PredictionPipeline = create_prediction_pipeline(bioimageio_model=model, devices=devices) +class ShapeValidator: + def __init__(self, model: PredictionPipeline): + self._model = model + + def check_tensors(self, tensors: Set[Tensor]): + for tensor in tensors: + self.check_shape(tensor.spec_id, tensor.data.dims, tensor.data.shape) + + def check_shape(self, spec_id: str, axes: Tuple[str, ...], shape: Tuple[int, ...]): + shape = self._get_axes_with_size(axes, shape) + spec = self._get_input_spec(spec_id) + if isinstance(spec.shape, list): + self._check_shape_explicit(spec, shape) + elif isinstance(spec.shape, ParametrizedInputShape): + self._check_shape_parameterized(spec, shape) + else: + raise ValueError(f"Unexpected shape {spec.shape}") + + def _get_input_spec(self, spec_id: str) -> nodes.InputTensor: + self._check_spec_exists(spec_id) + specs = [spec for spec in self._model.input_specs if spec.name == spec_id] + assert len(specs) == 1, "ids of tensor specs should be unique" + return specs[0] + + def _check_spec_exists(self, spec_id: str): + spec_names = [spec.name for spec in self._model.input_specs] + if spec_id not in spec_names: + raise ValueError(f"Spec {spec_id} doesn't exist for specs {spec_names}") + + def _check_shape_explicit(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]): + assert self._is_shape_explicit(spec) + reference_shape = {name: size for name, size in zip(spec.axes, spec.shape)} + if reference_shape != tensor_shape: + raise ValueError(f"Incompatible shapes found {tensor_shape}, expected {reference_shape}") + + def _check_shape_parameterized(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]): + assert isinstance(spec.shape, ParametrizedInputShape) + if not self._is_shape(tensor_shape.values()): + raise ValueError(f"Invalid shape's sizes {tensor_shape}") + + min_shape = self._get_axes_with_size(spec.axes, tuple(spec.shape.min)) + step = self._get_axes_with_size(spec.axes, tuple(spec.shape.step)) + assert min_shape.keys() == step.keys() + if tensor_shape.keys() != min_shape.keys(): + raise ValueError(f"Incompatible axes for tensor {tensor_shape} and spec {spec}") + + tensor_shapes_arr = np.array(list(tensor_shape.values())) + min_shape_arr = np.array(list(min_shape.values())) + step_arr = np.array(list(step.values())) + diff = tensor_shapes_arr - min_shape_arr + if any(size < 0 for size in diff): + raise ValueError(f"Tensor shape {tensor_shape} smaller than min shape {min_shape}") + + non_zero_idx = np.nonzero(step_arr) + multipliers = diff[non_zero_idx] / step_arr[non_zero_idx] + multiplier = np.unique(multipliers) + if len(multiplier) == 1 and self._is_natural_number(multiplier[0]): + return + raise ValueError(f"Tensor shape {tensor_shape} not valid for spec {spec}") + + def _is_natural_number(self, n) -> bool: + return np.floor(n) == np.ceil(n) and n >= 0 + + def _is_shape(self, shape: Iterator[int]) -> bool: + return all(self._is_natural_number(dim) for dim in shape) + + def _get_axes_with_size(self, axes: Tuple[str, ...], shape: Tuple[int, ...]) -> Dict[str, int]: + assert len(axes) == len(shape) + return {name: size for name, size in zip(axes, shape)} + + def _is_shape_explicit(self, spec: nodes.InputTensor) -> bool: + return isinstance(spec.shape, list) + + +class ModelSessionProcess(IRPCModelSession[PredictionPipeline]): + def __init__(self, model: PredictionPipeline) -> None: + super().__init__(model) self._datasets = {} self._worker = base.SessionBackend(self._model) + self._shape_validator = ShapeValidator(self._model) - def forward(self, input_tensors: numpy.ndarray) -> Future: - res = self._worker.forward(input_tensors) + def forward(self, input_tensors: Set[Tensor]) -> Future: + self._shape_validator.check_tensors(input_tensors) + tensors_data = [tensor.data for tensor in input_tensors] + res = self._worker.forward(tensors_data) return res def create_dataset(self, mean, stddev): @@ -46,7 +120,7 @@ def shutdown(self) -> Shutdown: def _run_model_session_process( - conn: Connection, model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None + conn: Connection, prediction_pipeline: PredictionPipeline, log_queue: Optional[_mp.Queue] = None ): try: # from: https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667 @@ -60,7 +134,7 @@ def _run_model_session_process( if log_queue: log.configure(log_queue) - session_proc = ModelSessionProcess(model_zip, devices) + session_proc = ModelSessionProcess(prediction_pipeline) srv = MPServer(session_proc, conn) srv.listen() @@ -69,10 +143,26 @@ def start_model_session_process( model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None ) -> Tuple[_mp.Process, IRPCModelSession]: client_conn, server_conn = _mp.Pipe() + prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_zip, devices) proc = _mp.Process( target=_run_model_session_process, name="ModelSessionProcess", - kwargs={"conn": server_conn, "devices": devices, "log_queue": log_queue, "model_zip": model_zip}, + kwargs={ + "conn": server_conn, + "log_queue": log_queue, + "prediction_pipeline": prediction_pipeline, + }, ) proc.start() - return proc, _mp_rpc.create_client(IRPCModelSession, client_conn) + # here create the prediction pipeline, share it to the model session class and the client + return proc, _mp_rpc.create_client( + iface_cls=IRPCModelSession, api_kwargs={"model": prediction_pipeline}, conn=client_conn + ) + + +def _get_prediction_pipeline_from_model_bytes(model_zip: bytes, devices: List[str]) -> PredictionPipeline: + with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as _tmp_file: + _tmp_file.write(model_zip) + temp_file_path = pathlib.Path(_tmp_file.name) + model = load_resource_description(temp_file_path) + return create_prediction_pipeline(bioimageio_model=model, devices=devices) diff --git a/tiktorch/server/session/rpc_interface.py b/tiktorch/server/session/rpc_interface.py index 2e505d18..b75a9a81 100644 --- a/tiktorch/server/session/rpc_interface.py +++ b/tiktorch/server/session/rpc_interface.py @@ -1,11 +1,22 @@ -from typing import List +from typing import Generic, List, Set, TypeVar +from tiktorch.converters import Tensor from tiktorch.rpc import RPCInterface, Shutdown, exposed from tiktorch.tiktypes import TikTensorBatch from tiktorch.types import ModelState +ModelType = TypeVar("ModelType") + + +class IRPCModelSession(RPCInterface, Generic[ModelType]): + def __init__(self, model: ModelType): + super().__init__() + self._model = model + + @property + def model(self): + return self._model -class IRPCModelSession(RPCInterface): @exposed def shutdown(self) -> Shutdown: raise NotImplementedError @@ -43,5 +54,5 @@ def create_dataset_description(self, mean, stddev) -> str: raise NotImplementedError @exposed - def forward(self, input_tensors): + def forward(self, input_tensors: Set[Tensor]): raise NotImplementedError diff --git a/tiktorch/server/session_manager.py b/tiktorch/server/session_manager.py index 37bc07ab..86a8a1ea 100644 --- a/tiktorch/server/session_manager.py +++ b/tiktorch/server/session_manager.py @@ -7,6 +7,8 @@ from typing import Callable, Dict, List, Optional from uuid import uuid4 +from tiktorch.server.session import IRPCModelSession + logger = getLogger(__name__) @@ -24,6 +26,11 @@ def id(self) -> str: """ ... + @property + @abc.abstractmethod + def client(self) -> IRPCModelSession: + ... + @abc.abstractmethod def on_close(self, handler: CloseCallback) -> None: """ @@ -36,9 +43,14 @@ def on_close(self, handler: CloseCallback) -> None: class _Session(ISession): - def __init__(self, id_: str, manager: SessionManager) -> None: + def __init__(self, id_: str, client: IRPCModelSession, manager: SessionManager) -> None: self.__id = id_ self.__manager = manager + self.__client = client + + @property + def client(self) -> IRPCModelSession: + return self.__client @property def id(self) -> str: @@ -53,13 +65,13 @@ class SessionManager: Manages session lifecycle (create/close) """ - def create_session(self) -> ISession: + def create_session(self, client: IRPCModelSession) -> ISession: """ Creates new session with unique id """ with self.__lock: session_id = uuid4().hex - session = _Session(session_id, manager=self) + session = _Session(session_id, client=client, manager=self) self.__session_by_id[session_id] = session logger.info("Created session %s", session.id) return session From d838567b363cc5652298b9c89a8c06487d6dea58 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Tue, 13 Aug 2024 16:05:40 +0200 Subject: [PATCH 5/8] Implement code review suggestions - Use the class `Sample` to associate xr.DataArray with tensor id coming from the model - Do the checks only from the client side - Add checking about axes validity - Add tests --- proto/inference.proto | 2 +- tests/conftest.py | 18 +++- tests/test_converters.py | 91 ++++++++++++++++--- .../test_grpc/test_inference_servicer.py | 53 ++++++++--- tiktorch/converters.py | 36 ++++---- tiktorch/proto/inference_pb2.py | 34 +++---- tiktorch/server/grpc/inference_servicer.py | 21 ++--- tiktorch/server/session/process.py | 49 +++++----- tiktorch/server/session/rpc_interface.py | 6 +- 9 files changed, 202 insertions(+), 108 deletions(-) diff --git a/proto/inference.proto b/proto/inference.proto index 1959374d..39845dc5 100644 --- a/proto/inference.proto +++ b/proto/inference.proto @@ -120,7 +120,7 @@ message NamedFloat { message Tensor { bytes buffer = 1; string dtype = 2; - string specId = 3; + string tensorId = 3; repeated NamedInt shape = 4; } diff --git a/tests/conftest.py b/tests/conftest.py index 2d840f59..1118935a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,9 @@ TEST_DATA = "data" TEST_BIOIMAGEIO_ZIPFOLDER = "unet2d" TEST_BIOIMAGEIO_ONNX = "unet2d_onnx" -TEST_BIOIMAGEIO_DUMMY = "dummy" +TEST_BIOIMAGEIO_DUMMY_EXPLICIT = "dummy" +TEST_BIOIMAGEIO_DUMMY_EXPLICIT_RDF = f"{TEST_BIOIMAGEIO_DUMMY_EXPLICIT}/Dummy.model.yaml" +TEST_BIOIMAGEIO_DUMMY_PARAM_RDF = "dummy_param/Dummy.model_param.yaml" TEST_BIOIMAGEIO_TENSORFLOW_DUMMY = "dummy_tensorflow" TEST_BIOIMAGEIO_TORCHSCRIPT = "unet2d_torchscript" @@ -98,7 +100,7 @@ def bioimageio_model_zipfile(bioimageio_model_bytes): @pytest.fixture def bioimageio_dummy_model_filepath(data_path, tmpdir): - bioimageio_net_dir = Path(data_path) / TEST_BIOIMAGEIO_DUMMY + bioimageio_net_dir = Path(data_path) / TEST_BIOIMAGEIO_DUMMY_EXPLICIT path = tmpdir / "dummy_model.zip" with ZipFile(path, mode="w") as zip_model: @@ -113,17 +115,23 @@ def bioimageio_dummy_model_filepath(data_path, tmpdir): @pytest.fixture -def bioimageio_dummy_model_bytes(data_path): - rdf_source = data_path / TEST_BIOIMAGEIO_DUMMY / "Dummy.model.yaml" +def bioimageio_dummy_explicit_model_bytes(data_path): + rdf_source = data_path / TEST_BIOIMAGEIO_DUMMY_EXPLICIT_RDF return _bioimageio_package(rdf_source) @pytest.fixture def bioimageio_dummy_param_model_bytes(data_path): - rdf_source = data_path / "dummy_param" / "Dummy.model_param.yaml" + rdf_source = data_path / TEST_BIOIMAGEIO_DUMMY_PARAM_RDF return _bioimageio_package(rdf_source) +@pytest.fixture(params=[(TEST_BIOIMAGEIO_DUMMY_PARAM_RDF, "param"), (TEST_BIOIMAGEIO_DUMMY_EXPLICIT_RDF, "input")]) +def bioimageio_dummy_model(request, data_path): + path, tensor_id = request.param + yield _bioimageio_package(data_path / path), tensor_id + + def _bioimageio_package(rdf_source): data = io.BytesIO() export_resource_package(rdf_source, output_path=data) diff --git a/tests/test_converters.py b/tests/test_converters.py index 33586a7b..8d6e4c09 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -7,12 +7,12 @@ NamedExplicitOutputShape, NamedImplicitOutputShape, NamedParametrizedShape, - Tensor, + Sample, input_shape_to_pb_input_shape, numpy_to_pb_tensor, output_shape_to_pb_output_shape, pb_tensor_to_numpy, - pb_tensor_to_tensor, + pb_tensor_to_xarray, xarray_to_pb_tensor, ) from tiktorch.proto import inference_pb2 @@ -28,11 +28,11 @@ def _numpy_to_pb_tensor(arr): return parsed -def to_pb_tensor(spec_id: str, arr: xr.DataArray): +def to_pb_tensor(tensor_id: str, arr: xr.DataArray): """ Makes sure that tensor was serialized/deserialized """ - tensor = xarray_to_pb_tensor(spec_id, arr) + tensor = xarray_to_pb_tensor(tensor_id, arr) parsed = inference_pb2.Tensor() parsed.ParseFromString(tensor.SerializeToString()) return parsed @@ -141,40 +141,40 @@ class TestPBTensorToXarray: def test_should_raise_on_empty_dtype(self): tensor = inference_pb2.Tensor(dtype="", shape=[inference_pb2.NamedInt(size=1), inference_pb2.NamedInt(size=2)]) with pytest.raises(ValueError): - pb_tensor_to_tensor(tensor) + pb_tensor_to_xarray(tensor) def test_should_raise_on_empty_shape(self): tensor = inference_pb2.Tensor(dtype="int64", shape=[]) with pytest.raises(ValueError): - pb_tensor_to_tensor(tensor) + pb_tensor_to_xarray(tensor) - def test_should_return_tensor(self): + def test_should_return_xarray(self): arr = xr.DataArray(np.arange(9)) parsed = to_pb_tensor("input0", arr) - result_tensor = pb_tensor_to_tensor(parsed) - assert isinstance(result_tensor, Tensor) + result_tensor = pb_tensor_to_xarray(parsed) + assert isinstance(result_tensor, xr.DataArray) @pytest.mark.parametrize("np_dtype,dtype_str", [(np.int64, "int64"), (np.uint8, "uint8"), (np.float32, "float32")]) def test_should_have_same_dtype(self, np_dtype, dtype_str): arr = xr.DataArray(np.arange(9, dtype=np_dtype)) pb_tensor = to_pb_tensor("input0", arr) - tensor = pb_tensor_to_tensor(pb_tensor) + result_arr = pb_tensor_to_xarray(pb_tensor) - assert arr.dtype == tensor.data.dtype + assert arr.dtype == result_arr.dtype @pytest.mark.parametrize("shape", [(3, 3), (1,), (1, 1), (18, 20, 1)]) def test_should_same_shape(self, shape): arr = xr.DataArray(np.zeros(shape)) pb_tensor = to_pb_tensor("input0", arr) - tensor = pb_tensor_to_tensor(pb_tensor) - assert arr.shape == tensor.data.shape + result_arr = pb_tensor_to_xarray(pb_tensor) + assert arr.shape == result_arr.shape @pytest.mark.parametrize("shape", [(3, 3), (1,), (1, 1), (18, 20, 1)]) def test_should_same_data(self, shape): arr = xr.DataArray(np.random.random(shape)) pb_tensor = to_pb_tensor("input0", arr) - tensor = pb_tensor_to_tensor(pb_tensor) - assert_array_equal(arr, tensor.data) + result_arr = pb_tensor_to_xarray(pb_tensor) + assert_array_equal(arr, result_arr) class TestShapeConversions: @@ -268,3 +268,64 @@ def test_parametrized_input_shape(self, min_shape, axes, step): assert [(d.name, d.size) for d in pb_shape.stepShape.namedInts] == [ (name, size) for name, size in zip(axes, step) ] + + +class TestSample: + def test_create_sample_from_pb_tensors(self): + arr_1 = np.arange(32 * 32, dtype=np.int64).reshape(32, 32) + tensor_1 = inference_pb2.Tensor( + dtype="int64", + tensorId="input1", + buffer=bytes(arr_1), + shape=[inference_pb2.NamedInt(name="x", size=32), inference_pb2.NamedInt(name="y", size=32)], + ) + + arr_2 = np.arange(64 * 64, dtype=int).reshape(64, 64) + tensor_2 = inference_pb2.Tensor( + dtype="int64", + tensorId="input2", + buffer=bytes(arr_2), + shape=[inference_pb2.NamedInt(name="x", size=64), inference_pb2.NamedInt(name="y", size=64)], + ) + + sample = Sample.from_pb_tensors([tensor_1, tensor_2]) + assert len(sample.tensors) == 2 + assert sample.tensors["input1"].equals(xr.DataArray(arr_1, dims=["x", "y"])) + assert sample.tensors["input2"].equals(xr.DataArray(arr_2, dims=["x", "y"])) + + def test_create_sample_from_raw_data(self): + arr_1 = np.arange(32 * 32, dtype=np.int64).reshape(32, 32) + tensor_1 = xr.DataArray(arr_1, dims=["x", "y"]) + arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64) + tensor_2 = xr.DataArray(arr_2, dims=["x", "y"]) + tensors_ids = ["input1", "input2"] + actual_sample = Sample.from_raw_data(tensors_ids, [tensor_1, tensor_2]) + + expected_dict = {tensors_ids[0]: tensor_1, tensors_ids[1]: tensor_2} + expected_sample = Sample(expected_dict) + assert actual_sample == expected_sample + + def test_sample_to_pb_tensors(self): + arr_1 = np.arange(32 * 32, dtype=np.int64).reshape(32, 32) + tensor_1 = xr.DataArray(arr_1, dims=["x", "y"]) + arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64) + tensor_2 = xr.DataArray(arr_2, dims=["x", "y"]) + tensors_ids = ["input1", "input2"] + sample = Sample.from_raw_data(tensors_ids, [tensor_1, tensor_2]) + + pb_tensor_1 = inference_pb2.Tensor( + dtype="int64", + tensorId="input1", + buffer=bytes(arr_1), + shape=[inference_pb2.NamedInt(name="x", size=32), inference_pb2.NamedInt(name="y", size=32)], + ) + pb_tensor_2 = inference_pb2.Tensor( + dtype="int64", + tensorId="input2", + buffer=bytes(arr_2), + shape=[inference_pb2.NamedInt(name="x", size=64), inference_pb2.NamedInt(name="y", size=64)], + ) + expected_tensors = [pb_tensor_1, pb_tensor_2] + + actual_tensors = sample.to_pb_tensors() + assert expected_tensors == actual_tensors diff --git a/tests/test_server/test_grpc/test_inference_servicer.py b/tests/test_server/test_grpc/test_inference_servicer.py index c45d9f3d..864c4d6a 100644 --- a/tests/test_server/test_grpc/test_inference_servicer.py +++ b/tests/test_server/test_grpc/test_inference_servicer.py @@ -49,8 +49,8 @@ def test_model_session_creation(self, grpc_stub, bioimageio_model_bytes): assert model.id grpc_stub.CloseModelSession(model) - def test_model_session_creation_using_upload_id(self, grpc_stub, data_store, bioimageio_dummy_model_bytes): - id_ = data_store.put(bioimageio_dummy_model_bytes.getvalue()) + def test_model_session_creation_using_upload_id(self, grpc_stub, data_store, bioimageio_dummy_explicit_model_bytes): + id_ = data_store.put(bioimageio_dummy_explicit_model_bytes.getvalue()) rq = inference_pb2.CreateModelSessionRequest(model_uri=f"upload://{id_}", deviceIds=["cpu"]) model = grpc_stub.CreateModelSession(rq) @@ -154,30 +154,33 @@ def test_call_fails_with_unknown_model_session_id(self, grpc_stub): assert grpc.StatusCode.FAILED_PRECONDITION == e.value.code() assert "model-session with id myid1 doesn't exist" in e.value.details() - def test_call_predict(self, grpc_stub, bioimageio_dummy_model_bytes): - model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_model_bytes)) + def test_call_predict_valid_explicit(self, grpc_stub, bioimageio_dummy_explicit_model_bytes): + model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_explicit_model_bytes)) arr = xr.DataArray(np.arange(128 * 128).reshape(1, 1, 128, 128), dims=("b", "c", "x", "y")) expected = arr + 1 - input_spec_id = "input" - output_spec_id = "output" - input_tensors = [converters.xarray_to_pb_tensor(input_spec_id, arr)] + input_tensor_id = "input" + output_tensor_id = "output" + input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)] res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) grpc_stub.CloseModelSession(model) assert len(res.tensors) == 1 - assert res.tensors[0].specId == output_spec_id + assert res.tensors[0].tensorId == output_tensor_id assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0])) - def test_call_predict_invalid_shape_explicit(self, grpc_stub, bioimageio_dummy_model_bytes): - model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_model_bytes)) + def test_call_predict_invalid_shape_explicit(self, grpc_stub, bioimageio_dummy_explicit_model_bytes): + model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_explicit_model_bytes)) arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y")) input_tensors = [converters.xarray_to_pb_tensor("input", arr)] with pytest.raises(grpc.RpcError): grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) grpc_stub.CloseModelSession(model) - @pytest.mark.parametrize("shape", [(1, 1, 64, 32), (1, 1, 32, 64), (1, 1, 64, 32), (0, 1, 64, 64), (1, 0, 64, 64)]) + @pytest.mark.parametrize( + "shape", + [(1, 1, 64, 32), (1, 1, 32, 64), (1, 1, 64, 32), (0, 1, 64, 64), (1, 0, 64, 64)], + ) def test_call_predict_invalid_shape_parameterized(self, grpc_stub, shape, bioimageio_dummy_param_model_bytes): model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_param_model_bytes)) arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("b", "c", "x", "y")) @@ -186,6 +189,26 @@ def test_call_predict_invalid_shape_parameterized(self, grpc_stub, shape, bioima grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) grpc_stub.CloseModelSession(model) + def test_call_predict_invalid_tensor_ids(self, grpc_stub, bioimageio_dummy_model): + model_bytes, _ = bioimageio_dummy_model + model = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) + arr = xr.DataArray(np.arange(32 * 32).reshape(32, 32), dims=("x", "y")) + input_tensors = [converters.xarray_to_pb_tensor("invalidTensorName", arr)] + with pytest.raises(grpc.RpcError) as error: + grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + assert error.value.details().startswith("Exception calling application: Spec invalidTensorName doesn't exist") + grpc_stub.CloseModelSession(model) + + def test_call_predict_invalid_axes(self, grpc_stub, bioimageio_dummy_model): + model_bytes, tensor_id = bioimageio_dummy_model + model = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) + arr = xr.DataArray(np.arange(32 * 32).reshape(32, 32), dims=("invalidAxis", "y")) + input_tensors = [converters.xarray_to_pb_tensor(tensor_id, arr)] + with pytest.raises(grpc.RpcError) as error: + grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + assert error.value.details().startswith("Exception calling application: Incompatible axes") + grpc_stub.CloseModelSession(model) + @pytest.mark.parametrize("shape", [(1, 1, 64, 64), (1, 1, 66, 65), (1, 1, 68, 66), (1, 1, 70, 67)]) def test_call_predict_valid_shape_parameterized(self, grpc_stub, shape, bioimageio_dummy_param_model_bytes): model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_param_model_bytes)) @@ -199,13 +222,13 @@ def test_call_predict_tf(self, grpc_stub, bioimageio_dummy_tensorflow_model_byte model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_tensorflow_model_bytes)) arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y")) expected = arr * -1 - input_spec_id = "input" - output_spec_id = "output" - input_tensors = [converters.xarray_to_pb_tensor(input_spec_id, arr)] + input_tensor_id = "input" + output_tensor_id = "output" + input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)] res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) grpc_stub.CloseModelSession(model) assert len(res.tensors) == 1 - assert res.tensors[0].specId == output_spec_id + assert res.tensors[0].tensorId == output_tensor_id assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0])) diff --git a/tiktorch/converters.py b/tiktorch/converters.py index 7fc69ba0..1cc488da 100644 --- a/tiktorch/converters.py +++ b/tiktorch/converters.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import dataclasses -from typing import List, Tuple, Union +from typing import Dict, List, Tuple, Union import numpy as np import xarray as xr @@ -34,22 +36,20 @@ class NamedImplicitOutputShape: @dataclasses.dataclass(frozen=True) -class Tensor: - spec_id: str - data: xr.DataArray +class Sample: + tensors: Dict[str, xr.DataArray] - def __hash__(self): - return hash(self.spec_id) + @classmethod + def from_pb_tensors(cls, pb_tensors: List[inference_pb2.Tensor]) -> Sample: + return Sample({tensor.tensorId: pb_tensor_to_xarray(tensor) for tensor in pb_tensors}) - def __eq__(self, other): - if isinstance(other, Tensor): - return self.spec_id == other.spec_id - return False + @classmethod + def from_raw_data(cls, tensor_ids: List[str], tensors_data: List[xr.DataArray]): + assert len(tensor_ids) == len(tensors_data) + return Sample({tensor_id: tensor_data for tensor_id, tensor_data in zip(tensor_ids, tensors_data)}) - def equals(self, other): - if not isinstance(other, Tensor): - return False - return self.__dict__.items() == other.__dict__.items() + def to_pb_tensors(self) -> List[inference_pb2.Tensor]: + return [xarray_to_pb_tensor(tensor_id, res_tensor) for tensor_id, res_tensor in self.tensors.items()] def numpy_to_pb_tensor(array: np.ndarray, axistags=None) -> inference_pb2.Tensor: @@ -60,9 +60,9 @@ def numpy_to_pb_tensor(array: np.ndarray, axistags=None) -> inference_pb2.Tensor return inference_pb2.Tensor(dtype=str(array.dtype), shape=shape, buffer=bytes(array)) -def xarray_to_pb_tensor(spec_id: str, array: xr.DataArray) -> inference_pb2.Tensor: +def xarray_to_pb_tensor(tensor_id: str, array: xr.DataArray) -> inference_pb2.Tensor: shape = [inference_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, array.dims)] - return inference_pb2.Tensor(specId=spec_id, dtype=str(array.dtype), shape=shape, buffer=bytes(array.data)) + return inference_pb2.Tensor(tensorId=tensor_id, dtype=str(array.dtype), shape=shape, buffer=bytes(array.data)) def name_int_tuples_to_pb_NamedInts(name_int_tuples) -> inference_pb2.NamedInts: @@ -112,7 +112,7 @@ def output_shape_to_pb_output_shape( raise TypeError(f"Conversion not supported for type {type(output_shape)}") -def pb_tensor_to_tensor(tensor: inference_pb2.Tensor) -> inference_pb2.Tensor: +def pb_tensor_to_xarray(tensor: inference_pb2.Tensor) -> inference_pb2.Tensor: if not tensor.dtype: raise ValueError("Tensor dtype is not specified") @@ -121,7 +121,7 @@ def pb_tensor_to_tensor(tensor: inference_pb2.Tensor) -> inference_pb2.Tensor: data = np.frombuffer(tensor.buffer, dtype=tensor.dtype).reshape(*[dim.size for dim in tensor.shape]) - return Tensor(spec_id=tensor.specId, data=xr.DataArray(data, dims=[d.name for d in tensor.shape])) + return xr.DataArray(data, dims=[d.name for d in tensor.shape]) def pb_tensor_to_numpy(tensor: inference_pb2.Tensor) -> np.ndarray: diff --git a/tiktorch/proto/inference_pb2.py b/tiktorch/proto/inference_pb2.py index 4f682116..dc5c6c9c 100644 --- a/tiktorch/proto/inference_pb2.py +++ b/tiktorch/proto/inference_pb2.py @@ -20,7 +20,7 @@ syntax='proto3', serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_pb=b'\n\x0finference.proto\"Y\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06status\x18\x02 \x01(\x0e\x32\x0e.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"W\n\x1f\x43reateDatasetDescriptionRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"i\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12\x1b\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\x9d\x01\n\nInputShape\x12(\n\tshapeType\x18\x01 \x01(\x0e\x32\x15.InputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x1d\n\tstepShape\x18\x04 \x01(\x0b\x32\n.NamedInts\"+\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x10\n\x0cPARAMETRIZED\x10\x01\"\xea\x01\n\x0bOutputShape\x12)\n\tshapeType\x18\x01 \x01(\x0e\x32\x16.OutputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x18\n\x04halo\x18\x03 \x01(\x0b\x32\n.NamedInts\x12\x17\n\x0freferenceTensor\x18\x04 \x01(\t\x12\x1b\n\x05scale\x18\x05 \x01(\x0b\x32\x0c.NamedFloats\x12\x1c\n\x06offset\x18\x06 \x01(\x0b\x32\x0c.NamedFloats\"\'\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x0c\n\x08IMPLICIT\x10\x01\"\x1a\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\"\x9e\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12\x1e\n\x05level\x18\x02 \x01(\x0e\x32\x0f.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\"#\n\x07\x44\x65vices\x12\x18\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x07.Device\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"Q\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x0e\n\x06specId\x18\x03 \x01(\t\x12\x18\n\x05shape\x18\x04 \x03(\x0b\x32\t.NamedInt\"U\n\x0ePredictRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x11\n\tdatasetId\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x03 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"\x07\n\x05\x45mpty\"\x1e\n\tModelInfo\x12\x11\n\tdeviceIds\x18\x01 \x03(\t\"^\n CreateModelSessionChunkedRequest\x12\x1a\n\x04info\x18\x01 \x01(\x0b\x32\n.ModelInfoH\x00\x12\x16\n\x05\x63hunk\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x42\x06\n\x04\x64\x61ta2\xc6\x02\n\tInference\x12\x41\n\x12\x43reateModelSession\x12\x1a.CreateModelSessionRequest\x1a\r.ModelSession\"\x00\x12,\n\x11\x43loseModelSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12S\n\x18\x43reateDatasetDescription\x12 .CreateDatasetDescriptionRequest\x1a\x13.DatasetDescription\"\x00\x12 \n\x07GetLogs\x12\x06.Empty\x1a\t.LogEntry\"\x00\x30\x01\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x32G\n\rFlightControl\x12\x18\n\x04Ping\x12\x06.Empty\x1a\x06.Empty\"\x00\x12\x1c\n\x08Shutdown\x12\x06.Empty\x1a\x06.Empty\"\x00\x62\x06proto3' + serialized_pb=b'\n\x0finference.proto\"Y\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06status\x18\x02 \x01(\x0e\x32\x0e.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"W\n\x1f\x43reateDatasetDescriptionRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"i\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12\x1b\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\x9d\x01\n\nInputShape\x12(\n\tshapeType\x18\x01 \x01(\x0e\x32\x15.InputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x1d\n\tstepShape\x18\x04 \x01(\x0b\x32\n.NamedInts\"+\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x10\n\x0cPARAMETRIZED\x10\x01\"\xea\x01\n\x0bOutputShape\x12)\n\tshapeType\x18\x01 \x01(\x0e\x32\x16.OutputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x18\n\x04halo\x18\x03 \x01(\x0b\x32\n.NamedInts\x12\x17\n\x0freferenceTensor\x18\x04 \x01(\t\x12\x1b\n\x05scale\x18\x05 \x01(\x0b\x32\x0c.NamedFloats\x12\x1c\n\x06offset\x18\x06 \x01(\x0b\x32\x0c.NamedFloats\"\'\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x0c\n\x08IMPLICIT\x10\x01\"\x1a\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\"\x9e\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12\x1e\n\x05level\x18\x02 \x01(\x0e\x32\x0f.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\"#\n\x07\x44\x65vices\x12\x18\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x07.Device\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"S\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x10\n\x08tensorId\x18\x03 \x01(\t\x12\x18\n\x05shape\x18\x04 \x03(\x0b\x32\t.NamedInt\"U\n\x0ePredictRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x11\n\tdatasetId\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x03 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"\x07\n\x05\x45mpty\"\x1e\n\tModelInfo\x12\x11\n\tdeviceIds\x18\x01 \x03(\t\"^\n CreateModelSessionChunkedRequest\x12\x1a\n\x04info\x18\x01 \x01(\x0b\x32\n.ModelInfoH\x00\x12\x16\n\x05\x63hunk\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x42\x06\n\x04\x64\x61ta2\xc6\x02\n\tInference\x12\x41\n\x12\x43reateModelSession\x12\x1a.CreateModelSessionRequest\x1a\r.ModelSession\"\x00\x12,\n\x11\x43loseModelSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12S\n\x18\x43reateDatasetDescription\x12 .CreateDatasetDescriptionRequest\x1a\x13.DatasetDescription\"\x00\x12 \n\x07GetLogs\x12\x06.Empty\x1a\t.LogEntry\"\x00\x30\x01\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x32G\n\rFlightControl\x12\x18\n\x04Ping\x12\x06.Empty\x1a\x06.Empty\"\x00\x12\x1c\n\x08Shutdown\x12\x06.Empty\x1a\x06.Empty\"\x00\x62\x06proto3' ) @@ -745,7 +745,7 @@ is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), _descriptor.FieldDescriptor( - name='specId', full_name='Tensor.specId', index=2, + name='tensorId', full_name='Tensor.tensorId', index=2, number=3, type=9, cpp_type=9, label=1, has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, @@ -771,7 +771,7 @@ oneofs=[ ], serialized_start=1178, - serialized_end=1259, + serialized_end=1261, ) @@ -816,8 +816,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1261, - serialized_end=1346, + serialized_start=1263, + serialized_end=1348, ) @@ -848,8 +848,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1348, - serialized_end=1391, + serialized_start=1350, + serialized_end=1393, ) @@ -873,8 +873,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1393, - serialized_end=1400, + serialized_start=1395, + serialized_end=1402, ) @@ -905,8 +905,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1402, - serialized_end=1432, + serialized_start=1404, + serialized_end=1434, ) @@ -949,8 +949,8 @@ create_key=_descriptor._internal_create_key, fields=[]), ], - serialized_start=1434, - serialized_end=1528, + serialized_start=1436, + serialized_end=1530, ) _DEVICE.fields_by_name['status'].enum_type = _DEVICE_STATUS @@ -1159,8 +1159,8 @@ index=0, serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_start=1531, - serialized_end=1857, + serialized_start=1533, + serialized_end=1859, methods=[ _descriptor.MethodDescriptor( name='CreateModelSession', @@ -1235,8 +1235,8 @@ index=1, serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_start=1859, - serialized_end=1930, + serialized_start=1861, + serialized_end=1932, methods=[ _descriptor.MethodDescriptor( name='Ping', diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py index 5e61b46b..9cada0e9 100644 --- a/tiktorch/server/grpc/inference_servicer.py +++ b/tiktorch/server/grpc/inference_servicer.py @@ -2,11 +2,11 @@ import grpc -from tiktorch import converters +from tiktorch.converters import Sample from tiktorch.proto import inference_pb2, inference_pb2_grpc from tiktorch.server.data_store import IDataStore from tiktorch.server.device_pool import DeviceStatus, IDevicePool -from tiktorch.server.session.process import ShapeValidator, start_model_session_process +from tiktorch.server.session.process import InputTensorValidator, start_model_session_process from tiktorch.server.session_manager import ISession, SessionManager @@ -75,16 +75,13 @@ def ListDevices(self, request: inference_pb2.Empty, context) -> inference_pb2.De def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse: session = self._getModelSession(context, request.modelSessionId) - tensors = set([converters.pb_tensor_to_tensor(tensor) for tensor in request.tensors]) - shape_validator = ShapeValidator(session.client.model) - shape_validator.check_tensors(tensors) - res = session.client.forward(tensors) - output_spec_ids = [spec.name for spec in session.client.model.output_specs] - assert len(output_spec_ids) == len(res) - pb_tensors = [ - converters.xarray_to_pb_tensor(spec_id, res_tensor) for spec_id, res_tensor in zip(output_spec_ids, res) - ] - return inference_pb2.PredictResponse(tensors=pb_tensors) + input_sample = Sample.from_pb_tensors(request.tensors) + tensor_validator = InputTensorValidator(session.client.model) + tensor_validator.check_tensors(input_sample) + res = session.client.forward(input_sample) + output_tensor_ids = [tensor.name for tensor in session.client.model.output_specs] + output_sample = Sample.from_raw_data(output_tensor_ids, res) + return inference_pb2.PredictResponse(tensors=output_sample.to_pb_tensors()) def _getModelSession(self, context, modelSessionId: str) -> ISession: if not modelSessionId: diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index 72e08b14..91272f83 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -4,7 +4,7 @@ import uuid from concurrent.futures import Future from multiprocessing.connection import Connection -from typing import Dict, Iterator, List, Optional, Set, Tuple +from typing import Dict, Iterator, List, Optional, Tuple import numpy as np from bioimageio.core import load_resource_description @@ -17,22 +17,25 @@ from tiktorch.rpc import mp as _mp_rpc from tiktorch.rpc.mp import MPServer -from ...converters import Tensor +from ...converters import Sample from .backend import base from .rpc_interface import IRPCModelSession -class ShapeValidator: +class InputTensorValidator: def __init__(self, model: PredictionPipeline): self._model = model - def check_tensors(self, tensors: Set[Tensor]): - for tensor in tensors: - self.check_shape(tensor.spec_id, tensor.data.dims, tensor.data.shape) + def check_tensors(self, sample: Sample): + for tensor_id, tensor in sample.tensors.items(): + self.check_shape(tensor_id, tensor.dims, tensor.shape) - def check_shape(self, spec_id: str, axes: Tuple[str, ...], shape: Tuple[int, ...]): + def _get_input_tensors_with_names(self) -> Dict[str, nodes.InputTensor]: + return {tensor.name: tensor for tensor in self._model.input_specs} + + def check_shape(self, tensor_id: str, axes: Tuple[str, ...], shape: Tuple[int, ...]): shape = self._get_axes_with_size(axes, shape) - spec = self._get_input_spec(spec_id) + spec = self._get_input_spec(tensor_id) if isinstance(spec.shape, list): self._check_shape_explicit(spec, shape) elif isinstance(spec.shape, ParametrizedInputShape): @@ -40,20 +43,21 @@ def check_shape(self, spec_id: str, axes: Tuple[str, ...], shape: Tuple[int, ... else: raise ValueError(f"Unexpected shape {spec.shape}") - def _get_input_spec(self, spec_id: str) -> nodes.InputTensor: - self._check_spec_exists(spec_id) - specs = [spec for spec in self._model.input_specs if spec.name == spec_id] + def _get_input_spec(self, tensor_id: str) -> nodes.InputTensor: + self._check_spec_exists(tensor_id) + specs = [spec for spec in self._model.input_specs if spec.name == tensor_id] assert len(specs) == 1, "ids of tensor specs should be unique" return specs[0] - def _check_spec_exists(self, spec_id: str): + def _check_spec_exists(self, tensor_id: str): spec_names = [spec.name for spec in self._model.input_specs] - if spec_id not in spec_names: - raise ValueError(f"Spec {spec_id} doesn't exist for specs {spec_names}") + if tensor_id not in spec_names: + raise ValueError(f"Spec {tensor_id} doesn't exist for specs {spec_names}") def _check_shape_explicit(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]): assert self._is_shape_explicit(spec) reference_shape = {name: size for name, size in zip(spec.axes, spec.shape)} + self._check_same_axes(reference_shape, tensor_shape) if reference_shape != tensor_shape: raise ValueError(f"Incompatible shapes found {tensor_shape}, expected {reference_shape}") @@ -64,9 +68,7 @@ def _check_shape_parameterized(self, spec: nodes.InputTensor, tensor_shape: Dict min_shape = self._get_axes_with_size(spec.axes, tuple(spec.shape.min)) step = self._get_axes_with_size(spec.axes, tuple(spec.shape.step)) - assert min_shape.keys() == step.keys() - if tensor_shape.keys() != min_shape.keys(): - raise ValueError(f"Incompatible axes for tensor {tensor_shape} and spec {spec}") + self._check_same_axes(tensor_shape, min_shape) tensor_shapes_arr = np.array(list(tensor_shape.values())) min_shape_arr = np.array(list(min_shape.values())) @@ -82,8 +84,12 @@ def _check_shape_parameterized(self, spec: nodes.InputTensor, tensor_shape: Dict return raise ValueError(f"Tensor shape {tensor_shape} not valid for spec {spec}") + def _check_same_axes(self, source: Dict[str, int], target: Dict[str, int]): + if source.keys() != target.keys(): + raise ValueError(f"Incompatible axes for tensor {target} and reference {source}") + def _is_natural_number(self, n) -> bool: - return np.floor(n) == np.ceil(n) and n >= 0 + return n % 1 == 0.0 and n >= 0 def _is_shape(self, shape: Iterator[int]) -> bool: return all(self._is_natural_number(dim) for dim in shape) @@ -101,11 +107,10 @@ def __init__(self, model: PredictionPipeline) -> None: super().__init__(model) self._datasets = {} self._worker = base.SessionBackend(self._model) - self._shape_validator = ShapeValidator(self._model) + self._shape_validator = InputTensorValidator(self._model) - def forward(self, input_tensors: Set[Tensor]) -> Future: - self._shape_validator.check_tensors(input_tensors) - tensors_data = [tensor.data for tensor in input_tensors] + def forward(self, sample: Sample) -> Future: + tensors_data = [sample.tensors[tensor.name] for tensor in self.model.input_specs] res = self._worker.forward(tensors_data) return res diff --git a/tiktorch/server/session/rpc_interface.py b/tiktorch/server/session/rpc_interface.py index b75a9a81..bb414138 100644 --- a/tiktorch/server/session/rpc_interface.py +++ b/tiktorch/server/session/rpc_interface.py @@ -1,6 +1,6 @@ -from typing import Generic, List, Set, TypeVar +from typing import Generic, List, TypeVar -from tiktorch.converters import Tensor +from tiktorch.converters import Sample from tiktorch.rpc import RPCInterface, Shutdown, exposed from tiktorch.tiktypes import TikTensorBatch from tiktorch.types import ModelState @@ -54,5 +54,5 @@ def create_dataset_description(self, mean, stddev) -> str: raise NotImplementedError @exposed - def forward(self, input_tensors: Set[Tensor]): + def forward(self, input_tensors: Sample): raise NotImplementedError From 1902067f7ffa0b9e142e87f9e8d10089f2845665 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Thu, 15 Aug 2024 10:16:40 +0200 Subject: [PATCH 6/8] Decouple client and server for model session processes --- tests/test_rpc/test_mp.py | 10 ++-- tiktorch/rpc/mp.py | 24 ++++++---- tiktorch/server/grpc/inference_servicer.py | 10 ++-- tiktorch/server/session/process.py | 54 ++++++++++++---------- tiktorch/server/session_manager.py | 10 ++-- 5 files changed, 60 insertions(+), 48 deletions(-) diff --git a/tests/test_rpc/test_mp.py b/tests/test_rpc/test_mp.py index ec539cac..2011c54e 100644 --- a/tests/test_rpc/test_mp.py +++ b/tests/test_rpc/test_mp.py @@ -8,7 +8,7 @@ from tiktorch import log from tiktorch.rpc import RPCFuture, RPCInterface, Shutdown, exposed -from tiktorch.rpc.mp import FutureStore, MPServer, create_client +from tiktorch.rpc.mp import FutureStore, MPServer, create_client_api class ITestApi(RPCInterface): @@ -64,7 +64,7 @@ def client(log_queue): p = mp.Process(target=_srv, args=(parent, log_queue)) p.start() - client = create_client(iface_cls=ITestApi, conn=child, timeout=10) + client = create_client_api(iface_cls=ITestApi, conn=child, timeout=10) yield client @@ -108,7 +108,7 @@ def __getattr__(self, name): p = mp.Process(target=_srv, args=(parent, log_queue)) p.start() - client = create_client(iface_cls=ITestApi, conn=SlowConn(child)) + client = create_client_api(iface_cls=ITestApi, conn=SlowConn(child)) client.fast_compute(2, 2) @@ -121,7 +121,7 @@ def test_future_timeout(client: ITestApi, log_queue): p = mp.Process(target=_srv, args=(parent, log_queue)) p.start() - client = create_client(iface_cls=ITestApi, conn=child, timeout=0.001) + client = create_client_api(iface_cls=ITestApi, conn=child, timeout=0.001) with pytest.raises(TimeoutError): client.compute(1, 2) @@ -256,7 +256,7 @@ def _spawn(iface_cls, srv_cls): p = mp.Process(target=_run_srv, args=(srv_cls, parent, log_queue)) p.start() - data["client"] = client = create_client(iface_cls=iface_cls, conn=child) + data["client"] = client = create_client_api(iface_cls=iface_cls, conn=child) data["process"] = p return client diff --git a/tiktorch/rpc/mp.py b/tiktorch/rpc/mp.py index a51cf924..94f3a638 100644 --- a/tiktorch/rpc/mp.py +++ b/tiktorch/rpc/mp.py @@ -1,3 +1,4 @@ +import dataclasses import logging import queue import threading @@ -5,9 +6,11 @@ from functools import wraps from multiprocessing.connection import Connection from threading import Event, Thread -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Any, Generic, List, Optional, Type, TypeVar from uuid import uuid4 +from bioimageio.core.resource_io import nodes + from .exceptions import Shutdown from .interface import get_exposed_methods from .types import RPCFuture, isfutureret @@ -72,7 +75,7 @@ def __call__(self, *args, **kwargs) -> Any: return self._client._invoke(self._method_name, *args, **kwargs) -def create_client(iface_cls: Type[T], conn: Connection, api_kwargs: Optional[Dict[str, any]] = None, timeout=None) -> T: +def create_client_api(iface_cls: Type[T], conn: Connection, timeout=None) -> T: client = MPClient(iface_cls.__name__, conn, timeout) def _make_method(method): @@ -96,16 +99,21 @@ def __call__(self, *args, **kwargs) -> Any: return MethodWrapper() - class _Client(iface_cls): - def __init__(self, kwargs: Optional[Dict]): - kwargs = kwargs or {} - super().__init__(**kwargs) + class _Api: + pass exposed_methods = get_exposed_methods(iface_cls) for method_name, method in exposed_methods.items(): - setattr(_Client, method_name, _make_method(method)) + setattr(_Api, method_name, _make_method(method)) + + return _Api() + - return _Client(api_kwargs) +@dataclasses.dataclass(frozen=True) +class Client(Generic[T]): + api: T + input_specs: List[nodes.InputTensor] + output_specs: List[nodes.OutputTensor] class MPClient: diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py index 9cada0e9..2ac2b632 100644 --- a/tiktorch/server/grpc/inference_servicer.py +++ b/tiktorch/server/grpc/inference_servicer.py @@ -38,7 +38,7 @@ def CreateModelSession( session = self.__session_manager.create_session(client) session.on_close(lease.terminate) - session.on_close(client.shutdown) + session.on_close(client.api.shutdown) return inference_pb2.ModelSession(id=session.id) @@ -46,7 +46,7 @@ def CreateDatasetDescription( self, request: inference_pb2.CreateDatasetDescriptionRequest, context ) -> inference_pb2.DatasetDescription: session = self._getModelSession(context, request.modelSessionId) - id = session.client.create_dataset_description(mean=request.mean, stddev=request.stddev) + id = session.client.api.create_dataset_description(mean=request.mean, stddev=request.stddev) return inference_pb2.DatasetDescription(id=id) def CloseModelSession(self, request: inference_pb2.ModelSession, context) -> inference_pb2.Empty: @@ -76,10 +76,10 @@ def ListDevices(self, request: inference_pb2.Empty, context) -> inference_pb2.De def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse: session = self._getModelSession(context, request.modelSessionId) input_sample = Sample.from_pb_tensors(request.tensors) - tensor_validator = InputTensorValidator(session.client.model) + tensor_validator = InputTensorValidator(session.client.input_specs) tensor_validator.check_tensors(input_sample) - res = session.client.forward(input_sample) - output_tensor_ids = [tensor.name for tensor in session.client.model.output_specs] + res = session.client.api.forward(input_sample) + output_tensor_ids = [tensor.name for tensor in session.client.output_specs] output_sample = Sample.from_raw_data(output_tensor_ids, res) return inference_pb2.PredictResponse(tensors=output_sample.to_pb_tensors()) diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index 91272f83..ef17f69e 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -15,7 +15,7 @@ from tiktorch import log from tiktorch.rpc import Shutdown from tiktorch.rpc import mp as _mp_rpc -from tiktorch.rpc.mp import MPServer +from tiktorch.rpc.mp import Client, MPServer from ...converters import Sample from .backend import base @@ -23,18 +23,18 @@ class InputTensorValidator: - def __init__(self, model: PredictionPipeline): - self._model = model + def __init__(self, input_specs: List[nodes.InputTensor]): + self._input_specs = input_specs def check_tensors(self, sample: Sample): for tensor_id, tensor in sample.tensors.items(): self.check_shape(tensor_id, tensor.dims, tensor.shape) def _get_input_tensors_with_names(self) -> Dict[str, nodes.InputTensor]: - return {tensor.name: tensor for tensor in self._model.input_specs} + return {tensor.name: tensor for tensor in self._input_specs} def check_shape(self, tensor_id: str, axes: Tuple[str, ...], shape: Tuple[int, ...]): - shape = self._get_axes_with_size(axes, shape) + shape = self.get_axes_with_size(axes, shape) spec = self._get_input_spec(tensor_id) if isinstance(spec.shape, list): self._check_shape_explicit(spec, shape) @@ -45,30 +45,30 @@ def check_shape(self, tensor_id: str, axes: Tuple[str, ...], shape: Tuple[int, . def _get_input_spec(self, tensor_id: str) -> nodes.InputTensor: self._check_spec_exists(tensor_id) - specs = [spec for spec in self._model.input_specs if spec.name == tensor_id] + specs = [spec for spec in self._input_specs if spec.name == tensor_id] assert len(specs) == 1, "ids of tensor specs should be unique" return specs[0] def _check_spec_exists(self, tensor_id: str): - spec_names = [spec.name for spec in self._model.input_specs] + spec_names = [spec.name for spec in self._input_specs] if tensor_id not in spec_names: raise ValueError(f"Spec {tensor_id} doesn't exist for specs {spec_names}") def _check_shape_explicit(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]): - assert self._is_shape_explicit(spec) + assert self.is_shape_explicit(spec) reference_shape = {name: size for name, size in zip(spec.axes, spec.shape)} - self._check_same_axes(reference_shape, tensor_shape) + self.check_same_axes(reference_shape, tensor_shape) if reference_shape != tensor_shape: raise ValueError(f"Incompatible shapes found {tensor_shape}, expected {reference_shape}") def _check_shape_parameterized(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]): assert isinstance(spec.shape, ParametrizedInputShape) - if not self._is_shape(tensor_shape.values()): + if not self.is_shape(tensor_shape.values()): raise ValueError(f"Invalid shape's sizes {tensor_shape}") - min_shape = self._get_axes_with_size(spec.axes, tuple(spec.shape.min)) - step = self._get_axes_with_size(spec.axes, tuple(spec.shape.step)) - self._check_same_axes(tensor_shape, min_shape) + min_shape = self.get_axes_with_size(spec.axes, tuple(spec.shape.min)) + step = self.get_axes_with_size(spec.axes, tuple(spec.shape.step)) + self.check_same_axes(tensor_shape, min_shape) tensor_shapes_arr = np.array(list(tensor_shape.values())) min_shape_arr = np.array(list(min_shape.values())) @@ -80,25 +80,30 @@ def _check_shape_parameterized(self, spec: nodes.InputTensor, tensor_shape: Dict non_zero_idx = np.nonzero(step_arr) multipliers = diff[non_zero_idx] / step_arr[non_zero_idx] multiplier = np.unique(multipliers) - if len(multiplier) == 1 and self._is_natural_number(multiplier[0]): + if len(multiplier) == 1 and self.is_natural_number(multiplier[0]): return raise ValueError(f"Tensor shape {tensor_shape} not valid for spec {spec}") - def _check_same_axes(self, source: Dict[str, int], target: Dict[str, int]): + @staticmethod + def check_same_axes(source: Dict[str, int], target: Dict[str, int]): if source.keys() != target.keys(): raise ValueError(f"Incompatible axes for tensor {target} and reference {source}") - def _is_natural_number(self, n) -> bool: + @staticmethod + def is_natural_number(n) -> bool: return n % 1 == 0.0 and n >= 0 - def _is_shape(self, shape: Iterator[int]) -> bool: - return all(self._is_natural_number(dim) for dim in shape) + @staticmethod + def is_shape(shape: Iterator[int]) -> bool: + return all(InputTensorValidator.is_natural_number(dim) for dim in shape) - def _get_axes_with_size(self, axes: Tuple[str, ...], shape: Tuple[int, ...]) -> Dict[str, int]: + @staticmethod + def get_axes_with_size(axes: Tuple[str, ...], shape: Tuple[int, ...]) -> Dict[str, int]: assert len(axes) == len(shape) return {name: size for name, size in zip(axes, shape)} - def _is_shape_explicit(self, spec: nodes.InputTensor) -> bool: + @staticmethod + def is_shape_explicit(spec: nodes.InputTensor) -> bool: return isinstance(spec.shape, list) @@ -107,7 +112,6 @@ def __init__(self, model: PredictionPipeline) -> None: super().__init__(model) self._datasets = {} self._worker = base.SessionBackend(self._model) - self._shape_validator = InputTensorValidator(self._model) def forward(self, sample: Sample) -> Future: tensors_data = [sample.tensors[tensor.name] for tensor in self.model.input_specs] @@ -146,7 +150,7 @@ def _run_model_session_process( def start_model_session_process( model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None -) -> Tuple[_mp.Process, IRPCModelSession]: +) -> Tuple[_mp.Process, Client[IRPCModelSession]]: client_conn, server_conn = _mp.Pipe() prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_zip, devices) proc = _mp.Process( @@ -159,9 +163,9 @@ def start_model_session_process( }, ) proc.start() - # here create the prediction pipeline, share it to the model session class and the client - return proc, _mp_rpc.create_client( - iface_cls=IRPCModelSession, api_kwargs={"model": prediction_pipeline}, conn=client_conn + api = _mp_rpc.create_client_api(iface_cls=IRPCModelSession, conn=client_conn) + return proc, Client( + input_specs=prediction_pipeline.input_specs, output_specs=prediction_pipeline.output_specs, api=api ) diff --git a/tiktorch/server/session_manager.py b/tiktorch/server/session_manager.py index 86a8a1ea..b12b2ad0 100644 --- a/tiktorch/server/session_manager.py +++ b/tiktorch/server/session_manager.py @@ -7,7 +7,7 @@ from typing import Callable, Dict, List, Optional from uuid import uuid4 -from tiktorch.server.session import IRPCModelSession +from tiktorch.rpc.mp import Client logger = getLogger(__name__) @@ -28,7 +28,7 @@ def id(self) -> str: @property @abc.abstractmethod - def client(self) -> IRPCModelSession: + def client(self) -> Client: ... @abc.abstractmethod @@ -43,13 +43,13 @@ def on_close(self, handler: CloseCallback) -> None: class _Session(ISession): - def __init__(self, id_: str, client: IRPCModelSession, manager: SessionManager) -> None: + def __init__(self, id_: str, client: Client, manager: SessionManager) -> None: self.__id = id_ self.__manager = manager self.__client = client @property - def client(self) -> IRPCModelSession: + def client(self) -> Client: return self.__client @property @@ -65,7 +65,7 @@ class SessionManager: Manages session lifecycle (create/close) """ - def create_session(self, client: IRPCModelSession) -> ISession: + def create_session(self, client: Client) -> ISession: """ Creates new session with unique id """ From fae27411ff2e3416be4f9af9953ff2b7978ffb60 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Thu, 15 Aug 2024 11:48:51 +0200 Subject: [PATCH 7/8] Rename method `from_raw_data` to `from_xr_tensors` --- tests/test_converters.py | 4 ++-- tiktorch/converters.py | 2 +- tiktorch/server/grpc/inference_servicer.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_converters.py b/tests/test_converters.py index 8d6e4c09..be268e42 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -299,7 +299,7 @@ def test_create_sample_from_raw_data(self): arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64) tensor_2 = xr.DataArray(arr_2, dims=["x", "y"]) tensors_ids = ["input1", "input2"] - actual_sample = Sample.from_raw_data(tensors_ids, [tensor_1, tensor_2]) + actual_sample = Sample.from_xr_tensors(tensors_ids, [tensor_1, tensor_2]) expected_dict = {tensors_ids[0]: tensor_1, tensors_ids[1]: tensor_2} expected_sample = Sample(expected_dict) @@ -311,7 +311,7 @@ def test_sample_to_pb_tensors(self): arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64) tensor_2 = xr.DataArray(arr_2, dims=["x", "y"]) tensors_ids = ["input1", "input2"] - sample = Sample.from_raw_data(tensors_ids, [tensor_1, tensor_2]) + sample = Sample.from_xr_tensors(tensors_ids, [tensor_1, tensor_2]) pb_tensor_1 = inference_pb2.Tensor( dtype="int64", diff --git a/tiktorch/converters.py b/tiktorch/converters.py index 1cc488da..80ac8af4 100644 --- a/tiktorch/converters.py +++ b/tiktorch/converters.py @@ -44,7 +44,7 @@ def from_pb_tensors(cls, pb_tensors: List[inference_pb2.Tensor]) -> Sample: return Sample({tensor.tensorId: pb_tensor_to_xarray(tensor) for tensor in pb_tensors}) @classmethod - def from_raw_data(cls, tensor_ids: List[str], tensors_data: List[xr.DataArray]): + def from_xr_tensors(cls, tensor_ids: List[str], tensors_data: List[xr.DataArray]): assert len(tensor_ids) == len(tensors_data) return Sample({tensor_id: tensor_data for tensor_id, tensor_data in zip(tensor_ids, tensors_data)}) diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py index 2ac2b632..4034098a 100644 --- a/tiktorch/server/grpc/inference_servicer.py +++ b/tiktorch/server/grpc/inference_servicer.py @@ -80,7 +80,7 @@ def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_p tensor_validator.check_tensors(input_sample) res = session.client.api.forward(input_sample) output_tensor_ids = [tensor.name for tensor in session.client.output_specs] - output_sample = Sample.from_raw_data(output_tensor_ids, res) + output_sample = Sample.from_xr_tensors(output_tensor_ids, res) return inference_pb2.PredictResponse(tensors=output_sample.to_pb_tensors()) def _getModelSession(self, context, modelSessionId: str) -> ISession: From da79ae011ec206702a99d8d36e3de4fad271685c Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Thu, 15 Aug 2024 14:52:39 +0200 Subject: [PATCH 8/8] Simplify session manager to work only with sessions associated with a bio model client --- tiktorch/converters.py | 2 +- tiktorch/rpc/mp.py | 7 +-- tiktorch/server/grpc/inference_servicer.py | 12 ++--- tiktorch/server/session/process.py | 6 +-- tiktorch/server/session_manager.py | 56 ++++++++-------------- 5 files changed, 33 insertions(+), 50 deletions(-) diff --git a/tiktorch/converters.py b/tiktorch/converters.py index 80ac8af4..b7fee3fe 100644 --- a/tiktorch/converters.py +++ b/tiktorch/converters.py @@ -44,7 +44,7 @@ def from_pb_tensors(cls, pb_tensors: List[inference_pb2.Tensor]) -> Sample: return Sample({tensor.tensorId: pb_tensor_to_xarray(tensor) for tensor in pb_tensors}) @classmethod - def from_xr_tensors(cls, tensor_ids: List[str], tensors_data: List[xr.DataArray]): + def from_xr_tensors(cls, tensor_ids: List[str], tensors_data: List[xr.DataArray]) -> Sample: assert len(tensor_ids) == len(tensors_data) return Sample({tensor_id: tensor_data for tensor_id, tensor_data in zip(tensor_ids, tensors_data)}) diff --git a/tiktorch/rpc/mp.py b/tiktorch/rpc/mp.py index 94f3a638..0cd69bf4 100644 --- a/tiktorch/rpc/mp.py +++ b/tiktorch/rpc/mp.py @@ -6,11 +6,12 @@ from functools import wraps from multiprocessing.connection import Connection from threading import Event, Thread -from typing import Any, Generic, List, Optional, Type, TypeVar +from typing import Any, List, Optional, Type, TypeVar from uuid import uuid4 from bioimageio.core.resource_io import nodes +from ..server.session import IRPCModelSession from .exceptions import Shutdown from .interface import get_exposed_methods from .types import RPCFuture, isfutureret @@ -110,8 +111,8 @@ class _Api: @dataclasses.dataclass(frozen=True) -class Client(Generic[T]): - api: T +class BioModelClient: + api: IRPCModelSession input_specs: List[nodes.InputTensor] output_specs: List[nodes.OutputTensor] diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py index 4034098a..f09e0bae 100644 --- a/tiktorch/server/grpc/inference_servicer.py +++ b/tiktorch/server/grpc/inference_servicer.py @@ -7,7 +7,7 @@ from tiktorch.server.data_store import IDataStore from tiktorch.server.device_pool import DeviceStatus, IDevicePool from tiktorch.server.session.process import InputTensorValidator, start_model_session_process -from tiktorch.server.session_manager import ISession, SessionManager +from tiktorch.server.session_manager import Session, SessionManager class InferenceServicer(inference_pb2_grpc.InferenceServicer): @@ -46,7 +46,7 @@ def CreateDatasetDescription( self, request: inference_pb2.CreateDatasetDescriptionRequest, context ) -> inference_pb2.DatasetDescription: session = self._getModelSession(context, request.modelSessionId) - id = session.client.api.create_dataset_description(mean=request.mean, stddev=request.stddev) + id = session.bio_model_client.api.create_dataset_description(mean=request.mean, stddev=request.stddev) return inference_pb2.DatasetDescription(id=id) def CloseModelSession(self, request: inference_pb2.ModelSession, context) -> inference_pb2.Empty: @@ -76,14 +76,14 @@ def ListDevices(self, request: inference_pb2.Empty, context) -> inference_pb2.De def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse: session = self._getModelSession(context, request.modelSessionId) input_sample = Sample.from_pb_tensors(request.tensors) - tensor_validator = InputTensorValidator(session.client.input_specs) + tensor_validator = InputTensorValidator(session.bio_model_client.input_specs) tensor_validator.check_tensors(input_sample) - res = session.client.api.forward(input_sample) - output_tensor_ids = [tensor.name for tensor in session.client.output_specs] + res = session.bio_model_client.api.forward(input_sample) + output_tensor_ids = [tensor.name for tensor in session.bio_model_client.output_specs] output_sample = Sample.from_xr_tensors(output_tensor_ids, res) return inference_pb2.PredictResponse(tensors=output_sample.to_pb_tensors()) - def _getModelSession(self, context, modelSessionId: str) -> ISession: + def _getModelSession(self, context, modelSessionId: str) -> Session: if not modelSessionId: context.abort(grpc.StatusCode.FAILED_PRECONDITION, "model-session-id has not been provided by client") diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index ef17f69e..c9e0186e 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -15,7 +15,7 @@ from tiktorch import log from tiktorch.rpc import Shutdown from tiktorch.rpc import mp as _mp_rpc -from tiktorch.rpc.mp import Client, MPServer +from tiktorch.rpc.mp import BioModelClient, MPServer from ...converters import Sample from .backend import base @@ -150,7 +150,7 @@ def _run_model_session_process( def start_model_session_process( model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None -) -> Tuple[_mp.Process, Client[IRPCModelSession]]: +) -> Tuple[_mp.Process, BioModelClient]: client_conn, server_conn = _mp.Pipe() prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_zip, devices) proc = _mp.Process( @@ -164,7 +164,7 @@ def start_model_session_process( ) proc.start() api = _mp_rpc.create_client_api(iface_cls=IRPCModelSession, conn=client_conn) - return proc, Client( + return proc, BioModelClient( input_specs=prediction_pipeline.input_specs, output_specs=prediction_pipeline.output_specs, api=api ) diff --git a/tiktorch/server/session_manager.py b/tiktorch/server/session_manager.py index b12b2ad0..3807e130 100644 --- a/tiktorch/server/session_manager.py +++ b/tiktorch/server/session_manager.py @@ -1,62 +1,44 @@ from __future__ import annotations -import abc import threading from collections import defaultdict from logging import getLogger from typing import Callable, Dict, List, Optional from uuid import uuid4 -from tiktorch.rpc.mp import Client +from tiktorch.rpc.mp import BioModelClient logger = getLogger(__name__) +CloseCallback = Callable[[], None] + -class ISession(abc.ABC): +class Session: """ session object has unique id Used for resource managent """ - @property - @abc.abstractmethod - def id(self) -> str: - """ - Returns unique id assigned to this session - """ - ... - - @property - @abc.abstractmethod - def client(self) -> Client: - ... - - @abc.abstractmethod - def on_close(self, handler: CloseCallback) -> None: - """ - Register cleanup function to be called when session ends - """ - ... - - -CloseCallback = Callable[[], None] - - -class _Session(ISession): - def __init__(self, id_: str, client: Client, manager: SessionManager) -> None: + def __init__(self, id_: str, bio_model_client: BioModelClient, manager: SessionManager) -> None: self.__id = id_ self.__manager = manager - self.__client = client + self.__bio_model_client = bio_model_client @property - def client(self) -> Client: - return self.__client + def bio_model_client(self) -> BioModelClient: + return self.__bio_model_client @property def id(self) -> str: + """ + Returns unique id assigned to this session + """ return self.__id def on_close(self, handler: CloseCallback) -> None: + """ + Register cleanup function to be called when session ends + """ self.__manager._on_close(self, handler) @@ -65,18 +47,18 @@ class SessionManager: Manages session lifecycle (create/close) """ - def create_session(self, client: Client) -> ISession: + def create_session(self, bio_model_client: BioModelClient) -> Session: """ Creates new session with unique id """ with self.__lock: session_id = uuid4().hex - session = _Session(session_id, client=client, manager=self) + session = Session(session_id, bio_model_client=bio_model_client, manager=self) self.__session_by_id[session_id] = session logger.info("Created session %s", session.id) return session - def get(self, session_id: str) -> Optional[ISession]: + def get(self, session_id: str) -> Optional[Session]: """ Returns existing session with given id if it exists """ @@ -102,10 +84,10 @@ def close_session(self, session_id: str) -> None: def __init__(self) -> None: self.__lock = threading.Lock() - self.__session_by_id: Dict[str, ISession] = {} + self.__session_by_id: Dict[str, Session] = {} self.__close_handlers_by_session_id: Dict[str, List[CloseCallback]] = defaultdict(list) - def _on_close(self, session: ISession, handler: CloseCallback): + def _on_close(self, session: Session, handler: CloseCallback): with self.__lock: logger.debug("Registered close handler %s for session %s", handler, session.id) self.__close_handlers_by_session_id[session.id].append(handler)