From 35caaf221c2bc82c6851cb760d60eebf7b4e5270 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 19 Sep 2023 16:07:49 +0200 Subject: [PATCH] test --- tests/exporters/onnx/test_onnx_export.py | 49 ++++++++++++------------ 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index 641e347c83a..11e6a53da36 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -569,10 +569,8 @@ def test_custom_export_official_model(self): assert "decoder_attentions.0" in output_names assert "cross_attentions.0" in output_names - @parameterized.expand( - grid_parameters({"fn_get_submodels": (None, fn_get_submodels_custom), "legacy": (True, False)}) - ) - def test_custom_export_trust_remote(self, test_name, fn_get_submodels, legacy): + @parameterized.expand([(None,), (fn_get_submodels_custom,)]) + def test_custom_export_trust_remote(self, fn_get_submodels): model_id = "fxmarty/tiny-mpt-random-remote-code" config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) onnx_config = CustomMPTOnnxConfig( @@ -583,28 +581,29 @@ def test_custom_export_trust_remote(self, test_name, fn_get_submodels, legacy): ) onnx_config_with_past = CustomMPTOnnxConfig(config, task="text-generation", use_past=True) - if legacy: - custom_onnx_configs = { - "decoder_model": onnx_config, - "decoder_with_past_model": onnx_config_with_past, - } - else: - custom_onnx_configs = { - "model": onnx_config_with_past, - } + for legacy in (True, False): + if legacy: + custom_onnx_configs = { + "decoder_model": onnx_config, + "decoder_with_past_model": onnx_config_with_past, + } + else: + custom_onnx_configs = { + "model": onnx_config_with_past, + } - with TemporaryDirectory() as tmpdirname: - main_export( - model_id, - output=tmpdirname, - task="text-generation-with-past", - trust_remote_code=True, - custom_onnx_configs=custom_onnx_configs, - no_post_process=True, - fn_get_submodels=partial(fn_get_submodels, legacy=legacy) if fn_get_submodels else None, - legacy=legacy, - opset=14, - ) + with TemporaryDirectory() as tmpdirname: + main_export( + model_id, + output=tmpdirname, + task="text-generation-with-past", + trust_remote_code=True, + custom_onnx_configs=custom_onnx_configs, + no_post_process=True, + fn_get_submodels=partial(fn_get_submodels, legacy=legacy) if fn_get_submodels else None, + legacy=legacy, + opset=14, + ) def test_custom_export_trust_remote_error(self): model_id = "mohitsha/tiny-ernie-random-remote-code"