From f22655c036e4e61a7b09748e7aa7e146a16ae64d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Mlyn=C3=A1=C5=99?= <47664722+mlynatom@users.noreply.github.com> Date: Mon, 2 Dec 2024 14:54:08 +0100 Subject: [PATCH] Add RemBERT ONNX support (#2108) * ONNX config for RemBERT added * added RemBERT to TasksManager * rembert added to exporters_utils * RemBERT added to test modelling tasks * changed rembert model * added RemBERT to test utils * Added RemBERT to documentation * Apply suggestions from code review --------- Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> --- docs/source/exporters/onnx/overview.mdx | 1 + optimum/exporters/onnx/model_configs.py | 4 ++++ optimum/exporters/tasks.py | 9 +++++++++ tests/exporters/exporters_utils.py | 3 ++- tests/onnxruntime/test_modeling.py | 5 +++++ tests/onnxruntime/utils_onnxruntime_tests.py | 1 + 6 files changed, 22 insertions(+), 1 deletion(-) diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 2eaada7dadd..57005b85678 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -83,6 +83,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - PoolFormer - Qwen2(Qwen1.5) - RegNet +- RemBERT - ResNet - Roberta - Roformer diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index bca7cf24acf..b39d19ec782 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -162,6 +162,10 @@ class SplinterOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 11 +class RemBertOnnxConfig(BertOnnxConfig): + DEFAULT_ONNX_OPSET = 11 + + class DistilBertOnnxConfig(BertOnnxConfig): DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for transformers>=4.46.0 diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index c50fa5cdfa4..0a3758e97cf 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -431,6 +431,15 @@ class TasksManager: onnx="BertOnnxConfig", tflite="BertTFLiteConfig", ), + "rembert": supported_tasks_mapping( + "fill-mask", + "feature-extraction", + "text-classification", + "multiple-choice", + "token-classification", + "question-answering", + onnx="RemBertOnnxConfig", + ), # For big-bird and bigbird-pegasus being unsupported, refer to model_configs.py # "big-bird": supported_tasks_mapping( # "feature-extraction", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index c56132c384c..32156d9eebf 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -138,6 +138,7 @@ "phi3": "Xenova/tiny-random-Phi3ForCausalLM", "pix2struct": "fxmarty/pix2struct-tiny-random", # "rembert": "google/rembert", + "rembert": "hf-internal-testing/tiny-random-RemBertModel", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", "qwen2": "fxmarty/tiny-dummy-qwen2", "regnet": "hf-internal-testing/tiny-random-RegNetModel", @@ -257,7 +258,7 @@ "owlv2": "google/owlv2-base-patch16", "owlvit": "google/owlvit-base-patch32", "perceiver": "hf-internal-testing/tiny-random-PerceiverModel", # Not using deepmind/language-perceiver because it takes too much time for testing. - # "rembert": "google/rembert", + "rembert": "google/rembert", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", "regnet": "facebook/regnet-y-040", "resnet": "microsoft/resnet-50", diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index c4340dcd8b6..8f52ef45180 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1312,6 +1312,7 @@ class ORTModelForQuestionAnsweringIntegrationTest(ORTModelTestMixin): "squeezebert", "xlm_qa", "xlm_roberta", + "rembert", ] FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} @@ -1502,6 +1503,7 @@ class ORTModelForMaskedLMIntegrationTest(ORTModelTestMixin): "squeezebert", "xlm", "xlm_roberta", + "rembert", ] FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} @@ -1682,6 +1684,7 @@ class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin): "squeezebert", "xlm", "xlm_roberta", + "rembert", ] FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} @@ -1882,6 +1885,7 @@ class ORTModelForTokenClassificationIntegrationTest(ORTModelTestMixin): "squeezebert", "xlm", "xlm_roberta", + "rembert", ] FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} @@ -2227,6 +2231,7 @@ class ORTModelForMultipleChoiceIntegrationTest(ORTModelTestMixin): "squeezebert", "xlm", "xlm_roberta", + "rembert", ] FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index ba8f6cc4abc..cccecd53817 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -135,6 +135,7 @@ "pix2struct": "fxmarty/pix2struct-tiny-random", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", "qwen2": "fxmarty/tiny-dummy-qwen2", + "rembert": "hf-internal-testing/tiny-random-RemBertModel", "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-RobertaModel", "roformer": "hf-internal-testing/tiny-random-RoFormerModel",