diff --git a/proto/inference.proto b/proto/inference.proto
index c6e85629..39845dc5 100644
--- a/proto/inference.proto
+++ b/proto/inference.proto
@@ -86,14 +86,6 @@ message OutputShape {
 
 message ModelSession {
   string id = 1;
-  string name = 2;
-  repeated string inputAxes = 3;
-  repeated string outputAxes = 4;
-  bool hasTraining = 5;
-  repeated InputShape inputShapes = 6;
-  repeated OutputShape outputShapes = 7;
-  repeated string inputNames = 8;
-  repeated string outputNames = 9;
 }
 
 message LogEntry {
@@ -128,7 +120,8 @@ message NamedFloat {
 message Tensor {
   bytes buffer = 1;
   string dtype = 2;
-  repeated NamedInt shape = 3;
+  string tensorId = 3;
+  repeated NamedInt shape = 4;
 }
 
 message PredictRequest {
diff --git a/tests/conftest.py b/tests/conftest.py
index 31929f3e..1118935a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -18,7 +18,9 @@
 TEST_DATA = "data"
 TEST_BIOIMAGEIO_ZIPFOLDER = "unet2d"
 TEST_BIOIMAGEIO_ONNX = "unet2d_onnx"
-TEST_BIOIMAGEIO_DUMMY = "dummy"
+TEST_BIOIMAGEIO_DUMMY_EXPLICIT = "dummy"
+TEST_BIOIMAGEIO_DUMMY_EXPLICIT_RDF = f"{TEST_BIOIMAGEIO_DUMMY_EXPLICIT}/Dummy.model.yaml"
+TEST_BIOIMAGEIO_DUMMY_PARAM_RDF = "dummy_param/Dummy.model_param.yaml"
 TEST_BIOIMAGEIO_TENSORFLOW_DUMMY = "dummy_tensorflow"
 TEST_BIOIMAGEIO_TORCHSCRIPT = "unet2d_torchscript"
 
@@ -98,7 +100,7 @@ def bioimageio_model_zipfile(bioimageio_model_bytes):
 
 @pytest.fixture
 def bioimageio_dummy_model_filepath(data_path, tmpdir):
-    bioimageio_net_dir = Path(data_path) / TEST_BIOIMAGEIO_DUMMY
+    bioimageio_net_dir = Path(data_path) / TEST_BIOIMAGEIO_DUMMY_EXPLICIT
     path = tmpdir / "dummy_model.zip"
 
     with ZipFile(path, mode="w") as zip_model:
@@ -113,8 +115,24 @@ def bioimageio_dummy_model_filepath(data_path, tmpdir):
 
 
 @pytest.fixture
-def bioimageio_dummy_model_bytes(data_path):
-    rdf_source = data_path / TEST_BIOIMAGEIO_DUMMY / "Dummy.model.yaml"
+def bioimageio_dummy_explicit_model_bytes(data_path):
+    rdf_source = data_path / TEST_BIOIMAGEIO_DUMMY_EXPLICIT_RDF
+    return _bioimageio_package(rdf_source)
+
+
+@pytest.fixture
+def bioimageio_dummy_param_model_bytes(data_path):
+    rdf_source = data_path / TEST_BIOIMAGEIO_DUMMY_PARAM_RDF
+    return _bioimageio_package(rdf_source)
+
+
+@pytest.fixture(params=[(TEST_BIOIMAGEIO_DUMMY_PARAM_RDF, "param"), (TEST_BIOIMAGEIO_DUMMY_EXPLICIT_RDF, "input")])
+def bioimageio_dummy_model(request, data_path):
+    path, tensor_id = request.param
+    yield _bioimageio_package(data_path / path), tensor_id
+
+
+def _bioimageio_package(rdf_source):
     data = io.BytesIO()
     export_resource_package(rdf_source, output_path=data)
     return data
diff --git a/tests/data/dummy/Dummy.model.yaml b/tests/data/dummy/Dummy.model.yaml
index a27a1d4d..1aa8fe90 100644
--- a/tests/data/dummy/Dummy.model.yaml
+++ b/tests/data/dummy/Dummy.model.yaml
@@ -41,15 +41,16 @@ inputs:
     data_range: [-inf, inf]
     shape: [1, 1, 128, 128]
 
+
 outputs:
   - name: output
     axes: bcyx
     data_type: float32
     data_range: [0, 1]
     shape:
-        reference_tensor: input   # FIXME(m-novikov) ignoring for now
+        reference_tensor: input
         scale: [1, 1, 1, 1]
         offset: [0, 0, 0, 0]
-    halo: [0, 0, 32, 32]  # Should be moved to outputs
+    halo: [0, 0, 32, 32]
 
 type: model
diff --git a/tests/data/dummy_param/Dummy.model_param.yaml b/tests/data/dummy_param/Dummy.model_param.yaml
new file mode 100644
index 00000000..87ddd885
--- /dev/null
+++ b/tests/data/dummy_param/Dummy.model_param.yaml
@@ -0,0 +1,57 @@
+format_version: 0.3.3
+language: python
+framework: pytorch
+
+name: UNet2DNucleiBroad
+description: A 2d U-Net pretrained on broad nucleus dataset.
+cite:
+    - text: "Ronneberger, Olaf et al. U-net: Convolutional networks for biomedical image segmentation. MICCAI 2015."
+      doi: https://doi.org/10.1007/978-3-319-24574-4_28
+authors:
+  - name: "ilastik-team"
+    affiliation: "EMBL Heidelberg"
+
+documentation: dummy.md
+tags: [pytorch, nucleus-segmentation]
+license: MIT
+git_repo: https://github.com/ilastik/tiktorch
+covers: []
+
+source: dummy.py::Dummy
+sha256: 00ffb1647cf7ec524892206dce6258d9da498fe040c62838f31b501a09bfd573
+timestamp: 2019-12-11T12:22:32Z  # ISO 8601
+
+test_inputs: [dummy_in.npy]
+test_outputs: [dummy_out.npy]
+
+weights:
+  pytorch_state_dict:
+    source: ./weights
+    sha256: 518cb80bad2eb3ec3dfbe6bab74920951391ce8fb24e15cf59b9b9f052a575a6
+    authors:
+     - name: "ilastik-team"
+       affiliation: "EMBL Heidelberg"
+
+
+# TODO double check inputs/outputs
+inputs:
+  - name: param
+    axes: bcyx
+    data_type: float32
+    data_range: [-inf, inf]
+    shape:
+      min: [1, 1, 64, 64]
+      step: [0, 0, 2, 1]
+
+outputs:
+  - name: output
+    axes: bcyx
+    data_type: float32
+    data_range: [0, 1]
+    shape:
+        reference_tensor: param
+        scale: [1, 1, 1, 1]
+        offset: [0, 0, 0, 0]
+    halo: [0, 0, 8, 8]
+
+type: model
diff --git a/tests/data/dummy_param/dummy.md b/tests/data/dummy_param/dummy.md
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/data/dummy_param/dummy.py b/tests/data/dummy_param/dummy.py
new file mode 100644
index 00000000..195e98e3
--- /dev/null
+++ b/tests/data/dummy_param/dummy.py
@@ -0,0 +1,7 @@
+from torch import nn
+
+
+class Dummy(nn.Module):
+    def forward(self, input):
+        x = input
+        return x + 1
diff --git a/tests/data/dummy_param/dummy_in.npy b/tests/data/dummy_param/dummy_in.npy
new file mode 100644
index 00000000..96a78a7b
Binary files /dev/null and b/tests/data/dummy_param/dummy_in.npy differ
diff --git a/tests/data/dummy_param/dummy_out.npy b/tests/data/dummy_param/dummy_out.npy
new file mode 100644
index 00000000..56f76ca7
Binary files /dev/null and b/tests/data/dummy_param/dummy_out.npy differ
diff --git a/tests/data/dummy_param/environment.yaml b/tests/data/dummy_param/environment.yaml
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/data/dummy_param/weights b/tests/data/dummy_param/weights
new file mode 100644
index 00000000..da14f342
Binary files /dev/null and b/tests/data/dummy_param/weights differ
diff --git a/tests/test_converters.py b/tests/test_converters.py
index c775ede3..be268e42 100644
--- a/tests/test_converters.py
+++ b/tests/test_converters.py
@@ -7,6 +7,7 @@
     NamedExplicitOutputShape,
     NamedImplicitOutputShape,
     NamedParametrizedShape,
+    Sample,
     input_shape_to_pb_input_shape,
     numpy_to_pb_tensor,
     output_shape_to_pb_output_shape,
@@ -27,6 +28,16 @@ def _numpy_to_pb_tensor(arr):
     return parsed
 
 
+def to_pb_tensor(tensor_id: str, arr: xr.DataArray):
+    """
+    Makes sure that tensor was serialized/deserialized
+    """
+    tensor = xarray_to_pb_tensor(tensor_id, arr)
+    parsed = inference_pb2.Tensor()
+    parsed.ParseFromString(tensor.SerializeToString())
+    return parsed
+
+
 class TestNumpyToPBTensor:
     def test_should_serialize_to_tensor_type(self):
         arr = np.arange(9)
@@ -97,18 +108,9 @@ def test_should_same_data(self, shape):
 
 
 class TestXarrayToPBTensor:
-    def to_pb_tensor(self, arr):
-        """
-        Makes sure that tensor was serialized/deserialized
-        """
-        tensor = xarray_to_pb_tensor(arr)
-        parsed = inference_pb2.Tensor()
-        parsed.ParseFromString(tensor.SerializeToString())
-        return parsed
-
     def test_should_serialize_to_tensor_type(self):
         xarr = xr.DataArray(np.arange(8).reshape((2, 4)), dims=("x", "y"))
-        pb_tensor = self.to_pb_tensor(xarr)
+        pb_tensor = to_pb_tensor("input0", xarr)
         assert isinstance(pb_tensor, inference_pb2.Tensor)
         assert len(pb_tensor.shape) == 2
         dim1 = pb_tensor.shape[0]
@@ -123,28 +125,19 @@ def test_should_serialize_to_tensor_type(self):
     @pytest.mark.parametrize("shape", [(3, 3), (1,), (1, 1), (18, 20, 1)])
     def test_should_have_shape(self, shape):
         arr = xr.DataArray(np.zeros(shape))
-        tensor = self.to_pb_tensor(arr)
+        tensor = to_pb_tensor("input0", arr)
         assert tensor.shape
         assert list(shape) == [dim.size for dim in tensor.shape]
 
     def test_should_have_serialized_bytes(self):
         arr = xr.DataArray(np.arange(9, dtype=np.uint8))
         expected = bytes(arr.data)
-        tensor = self.to_pb_tensor(arr)
+        tensor = to_pb_tensor("input0", arr)
 
         assert expected == tensor.buffer
 
 
 class TestPBTensorToXarray:
-    def to_pb_tensor(self, arr):
-        """
-        Makes sure that tensor was serialized/deserialized
-        """
-        tensor = xarray_to_pb_tensor(arr)
-        parsed = inference_pb2.Tensor()
-        parsed.ParseFromString(tensor.SerializeToString())
-        return parsed
-
     def test_should_raise_on_empty_dtype(self):
         tensor = inference_pb2.Tensor(dtype="", shape=[inference_pb2.NamedInt(size=1), inference_pb2.NamedInt(size=2)])
         with pytest.raises(ValueError):
@@ -155,33 +148,32 @@ def test_should_raise_on_empty_shape(self):
         with pytest.raises(ValueError):
             pb_tensor_to_xarray(tensor)
 
-    def test_should_return_ndarray(self):
+    def test_should_return_xarray(self):
         arr = xr.DataArray(np.arange(9))
-        parsed = self.to_pb_tensor(arr)
-        result_arr = pb_tensor_to_xarray(parsed)
-
-        assert isinstance(result_arr, xr.DataArray)
+        parsed = to_pb_tensor("input0", arr)
+        result_tensor = pb_tensor_to_xarray(parsed)
+        assert isinstance(result_tensor, xr.DataArray)
 
     @pytest.mark.parametrize("np_dtype,dtype_str", [(np.int64, "int64"), (np.uint8, "uint8"), (np.float32, "float32")])
     def test_should_have_same_dtype(self, np_dtype, dtype_str):
         arr = xr.DataArray(np.arange(9, dtype=np_dtype))
-        tensor = self.to_pb_tensor(arr)
-        result_arr = pb_tensor_to_xarray(tensor)
+        pb_tensor = to_pb_tensor("input0", arr)
+        result_arr = pb_tensor_to_xarray(pb_tensor)
 
         assert arr.dtype == result_arr.dtype
 
     @pytest.mark.parametrize("shape", [(3, 3), (1,), (1, 1), (18, 20, 1)])
     def test_should_same_shape(self, shape):
         arr = xr.DataArray(np.zeros(shape))
-        tensor = self.to_pb_tensor(arr)
-        result_arr = pb_tensor_to_xarray(tensor)
+        pb_tensor = to_pb_tensor("input0", arr)
+        result_arr = pb_tensor_to_xarray(pb_tensor)
         assert arr.shape == result_arr.shape
 
     @pytest.mark.parametrize("shape", [(3, 3), (1,), (1, 1), (18, 20, 1)])
     def test_should_same_data(self, shape):
         arr = xr.DataArray(np.random.random(shape))
-        tensor = self.to_pb_tensor(arr)
-        result_arr = pb_tensor_to_xarray(tensor)
+        pb_tensor = to_pb_tensor("input0", arr)
+        result_arr = pb_tensor_to_xarray(pb_tensor)
         assert_array_equal(arr, result_arr)
 
 
@@ -276,3 +268,64 @@ def test_parametrized_input_shape(self, min_shape, axes, step):
         assert [(d.name, d.size) for d in pb_shape.stepShape.namedInts] == [
             (name, size) for name, size in zip(axes, step)
         ]
+
+
+class TestSample:
+    def test_create_sample_from_pb_tensors(self):
+        arr_1 = np.arange(32 * 32, dtype=np.int64).reshape(32, 32)
+        tensor_1 = inference_pb2.Tensor(
+            dtype="int64",
+            tensorId="input1",
+            buffer=bytes(arr_1),
+            shape=[inference_pb2.NamedInt(name="x", size=32), inference_pb2.NamedInt(name="y", size=32)],
+        )
+
+        arr_2 = np.arange(64 * 64, dtype=int).reshape(64, 64)
+        tensor_2 = inference_pb2.Tensor(
+            dtype="int64",
+            tensorId="input2",
+            buffer=bytes(arr_2),
+            shape=[inference_pb2.NamedInt(name="x", size=64), inference_pb2.NamedInt(name="y", size=64)],
+        )
+
+        sample = Sample.from_pb_tensors([tensor_1, tensor_2])
+        assert len(sample.tensors) == 2
+        assert sample.tensors["input1"].equals(xr.DataArray(arr_1, dims=["x", "y"]))
+        assert sample.tensors["input2"].equals(xr.DataArray(arr_2, dims=["x", "y"]))
+
+    def test_create_sample_from_raw_data(self):
+        arr_1 = np.arange(32 * 32, dtype=np.int64).reshape(32, 32)
+        tensor_1 = xr.DataArray(arr_1, dims=["x", "y"])
+        arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64)
+        tensor_2 = xr.DataArray(arr_2, dims=["x", "y"])
+        tensors_ids = ["input1", "input2"]
+        actual_sample = Sample.from_xr_tensors(tensors_ids, [tensor_1, tensor_2])
+
+        expected_dict = {tensors_ids[0]: tensor_1, tensors_ids[1]: tensor_2}
+        expected_sample = Sample(expected_dict)
+        assert actual_sample == expected_sample
+
+    def test_sample_to_pb_tensors(self):
+        arr_1 = np.arange(32 * 32, dtype=np.int64).reshape(32, 32)
+        tensor_1 = xr.DataArray(arr_1, dims=["x", "y"])
+        arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64)
+        tensor_2 = xr.DataArray(arr_2, dims=["x", "y"])
+        tensors_ids = ["input1", "input2"]
+        sample = Sample.from_xr_tensors(tensors_ids, [tensor_1, tensor_2])
+
+        pb_tensor_1 = inference_pb2.Tensor(
+            dtype="int64",
+            tensorId="input1",
+            buffer=bytes(arr_1),
+            shape=[inference_pb2.NamedInt(name="x", size=32), inference_pb2.NamedInt(name="y", size=32)],
+        )
+        pb_tensor_2 = inference_pb2.Tensor(
+            dtype="int64",
+            tensorId="input2",
+            buffer=bytes(arr_2),
+            shape=[inference_pb2.NamedInt(name="x", size=64), inference_pb2.NamedInt(name="y", size=64)],
+        )
+        expected_tensors = [pb_tensor_1, pb_tensor_2]
+
+        actual_tensors = sample.to_pb_tensors()
+        assert expected_tensors == actual_tensors
diff --git a/tests/test_rpc/test_mp.py b/tests/test_rpc/test_mp.py
index 981f4553..2011c54e 100644
--- a/tests/test_rpc/test_mp.py
+++ b/tests/test_rpc/test_mp.py
@@ -8,7 +8,7 @@
 
 from tiktorch import log
 from tiktorch.rpc import RPCFuture, RPCInterface, Shutdown, exposed
-from tiktorch.rpc.mp import FutureStore, MPServer, create_client
+from tiktorch.rpc.mp import FutureStore, MPServer, create_client_api
 
 
 class ITestApi(RPCInterface):
@@ -64,7 +64,7 @@ def client(log_queue):
     p = mp.Process(target=_srv, args=(parent, log_queue))
     p.start()
 
-    client = create_client(ITestApi, child, timeout=10)
+    client = create_client_api(iface_cls=ITestApi, conn=child, timeout=10)
 
     yield client
 
@@ -108,7 +108,7 @@ def __getattr__(self, name):
     p = mp.Process(target=_srv, args=(parent, log_queue))
     p.start()
 
-    client = create_client(ITestApi, SlowConn(child))
+    client = create_client_api(iface_cls=ITestApi, conn=SlowConn(child))
 
     client.fast_compute(2, 2)
 
@@ -121,7 +121,7 @@ def test_future_timeout(client: ITestApi, log_queue):
     p = mp.Process(target=_srv, args=(parent, log_queue))
     p.start()
 
-    client = create_client(ITestApi, child, timeout=0.001)
+    client = create_client_api(iface_cls=ITestApi, conn=child, timeout=0.001)
 
     with pytest.raises(TimeoutError):
         client.compute(1, 2)
@@ -256,7 +256,7 @@ def _spawn(iface_cls, srv_cls):
         p = mp.Process(target=_run_srv, args=(srv_cls, parent, log_queue))
         p.start()
 
-        data["client"] = client = create_client(iface_cls, child)
+        data["client"] = client = create_client_api(iface_cls=iface_cls, conn=child)
         data["process"] = p
         return client
 
diff --git a/tests/test_server/test_grpc/test_inference_servicer.py b/tests/test_server/test_grpc/test_inference_servicer.py
index 8a5188af..864c4d6a 100644
--- a/tests/test_server/test_grpc/test_inference_servicer.py
+++ b/tests/test_server/test_grpc/test_inference_servicer.py
@@ -47,12 +47,10 @@ def method_requiring_session(self, request, grpc_stub):
     def test_model_session_creation(self, grpc_stub, bioimageio_model_bytes):
         model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_model_bytes))
         assert model.id
