From be87d2cee205df71cb7273e506658a69f8bc3716 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 13 Dec 2023 15:42:53 +0200 Subject: [PATCH] Add ESM onnx support (#1581) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add ESM onnx support * set default opset=12 --------- Co-authored-by: FĂ©lix Marty <9808326+fxmarty@users.noreply.github.com> --- docs/source/exporters/onnx/overview.mdx | 1 + optimum/exporters/onnx/model_configs.py | 14 ++++++++++++++ optimum/exporters/tasks.py | 7 +++++++ tests/exporters/exporters_utils.py | 1 + 4 files changed, 23 insertions(+) diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 897705745ba..82b30f1e13b 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -42,6 +42,7 @@ Supported architectures: - Donut-Swin - Electra - Encoder Decoder +- ESM - Falcon - Flaubert - GPT-2 diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 04a6759a4fc..db210507d68 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -183,6 +183,20 @@ class DebertaV2OnnxConfig(DebertaOnnxConfig): pass +class EsmOnnxConfig(TextEncoderOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + ATOL_FOR_VALIDATION = 1e-4 + DEFAULT_ONNX_OPSET = 12 + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + dynamic_axis = {0: "batch_size", 1: "sequence_length"} + return { + "input_ids": dynamic_axis, + "attention_mask": dynamic_axis, + } + + class GPT2OnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head") diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index a9386e4e6d9..3241b3a822a 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -525,6 +525,13 @@ class TasksManager: "text2text-generation-with-past", onnx="EncoderDecoderOnnxConfig", ), + "esm": supported_tasks_mapping( + "feature-extraction", + "fill-mask", + "text-classification", + "token-classification", + onnx="EsmOnnxConfig", + ), "falcon": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index cff0cf38127..c738ca5389b 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -74,6 +74,7 @@ ], "mohitsha/tiny-random-testing-bert2gpt2": ["text2text-generation", "text2text-generation-with-past"], }, + "esm": "hf-internal-testing/tiny-random-EsmModel", "falcon": { "fxmarty/really-tiny-falcon-testing": [ "feature-extraction",