Skip to content

Commit

Permalink
Fixed inference tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Feb 16, 2025
1 parent 1c86d5e commit f939695
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@ def test_init_predict(mock_inference_pipeline: StaticModelPipeline) -> None:
"""Test successful initialization of StaticModelPipeline."""
target: list[str] | list[list[str]]
if mock_inference_pipeline.multilabel:
target = [["a", "b"]]
if isinstance(mock_inference_pipeline.classes_[0], str):
target = [["a", "b"]]
else:
target = [[0, 1]] # type: ignore
else:
target = ["b"]
if isinstance(mock_inference_pipeline.classes_[0], str):
target = ["b"]
else:
target = [1] # type: ignore
assert mock_inference_pipeline.predict("dog").tolist() == target
assert mock_inference_pipeline.predict(["dog"]).tolist() == target

Expand All @@ -32,9 +38,15 @@ def test_roundtrip_save(mock_inference_pipeline: StaticModelPipeline) -> None:
loaded = StaticModelPipeline.from_pretrained(temp_dir)
target: list[str] | list[list[str]]
if mock_inference_pipeline.multilabel:
target = [["a", "b"]]
if isinstance(mock_inference_pipeline.classes_[0], str):
target = [["a", "b"]]
else:
target = [[0, 1]] # type: ignore
else:
target = ["b"]
if isinstance(mock_inference_pipeline.classes_[0], str):
target = ["b"]
else:
target = [1] # type: ignore
assert loaded.predict("dog").tolist() == target
assert loaded.predict(["dog"]).tolist() == target
assert loaded.predict_proba("dog").argmax() == 1
Expand Down

0 comments on commit f939695

Please sign in to comment.