-        assert hasattr(model, "outputShapes")
-        assert hasattr(model, "inputShapes")
         grpc_stub.CloseModelSession(model)
 
-    def test_model_session_creation_using_upload_id(self, grpc_stub, data_store, bioimageio_dummy_model_bytes):
-        id_ = data_store.put(bioimageio_dummy_model_bytes.getvalue())
+    def test_model_session_creation_using_upload_id(self, grpc_stub, data_store, bioimageio_dummy_explicit_model_bytes):
+        id_ = data_store.put(bioimageio_dummy_explicit_model_bytes.getvalue())
 
         rq = inference_pb2.CreateModelSessionRequest(model_uri=f"upload://{id_}", deviceIds=["cpu"])
         model = grpc_stub.CreateModelSession(rq)
@@ -156,27 +154,81 @@ def test_call_fails_with_unknown_model_session_id(self, grpc_stub):
         assert grpc.StatusCode.FAILED_PRECONDITION == e.value.code()
         assert "model-session with id myid1 doesn't exist" in e.value.details()
 
-    def test_call_predict(self, grpc_stub, bioimageio_dummy_model_bytes):
-        model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_model_bytes))
-        arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y"))
+    def test_call_predict_valid_explicit(self, grpc_stub, bioimageio_dummy_explicit_model_bytes):
+        model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_explicit_model_bytes))
+        arr = xr.DataArray(np.arange(128 * 128).reshape(1, 1, 128, 128), dims=("b", "c", "x", "y"))
         expected = arr + 1
