Skip to content

Commit

Permalink
Fix fp16 ONNX export test (#1373)
Browse files Browse the repository at this point in the history
fix test
  • Loading branch information
fxmarty authored Sep 8, 2023
1 parent b7a11b9 commit 058c72d
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions tests/exporters/onnx/test_exporters_onnx_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,14 @@ def test_stable_diffusion(self):
@slow
@pytest.mark.run_slow
def test_export_on_fp16(
self, test_name: str, model_type: str, model_name: str, task: str, monolith: bool, no_post_process: bool
self,
test_name: str,
model_type: str,
model_name: str,
task: str,
variant: str,
monolith: bool,
no_post_process: bool,
):
# TODO: refer to https://github.com/pytorch/pytorch/issues/95377
if model_type == "yolos":
Expand All @@ -446,7 +453,7 @@ def test_export_on_fp16(
if model_type == "ibert":
self.skipTest("ibert can not be supported in fp16")

self._onnx_export(model_name, task, monolith, no_post_process, fp16=True)
self._onnx_export(model_name, task, monolith, no_post_process, variant=variant, fp16=True, device="cuda")

@parameterized.expand(
[
Expand Down

0 comments on commit 058c72d

Please sign in to comment.