Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX export support for PatchTST #2101

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Musicgen (text-conditional only)
- Nystromformer
- OWL-ViT
- PatchTST
- PatchTSMixer
- Pegasus
- Perceiver
- Phi
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class OnnxConfig(ExportConfig, ABC):
"text2text-generation": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}),
"text-classification": OrderedDict({"logits": {0: "batch_size"}}),
"text-generation": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"time-series-forecasting": OrderedDict({"prediction_outputs": {0: "batch_size"}}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find "time-series-forecasting" in https://huggingface.co/datasets/huggingface/transformers-metadata/blob/main/pipeline_tags.json, do you know if this specific to PatchTST models ? (-> PatchTSTXxxForPrediction)

"token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"visual-question-answering": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"zero-shot-image-classification": OrderedDict(
Expand Down
49 changes: 49 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedTimeSeriesForecastingConfig,
NormalizedVisionConfig,
check_if_diffusers_greater,
check_if_transformers_greater,
Expand Down Expand Up @@ -2445,3 +2446,51 @@ class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig

DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.


class PatchTSTDummyInputGenerator(DummyInputGenerator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SUPPORTED_INPUT_NAMES = ("past_values",)

def __init__(
self,
task: str,
normalized_config: NormalizedConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
**kwargs,
):
self.task = task
self.normalized_config = normalized_config

self.batch_size = batch_size
self.context_length = normalized_config.context_length
self.num_input_channels = normalized_config.num_input_channels

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
return self.random_float_tensor(
shape=[self.batch_size, self.context_length, self.num_input_channels],
min_value=-1,
max_value=1,
framework=framework,
dtype=float_dtype,
)


class PatchTSTOnnxConfig(OnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTimeSeriesForecastingConfig
DUMMY_INPUT_GENERATOR_CLASSES = (PatchTSTDummyInputGenerator,)
ATOL_FOR_VALIDATION = 1e-4

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"past_values": {0: "batch_size", 1: "sequence_length"}}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self.task == "feature-extraction":
return {"last_hidden_state": {0: "batch_size"}}
else:
return super().outputs


class PatchTSMixerOnnxConfig(PatchTSTOnnxConfig):
pass
44 changes: 42 additions & 2 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,44 @@ class PatchingSpec:
op_wrapper: Optional[Callable] = None


# An ONNX-export-compatible version of `tensor.unfold`. Without this, we get:
# torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator Unfold, input size not accessible.
# See https://github.com/pytorch/pytorch/issues/81871 for more information
def onnx_compatible_unfold(input_tensor, dimension, size, step):
"""
Custom implementation of torch.unfold without using torch.unfold.

Args:
input_tensor (torch.Tensor): The input tensor.
dimension (int): The dimension to unfold.
size (int): The size of each slice.
step (int): The step size between slices.

Returns:
torch.Tensor: The unfolded tensor.
"""
# Compute the shape of the unfolded output
input_size = input_tensor.size(dimension)
num_slices = (input_size - size) // step + 1

# Permute dimension to the end for easier indexing
input_tensor = input_tensor.transpose(dimension, -1)

# Extract slices
slices = []
for i in range(num_slices):
start = i * step
end = start + size
slices.append(input_tensor[..., start:end])

# Stack slices and permute dimensions back
result = torch.stack(slices, dim=-2).transpose(dimension, -2)
return result


UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)]


class ModelPatcher:
def __init__(
self,
Expand All @@ -122,9 +160,11 @@ def __init__(
):
self._model = model

patching_specs = config.PATCHING_SPECS
patching_specs = config.PATCHING_SPECS or []
patching_specs.extend(UNSUPPORTED_OPS_PATCHING_SPEC)

self._patching_specs = []
for spec in patching_specs if patching_specs is not None else []:
for spec in patching_specs:
final_spec = spec
if spec.orig_op is None:
final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name))
Expand Down
14 changes: 14 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ class TasksManager:
}