-        input_tensors = [converters.xarray_to_pb_tensor(arr)]
+        input_tensor_id = "input"
+        output_tensor_id = "output"
+        input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
         res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
 
         grpc_stub.CloseModelSession(model)
 
         assert len(res.tensors) == 1
+        assert res.tensors[0].tensorId == output_tensor_id
         assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0]))
 
+    def test_call_predict_invalid_shape_explicit(self, grpc_stub, bioimageio_dummy_explicit_model_bytes):
+        model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_explicit_model_bytes))
+        arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y"))
+        input_tensors = [converters.xarray_to_pb_tensor("input", arr)]
+        with pytest.raises(grpc.RpcError):
+            grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
+        grpc_stub.CloseModelSession(model)
+
+    @pytest.mark.parametrize(
+        "shape",
+        [(1, 1, 64, 32), (1, 1, 32, 64), (1, 1, 64, 32), (0, 1, 64, 64), (1, 0, 64, 64)],
+    )
+    def test_call_predict_invalid_shape_parameterized(self, grpc_stub, shape, bioimageio_dummy_param_model_bytes):
+        model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_param_model_bytes))
+        arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("b", "c", "x", "y"))
+        input_tensors = [converters.xarray_to_pb_tensor("param", arr)]
+        with pytest.raises(grpc.RpcError):
+            grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
+        grpc_stub.CloseModelSession(model)
+
+    def test_call_predict_invalid_tensor_ids(self, grpc_stub, bioimageio_dummy_model):
+        model_bytes, _ = bioimageio_dummy_model
+        model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
+        arr = xr.DataArray(np.arange(32 * 32).reshape(32, 32), dims=("x", "y"))
+        input_tensors = [converters.xarray_to_pb_tensor("invalidTensorName", arr)]
+        with pytest.raises(grpc.RpcError) as error:
+            grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
+        assert error.value.details().startswith("Exception calling application: Spec invalidTensorName doesn't exist")
+        grpc_stub.CloseModelSession(model)
+
+    def test_call_predict_invalid_axes(self, grpc_stub, bioimageio_dummy_model):
+        model_bytes, tensor_id = bioimageio_dummy_model
+        model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
+        arr = xr.DataArray(np.arange(32 * 32).reshape(32, 32), dims=("invalidAxis", "y"))
+        input_tensors = [converters.xarray_to_pb_tensor(tensor_id, arr)]
+        with pytest.raises(grpc.RpcError) as error:
+            grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
+        assert error.value.details().startswith("Exception calling application: Incompatible axes")
+        grpc_stub.CloseModelSession(model)
+
+    @pytest.mark.parametrize("shape", [(1, 1, 64, 64), (1, 1, 66, 65), (1, 1, 68, 66), (1, 1, 70, 67)])
+    def test_call_predict_valid_shape_parameterized(self, grpc_stub, shape, bioimageio_dummy_param_model_bytes):
+        model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_param_model_bytes))
+        arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("b", "c", "x", "y"))
+        input_tensors = [converters.xarray_to_pb_tensor("param", arr)]
+        grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
+        grpc_stub.CloseModelSession(model)
+
     @pytest.mark.skip
     def test_call_predict_tf(self, grpc_stub, bioimageio_dummy_tensorflow_model_bytes):
         model = grpc_stub.CreateModelSession(valid_model_request(bioimageio_dummy_tensorflow_model_bytes))
         arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("b", "c", "x", "y"))
         expected = arr * -1
-        input_tensors = [converters.xarray_to_pb_tensor(arr)]
+        input_tensor_id = "input"
+        output_tensor_id = "output"
+        input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
         res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
 
         grpc_stub.CloseModelSession(model)
 
         assert len(res.tensors) == 1
+        assert res.tensors[0].tensorId == output_tensor_id
         assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensors[0]))
