Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove interface model info #215

Merged
merged 8 commits into from
Aug 16, 2024
Prev Previous commit
Next Next commit
Rename method from_raw_data to from_xr_tensors
thodkatz committed Aug 15, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit fae27411ff2e3416be4f9af9953ff2b7978ffb60
4 changes: 2 additions & 2 deletions tests/test_converters.py
Original file line number Diff line number Diff line change
@@ -299,7 +299,7 @@ def test_create_sample_from_raw_data(self):
arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64)
tensor_2 = xr.DataArray(arr_2, dims=["x", "y"])
tensors_ids = ["input1", "input2"]
actual_sample = Sample.from_raw_data(tensors_ids, [tensor_1, tensor_2])
actual_sample = Sample.from_xr_tensors(tensors_ids, [tensor_1, tensor_2])

expected_dict = {tensors_ids[0]: tensor_1, tensors_ids[1]: tensor_2}
expected_sample = Sample(expected_dict)
@@ -311,7 +311,7 @@ def test_sample_to_pb_tensors(self):
arr_2 = np.arange(64 * 64, dtype=np.int64).reshape(64, 64)
tensor_2 = xr.DataArray(arr_2, dims=["x", "y"])
tensors_ids = ["input1", "input2"]
sample = Sample.from_raw_data(tensors_ids, [tensor_1, tensor_2])
sample = Sample.from_xr_tensors(tensors_ids, [tensor_1, tensor_2])

pb_tensor_1 = inference_pb2.Tensor(
dtype="int64",
2 changes: 1 addition & 1 deletion tiktorch/converters.py
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@ def from_pb_tensors(cls, pb_tensors: List[inference_pb2.Tensor]) -> Sample:
return Sample({tensor.tensorId: pb_tensor_to_xarray(tensor) for tensor in pb_tensors})

@classmethod
def from_raw_data(cls, tensor_ids: List[str], tensors_data: List[xr.DataArray]):
def from_xr_tensors(cls, tensor_ids: List[str], tensors_data: List[xr.DataArray]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def from_xr_tensors(cls, tensor_ids: List[str], tensors_data: List[xr.DataArray]):
def from_xr_tensors(cls, tensor_ids: List[str], tensors_data: List[xr.DataArray]) -> Sample:

for completeness

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I have to add pyright in the pre-commit hooks, so we can catch these as well. Thank you for point it out :)

assert len(tensor_ids) == len(tensors_data)
return Sample({tensor_id: tensor_data for tensor_id, tensor_data in zip(tensor_ids, tensors_data)})

2 changes: 1 addition & 1 deletion tiktorch/server/grpc/inference_servicer.py
Original file line number Diff line number Diff line change
@@ -80,7 +80,7 @@ def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_p
tensor_validator.check_tensors(input_sample)
res = session.client.api.forward(input_sample)
output_tensor_ids = [tensor.name for tensor in session.client.output_specs]
output_sample = Sample.from_raw_data(output_tensor_ids, res)
output_sample = Sample.from_xr_tensors(output_tensor_ids, res)
return inference_pb2.PredictResponse(tensors=output_sample.to_pb_tensors())

def _getModelSession(self, context, modelSessionId: str) -> ISession: