From 600dc5144a67f75ca1df0d5bf35bea606b178b92 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 5 Dec 2023 13:13:27 +0200 Subject: [PATCH] Add `convnextv2` onnx export (#1560) --- docs/source/exporters/onnx/overview.mdx | 1 + optimum/exporters/onnx/model_configs.py | 4 ++++ optimum/exporters/tasks.py | 5 +++++ optimum/onnxruntime/modeling_ort.py | 2 +- optimum/utils/normalized_config.py | 1 + tests/exporters/exporters_utils.py | 1 + tests/onnxruntime/test_modeling.py | 1 + tests/onnxruntime/utils_onnxruntime_tests.py | 1 + 8 files changed, 15 insertions(+), 1 deletion(-) diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 0ea17b6afec..0a5da755a3b 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -30,6 +30,7 @@ Supported architectures: - CodeGen - ConvBert - ConvNext +- ConvNextV2 - Data2VecAudio - Data2VecText - Data2VecVision diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index f4d50ad58d4..a58f42dca4b 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -704,6 +704,10 @@ class ConvNextOnnxConfig(ViTOnnxConfig): pass +class ConvNextV2OnnxConfig(ViTOnnxConfig): + pass + + class MobileViTOnnxConfig(ViTOnnxConfig): ATOL_FOR_VALIDATION = 1e-4 diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 7545c72d6c6..4d3f9f98d05 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -424,6 +424,11 @@ class TasksManager: "image-classification", onnx="ConvNextOnnxConfig", ), + "convnextv2": supported_tasks_mapping( + "feature-extraction", + "image-classification", + onnx="ConvNextV2OnnxConfig", + ), "cvt": supported_tasks_mapping("feature-extraction", "image-classification", onnx="CvTOnnxConfig"), "data2vec-text": supported_tasks_mapping( "feature-extraction", diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index c618f8daad3..eb9b5404806 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -1534,7 +1534,7 @@ def forward( @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) class ORTModelForImageClassification(ORTModel): """ - ONNX Model for image-classification tasks. This class officially supports beit, convnext, data2vec_vision, deit, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, vit. + ONNX Model for image-classification tasks. This class officially supports beit, convnext, convnextv2, data2vec_vision, deit, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, vit. """ auto_model_class = AutoModelForImageClassification diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 7a0af9a1a48..6fa4adcecfd 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -186,6 +186,7 @@ class NormalizedConfigManager: 'clip', 'convbert', 'convnext', + 'convnextv2', 'data2vec-text', 'data2vec-vision', 'detr', diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 6e43b65e34f..9af7806e7f8 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -53,6 +53,7 @@ "clip": "hf-internal-testing/tiny-random-CLIPModel", "convbert": "hf-internal-testing/tiny-random-ConvBertModel", "convnext": "hf-internal-testing/tiny-random-convnext", + "convnextv2": "hf-internal-testing/tiny-random-ConvNextV2Model", "codegen": "hf-internal-testing/tiny-random-CodeGenModel", "cvt": "hf-internal-testing/tiny-random-CvTModel", "data2vec-text": "hf-internal-testing/tiny-random-Data2VecTextModel", diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 0db60c289db..bb06e42157b 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2699,6 +2699,7 @@ class ORTModelForImageClassificationIntegrationTest(ORTModelTestMixin): SUPPORTED_ARCHITECTURES = [ "beit", "convnext", + "convnextv2", "data2vec_vision", "deit", "levit", diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 654e63c6399..8e579879ea6 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -40,6 +40,7 @@ "clip": "hf-internal-testing/tiny-random-CLIPModel", "convbert": "hf-internal-testing/tiny-random-ConvBertModel", "convnext": "hf-internal-testing/tiny-random-convnext", + "convnextv2": "hf-internal-testing/tiny-random-ConvNextV2Model", "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", "data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel", "data2vec_vision": "hf-internal-testing/tiny-random-Data2VecVisionModel",