From 844aa66e0694a97a06fb9a25bfa582ea3f2de481 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 30 Aug 2024 16:49:19 +0000 Subject: [PATCH] Add ONNX export support for PVT --- docs/source/exporters/onnx/overview.mdx | 1 + optimum/exporters/onnx/model_configs.py | 4 ++++ optimum/exporters/tasks.py | 5 +++++ tests/exporters/exporters_utils.py | 2 ++ tests/onnxruntime/utils_onnxruntime_tests.py | 1 + 5 files changed, 13 insertions(+) diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 164118ba9c1..cf83fbeaba2 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -83,6 +83,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - Phi3 - Pix2Struct - PoolFormer +- PVT - Qwen2(Qwen1.5) - RegNet - ResNet diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 1c11d1e5547..47bfed12674 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -776,6 +776,10 @@ class HieraOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 11 +class PvtOnnxConfig(ViTOnnxConfig): + DEFAULT_ONNX_OPSET = 11 + + class Dinov2DummyInputGenerator(DummyVisionInputGenerator): def __init__( self, diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 192c2dbfb5a..2231d66de08 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -974,6 +974,11 @@ class TasksManager: "image-classification", onnx="PoolFormerOnnxConfig", ), + "pvt": supported_tasks_mapping( + "feature-extraction", + "image-classification", + onnx="PvtOnnxConfig", + ), "regnet": supported_tasks_mapping( "feature-extraction", "image-classification", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index e96e7567257..5d388715e01 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -140,6 +140,7 @@ "pix2struct": "fxmarty/pix2struct-tiny-random", # "rembert": "google/rembert", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", + "pvt": "hf-internal-testing/tiny-random-PvtForImageClassification", "qwen2": "fxmarty/tiny-dummy-qwen2", "regnet": "hf-internal-testing/tiny-random-RegNetModel", "resnet": "hf-internal-testing/tiny-random-resnet", @@ -264,6 +265,7 @@ "perceiver": "hf-internal-testing/tiny-random-PerceiverModel", # Not using deepmind/language-perceiver because it takes too much time for testing. # "rembert": "google/rembert", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", + "pvt": "hf-internal-testing/tiny-random-PvtForImageClassification", "regnet": "facebook/regnet-y-040", "resnet": "microsoft/resnet-50", "roberta": "roberta-base", diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 0e8d42fbcca..947db0d8cd0 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -133,6 +133,7 @@ "phi3": "Xenova/tiny-random-Phi3ForCausalLM", "pix2struct": "fxmarty/pix2struct-tiny-random", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", + "pvt": "hf-internal-testing/tiny-random-PvtForImageClassification", "qwen2": "fxmarty/tiny-dummy-qwen2", "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-RobertaModel",