Skip to content

Commit

Permalink
Rename method from_raw_data to from_xr_tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
thodkatz committed Aug 15, 2024
1 parent 1902067 commit fae2741
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tests/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def test_create_sample_from_raw_data(self):
arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64)
tensor_2 = xr.DataArray(arr_2, dims=["x", "y"])
tensors_ids = ["input1", "input2"]
actual_sample = Sample.from_raw_data(tensors_ids, [tensor_1, tensor_2])
actual_sample = Sample.from_xr_tensors(tensors_ids, [tensor_1, tensor_2])

expected_dict = {tensors_ids[0]: tensor_1, tensors_ids[1]: tensor_2}
expected_sample = Sample(expected_dict)
Expand All @@ -311,7 +311,7 @@ def test_sample_to_pb_tensors(self):
arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64)
tensor_2 = xr.DataArray(arr_2, dims=["x", "y"])
tensors_ids = ["input1", "input2"]
sample = Sample.from_raw_data(tensors_ids, [tensor_1, tensor_2])
sample = Sample.from_xr_tensors(tensors_ids, [tensor_1, tensor_2])

pb_tensor_1 = inference_pb2.Tensor(
dtype="int64",
Expand Down
2 changes: 1 addition & 1 deletion tiktorch/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def from_pb_tensors(cls, pb_tensors: List[inference_pb2.Tensor]) -> Sample:
return Sample({tensor.tensorId: pb_tensor_to_xarray(tensor) for tensor in pb_tensors})

@classmethod
def from_raw_data(cls, tensor_ids: List[str], tensors_data: List[xr.DataArray]):
def from_xr_tensors(cls, tensor_ids: List[str], tensors_data: List[xr.DataArray]):
assert len(tensor_ids) == len(tensors_data)
return Sample({tensor_id: tensor_data for tensor_id, tensor_data in zip(tensor_ids, tensors_data)})

Expand Down
2 changes: 1 addition & 1 deletion tiktorch/server/grpc/inference_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_p
tensor_validator.check_tensors(input_sample)
res = session.client.api.forward(input_sample)
output_tensor_ids = [tensor.name for tensor in session.client.output_specs]
output_sample = Sample.from_raw_data(output_tensor_ids, res)
output_sample = Sample.from_xr_tensors(output_tensor_ids, res)
return inference_pb2.PredictResponse(tensors=output_sample.to_pb_tensors())

def _getModelSession(self, context, modelSessionId: str) -> ISession:
Expand Down

0 comments on commit fae2741

Please sign in to comment.