Skip to content

Commit

Permalink
re-enable decoder sequence classification (#1679)
Browse files Browse the repository at this point in the history
* re-enable decoder sequence classification

* update tests

* revert to better pad token handling logic

* minor updates

* format
dwyatte authored Feb 8, 2024
1 parent c05ab93 commit 3988bbd
Showing 5 changed files with 22 additions and 29 deletions.
14 changes: 5 additions & 9 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
@@ -410,23 +410,19 @@ def main_export(
**loading_kwargs,
)

needs_pad_token_id = (
task == "text-classification"
and getattr(model.config, "pad_token_id", None)
and getattr(model.config, "is_decoder", False)
)
needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None

if needs_pad_token_id:
if pad_token_id is not None:
model.config.pad_token_id = pad_token_id
else:
try:
tok = AutoTokenizer.from_pretrained(model_name_or_path)
model.config.pad_token_id = tok.pad_token_id
except Exception:
tok = AutoTokenizer.from_pretrained(model_name_or_path)
pad_token_id = getattr(tok, "pad_token_id", None)
if pad_token_id is None:
raise ValueError(
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
)
model.config.pad_token_id = pad_token_id

if "stable-diffusion" in task:
model_type = "stable-diffusion"
10 changes: 0 additions & 10 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
@@ -203,16 +203,6 @@ class GPT2OnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")

@property
def values_override(self) -> Optional[Dict[str, Any]]:
pad_value_override = {}
if not getattr(self._config, "pad_token_id", None):
pad_value_override = {"pad_token_id": 0}
super_values_override = super().values_override
if super_values_override:
return {**super_values_override, **pad_value_override}
return pad_value_override


class GPTJOnnxConfig(GPT2OnnxConfig):
pass
17 changes: 10 additions & 7 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
@@ -603,7 +603,7 @@ class TasksManager:
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
"text-classification",
"token-classification",
onnx="GPT2OnnxConfig",
),
@@ -612,7 +612,7 @@ class TasksManager:
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
"text-classification",
"token-classification",
onnx="GPTBigCodeOnnxConfig",
),
@@ -622,22 +622,23 @@ class TasksManager:
"text-generation",
"text-generation-with-past",
"question-answering",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
"text-classification",
onnx="GPTJOnnxConfig",
),
"gpt-neo": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
"text-classification",
onnx="GPTNeoOnnxConfig",
),
"gpt-neox": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
onnx="GPTNeoXOnnxConfig",
),
"groupvit": supported_tasks_mapping(
@@ -734,7 +735,7 @@ class TasksManager:
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
# "text-classification",
"text-classification",
onnx="MistralOnnxConfig",
),
# TODO: enable once the missing operator is supported.
@@ -782,6 +783,7 @@ class TasksManager:
"mpt": supported_tasks_mapping(
"text-generation",
"text-generation-with-past",
"text-classification",
onnx="MPTOnnxConfig",
),
"mt5": supported_tasks_mapping(
@@ -818,15 +820,15 @@ class TasksManager:
"text-generation",
"text-generation-with-past",
"question-answering",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
"text-classification",
onnx="OPTOnnxConfig",
),
"llama": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
"text-classification",
onnx="LlamaOnnxConfig",
),
"pegasus": supported_tasks_mapping(
@@ -849,6 +851,7 @@ class TasksManager:
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
onnx="PhiOnnxConfig",
),
"pix2struct": supported_tasks_mapping(
7 changes: 7 additions & 0 deletions tests/exporters/onnx/test_exporters_onnx_cli.py
Original file line number Diff line number Diff line change
@@ -185,6 +185,12 @@ def _onnx_export(
no_dynamic_axes: bool = False,
model_kwargs: Optional[Dict] = None,
):
# We need to set this to some value to be able to test the outputs values for batch size > 1.
if task == "text-classification":
pad_token_id = 0
else:
pad_token_id = None

with TemporaryDirectory() as tmpdir:
try:
main_export(
@@ -198,6 +204,7 @@ def _onnx_export(
no_post_process=no_post_process,
_variant=variant,
no_dynamic_axes=no_dynamic_axes,
pad_token_id=pad_token_id,
model_kwargs=model_kwargs,
)
except MinimumVersionError as e:
3 changes: 0 additions & 3 deletions tests/exporters/onnx/test_onnx_config_loss.py
Original file line number Diff line number Diff line change
@@ -123,9 +123,6 @@ def test_onnx_config_with_loss(self):
gc.collect()

def test_onnx_decoder_model_with_config_with_loss(self):
self.skipTest(
"Skipping due to a bug introduced in transformers with https://github.com/huggingface/transformers/pull/24979, argmax on int64 is not supported by ONNX"
)
with tempfile.TemporaryDirectory() as tmp_dir:
# Prepare model and dataset
model_checkpoint = "hf-internal-testing/tiny-random-gpt2"

0 comments on commit 3988bbd

Please sign in to comment.