_CUSTOM_CLASSES = {
("pt", "patchtsmixer", "feature-extraction"): ("transformers", "PatchTSMixerModel"),
("pt", "patchtsmixer", "time-series-forecasting"): ("transformers", "PatchTSMixerForPrediction"),
("pt", "patchtst", "feature-extraction"): ("transformers", "PatchTSTModel"),
("pt", "patchtst", "time-series-forecasting"): ("transformers", "PatchTSTForPrediction"),
Comment on lines +317 to +320
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't AutoModel load the models as expected https://huggingface.co/datasets/huggingface/transformers-metadata/blob/main/pipeline_tags.json#L666-L667 ?

Suggested change
("pt", "patchtsmixer", "feature-extraction"): ("transformers", "PatchTSMixerModel"),
("pt", "patchtsmixer", "time-series-forecasting"): ("transformers", "PatchTSMixerForPrediction"),
("pt", "patchtst", "feature-extraction"): ("transformers", "PatchTSTModel"),
("pt", "patchtst", "time-series-forecasting"): ("transformers", "PatchTSTForPrediction"),
("pt", "patchtsmixer", "time-series-forecasting"): ("transformers", "PatchTSMixerForPrediction"),
("pt", "patchtst", "time-series-forecasting"): ("transformers", "PatchTSTForPrediction"),

("pt", "pix2struct", "image-to-text"): ("transformers", "Pix2StructForConditionalGeneration"),
("pt", "pix2struct", "visual-question-answering"): ("transformers", "Pix2StructForConditionalGeneration"),
("pt", "visual-bert", "question-answering"): ("transformers", "VisualBertForQuestionAnswering"),
Expand Down Expand Up @@ -911,6 +915,16 @@ class TasksManager:
"text-classification",
onnx="OPTOnnxConfig",
),
"patchtst": supported_tasks_mapping(
"feature-extraction",
"time-series-forecasting",
onnx="PatchTSTOnnxConfig",
),
"patchtsmixer": supported_tasks_mapping(
"feature-extraction",
"time-series-forecasting",
onnx="PatchTSMixerOnnxConfig",
),
"qwen2": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,6 @@
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedTimeSeriesForecastingConfig,
NormalizedVisionConfig,
)
5 changes: 5 additions & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def has_attribute(self, attr_name):
return True


class NormalizedTimeSeriesForecastingConfig(NormalizedConfig):
NUM_INPUT_CHANNELS = "num_input_channels"
CONTEXT_LENGTH = "context_length"


class NormalizedTextConfig(NormalizedConfig):
VOCAB_SIZE = "vocab_size"
HIDDEN_SIZE = "hidden_size"
Expand Down
4 changes: 4 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@
"opt": "hf-internal-testing/tiny-random-OPTModel",
"owlv2": "hf-internal-testing/tiny-random-Owlv2Model",
"owlvit": "hf-tiny-model-private/tiny-random-OwlViTModel",
"patchtst": "ibm/test-patchtst",
"patchtsmixer": "ibm/test-patchtsmixer",
"pegasus": "hf-internal-testing/tiny-random-PegasusModel",
"perceiver": {
"hf-internal-testing/tiny-random-language_perceiver": ["fill-mask", "text-classification"],
Expand Down Expand Up @@ -255,6 +257,8 @@
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
"owlv2": "google/owlv2-base-patch16",
"owlvit": "google/owlvit-base-patch32",
"patchtst": "ibm/test-patchtst",
"patchtsmixer": "ibm/test-patchtsmixer",
Comment on lines +260 to +261
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already included in PYTORCH_EXPORT_MODELS_TINY so can be removed

Suggested change
"patchtst": "ibm/test-patchtst",
"patchtsmixer": "ibm/test-patchtsmixer",

"perceiver": "hf-internal-testing/tiny-random-PerceiverModel", # Not using deepmind/language-perceiver because it takes too much time for testing.
# "rembert": "google/rembert",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
Expand Down
Loading