diff --git a/tests/test_server/test_modelinfo.py b/tests/test_server/test_modelinfo.py
deleted file mode 100644
index 4fbb72c5..00000000
--- a/tests/test_server/test_modelinfo.py
+++ /dev/null
@@ -1,153 +0,0 @@
-import random
-from unittest import mock
-
-import pytest
-from bioimageio.spec.shared.raw_nodes import ImplicitOutputShape, ParametrizedInputShape
-from marshmallow import missing
-
-from tiktorch.converters import NamedExplicitOutputShape, NamedImplicitOutputShape, NamedParametrizedShape
-from tiktorch.server.session.process import ModelInfo
-
-
-@pytest.fixture
-def implicit_output_spec():
-    """output spec with ImplicitOutputShape"""
-    shape = ImplicitOutputShape(
-        reference_tensor="blah",
-        scale=[1.0] + [float(random.randint(0, 2**32)) for _ in range(4)],
-        offset=[0.0] + [float(random.randint(0, 2**32)) for _ in range(4)],
-    )
-    output_spec = mock.Mock(axes=("x", "y"), shape=shape, halo=[5, 12])
-    output_spec.name = "implicit_out"
-    return output_spec
-
-
-@pytest.fixture
-def parametrized_input_spec():
-    shape = ParametrizedInputShape(
-        min=[random.randint(0, 2**32) for _ in range(5)], step=[float(random.randint(0, 2**32)) for _ in range(5)]
-    )
-    input_spec = mock.Mock(axes=("b", "x", "y", "z", "c"), shape=shape)
-    input_spec.name = "param_in"
-    return input_spec
-
-
-@pytest.fixture
-def explicit_input_spec():
-    input_shape = [random.randint(0, 2**32) for _ in range(3)]
-    input_spec = mock.Mock(axes=("b", "x", "y"), shape=input_shape)
-    input_spec.name = "explicit_in"
-    return input_spec
-
-
-@pytest.fixture
-def explicit_output_spec():
-    output_shape = [random.randint(0, 2**32) for _ in range(3)]
-    halo = [0] + [random.randint(0, 2**32) for _ in range(2)]
-    output_spec = mock.Mock(axes=("b", "x", "y"), shape=output_shape, halo=halo)
-    output_spec.name = "explicit_out"
-    return output_spec
-
-
-def test_model_info_explicit_shapes(explicit_input_spec, explicit_output_spec):
-    prediction_pipeline = mock.Mock(input_specs=[explicit_input_spec], output_specs=[explicit_output_spec], name="bleh")
-
-    model_info = ModelInfo.from_prediction_pipeline(prediction_pipeline)
-
-    assert model_info.input_axes == ["".join(explicit_input_spec.axes)]
-    assert model_info.output_axes == ["".join(explicit_output_spec.axes)]
-    assert len(model_info.input_shapes) == 1
-    assert len(model_info.output_shapes) == 1
-    assert isinstance(model_info.input_shapes[0], list)
-    assert model_info.input_shapes[0] == [(ax, s) for ax, s in zip(explicit_input_spec.axes, explicit_input_spec.shape)]
-    assert isinstance(model_info.output_shapes[0], NamedExplicitOutputShape)
-    assert model_info.output_shapes[0].shape == [
-        (ax, s) for ax, s in zip(explicit_output_spec.axes, explicit_output_spec.shape)
-    ]
-    assert model_info.output_shapes[0].halo == [
-        (ax, s) for ax, s in zip(explicit_output_spec.axes, explicit_output_spec.halo)
-    ]
-    assert model_info.input_names == ["explicit_in"]
-    assert model_info.output_names == ["explicit_out"]
-
-
-def test_model_info_explicit_shapes_missing_halo(explicit_input_spec, explicit_output_spec):
-    explicit_output_spec.halo = missing
-
-    prediction_pipeline = mock.Mock(input_specs=[explicit_input_spec], output_specs=[explicit_output_spec], name="bleh")
-
-    model_info = ModelInfo.from_prediction_pipeline(prediction_pipeline)
-
-    assert model_info.input_axes == ["".join(explicit_input_spec.axes)]
-    assert model_info.output_axes == ["".join(explicit_output_spec.axes)]
-    assert len(model_info.input_shapes) == 1
-    assert len(model_info.output_shapes) == 1
-    assert isinstance(model_info.input_shapes[0], list)
-    assert model_info.input_shapes[0] == [(ax, s) for ax, s in zip(explicit_input_spec.axes, explicit_input_spec.shape)]
-    assert isinstance(model_info.output_shapes[0], NamedExplicitOutputShape)
-    assert model_info.output_shapes[0].shape == [
-        (ax, s) for ax, s in zip(explicit_output_spec.axes, explicit_output_spec.shape)
-    ]
-    assert model_info.output_shapes[0].halo == [(ax, s) for ax, s in zip(explicit_output_spec.axes, [0, 0, 0])]
-
-
-def test_model_info_implicit_shapes(parametrized_input_spec, implicit_output_spec):
-    prediction_pipeline = mock.Mock(
-        input_specs=[parametrized_input_spec], output_specs=[implicit_output_spec], name="bleh"
-    )
-
-    model_info = ModelInfo.from_prediction_pipeline(prediction_pipeline)
-    assert model_info.input_axes == ["".join(parametrized_input_spec.axes)]
-    assert model_info.output_axes == ["".join(implicit_output_spec.axes)]
-    assert len(model_info.input_shapes) == 1
-    assert len(model_info.output_shapes) == 1
-    assert isinstance(model_info.input_shapes[0], NamedParametrizedShape)
-    assert model_info.input_shapes[0].min_shape == [
-        (ax, s) for ax, s in zip(parametrized_input_spec.axes, parametrized_input_spec.shape.min)
-    ]
-    assert model_info.input_shapes[0].step_shape == [
-        (ax, s) for ax, s in zip(parametrized_input_spec.axes, parametrized_input_spec.shape.step)
-    ]
-    assert isinstance(model_info.output_shapes[0], NamedImplicitOutputShape)
-    assert model_info.output_shapes[0].offset == [
-        (ax, s) for ax, s in zip(implicit_output_spec.axes, implicit_output_spec.shape.offset)
-    ]
-    assert model_info.output_shapes[0].scale == [
-        (ax, s) for ax, s in zip(implicit_output_spec.axes, implicit_output_spec.shape.scale)
-    ]
-    assert model_info.output_shapes[0].halo == [
-        (ax, s) for ax, s in zip(implicit_output_spec.axes, implicit_output_spec.halo)
-    ]
-    assert model_info.output_shapes[0].reference_tensor == implicit_output_spec.shape.reference_tensor
-
-    assert model_info.input_names == ["param_in"]
-    assert model_info.output_names == ["implicit_out"]
-
-
-def test_model_info_implicit_shapes_missing_halo(parametrized_input_spec, implicit_output_spec):
-    implicit_output_spec.halo = missing
-    prediction_pipeline = mock.Mock(
-        input_specs=[parametrized_input_spec], output_specs=[implicit_output_spec], name="bleh"
-    )
-
-    model_info = ModelInfo.from_prediction_pipeline(prediction_pipeline)
-    assert model_info.input_axes == ["".join(parametrized_input_spec.axes)]
-    assert model_info.output_axes == ["".join(implicit_output_spec.axes)]
-    assert len(model_info.input_shapes) == 1
-    assert len(model_info.output_shapes) == 1
-    assert isinstance(model_info.input_shapes[0], NamedParametrizedShape)
-    assert model_info.input_shapes[0].min_shape == [
-        (ax, s) for ax, s in zip(parametrized_input_spec.axes, parametrized_input_spec.shape.min)
-    ]
-    assert model_info.input_shapes[0].step_shape == [
-        (ax, s) for ax, s in zip(parametrized_input_spec.axes, parametrized_input_spec.shape.step)
-    ]
-    assert isinstance(model_info.output_shapes[0], NamedImplicitOutputShape)
-    assert model_info.output_shapes[0].offset == [
-        (ax, s) for ax, s in zip(implicit_output_spec.axes, implicit_output_spec.shape.offset)
-    ]
-    assert model_info.output_shapes[0].scale == [
-        (ax, s) for ax, s in zip(implicit_output_spec.axes, implicit_output_spec.shape.scale)
-    ]
-    assert model_info.output_shapes[0].halo == [(ax, s) for ax, s in zip(implicit_output_spec.axes, [0, 0, 0, 0, 0])]
-    assert model_info.output_shapes[0].reference_tensor == implicit_output_spec.shape.reference_tensor
diff --git a/tiktorch/converters.py b/tiktorch/converters.py
index 4093bec3..b7fee3fe 100644
--- a/tiktorch/converters.py
+++ b/tiktorch/converters.py
@@ -1,5 +1,7 @@
+from __future__ import annotations
+
 import dataclasses
-from typing import List, Tuple, Union
+from typing import Dict, List, Tuple, Union
 
 import numpy as np
 import xarray as xr
@@ -33,6 +35,23 @@ class NamedImplicitOutputShape:
     halo: NamedShape
 
 
+@dataclasses.dataclass(frozen=True)
+class Sample:
+    tensors: Dict[str, xr.DataArray]
+
+    @classmethod
+    def from_pb_tensors(cls, pb_tensors: List[inference_pb2.Tensor]) -> Sample:
+        return Sample({tensor.tensorId: pb_tensor_to_xarray(tensor) for tensor in pb_tensors})
+
+    @classmethod
+    def from_xr_tensors(cls, tensor_ids: List[str], tensors_data: List[xr.DataArray]) -> Sample:
+        assert len(tensor_ids) == len(tensors_data)
+        return Sample({tensor_id: tensor_data for tensor_id, tensor_data in zip(tensor_ids, tensors_data)})
+
+    def to_pb_tensors(self) -> List[inference_pb2.Tensor]:
+        return [xarray_to_pb_tensor(tensor_id, res_tensor) for tensor_id, res_tensor in self.tensors.items()]
+
+
 def numpy_to_pb_tensor(array: np.ndarray, axistags=None) -> inference_pb2.Tensor:
     if axistags:
         shape = [inference_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, axistags)]
@@ -41,9 +60,9 @@ def numpy_to_pb_tensor(array: np.ndarray, axistags=None) -> inference_pb2.Tensor
     return inference_pb2.Tensor(dtype=str(array.dtype), shape=shape, buffer=bytes(array))
 
 
-def xarray_to_pb_tensor(array: xr.DataArray) -> inference_pb2.Tensor:
+def xarray_to_pb_tensor(tensor_id: str, array: xr.DataArray) -> inference_pb2.Tensor:
     shape = [inference_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, array.dims)]
-    return inference_pb2.Tensor(dtype=str(array.dtype), shape=shape, buffer=bytes(array.data))
+    return inference_pb2.Tensor(tensorId=tensor_id, dtype=str(array.dtype), shape=shape, buffer=bytes(array.data))
 
 
 def name_int_tuples_to_pb_NamedInts(name_int_tuples) -> inference_pb2.NamedInts:
diff --git a/tiktorch/proto/inference_pb2.py b/tiktorch/proto/inference_pb2.py
index e098a4df..dc5c6c9c 100644
--- a/tiktorch/proto/inference_pb2.py
+++ b/tiktorch/proto/inference_pb2.py
@@ -20,7 +20,7 @@
   syntax='proto3',
   serialized_options=None,
   create_key=_descriptor._internal_create_key,
