Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX export support for DinoV2, Hiera, Maskformer, PVT, SigLIP, SwinV2, VitMAE, and VitMSN models #2001

Merged
merged 28 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
62174bd
Add support for siglip models
xenova Dec 4, 2023
3860ac3
cleanup
xenova Dec 4, 2023
3e23538
remove submodule
xenova Dec 4, 2023
be9c707
Add ONNX export for DinoV2 models
xenova Dec 9, 2023
c4d6bc2
Use height and width from preprocessor
xenova Dec 9, 2023
cb8d362
formatting
xenova Dec 9, 2023
94c3329
Remove attention mask from model input
xenova Dec 23, 2023
9db1428
Merge branch 'main' into xenova-add-siglip
xenova Jan 26, 2024
8d4b09e
Add ONNX export support for Hiera models
xenova Aug 29, 2024
b96bb61
Add ONNX export support for SwinV2
xenova Aug 29, 2024
95336c0
Merge remote-tracking branch 'origin/add-dino' into add-hiera-onnx
xenova Aug 29, 2024
d4321b6
Merge remote-tracking branch 'origin/xenova-add-siglip' into add-hier…
xenova Aug 29, 2024
fe140c6
Upgrade Siglip to opset=14
xenova Aug 29, 2024
09ae91a
Add VQA task
xenova Aug 30, 2024
96afc91
Add ONNX export support for Maskformer
xenova Aug 30, 2024
844aa66
Add ONNX export support for PVT
xenova Aug 30, 2024
de07c7a
Add ONNX export support for ViTMAE and ViTMSN
xenova Aug 30, 2024
398d07a
Add siglip unit tests
xenova Nov 14, 2024
86706d1
Add vit-mae unit tests
xenova Nov 14, 2024
2fa69b4
Merge branch 'main' into add-hiera-onnx
xenova Nov 14, 2024
8ad2e3a
Code formatting
xenova Nov 14, 2024
55a19cb
Add maskformer to list of supported models
xenova Nov 14, 2024
fd15bd3
Formatting
xenova Nov 14, 2024
a47cd96
merge main in branch
echarlaix Dec 18, 2024
7f0cb92
fix typo
echarlaix Dec 18, 2024
7a2e94a
remove vit-mae masked-im task
echarlaix Dec 18, 2024
01929b2
remove vit-msn masked-im task
echarlaix Dec 18, 2024
3fa346c
fix output names for maskformer export
echarlaix Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Decision Transformer
- Deit
- Detr
- DINOv2
- DistilBert
- Donut-Swin
- Electra
Expand All @@ -53,6 +54,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- GPT-NeoX
- OPT
- GroupVit
- Hiera
- Hubert
- IBert
- LayoutLM
Expand All @@ -64,6 +66,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- M2-M100
- Marian
- MarkupLM
- MaskFormer
- MBart
- MGP-STR
- Mistral
Expand All @@ -84,6 +87,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Phi3
- Pix2Struct
- PoolFormer
- PVT
- Qwen2(Qwen1.5)
- RegNet
- RemBERT
Expand All @@ -95,17 +99,21 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- SEW
- SEW-D
- Speech2Text
- SigLIP
- SpeechT5
- Splinter
- SqueezeBert
- Swin
- SwinV2
- T5
- Table Transformer
- TROCR
- UniSpeech
- UniSpeech SAT
- Vision Encoder Decoder
- Vit
- VitMAE
- VitMSN
- Wav2Vec2
- Wav2Vec2 Conformer
- WavLM
Expand Down
118 changes: 118 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,65 @@ class ConvNextV2OnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class HieraOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class PvtOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class VitMAEOnnxConfig(ViTOnnxConfig):
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14


class VitMSNOnnxConfig(ViTOnnxConfig):
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14


class Dinov2DummyInputGenerator(DummyVisionInputGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = DEFAULT_DUMMY_SHAPES["width"],
height: int = DEFAULT_DUMMY_SHAPES["height"],
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
num_channels=num_channels,
width=width,
height=height,
**kwargs,
)

from transformers.onnx.utils import get_preprocessor

preprocessor = get_preprocessor(normalized_config._name_or_path)
if preprocessor is not None and hasattr(preprocessor, "crop_size"):
self.height = preprocessor.crop_size.get("height", self.height)
self.width = preprocessor.crop_size.get("width", self.width)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
input_ = super().generate(
input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype
)
return input_


