From b42db7ee6b5fa43e41adcbd501a3bd183b589991 Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Fri, 11 Oct 2024 11:04:50 +0200 Subject: [PATCH] Fix onnx export CLI for transformers >= 4.45 (#2053) * fix onnx export * add test --- optimum/exporters/onnx/convert.py | 3 ++- tests/exporters/onnx/test_exporters_onnx_cli.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index d72fd7eb21a..565183b38fc 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -26,6 +26,7 @@ import numpy as np import onnx +from transformers.generation import GenerationMixin from transformers.modeling_utils import get_parameter_dtype from transformers.utils import is_tf_available, is_torch_available @@ -1127,7 +1128,7 @@ def onnx_export_from_model( if check_if_transformers_greater("4.44.99"): misplaced_generation_parameters = model.config._get_non_default_generation_parameters() - if model.can_generate() and len(misplaced_generation_parameters) > 0: + if isinstance(model, GenerationMixin) and len(misplaced_generation_parameters) > 0: logger.warning( "Moving the following attributes in the config to the generation config: " f"{misplaced_generation_parameters}. You are seeing this warning because you've set " diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index ed611ade04e..8b186e9307b 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -602,6 +602,14 @@ def test_diffusion(self): check=True, ) + def test_sentence_transformers(self): + with TemporaryDirectory() as tmpdirname: + subprocess.run( + f"python3 -m optimum.exporters.onnx --model sentence-transformers-testing/stsb-bert-tiny-onnx --task feature-extraction {tmpdirname}", + shell=True, + check=True, + ) + def test_legacy(self): with TemporaryDirectory() as tmpdirname: subprocess.run(