-  serialized_pb=b'\n\x0finference.proto\"Y\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06status\x18\x02 \x01(\x0e\x32\x0e.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"W\n\x1f\x43reateDatasetDescriptionRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"i\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12\x1b\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\x9d\x01\n\nInputShape\x12(\n\tshapeType\x18\x01 \x01(\x0e\x32\x15.InputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x1d\n\tstepShape\x18\x04 \x01(\x0b\x32\n.NamedInts\"+\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x10\n\x0cPARAMETRIZED\x10\x01\"\xea\x01\n\x0bOutputShape\x12)\n\tshapeType\x18\x01 \x01(\x0e\x32\x16.OutputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x18\n\x04halo\x18\x03 \x01(\x0b\x32\n.NamedInts\x12\x17\n\x0freferenceTensor\x18\x04 \x01(\t\x12\x1b\n\x05scale\x18\x05 \x01(\x0b\x32\x0c.NamedFloats\x12\x1c\n\x06offset\x18\x06 \x01(\x0b\x32\x0c.NamedFloats\"\'\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x0c\n\x08IMPLICIT\x10\x01\"\xd3\x01\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tinputAxes\x18\x03 \x03(\t\x12\x12\n\noutputAxes\x18\x04 \x03(\t\x12\x13\n\x0bhasTraining\x18\x05 \x01(\x08\x12 \n\x0binputShapes\x18\x06 \x03(\x0b\x32\x0b.InputShape\x12\"\n\x0coutputShapes\x18\x07 \x03(\x0b\x32\x0c.OutputShape\x12\x12\n\ninputNames\x18\x08 \x03(\t\x12\x13\n\x0boutputNames\x18\t \x03(\t\"\x9e\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12\x1e\n\x05level\x18\x02 \x01(\x0e\x32\x0f.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\"#\n\x07\x44\x65vices\x12\x18\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x07.Device\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"A\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x18\n\x05shape\x18\x03 \x03(\x0b\x32\t.NamedInt\"U\n\x0ePredictRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x11\n\tdatasetId\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x03 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"\x07\n\x05\x45mpty\"\x1e\n\tModelInfo\x12\x11\n\tdeviceIds\x18\x01 \x03(\t\"^\n CreateModelSessionChunkedRequest\x12\x1a\n\x04info\x18\x01 \x01(\x0b\x32\n.ModelInfoH\x00\x12\x16\n\x05\x63hunk\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x42\x06\n\x04\x64\x61ta2\xc6\x02\n\tInference\x12\x41\n\x12\x43reateModelSession\x12\x1a.CreateModelSessionRequest\x1a\r.ModelSession\"\x00\x12,\n\x11\x43loseModelSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12S\n\x18\x43reateDatasetDescription\x12 .CreateDatasetDescriptionRequest\x1a\x13.DatasetDescription\"\x00\x12 \n\x07GetLogs\x12\x06.Empty\x1a\t.LogEntry\"\x00\x30\x01\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x32G\n\rFlightControl\x12\x18\n\x04Ping\x12\x06.Empty\x1a\x06.Empty\"\x00\x12\x1c\n\x08Shutdown\x12\x06.Empty\x1a\x06.Empty\"\x00\x62\x06proto3'
+  serialized_pb=b'\n\x0finference.proto\"Y\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06status\x18\x02 \x01(\x0e\x32\x0e.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"W\n\x1f\x43reateDatasetDescriptionRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"i\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12\x1b\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\x9d\x01\n\nInputShape\x12(\n\tshapeType\x18\x01 \x01(\x0e\x32\x15.InputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x1d\n\tstepShape\x18\x04 \x01(\x0b\x32\n.NamedInts\"+\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x10\n\x0cPARAMETRIZED\x10\x01\"\xea\x01\n\x0bOutputShape\x12)\n\tshapeType\x18\x01 \x01(\x0e\x32\x16.OutputShape.ShapeType\x12\x19\n\x05shape\x18\x02 \x01(\x0b\x32\n.NamedInts\x12\x18\n\x04halo\x18\x03 \x01(\x0b\x32\n.NamedInts\x12\x17\n\x0freferenceTensor\x18\x04 \x01(\t\x12\x1b\n\x05scale\x18\x05 \x01(\x0b\x32\x0c.NamedFloats\x12\x1c\n\x06offset\x18\x06 \x01(\x0b\x32\x0c.NamedFloats\"\'\n\tShapeType\x12\x0c\n\x08\x45XPLICIT\x10\x00\x12\x0c\n\x08IMPLICIT\x10\x01\"\x1a\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\"\x9e\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12\x1e\n\x05level\x18\x02 \x01(\x0e\x32\x0f.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\"#\n\x07\x44\x65vices\x12\x18\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x07.Device\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"S\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x10\n\x08tensorId\x18\x03 \x01(\t\x12\x18\n\x05shape\x18\x04 \x03(\x0b\x32\t.NamedInt\"U\n\x0ePredictRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x11\n\tdatasetId\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x03 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"\x07\n\x05\x45mpty\"\x1e\n\tModelInfo\x12\x11\n\tdeviceIds\x18\x01 \x03(\t\"^\n CreateModelSessionChunkedRequest\x12\x1a\n\x04info\x18\x01 \x01(\x0b\x32\n.ModelInfoH\x00\x12\x16\n\x05\x63hunk\x18\x02 \x01(\x0b\x32\x05.BlobH\x00\x42\x06\n\x04\x64\x61ta2\xc6\x02\n\tInference\x12\x41\n\x12\x43reateModelSession\x12\x1a.CreateModelSessionRequest\x1a\r.ModelSession\"\x00\x12,\n\x11\x43loseModelSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12S\n\x18\x43reateDatasetDescription\x12 .CreateDatasetDescriptionRequest\x1a\x13.DatasetDescription\"\x00\x12 \n\x07GetLogs\x12\x06.Empty\x1a\t.LogEntry\"\x00\x30\x01\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x32G\n\rFlightControl\x12\x18\n\x04Ping\x12\x06.Empty\x1a\x06.Empty\"\x00\x12\x1c\n\x08Shutdown\x12\x06.Empty\x1a\x06.Empty\"\x00\x62\x06proto3'
 )
 
 
@@ -140,8 +140,8 @@
   ],
   containing_type=None,
   serialized_options=None,
-  serialized_start=1165,
-  serialized_end=1243,
+  serialized_start=979,
+  serialized_end=1057,
 )
 _sym_db.RegisterEnumDescriptor(_LOGENTRY_LEVEL)
 
@@ -548,62 +548,6 @@
       message_type=None, enum_type=None, containing_type=None,
       is_extension=False, extension_scope=None,
       serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
-    _descriptor.FieldDescriptor(
-      name='name', full_name='ModelSession.name', index=1,
-      number=2, type=9, cpp_type=9, label=1,
-      has_default_value=False, default_value=b"".decode('utf-8'),
-      message_type=None, enum_type=None, containing_type=None,
-      is_extension=False, extension_scope=None,
-      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
-    _descriptor.FieldDescriptor(
-      name='inputAxes', full_name='ModelSession.inputAxes', index=2,
-      number=3, type=9, cpp_type=9, label=3,
-      has_default_value=False, default_value=[],
-      message_type=None, enum_type=None, containing_type=None,
-      is_extension=False, extension_scope=None,
-      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
-    _descriptor.FieldDescriptor(
-      name='outputAxes', full_name='ModelSession.outputAxes', index=3,
-      number=4, type=9, cpp_type=9, label=3,
-      has_default_value=False, default_value=[],
-      message_type=None, enum_type=None, containing_type=None,
-      is_extension=False, extension_scope=None,
-      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
-    _descriptor.FieldDescriptor(
-      name='hasTraining', full_name='ModelSession.hasTraining', index=4,
-      number=5, type=8, cpp_type=7, label=1,
-      has_default_value=False, default_value=False,
-      message_type=None, enum_type=None, containing_type=None,
-      is_extension=False, extension_scope=None,
-      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
-    _descriptor.FieldDescriptor(
-      name='inputShapes', full_name='ModelSession.inputShapes', index=5,
-      number=6, type=11, cpp_type=10, label=3,
-      has_default_value=False, default_value=[],
-      message_type=None, enum_type=None, containing_type=None,
-      is_extension=False, extension_scope=None,
-      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
-    _descriptor.FieldDescriptor(
-      name='outputShapes', full_name='ModelSession.outputShapes', index=6,
-      number=7, type=11, cpp_type=10, label=3,
-      has_default_value=False, default_value=[],
-      message_type=None, enum_type=None, containing_type=None,
-      is_extension=False, extension_scope=None,
-      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
-    _descriptor.FieldDescriptor(
-      name='inputNames', full_name='ModelSession.inputNames', index=7,
-      number=8, type=9, cpp_type=9, label=3,
-      has_default_value=False, default_value=[],
-      message_type=None, enum_type=None, containing_type=None,
-      is_extension=False, extension_scope=None,
-      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
-    _descriptor.FieldDescriptor(
-      name='outputNames', full_name='ModelSession.outputNames', index=8,
-      number=9, type=9, cpp_type=9, label=3,
-      has_default_value=False, default_value=[],
-      message_type=None, enum_type=None, containing_type=None,
-      is_extension=False, extension_scope=None,
-      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
   ],
   extensions=[
   ],
@@ -616,8 +560,8 @@
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=871,
-  serialized_end=1082,
+  serialized_start=870,
+  serialized_end=896,
 )
 
 
@@ -663,8 +607,8 @@
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1085,
-  serialized_end=1243,
+  serialized_start=899,
+  serialized_end=1057,
 )
 
 
@@ -695,8 +639,8 @@
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1245,
-  serialized_end=1280,
+  serialized_start=1059,
+  serialized_end=1094,
 )
 
 
@@ -734,8 +678,8 @@
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1282,
-  serialized_end=1320,
+  serialized_start=1096,
+  serialized_end=1134,
 )
 
 
@@ -773,8 +717,8 @@
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1322,
-  serialized_end=1362,
+  serialized_start=1136,
+  serialized_end=1176,
 )
 
 
@@ -801,8 +745,15 @@
       is_extension=False, extension_scope=None,
       serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
     _descriptor.FieldDescriptor(
-      name='shape', full_name='Tensor.shape', index=2,
-      number=3, type=11, cpp_type=10, label=3,
+      name='tensorId', full_name='Tensor.tensorId', index=2,
+      number=3, type=9, cpp_type=9, label=1,
+      has_default_value=False, default_value=b"".decode('utf-8'),
+      message_type=None, enum_type=None, containing_type=None,
+      is_extension=False, extension_scope=None,
+      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
+    _descriptor.FieldDescriptor(
+      name='shape', full_name='Tensor.shape', index=3,
+      number=4, type=11, cpp_type=10, label=3,
       has_default_value=False, default_value=[],
       message_type=None, enum_type=None, containing_type=None,
       is_extension=False, extension_scope=None,
@@ -819,8 +770,8 @@
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1364,
-  serialized_end=1429,
+  serialized_start=1178,
+  serialized_end=1261,
 )
 
 
@@ -865,8 +816,8 @@
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1431,
-  serialized_end=1516,
+  serialized_start=1263,
+  serialized_end=1348,
 )
 
 
@@ -897,8 +848,8 @@
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1518,
-  serialized_end=1561,
+  serialized_start=1350,
+  serialized_end=1393,
 )
 
 
@@ -922,8 +873,8 @@
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1563,
-  serialized_end=1570,
+  serialized_start=1395,
+  serialized_end=1402,
 )
 
 
