Skip to content

Commit

Permalink
Add ESM onnx support (#1581)
Browse files Browse the repository at this point in the history
* Add ESM onnx support

* set default opset=12

---------

Co-authored-by: Félix Marty <[email protected]>
  • Loading branch information
xenova and fxmarty authored Dec 13, 2023
1 parent 9294d76 commit be87d2c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Supported architectures:
- Donut-Swin
- Electra
- Encoder Decoder
- ESM
- Falcon
- Flaubert
- GPT-2
Expand Down
14 changes: 14 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,20 @@ class DebertaV2OnnxConfig(DebertaOnnxConfig):
pass


class EsmOnnxConfig(TextEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 12

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
dynamic_axis = {0: "batch_size", 1: "sequence_length"}
return {
"input_ids": dynamic_axis,
"attention_mask": dynamic_axis,
}


class GPT2OnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")
Expand Down
7 changes: 7 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,13 @@ class TasksManager:
"text2text-generation-with-past",
onnx="EncoderDecoderOnnxConfig",
),
"esm": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
"text-classification",
"token-classification",
onnx="EsmOnnxConfig",
),
"falcon": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
1 change: 1 addition & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
],
"mohitsha/tiny-random-testing-bert2gpt2": ["text2text-generation", "text2text-generation-with-past"],
},
"esm": "hf-internal-testing/tiny-random-EsmModel",
"falcon": {
"fxmarty/really-tiny-falcon-testing": [
"feature-extraction",
Expand Down

0 comments on commit be87d2c

Please sign in to comment.