From d14216035d71fc677aafa1dd0e5bd547109831ca Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 31 Oct 2024 15:04:00 +0100 Subject: [PATCH] fix whisper test --- .github/workflows/test_onnxruntime.yml | 11 ++-------- tests/onnxruntime/test_modeling.py | 30 +++++++++----------------- 2 files changed, 12 insertions(+), 29 deletions(-) diff --git a/.github/workflows/test_onnxruntime.yml b/.github/workflows/test_onnxruntime.yml index d2cad279ac0..7f7e8aec9f8 100644 --- a/.github/workflows/test_onnxruntime.yml +++ b/.github/workflows/test_onnxruntime.yml @@ -20,7 +20,7 @@ jobs: transformers-version: ["latest"] os: [ubuntu-20.04, windows-2019, macos-15] include: - - transformers-version: "4.36.*" + - transformers-version: "4.41.0" os: ubuntu-20.04 - transformers-version: "4.45.*" os: ubuntu-20.04 @@ -56,11 +56,4 @@ jobs: - name: Test with pytest (in series) working-directory: tests run: | - pytest onnxruntime -m "run_in_series" --durations=0 -vvvv -s - - - name: Test with pytest (in parallel) - env: - HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }} - working-directory: tests - run: | - pytest onnxruntime -m "not run_in_series" --durations=0 -vvvv -s -n auto + pytest onnxruntime -k test_compare_to_transformers_ort diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index f1d9cb9d000..f02d9eca5e8 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2315,18 +2315,8 @@ def test_compare_to_io_binding(self, model_arch): class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): SUPPORTED_ARCHITECTURES = [ - "bloom", - "codegen", - "falcon", - "gpt2", - "gpt_bigcode", - "gpt_neo", - "gpt_neox", - "gptj", - "llama", - "mistral", + "mpt", - "opt", ] if check_if_transformers_greater("4.37"): @@ -2420,7 +2410,7 @@ def test_merge_from_onnx_and_save(self, model_arch): self.assertNotIn(ONNX_WEIGHTS_NAME, folder_contents) @parameterized.expand(grid_parameters({**FULL_GRID, "num_beams": [1, 4]})) - def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, num_beams: int): + def test_compare_to_transformers_ort(self, test_name: str, model_arch: str, use_cache: bool, num_beams: int): use_io_binding = None if use_cache is False: use_io_binding = False @@ -4602,14 +4592,14 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str): ) self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) - self.assertEqual( - outputs_model_with_pkv.shape[1], - self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 1, - ) - self.assertEqual( - outputs_model_without_pkv.shape[1], - self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 1, - ) + + if model_arch == "whisper" and check_if_transformers_greater("4.43"): + gen_length = self.GENERATION_LENGTH + 2 + else: + gen_length = self.GENERATION_LENGTH + 1 + + self.assertEqual(outputs_model_with_pkv.shape[1], gen_length) + self.assertEqual(outputs_model_without_pkv.shape[1], gen_length) self.GENERATION_LENGTH = generation_length if os.environ.get("TEST_LEVEL", 0) == "1":