class Dinov2OnnxConfig(ViTOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (Dinov2DummyInputGenerator,)


class MobileViTOnnxConfig(ViTOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 11
Expand Down Expand Up @@ -888,6 +947,10 @@ class SwinOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class SwinV2OnnxConfig(SwinOnnxConfig):
pass


class Swin2srOnnxConfig(SwinOnnxConfig):
pass

Expand Down Expand Up @@ -923,6 +986,28 @@ class MobileNetV2OnnxConfig(MobileNetV1OnnxConfig):
pass


class MaskFormerOnnxConfig(ViTOnnxConfig):
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::einsum' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 12, try exporting with this version.
DEFAULT_ONNX_OPSET = 12

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self.task == "image-segmentation":
return {
"class_queries_logits": {0: "batch_size", 1: "num_queries"},
"masks_queries_logits": {0: "batch_size", 1: "num_queries", 2: "height", 3: "width"},
}
else:
return super().outputs

@property
def torch_to_onnx_output_map(self) -> Dict[str, str]:
return {
"transformer_decoder_last_hidden_state": "last_hidden_state",
}


class DonutSwinOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11

Expand Down Expand Up @@ -1115,6 +1200,39 @@ def patch_model_for_export(
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class SiglipNormalizedConfig(CLIPNormalizedConfig):
pass


class SiglipOnnxConfig(CLIPOnnxConfig):
NORMALIZED_CONFIG_CLASS = SiglipNormalizedConfig
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 13 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "text_batch_size", 1: "sequence_length"},
"pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"},
# NOTE: No attention_mask
}


class SiglipTextWithProjectionOnnxConfig(CLIPTextWithProjectionOnnxConfig):
pass


class SiglipTextOnnxConfig(CLIPTextOnnxConfig):
pass


class SiglipVisionModelOnnxConfig(CLIPVisionModelOnnxConfig):
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14


class UNetOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
Expand Down
70 changes: 68 additions & 2 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,12 @@ class TasksManager:
"feature-extraction": "AutoModel",
"fill-mask": "AutoModelForMaskedLM",
"image-classification": "AutoModelForImageClassification",
"image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"),
"image-segmentation": (
"AutoModelForImageSegmentation",
"AutoModelForSemanticSegmentation",
"AutoModelForInstanceSegmentation",
"AutoModelForUniversalSegmentation",
),
"image-to-image": "AutoModelForImageToImage",
"image-to-text": ("AutoModelForVision2Seq", "AutoModel"),
"mask-generation": "AutoModel",
Expand All @@ -224,6 +229,7 @@ class TasksManager:
"text2text-generation": "AutoModelForSeq2SeqLM",
"text-classification": "AutoModelForSequenceClassification",
"token-classification": "AutoModelForTokenClassification",
"visual-question-answering": "AutoModelForVisualQuestionAnswering",
"zero-shot-image-classification": "AutoModelForZeroShotImageClassification",
"zero-shot-object-detection": "AutoModelForZeroShotObjectDetection",
}
Expand Down Expand Up @@ -307,6 +313,7 @@ class TasksManager:
"vision2seq-lm": "image-to-text",
"zero-shot-classification": "text-classification",
"image-feature-extraction": "feature-extraction",
"pretraining": "feature-extraction",
# for backward compatibility and testing (where
# model task and model type are still the same)
"stable-diffusion": "text-to-image",
Expand Down Expand Up @@ -601,6 +608,11 @@ class TasksManager:
"image-segmentation",
onnx="DetrOnnxConfig",
),
"dinov2": supported_tasks_mapping(
"feature-extraction",
"image-classification",
onnx="Dinov2OnnxConfig",
),
"distilbert": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
Expand Down Expand Up @@ -732,6 +744,11 @@ class TasksManager:
"feature-extraction",
onnx="GroupViTOnnxConfig",
),
"hiera": supported_tasks_mapping(
"feature-extraction",
"image-classification",
onnx="HieraOnnxConfig",
),
"hubert": supported_tasks_mapping(
"feature-extraction",
"automatic-speech-recognition",
Expand Down Expand Up @@ -813,6 +830,11 @@ class TasksManager:
"question-answering",
onnx="MarkupLMOnnxConfig",
),
"maskformer": supported_tasks_mapping(
"feature-extraction",
"image-segmentation",
onnx="MaskFormerOnnxConfig",
),
"mbart": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down Expand Up @@ -1011,6 +1033,11 @@ class TasksManager:
"image-classification",
onnx="PoolFormerOnnxConfig",
),
"pvt": supported_tasks_mapping(
"feature-extraction",
"image-classification",
onnx="PvtOnnxConfig",
),
"regnet": supported_tasks_mapping(
"feature-extraction",
"image-classification",
Expand Down Expand Up @@ -1070,6 +1097,23 @@ class TasksManager:
"audio-classification",
onnx="SEWDOnnxConfig",
),
"siglip": supported_tasks_mapping(
"feature-extraction",
"zero-shot-image-classification",
onnx="SiglipOnnxConfig",
),
"siglip-text-model": supported_tasks_mapping(
"feature-extraction",
onnx="SiglipTextOnnxConfig",
),
"siglip-text-with-projection": supported_tasks_mapping(
"feature-extraction",
onnx="SiglipTextWithProjectionOnnxConfig",
),
"siglip-vision-model": supported_tasks_mapping(
"feature-extraction",
onnx="SiglipVisionModelOnnxConfig",
),
"speech-to-text": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down Expand Up @@ -1102,6 +1146,12 @@ class TasksManager:
"masked-im",
onnx="SwinOnnxConfig",
),
"swinv2": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"masked-im",
onnx="SwinV2OnnxConfig",
),
"swin2sr": supported_tasks_mapping(
"feature-extraction",
"image-to-image",
Expand Down Expand Up @@ -1148,7 +1198,19 @@ class TasksManager:
onnx="VisionEncoderDecoderOnnxConfig",
),
"vit": supported_tasks_mapping(
"feature-extraction", "image-classification", "masked-im", onnx="ViTOnnxConfig"
"feature-extraction",
"image-classification",
"masked-im",
onnx="ViTOnnxConfig",
),
"vit-mae": supported_tasks_mapping(
"feature-extraction",
onnx="VitMAEOnnxConfig",
),
"vit-msn": supported_tasks_mapping(
"feature-extraction",
"image-classification",
onnx="VitMSNOnnxConfig",
),
"vits": supported_tasks_mapping(
"text-to-audio",
Expand Down Expand Up @@ -1232,6 +1294,10 @@ class TasksManager:
"unet-2d-condition",
"vae-encoder",
"vae-decoder",
"clip-text-model",
"clip-text-with-projection",
"siglip-text-model",
"siglip-text-with-projection",
# redundant model types
"trocr", # same as vision-encoder-decoder
}
Expand Down
4 changes: 2 additions & 2 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,7 +1696,7 @@ def forward(
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForImageClassification(ORTModel):
"""
ONNX Model for image-classification tasks. This class officially supports beit, convnext, convnextv2, data2vec_vision, deit, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, vit.
ONNX Model for image-classification tasks. This class officially supports beit, convnext, convnextv2, data2vec_vision, deit, dinov2, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, swinv2, vit.
"""

auto_model_class = AutoModelForImageClassification
Expand Down Expand Up @@ -1784,7 +1784,7 @@ def forward(
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForSemanticSegmentation(ORTModel):
"""
ONNX Model for semantic-segmentation, with an all-MLP decode head on top e.g. for ADE20k, CityScapes. This class officially supports segformer.
ONNX Model for semantic-segmentation, with an all-MLP decode head on top e.g. for ADE20k, CityScapes. This class officially supports maskformer, segformer.
"""

auto_model_class = AutoModelForSemanticSegmentation
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def check_optimization_supported_model(cls, model_type: str, optimization_config
"clip",
"vit",
"swin",
"swinv2",
]
model_type = model_type.replace("_", "-")
if (model_type not in cls._conf) or (cls._conf[model_type] not in supported_model_types_for_optimization):
Expand Down
4 changes: 4 additions & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,10 @@ class NormalizedConfigManager:
'data2vec-text',
'data2vec-vision',
'detr',
'dinov2',
'flaubert',
'groupvit',
'hiera',
'ibert',
'layoutlm',
'layoutlmv3',
Expand All @@ -216,6 +218,8 @@ class NormalizedConfigManager:
'owlvit',
'perceiver',
'roformer',
'segformer',
'siglip',
'squeezebert',
'table-transformer',
"""
Expand Down
Loading
Loading