diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 35c1748ff8..747c5f3e73 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -410,23 +410,19 @@ def main_export( **loading_kwargs, ) - needs_pad_token_id = ( - task == "text-classification" - and getattr(model.config, "pad_token_id", None) - and getattr(model.config, "is_decoder", False) - ) + needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None if needs_pad_token_id: if pad_token_id is not None: model.config.pad_token_id = pad_token_id else: - try: - tok = AutoTokenizer.from_pretrained(model_name_or_path) - model.config.pad_token_id = tok.pad_token_id - except Exception: + tok = AutoTokenizer.from_pretrained(model_name_or_path) + pad_token_id = getattr(tok, "pad_token_id", None) + if pad_token_id is None: raise ValueError( "Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument" ) + model.config.pad_token_id = pad_token_id if "stable-diffusion" in task: model_type = "stable-diffusion" diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 799e325cb0..82df63fdcd 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -203,16 +203,6 @@ class GPT2OnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head") - @property - def values_override(self) -> Optional[Dict[str, Any]]: - pad_value_override = {} - if not getattr(self._config, "pad_token_id", None): - pad_value_override = {"pad_token_id": 0} - super_values_override = super().values_override - if super_values_override: - return {**super_values_override, **pad_value_override} - return pad_value_override - class GPTJOnnxConfig(GPT2OnnxConfig): pass diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 00da74a878..fb4efe6e50 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -603,7 +603,7 @@ class TasksManager: "feature-extraction-with-past", "text-generation", "text-generation-with-past", - # "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308 + "text-classification", "token-classification", onnx="GPT2OnnxConfig", ), @@ -612,7 +612,7 @@ class TasksManager: "feature-extraction-with-past", "text-generation", "text-generation-with-past", - # "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308 + "text-classification", "token-classification", onnx="GPTBigCodeOnnxConfig", ), @@ -622,7 +622,7 @@ class TasksManager: "text-generation", "text-generation-with-past", "question-answering", - # "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308 + "text-classification", onnx="GPTJOnnxConfig", ), "gpt-neo": supported_tasks_mapping( @@ -630,7 +630,7 @@ class TasksManager: "feature-extraction-with-past", "text-generation", "text-generation-with-past", - # "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308 + "text-classification", onnx="GPTNeoOnnxConfig", ), "gpt-neox": supported_tasks_mapping( @@ -638,6 +638,7 @@ class TasksManager: "feature-extraction-with-past", "text-generation", "text-generation-with-past", + "text-classification", onnx="GPTNeoXOnnxConfig", ), "groupvit": supported_tasks_mapping( @@ -734,7 +735,7 @@ class TasksManager: "feature-extraction-with-past", "text-generation", "text-generation-with-past", - # "text-classification", + "text-classification", onnx="MistralOnnxConfig", ), # TODO: enable once the missing operator is supported. @@ -782,6 +783,7 @@ class TasksManager: "mpt": supported_tasks_mapping( "text-generation", "text-generation-with-past", + "text-classification", onnx="MPTOnnxConfig", ), "mt5": supported_tasks_mapping( @@ -818,7 +820,7 @@ class TasksManager: "text-generation", "text-generation-with-past", "question-answering", - # "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308 + "text-classification", onnx="OPTOnnxConfig", ), "llama": supported_tasks_mapping( @@ -826,7 +828,7 @@ class TasksManager: "feature-extraction-with-past", "text-generation", "text-generation-with-past", - # "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308 + "text-classification", onnx="LlamaOnnxConfig", ), "pegasus": supported_tasks_mapping( @@ -849,6 +851,7 @@ class TasksManager: "feature-extraction-with-past", "text-generation", "text-generation-with-past", + "text-classification", onnx="PhiOnnxConfig", ), "pix2struct": supported_tasks_mapping( diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index d6a89afe41..05a946bbd9 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -185,6 +185,12 @@ def _onnx_export( no_dynamic_axes: bool = False, model_kwargs: Optional[Dict] = None, ): + # We need to set this to some value to be able to test the outputs values for batch size > 1. + if task == "text-classification": + pad_token_id = 0 + else: + pad_token_id = None + with TemporaryDirectory() as tmpdir: try: main_export( @@ -198,6 +204,7 @@ def _onnx_export( no_post_process=no_post_process, _variant=variant, no_dynamic_axes=no_dynamic_axes, + pad_token_id=pad_token_id, model_kwargs=model_kwargs, ) except MinimumVersionError as e: diff --git a/tests/exporters/onnx/test_onnx_config_loss.py b/tests/exporters/onnx/test_onnx_config_loss.py index 667f599b88..1eed7d9b61 100644 --- a/tests/exporters/onnx/test_onnx_config_loss.py +++ b/tests/exporters/onnx/test_onnx_config_loss.py @@ -123,9 +123,6 @@ def test_onnx_config_with_loss(self): gc.collect() def test_onnx_decoder_model_with_config_with_loss(self): - self.skipTest( - "Skipping due to a bug introduced in transformers with https://github.com/huggingface/transformers/pull/24979, argmax on int64 is not supported by ONNX" - ) with tempfile.TemporaryDirectory() as tmp_dir: # Prepare model and dataset model_checkpoint = "hf-internal-testing/tiny-random-gpt2"