From a00e9c45b406d780b4f490b73db08146f17d913c Mon Sep 17 00:00:00 2001 From: bas Date: Sat, 14 Oct 2023 11:19:17 +0200 Subject: [PATCH] Fix perceiver tests and dummy inputs for ONNX --- optimum/exporters/onnx/model_configs.py | 53 ++++++++++++++------ tests/onnxruntime/test_modeling.py | 21 +++----- tests/onnxruntime/utils_onnxruntime_tests.py | 2 + 3 files changed, 47 insertions(+), 29 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 73308e24a5d..2666d971555 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -998,12 +998,37 @@ class Data2VecAudioOnnxConfig(AudioOnnxConfig): class PerceiverDummyInputGenerator(DummyVisionInputGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedVisionConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], + width: int = DEFAULT_DUMMY_SHAPES["width"], + height: int = DEFAULT_DUMMY_SHAPES["height"], + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + num_channels=num_channels, + width=width, + height=height, + **kwargs, + ) + + from transformers.onnx.utils import get_preprocessor + + preprocessor = get_preprocessor(normalized_config._name_or_path) + if preprocessor is not None and hasattr(preprocessor, "size"): + self.height = preprocessor.size.get("height", self.height) + self.width = preprocessor.size.get("width", self.width) + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): input_ = super().generate( input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype ) - # if input_name == "pixel_values": - # input_ = input_[None, :] return input_ @@ -1038,22 +1063,22 @@ def inputs_name(self): @property def inputs(self) -> Dict[str, Dict[int, str]]: - # TODO: validate that. - dynamic_axis = {0: "batch_size", 1: "sequence_length"} - return { - self.inputs_name: dynamic_axis, - # TODO: should we add the attention_mask? - # This breaks things for image-classification, suspected bug is the DummyInputGenerators not having the - # same num_channels / sequence_length. - # "attention_mask": dynamic_axis, - } + if self.inputs_name in ["input_ids", "inputs"]: + dynamic_axis = {0: "batch_size", 1: "sequence_length"} + return { + "input_ids": dynamic_axis, + "attention_mask": dynamic_axis, + } + else: + dynamic_axis = {0: "batch_size", 1: "sequence_length", 2: "width", 3: "height"} + return { + "pixel_values": dynamic_axis, + } def generate_dummy_inputs(self, framework: str = "pt", **kwargs): self.is_generating_dummy_inputs = True dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) - specialized_inputs_name = self.inputs_name - self.is_generating_dummy_inputs = True - dummy_inputs[self.inputs_name] = dummy_inputs.pop(specialized_inputs_name) + dummy_inputs[self.inputs_name] = dummy_inputs.pop(self.inputs_name) return dummy_inputs diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index e6868c2fa7d..3f87f87fbf2 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1239,7 +1239,7 @@ class ORTModelForMaskedLMIntegrationTest(ORTModelTestMixin): "flaubert", "ibert", "mobilebert", - # "perceiver", + "perceiver_text", "roberta", "roformer", "squeezebert", @@ -1247,10 +1247,7 @@ class ORTModelForMaskedLMIntegrationTest(ORTModelTestMixin): "xlm_roberta", ] - ARCH_MODEL_MAP = { - # TODO: fix non passing test - # "perceiver": "hf-internal-testing/tiny-random-language_perceiver", - } + ARCH_MODEL_MAP = {} # TODO remove FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} ORTMODEL_CLASS = ORTModelForMaskedLM @@ -1401,7 +1398,7 @@ class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin): "mbart", "mobilebert", "nystromformer", - # "perceiver", + "perceiver_text", "roberta", "roformer", "squeezebert", @@ -1409,10 +1406,7 @@ class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin): "xlm_roberta", ] - ARCH_MODEL_MAP = { - # TODO: fix non passing test - # "perceiver": "hf-internal-testing/tiny-random-language_perceiver", - } + ARCH_MODEL_MAP = {} # TODO remove FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} ORTMODEL_CLASS = ORTModelForSequenceClassification @@ -2404,7 +2398,7 @@ class ORTModelForImageClassificationIntegrationTest(ORTModelTestMixin): "mobilenet_v1", "mobilenet_v2", "mobilevit", - # "perceiver", + "perceiver_vision", "poolformer", "resnet", "segformer", @@ -2412,10 +2406,7 @@ class ORTModelForImageClassificationIntegrationTest(ORTModelTestMixin): "vit", ] - ARCH_MODEL_MAP = { - # TODO: fix non passing test - # "perceiver": "hf-internal-testing/tiny-random-vision_perceiver_conv", - } + ARCH_MODEL_MAP = {} # TODO remove FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} ORTMODEL_CLASS = ORTModelForImageClassification diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 09603c5e1e8..93a956cbf8f 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -83,6 +83,8 @@ "mt5": "lewtun/tiny-random-mt5", "nystromformer": "hf-internal-testing/tiny-random-NystromformerModel", "pegasus": "hf-internal-testing/tiny-random-PegasusModel", + "perceiver_text": "hf-internal-testing/tiny-random-language_perceiver", + "perceiver_vision": "hf-internal-testing/tiny-random-vision_perceiver_conv", "pix2struct": "fxmarty/pix2struct-tiny-random", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", "resnet": "hf-internal-testing/tiny-random-resnet",