@@ -954,8 +905,8 @@
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1572,
-  serialized_end=1602,
+  serialized_start=1404,
+  serialized_end=1434,
 )
 
 
@@ -998,8 +949,8 @@
       create_key=_descriptor._internal_create_key,
     fields=[]),
   ],
-  serialized_start=1604,
-  serialized_end=1698,
+  serialized_start=1436,
+  serialized_end=1530,
 )
 
 _DEVICE.fields_by_name['status'].enum_type = _DEVICE_STATUS
@@ -1023,8 +974,6 @@
 _OUTPUTSHAPE.fields_by_name['scale'].message_type = _NAMEDFLOATS
 _OUTPUTSHAPE.fields_by_name['offset'].message_type = _NAMEDFLOATS
 _OUTPUTSHAPE_SHAPETYPE.containing_type = _OUTPUTSHAPE
-_MODELSESSION.fields_by_name['inputShapes'].message_type = _INPUTSHAPE
-_MODELSESSION.fields_by_name['outputShapes'].message_type = _OUTPUTSHAPE
 _LOGENTRY.fields_by_name['level'].enum_type = _LOGENTRY_LEVEL
 _LOGENTRY_LEVEL.containing_type = _LOGENTRY
 _DEVICES.fields_by_name['devices'].message_type = _DEVICE
@@ -1210,8 +1159,8 @@
   index=0,
   serialized_options=None,
   create_key=_descriptor._internal_create_key,
