Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Dec 9, 2024
1 parent d8caf8c commit 00d0b32
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,10 +1388,11 @@ def test_pipeline_model_is_none(self):
question = "Whats my name?"
context = "My Name is Philipp and I live in Nuremberg."
outputs = pipe(question, context)

# compare model output class
self.assertGreaterEqual(outputs["score"], 0.0)
self.assertIsInstance(outputs["answer"], str)
del pipe
gc.collect()

@parameterized.expand(
grid_parameters(
Expand Down Expand Up @@ -1576,6 +1577,8 @@ def test_pipeline_model_is_none(self):
# compare model output class
self.assertGreaterEqual(outputs[0]["score"], 0.0)
self.assertIsInstance(outputs[0]["token_str"], str)
del pipe
gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_torch_gpu
Expand Down Expand Up @@ -1756,6 +1759,8 @@ def test_pipeline_model_is_none(self):
# compare model output class
self.assertGreaterEqual(outputs[0]["score"], 0.0)
self.assertIsInstance(outputs[0]["label"], str)
del pipe
gc.collect()

@parameterized.expand(
grid_parameters(
Expand Down Expand Up @@ -1955,6 +1960,8 @@ def test_pipeline_model_is_none(self):

# compare model output class
self.assertTrue(all(item["score"] > 0.0 for item in outputs))
del pipe
gc.collect()

@parameterized.expand(
grid_parameters(
Expand Down Expand Up @@ -2117,6 +2124,8 @@ def test_pipeline_model_is_none(self):

# compare model output class
self.assertTrue(all(all(isinstance(item, float) for item in row) for row in outputs[0]))
del pipe
gc.collect()

@parameterized.expand(
grid_parameters(
Expand Down Expand Up @@ -2561,10 +2570,11 @@ def test_pipeline_model_is_none(self):
pipe = pipeline("text-generation")
text = "My Name is Philipp and i live"
outputs = pipe(text)

# compare model output class
self.assertIsInstance(outputs[0]["generated_text"], str)
self.assertTrue(len(outputs[0]["generated_text"]) > len(text))
del pipe
gc.collect()

@parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]}))
@require_torch_gpu
Expand Down Expand Up @@ -2976,6 +2986,8 @@ def test_pipeline_model_is_none(self):
# compare model output class
self.assertGreaterEqual(outputs[0]["score"], 0.0)
self.assertTrue(isinstance(outputs[0]["label"], str))
del pipe
gc.collect()

@parameterized.expand(
grid_parameters(
Expand Down Expand Up @@ -3142,6 +3154,8 @@ def test_pipeline_model_is_none(self):
# compare model output class
self.assertTrue(outputs[0]["mask"] is not None)
self.assertTrue(isinstance(outputs[0]["label"], str))
del pipe
gc.collect()

# TODO: enable TensorrtExecutionProvider test once https://github.com/huggingface/optimum/issues/798 is fixed
@parameterized.expand(
Expand Down Expand Up @@ -3327,6 +3341,8 @@ def test_pipeline_model_is_none(self):
# compare model output class
self.assertGreaterEqual(outputs[0]["score"], 0.0)
self.assertIsInstance(outputs[0]["label"], str)
del pipe
gc.collect()

@parameterized.expand(
grid_parameters(
Expand Down Expand Up @@ -3970,6 +3986,8 @@ def test_pipeline_model_is_none(self):
outputs = pipe(text, min_length=1, max_length=2)
# compare model output class
self.assertIsInstance(outputs[0]["translation_text"], str)
del pipe
gc.collect()

@parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]}))
@require_torch_gpu
Expand Down

0 comments on commit 00d0b32

Please sign in to comment.