From 527bbdafcd2ff52d93ccdecc3d7a52590b9ff47f Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 21 Nov 2024 22:34:34 +0000 Subject: [PATCH 1/6] Add ONNX export support for `PatchTST` --- optimum/exporters/onnx/base.py | 1 + optimum/exporters/onnx/model_configs.py | 38 +++++++++++++++++++++++++ optimum/exporters/onnx/model_patcher.py | 20 +++++++++++-- optimum/exporters/tasks.py | 5 ++++ optimum/utils/__init__.py | 1 + optimum/utils/normalized_config.py | 5 ++++ 6 files changed, 68 insertions(+), 2 deletions(-) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 7e35691d54b..0bd99a585e0 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -177,6 +177,7 @@ class OnnxConfig(ExportConfig, ABC): "text2text-generation": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}), "text-classification": OrderedDict({"logits": {0: "batch_size"}}), "text-generation": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), + "time-series-forecasting": OrderedDict({"prediction_outputs": {0: "batch_size"}}), "token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "visual-question-answering": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "zero-shot-image-classification": OrderedDict( diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 8984162ee8c..45bbfca2315 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -58,6 +58,7 @@ NormalizedTextAndVisionConfig, NormalizedTextConfig, NormalizedTextConfigWithGQA, + NormalizedTimeSeriesForecastingConfig, NormalizedVisionConfig, check_if_diffusers_greater, check_if_transformers_greater, @@ -2445,3 +2446,40 @@ class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. + + +class PatchTSTDummyInputGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ("past_values",) + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + **kwargs, + ): + self.task = task + self.normalized_config = normalized_config + + self.batch_size = batch_size + self.context_length = normalized_config.context_length + self.num_input_channels = normalized_config.num_input_channels + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + return self.random_float_tensor( + shape=[self.batch_size, self.context_length, self.num_input_channels], + min_value=-1, + max_value=1, + framework=framework, + dtype=float_dtype, + ) + + +class PatchTSTOnnxConfig(OnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTimeSeriesForecastingConfig + DUMMY_INPUT_GENERATOR_CLASSES = (PatchTSTDummyInputGenerator,) + ATOL_FOR_VALIDATION = 1e-4 + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return {"past_values": {0: "batch_size", 1: "sequence_length"}} diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index fdfb0e280f5..d9cc224e3bd 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -113,6 +113,20 @@ class PatchingSpec: op_wrapper: Optional[Callable] = None +# An ONNX-export-compatible version of `tensor.unfold`. Without this, we get: +# torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator Unfold, input size not accessible. +# See https://github.com/pytorch/pytorch/issues/81871 for more information +def onnx_compatible_unfold(self, dimension, size, step): + num_patches = (self.size(dimension) - size) // step + 1 + return torch.stack( + [self[:, i : i + size, :] for i in range(0, num_patches * step, step)], + dim=1, + ).transpose(3, 2) + + +UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)] + + class ModelPatcher: def __init__( self, @@ -122,9 +136,11 @@ def __init__( ): self._model = model - patching_specs = config.PATCHING_SPECS + patching_specs = config.PATCHING_SPECS or [] + patching_specs.extend(UNSUPPORTED_OPS_PATCHING_SPEC) + self._patching_specs = [] - for spec in patching_specs if patching_specs is not None else []: + for spec in patching_specs: final_spec = spec if spec.orig_op is None: final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name)) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index b4bce4696f3..8e94c668c82 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -222,6 +222,7 @@ class TasksManager: "text-generation": "AutoModelForCausalLM", "text2text-generation": "AutoModelForSeq2SeqLM", "text-classification": "AutoModelForSequenceClassification", + "time-series-forecasting": "PatchTSTForPrediction", # TODO: AutoModelForPrediction is not yet supported "token-classification": "AutoModelForTokenClassification", "zero-shot-image-classification": "AutoModelForZeroShotImageClassification", "zero-shot-object-detection": "AutoModelForZeroShotObjectDetection", @@ -911,6 +912,10 @@ class TasksManager: "text-classification", onnx="OPTOnnxConfig", ), + "patchtst": supported_tasks_mapping( + "time-series-forecasting", + onnx="PatchTSTOnnxConfig", + ), "qwen2": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index fb1794af49c..0c60d5f4a34 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -90,5 +90,6 @@ NormalizedTextAndVisionConfig, NormalizedTextConfig, NormalizedTextConfigWithGQA, + NormalizedTimeSeriesForecastingConfig, NormalizedVisionConfig, ) diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 9ceed24c2dd..0e54c310d54 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -77,6 +77,11 @@ def has_attribute(self, attr_name): return True +class NormalizedTimeSeriesForecastingConfig(NormalizedConfig): + NUM_INPUT_CHANNELS = "num_input_channels" + CONTEXT_LENGTH = "context_length" + + class NormalizedTextConfig(NormalizedConfig): VOCAB_SIZE = "vocab_size" HIDDEN_SIZE = "hidden_size" From 767c509c580bf80f64aacda4a40a205503b6d570 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 21 Nov 2024 22:41:15 +0000 Subject: [PATCH 2/6] Add unit test for patchtst --- tests/exporters/exporters_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 31059c403de..9803c5c8d37 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -128,6 +128,7 @@ "opt": "hf-internal-testing/tiny-random-OPTModel", "owlv2": "hf-internal-testing/tiny-random-Owlv2Model", "owlvit": "hf-tiny-model-private/tiny-random-OwlViTModel", + "patchtst": "ibm/test-patchtst", "pegasus": "hf-internal-testing/tiny-random-PegasusModel", "perceiver": { "hf-internal-testing/tiny-random-language_perceiver": ["fill-mask", "text-classification"], @@ -255,6 +256,7 @@ "nystromformer": "hf-internal-testing/tiny-random-NystromformerModel", "owlv2": "google/owlv2-base-patch16", "owlvit": "google/owlvit-base-patch32", + "patchtst": "ibm/test-patchtst", "perceiver": "hf-internal-testing/tiny-random-PerceiverModel", # Not using deepmind/language-perceiver because it takes too much time for testing. # "rembert": "google/rembert", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", From e348d47fecb11b72985c25659d302fc9e0f26793 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 21 Nov 2024 22:42:13 +0000 Subject: [PATCH 3/6] Add listed support for PatchTST --- docs/source/exporters/onnx/overview.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 747e1396fb4..38919cd1da7 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -74,6 +74,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - Musicgen (text-conditional only) - Nystromformer - OWL-ViT +- PatchTST - Pegasus - Perceiver - Phi From eeb31595e200ddd08036c66613a0573eb51da6c2 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 22 Nov 2024 19:47:45 +0000 Subject: [PATCH 4/6] Add ONNX export support for patchtsmixer --- docs/source/exporters/onnx/overview.mdx | 1 + optimum/exporters/onnx/model_configs.py | 4 ++++ optimum/exporters/tasks.py | 11 ++++++++++- tests/exporters/exporters_utils.py | 2 ++ 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 38919cd1da7..204d9c51129 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -75,6 +75,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - Nystromformer - OWL-ViT - PatchTST +- PatchTSMixer - Pegasus - Perceiver - Phi diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 45bbfca2315..61ab251d3b9 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -2483,3 +2483,7 @@ class PatchTSTOnnxConfig(OnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: return {"past_values": {0: "batch_size", 1: "sequence_length"}} + + +class PatchTSMixerOnnxConfig(PatchTSTOnnxConfig): + pass diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 8e94c668c82..54164a6491e 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -222,7 +222,6 @@ class TasksManager: "text-generation": "AutoModelForCausalLM", "text2text-generation": "AutoModelForSeq2SeqLM", "text-classification": "AutoModelForSequenceClassification", - "time-series-forecasting": "PatchTSTForPrediction", # TODO: AutoModelForPrediction is not yet supported "token-classification": "AutoModelForTokenClassification", "zero-shot-image-classification": "AutoModelForZeroShotImageClassification", "zero-shot-object-detection": "AutoModelForZeroShotObjectDetection", @@ -315,6 +314,10 @@ class TasksManager: } _CUSTOM_CLASSES = { + ("pt", "patchtsmixer", "feature-extraction"): ("transformers", "PatchTSMixerModel"), + ("pt", "patchtsmixer", "time-series-forecasting"): ("transformers", "PatchTSMixerForPrediction"), + ("pt", "patchtst", "feature-extraction"): ("transformers", "PatchTSTModel"), + ("pt", "patchtst", "time-series-forecasting"): ("transformers", "PatchTSTForPrediction"), ("pt", "pix2struct", "image-to-text"): ("transformers", "Pix2StructForConditionalGeneration"), ("pt", "pix2struct", "visual-question-answering"): ("transformers", "Pix2StructForConditionalGeneration"), ("pt", "visual-bert", "question-answering"): ("transformers", "VisualBertForQuestionAnswering"), @@ -913,9 +916,15 @@ class TasksManager: onnx="OPTOnnxConfig", ), "patchtst": supported_tasks_mapping( + "feature-extraction", "time-series-forecasting", onnx="PatchTSTOnnxConfig", ), + "patchtsmixer": supported_tasks_mapping( + "feature-extraction", + "time-series-forecasting", + onnx="PatchTSMixerOnnxConfig", + ), "qwen2": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 9803c5c8d37..a58ee505316 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -129,6 +129,7 @@ "owlv2": "hf-internal-testing/tiny-random-Owlv2Model", "owlvit": "hf-tiny-model-private/tiny-random-OwlViTModel", "patchtst": "ibm/test-patchtst", + "patchtsmixer": "ibm/test-patchtsmixer", "pegasus": "hf-internal-testing/tiny-random-PegasusModel", "perceiver": { "hf-internal-testing/tiny-random-language_perceiver": ["fill-mask", "text-classification"], @@ -257,6 +258,7 @@ "owlv2": "google/owlv2-base-patch16", "owlvit": "google/owlvit-base-patch32", "patchtst": "ibm/test-patchtst", + "patchtsmixer": "ibm/test-patchtsmixer", "perceiver": "hf-internal-testing/tiny-random-PerceiverModel", # Not using deepmind/language-perceiver because it takes too much time for testing. # "rembert": "google/rembert", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", From e2828ffbe5c7fd732a6947b8c216b9aefb01d643 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 22 Nov 2024 22:48:20 +0000 Subject: [PATCH 5/6] Add task=feature-extraction --- optimum/exporters/onnx/model_configs.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 61ab251d3b9..4dc72bbb622 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -2484,6 +2484,13 @@ class PatchTSTOnnxConfig(OnnxConfig): def inputs(self) -> Dict[str, Dict[int, str]]: return {"past_values": {0: "batch_size", 1: "sequence_length"}} + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + if self.task == "feature-extraction": + return {"last_hidden_state": {0: "batch_size"}} + else: + return super().outputs + class PatchTSMixerOnnxConfig(PatchTSTOnnxConfig): pass From 43c21d020c1525acd83688d8b6552d212d658e1a Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 5 Dec 2024 18:41:38 +0200 Subject: [PATCH 6/6] Fix ONNX compatible unfold --- optimum/exporters/onnx/model_patcher.py | 36 ++++++++++++++++++++----- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index d9cc224e3bd..46ceddcc2bf 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -116,12 +116,36 @@ class PatchingSpec: # An ONNX-export-compatible version of `tensor.unfold`. Without this, we get: # torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator Unfold, input size not accessible. # See https://github.com/pytorch/pytorch/issues/81871 for more information -def onnx_compatible_unfold(self, dimension, size, step): - num_patches = (self.size(dimension) - size) // step + 1 - return torch.stack( - [self[:, i : i + size, :] for i in range(0, num_patches * step, step)], - dim=1, - ).transpose(3, 2) +def onnx_compatible_unfold(input_tensor, dimension, size, step): + """ + Custom implementation of torch.unfold without using torch.unfold. + + Args: + input_tensor (torch.Tensor): The input tensor. + dimension (int): The dimension to unfold. + size (int): The size of each slice. + step (int): The step size between slices. + + Returns: + torch.Tensor: The unfolded tensor. + """ + # Compute the shape of the unfolded output + input_size = input_tensor.size(dimension) + num_slices = (input_size - size) // step + 1 + + # Permute dimension to the end for easier indexing + input_tensor = input_tensor.transpose(dimension, -1) + + # Extract slices + slices = [] + for i in range(num_slices): + start = i * step + end = start + size + slices.append(input_tensor[..., start:end]) + + # Stack slices and permute dimensions back + result = torch.stack(slices, dim=-2).transpose(dimension, -2) + return result UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)]