-  serialized_start=1701,
-  serialized_end=2027,
+  serialized_start=1533,
+  serialized_end=1859,
   methods=[
   _descriptor.MethodDescriptor(
     name='CreateModelSession',
@@ -1286,8 +1235,8 @@
   index=1,
   serialized_options=None,
   create_key=_descriptor._internal_create_key,
-  serialized_start=2029,
-  serialized_end=2100,
+  serialized_start=1861,
+  serialized_end=1932,
   methods=[
   _descriptor.MethodDescriptor(
     name='Ping',
diff --git a/tiktorch/rpc/mp.py b/tiktorch/rpc/mp.py
index bab3ab1f..0cd69bf4 100644
--- a/tiktorch/rpc/mp.py
+++ b/tiktorch/rpc/mp.py
@@ -1,3 +1,4 @@
+import dataclasses
 import logging
 import queue
 import threading
@@ -5,9 +6,12 @@
 from functools import wraps
 from multiprocessing.connection import Connection
 from threading import Event, Thread
-from typing import Any, Optional, Type, TypeVar
+from typing import Any, List, Optional, Type, TypeVar
 from uuid import uuid4
 
+from bioimageio.core.resource_io import nodes
+
+from ..server.session import IRPCModelSession
 from .exceptions import Shutdown
 from .interface import get_exposed_methods
 from .types import RPCFuture, isfutureret
@@ -72,9 +76,8 @@ def __call__(self, *args, **kwargs) -> Any:
         return self._client._invoke(self._method_name, *args, **kwargs)
 
 
-def create_client(iface_cls: Type[T], conn: Connection, timeout=None) -> T:
+def create_client_api(iface_cls: Type[T], conn: Connection, timeout=None) -> T:
     client = MPClient(iface_cls.__name__, conn, timeout)
-    get_exposed_methods(iface_cls)
 
     def _make_method(method):
         class MethodWrapper:
@@ -97,13 +100,21 @@ def __call__(self, *args, **kwargs) -> Any:
 
         return MethodWrapper()
 
-    class _Client(iface_cls):
+    class _Api:
         pass
 
-    for method_name, method in get_exposed_methods(iface_cls).items():
-        setattr(_Client, method_name, _make_method(method))
+    exposed_methods = get_exposed_methods(iface_cls)
+    for method_name, method in exposed_methods.items():
+        setattr(_Api, method_name, _make_method(method))
+
+    return _Api()
+
 
-    return _Client()
+@dataclasses.dataclass(frozen=True)
+class BioModelClient:
+    api: IRPCModelSession
+    input_specs: List[nodes.InputTensor]
+    output_specs: List[nodes.OutputTensor]
 
 
 class MPClient:
@@ -190,7 +201,7 @@ def _shutdown(self, exc):
 
 class Message:
     def __init__(self, id_):
-        self.id = id
+        self.id = id_
 
 
 class Signal:
@@ -200,20 +211,19 @@ def __init__(self, payload):
 
 class MethodCall(Message):
     def __init__(self, id_, method_name, args, kwargs):
-        self.id = id_
+        super().__init__(id_)
         self.method_name = method_name
         self.args = args
         self.kwargs = kwargs
 
 
 class Cancellation(Message):
-    def __init__(self, id_):
-        self.id = id_
+    pass
 
 
 class MethodReturn(Message):
     def __init__(self, id_, result: Result):
-        self.id = id_
+        super().__init__(id_)
         self.result = result
 
 
diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py
index 5e188787..f09e0bae 100644
--- a/tiktorch/server/grpc/inference_servicer.py
+++ b/tiktorch/server/grpc/inference_servicer.py
@@ -2,12 +2,12 @@
 
 import grpc
 
-from tiktorch import converters
+from tiktorch.converters import Sample
 from tiktorch.proto import inference_pb2, inference_pb2_grpc
 from tiktorch.server.data_store import IDataStore
 from tiktorch.server.device_pool import DeviceStatus, IDevicePool
-from tiktorch.server.session.process import start_model_session_process
-from tiktorch.server.session_manager import ISession, SessionManager
+from tiktorch.server.session.process import InputTensorValidator, start_model_session_process
+from tiktorch.server.session_manager import Session, SessionManager
 
 
 class InferenceServicer(inference_pb2_grpc.InferenceServicer):
@@ -36,37 +36,17 @@ def CreateModelSession(
             lease.terminate()
             raise
 
-        session = self.__session_manager.create_session()
+        session = self.__session_manager.create_session(client)
         session.on_close(lease.terminate)
-        session.on_close(client.shutdown)
-        session.client = client
+        session.on_close(client.api.shutdown)
 
-        try:
-            model_info = session.client.get_model_info()
-        except Exception:
-            lease.terminate()
-            raise
-
-        pb_input_shapes = [converters.input_shape_to_pb_input_shape(shape) for shape in model_info.input_shapes]
-        pb_output_shapes = [converters.output_shape_to_pb_output_shape(shape) for shape in model_info.output_shapes]
-
-        return inference_pb2.ModelSession(
-            id=session.id,
-            name=model_info.name,
-            inputAxes=model_info.input_axes,
-            outputAxes=model_info.output_axes,
-            inputShapes=pb_input_shapes,
-            hasTraining=False,
-            outputShapes=pb_output_shapes,
-            inputNames=model_info.input_names,
-            outputNames=model_info.output_names,
-        )
+        return inference_pb2.ModelSession(id=session.id)
 
     def CreateDatasetDescription(
         self, request: inference_pb2.CreateDatasetDescriptionRequest, context
     ) -> inference_pb2.DatasetDescription:
         session = self._getModelSession(context, request.modelSessionId)
-        id = session.client.create_dataset_description(mean=request.mean, stddev=request.stddev)
+        id = session.bio_model_client.api.create_dataset_description(mean=request.mean, stddev=request.stddev)
         return inference_pb2.DatasetDescription(id=id)
 
     def CloseModelSession(self, request: inference_pb2.ModelSession, context) -> inference_pb2.Empty:
@@ -95,12 +75,15 @@ def ListDevices(self, request: inference_pb2.Empty, context) -> inference_pb2.De
 
     def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse:
         session = self._getModelSession(context, request.modelSessionId)
-        arrs = [converters.pb_tensor_to_xarray(tensor) for tensor in request.tensors]
-        res = session.client.forward(arrs)
-        pb_tensors = [converters.xarray_to_pb_tensor(res_tensor) for res_tensor in res]
-        return inference_pb2.PredictResponse(tensors=pb_tensors)
-
-    def _getModelSession(self, context, modelSessionId: str) -> ISession:
+        input_sample = Sample.from_pb_tensors(request.tensors)
+        tensor_validator = InputTensorValidator(session.bio_model_client.input_specs)
+        tensor_validator.check_tensors(input_sample)
+        res = session.bio_model_client.api.forward(input_sample)
+        output_tensor_ids = [tensor.name for tensor in session.bio_model_client.output_specs]
+        output_sample = Sample.from_xr_tensors(output_tensor_ids, res)
+        return inference_pb2.PredictResponse(tensors=output_sample.to_pb_tensors())
+
+    def _getModelSession(self, context, modelSessionId: str) -> Session:
         if not modelSessionId:
             context.abort(grpc.StatusCode.FAILED_PRECONDITION, "model-session-id has not been provided by client")
 
diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py
index 91b00f83..c9e0186e 100644
--- a/tiktorch/server/session/process.py
+++ b/tiktorch/server/session/process.py
@@ -1,106 +1,121 @@
-import dataclasses
 import multiprocessing as _mp
-import os
 import pathlib
 import tempfile
 import uuid
 from concurrent.futures import Future
 from multiprocessing.connection import Connection
-from typing import List, Optional, Tuple, Union
+from typing import Dict, Iterator, List, Optional, Tuple
 
-import numpy
+import numpy as np
 from bioimageio.core import load_resource_description
 from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline
-from bioimageio.spec.shared.raw_nodes import ImplicitOutputShape, ParametrizedInputShape
-from marshmallow import missing
+from bioimageio.core.resource_io import nodes
+from bioimageio.core.resource_io.nodes import ParametrizedInputShape
 
 from tiktorch import log
-from tiktorch.converters import NamedExplicitOutputShape, NamedImplicitOutputShape, NamedParametrizedShape, NamedShape
 from tiktorch.rpc import Shutdown
 from tiktorch.rpc import mp as _mp_rpc
-from tiktorch.rpc.mp import MPServer
+from tiktorch.rpc.mp import BioModelClient, MPServer
 
+from ...converters import Sample
 from .backend import base
 from .rpc_interface import IRPCModelSession
 
 
-@dataclasses.dataclass
-class ModelInfo:
-    """Intermediate representation of bioimageio neural network model
-
-    TODO (k-dominik): ModelInfo only used in inference_servicer to convert to
-    protobuf modelinfo.
-
-    """
-
-    name: str
-    input_axes: List[str]  # one per input
-    output_axes: List[str]  # one per output
-    input_shapes: List[Union[NamedShape, NamedParametrizedShape]]  # per input multiple shapes
-    output_shapes: List[Union[NamedExplicitOutputShape, NamedImplicitOutputShape]]
-    input_names: List[str]  # one per input
-    output_names: List[str]  # one per output
-
-    @classmethod
-    def from_prediction_pipeline(cls, prediction_pipeline: PredictionPipeline) -> "ModelInfo":
-        input_shapes = []
-        for input_spec in prediction_pipeline.input_specs:
-            if isinstance(input_spec.shape, ParametrizedInputShape):
-                input_shapes.append(
-                    NamedParametrizedShape(
-                        min_shape=list(map(tuple, zip(input_spec.axes, input_spec.shape.min))),
-                        step_shape=list(map(tuple, zip(input_spec.axes, input_spec.shape.step))),
-                    )
-                )
-            else:
-                input_shapes.append(list(map(tuple, zip(input_spec.axes, input_spec.shape))))
-
-        output_shapes = []
-        for output_spec in prediction_pipeline.output_specs:
-            # halo is not required by spec. We could alternatively make it optional in the
-            # respective grpc message types and handle missing values in ilastik
-            halo = [0 for _ in output_spec.axes] if output_spec.halo == missing else output_spec.halo
-            if isinstance(output_spec.shape, ImplicitOutputShape):
-                output_shapes.append(
-                    NamedImplicitOutputShape(
-                        reference_tensor=output_spec.shape.reference_tensor,
-                        scale=list(map(tuple, zip(output_spec.axes, output_spec.shape.scale))),
-                        offset=list(map(tuple, zip(output_spec.axes, output_spec.shape.offset))),
-                        halo=list(map(tuple, zip(output_spec.axes, halo))),
-                    )
-                )
-            else:  # isinstance(output_spec.shape, ExplicitShape):
-                output_shapes.append(
-                    NamedExplicitOutputShape(
-                        shape=list(map(tuple, zip(output_spec.axes, output_spec.shape))),
-                        halo=list(map(tuple, zip(output_spec.axes, halo))),
-                    )
-                )
-
-        return cls(
-            name=prediction_pipeline.name,
-            input_axes=["".join(input_spec.axes) for input_spec in prediction_pipeline.input_specs],
-            output_axes=["".join(output_spec.axes) for output_spec in prediction_pipeline.output_specs],
-            input_shapes=input_shapes,
-            output_shapes=output_shapes,
-            input_names=[input_spec.name for input_spec in prediction_pipeline.input_specs],
-            output_names=[output_spec.name for output_spec in prediction_pipeline.output_specs],
-        )
-
-
-class ModelSessionProcess(IRPCModelSession):
-    def __init__(self, model_zip: bytes, devices: List[str]) -> None:
-        _tmp_file = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
-        _tmp_file.write(model_zip)
-        _tmp_file.close()
-        model = load_resource_description(pathlib.Path(_tmp_file.name))
-        os.unlink(_tmp_file.name)
-        self._model: PredictionPipeline = create_prediction_pipeline(bioimageio_model=model, devices=devices)
+class InputTensorValidator:
+    def __init__(self, input_specs: List[nodes.InputTensor]):
+        self._input_specs = input_specs
+
+    def check_tensors(self, sample: Sample):
+        for tensor_id, tensor in sample.tensors.items():
+            self.check_shape(tensor_id, tensor.dims, tensor.shape)
+
+    def _get_input_tensors_with_names(self) -> Dict[str, nodes.InputTensor]:
+        return {tensor.name: tensor for tensor in self._input_specs}
+
+    def check_shape(self, tensor_id: str, axes: Tuple[str, ...], shape: Tuple[int, ...]):
+        shape = self.get_axes_with_size(axes, shape)
+        spec = self._get_input_spec(tensor_id)
+        if isinstance(spec.shape, list):
+            self._check_shape_explicit(spec, shape)
+        elif isinstance(spec.shape, ParametrizedInputShape):
+            self._check_shape_parameterized(spec, shape)
+        else:
+            raise ValueError(f"Unexpected shape {spec.shape}")
+
+    def _get_input_spec(self, tensor_id: str) -> nodes.InputTensor:
+        self._check_spec_exists(tensor_id)
+        specs = [spec for spec in self._input_specs if spec.name == tensor_id]
+        assert len(specs) == 1, "ids of tensor specs should be unique"
+        return specs[0]
+
+    def _check_spec_exists(self, tensor_id: str):
+        spec_names = [spec.name for spec in self._input_specs]
+        if tensor_id not in spec_names:
+            raise ValueError(f"Spec {tensor_id} doesn't exist for specs {spec_names}")
+
+    def _check_shape_explicit(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]):
+        assert self.is_shape_explicit(spec)
+        reference_shape = {name: size for name, size in zip(spec.axes, spec.shape)}
+        self.check_same_axes(reference_shape, tensor_shape)
+        if reference_shape != tensor_shape:
+            raise ValueError(f"Incompatible shapes found {tensor_shape}, expected {reference_shape}")
+
+    def _check_shape_parameterized(self, spec: nodes.InputTensor, tensor_shape: Dict[str, int]):
+        assert isinstance(spec.shape, ParametrizedInputShape)
+        if not self.is_shape(tensor_shape.values()):
+            raise ValueError(f"Invalid shape's sizes {tensor_shape}")
+
+        min_shape = self.get_axes_with_size(spec.axes, tuple(spec.shape.min))
+        step = self.get_axes_with_size(spec.axes, tuple(spec.shape.step))
+        self.check_same_axes(tensor_shape, min_shape)
+
+        tensor_shapes_arr = np.array(list(tensor_shape.values()))
+        min_shape_arr = np.array(list(min_shape.values()))
+        step_arr = np.array(list(step.values()))
+        diff = tensor_shapes_arr - min_shape_arr
+        if any(size < 0 for size in diff):
+            raise ValueError(f"Tensor shape {tensor_shape} smaller than min shape {min_shape}")
+
+        non_zero_idx = np.nonzero(step_arr)
+        multipliers = diff[non_zero_idx] / step_arr[non_zero_idx]
+        multiplier = np.unique(multipliers)
+        if len(multiplier) == 1 and self.is_natural_number(multiplier[0]):
+            return
+        raise ValueError(f"Tensor shape {tensor_shape} not valid for spec {spec}")
+
+    @staticmethod
+    def check_same_axes(source: Dict[str, int], target: Dict[str, int]):
+        if source.keys() != target.keys():
+            raise ValueError(f"Incompatible axes for tensor {target} and reference {source}")
+
+    @staticmethod
+    def is_natural_number(n) -> bool:
+        return n % 1 == 0.0 and n >= 0
+
+    @staticmethod
+    def is_shape(shape: Iterator[int]) -> bool:
+        return all(InputTensorValidator.is_natural_number(dim) for dim in shape)
+
+    @staticmethod
+    def get_axes_with_size(axes: Tuple[str, ...], shape: Tuple[int, ...]) -> Dict[str, int]:
+        assert len(axes) == len(shape)
+        return {name: size for name, size in zip(axes, shape)}
+
+    @staticmethod
+    def is_shape_explicit(spec: nodes.InputTensor) -> bool:
+        return isinstance(spec.shape, list)
+
+
+class ModelSessionProcess(IRPCModelSession[PredictionPipeline]):
+    def __init__(self, model: PredictionPipeline) -> None:
+        super().__init__(model)
         self._datasets = {}
         self._worker = base.SessionBackend(self._model)
 
-    def forward(self, input_tensors: numpy.ndarray) -> Future:
-        res = self._worker.forward(input_tensors)
+    def forward(self, sample: Sample) -> Future:
+        tensors_data = [sample.tensors[tensor.name] for tensor in self.model.input_specs]
+        res = self._worker.forward(tensors_data)
         return res
 
     def create_dataset(self, mean, stddev):
@@ -108,16 +123,13 @@ def create_dataset(self, mean, stddev):
         self._datasets[id_] = {"mean": mean, "stddev": stddev}
         return id_
 
-    def get_model_info(self) -> ModelInfo:
-        return ModelInfo.from_prediction_pipeline(self._model)
-
     def shutdown(self) -> Shutdown:
         self._worker.shutdown()
         return Shutdown()
 
 
 def _run_model_session_process(
-    conn: Connection, model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None
+    conn: Connection, prediction_pipeline: PredictionPipeline, log_queue: Optional[_mp.Queue] = None
 ):
     try:
         # from: https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667
@@ -131,19 +143,35 @@ def _run_model_session_process(
     if log_queue:
         log.configure(log_queue)
 
-    session_proc = ModelSessionProcess(model_zip, devices)
+    session_proc = ModelSessionProcess(prediction_pipeline)
     srv = MPServer(session_proc, conn)
     srv.listen()
 
 
 def start_model_session_process(
     model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None
-) -> Tuple[_mp.Process, IRPCModelSession]:
+) -> Tuple[_mp.Process, BioModelClient]:
     client_conn, server_conn = _mp.Pipe()
+    prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_zip, devices)
     proc = _mp.Process(
         target=_run_model_session_process,
         name="ModelSessionProcess",
-        kwargs={"conn": server_conn, "devices": devices, "log_queue": log_queue, "model_zip": model_zip},
+        kwargs={
+            "conn": server_conn,
+            "log_queue": log_queue,
+            "prediction_pipeline": prediction_pipeline,
+        },
     )
     proc.start()
-    return proc, _mp_rpc.create_client(IRPCModelSession, client_conn)
+    api = _mp_rpc.create_client_api(iface_cls=IRPCModelSession, conn=client_conn)
+    return proc, BioModelClient(
+        input_specs=prediction_pipeline.input_specs, output_specs=prediction_pipeline.output_specs, api=api
+    )
+
+
+def _get_prediction_pipeline_from_model_bytes(model_zip: bytes, devices: List[str]) -> PredictionPipeline:
+    with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as _tmp_file:
+        _tmp_file.write(model_zip)
+        temp_file_path = pathlib.Path(_tmp_file.name)
+    model = load_resource_description(temp_file_path)
+    return create_prediction_pipeline(bioimageio_model=model, devices=devices)
diff --git a/tiktorch/server/session/rpc_interface.py b/tiktorch/server/session/rpc_interface.py
index a01dd02a..bb414138 100644
--- a/tiktorch/server/session/rpc_interface.py
+++ b/tiktorch/server/session/rpc_interface.py
@@ -1,11 +1,22 @@
-from typing import List
+from typing import Generic, List, TypeVar
 
+from tiktorch.converters import Sample
 from tiktorch.rpc import RPCInterface, Shutdown, exposed
 from tiktorch.tiktypes import TikTensorBatch
 from tiktorch.types import ModelState
 
+ModelType = TypeVar("ModelType")
+
+
+class IRPCModelSession(RPCInterface, Generic[ModelType]):
+    def __init__(self, model: ModelType):
+        super().__init__()
+        self._model = model
+
+    @property
+    def model(self):
+        return self._model
 
-class IRPCModelSession(RPCInterface):
     @exposed
     def shutdown(self) -> Shutdown:
         raise NotImplementedError
@@ -43,9 +54,5 @@ def create_dataset_description(self, mean, stddev) -> str:
         raise NotImplementedError
 
     @exposed
-    def forward(self, input_tensors):
-        raise NotImplementedError
-
-    @exposed
-    def get_model_info(self):
+    def forward(self, input_tensors: Sample):
         raise NotImplementedError
diff --git a/tiktorch/server/session_manager.py b/tiktorch/server/session_manager.py
index 37bc07ab..3807e130 100644
--- a/tiktorch/server/session_manager.py
+++ b/tiktorch/server/session_manager.py
@@ -1,50 +1,44 @@
 from __future__ import annotations
 
-import abc
 import threading
 from collections import defaultdict
 from logging import getLogger
 from typing import Callable, Dict, List, Optional
 from uuid import uuid4
 
+from tiktorch.rpc.mp import BioModelClient
+
 logger = getLogger(__name__)
 
+CloseCallback = Callable[[], None]
+
 
-class ISession(abc.ABC):
+class Session:
     """
     session object has unique id
     Used for resource managent
     """
 
+    def __init__(self, id_: str, bio_model_client: BioModelClient, manager: SessionManager) -> None:
+        self.__id = id_
+        self.__manager = manager
+        self.__bio_model_client = bio_model_client
+
+    @property
+    def bio_model_client(self) -> BioModelClient:
+        return self.__bio_model_client
+
     @property
-    @abc.abstractmethod
     def id(self) -> str:
         """
         Returns unique id assigned to this session
         """
-        ...
+        return self.__id
 
-    @abc.abstractmethod
     def on_close(self, handler: CloseCallback) -> None:
         """
         Register cleanup function to be called when session ends
         """
-        ...
-
-
-CloseCallback = Callable[[], None]
-
-
-class _Session(ISession):
-    def __init__(self, id_: str, manager: SessionManager) -> None:
-        self.__id = id_
-        self.__manager = manager
-
-    @property
-    def id(self) -> str:
-        return self.__id
-
-    def on_close(self, handler: CloseCallback) -> None:
         self.__manager._on_close(self, handler)
 
 
@@ -53,18 +47,18 @@ class SessionManager:
     Manages session lifecycle (create/close)
     """
 
-    def create_session(self) -> ISession:
+    def create_session(self, bio_model_client: BioModelClient) -> Session:
         """
         Creates new session with unique id
         """
         with self.__lock:
             session_id = uuid4().hex
-            session = _Session(session_id, manager=self)
+            session = Session(session_id, bio_model_client=bio_model_client, manager=self)
             self.__session_by_id[session_id] = session
             logger.info("Created session %s", session.id)
             return session
 
-    def get(self, session_id: str) -> Optional[ISession]:
+    def get(self, session_id: str) -> Optional[Session]:
         """
         Returns existing session with given id if it exists
         """
@@ -90,10 +84,10 @@ def close_session(self, session_id: str) -> None:
 
     def __init__(self) -> None:
         self.__lock = threading.Lock()
-        self.__session_by_id: Dict[str, ISession] = {}
+        self.__session_by_id: Dict[str, Session] = {}
         self.__close_handlers_by_session_id: Dict[str, List[CloseCallback]] = defaultdict(list)
 
-    def _on_close(self, session: ISession, handler: CloseCallback):
+    def _on_close(self, session: Session, handler: CloseCallback):
         with self.__lock:
             logger.debug("Registered close handler %s for session %s", handler, session.id)
             self.__close_handlers_by_session_id[session.id].append(handler)