Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PatchTST (time-series-forecasting) #1048

Merged
merged 2 commits into from
Nov 25, 2024
Merged

Conversation

xenova
Copy link
Collaborator

@xenova xenova commented Nov 21, 2024

This PR adds support for PatchTST models (time-series-forecasting).

ONNX export support is being added here: huggingface/optimum#2101

Export example:

git clone https://github.com/huggingface/transformers.js.git
cd transformers.js
pip install -q -r scripts/requirements.txt
pip install --upgrade git+https://github.com/huggingface/optimum@add-patchtst-onnx

followed by

python -m scripts.convert --quantize --model_id ibm/test-patchtst

Closes #1047

Example usage:

import { PatchTSTForPrediction, Tensor } from '@huggingface/transformers';

const model_id = "onnx-community/test-patchtst";
const model = await PatchTSTForPrediction.from_pretrained(model_id, { dtype: "fp32" });

// Example input
const dims = [64, 512, 7];
const prod = dims.reduce((a, b) => a * b, 1);
const past_values = new Tensor('float32',
    Float32Array.from({ length: prod }, (_, i) => i / prod),
    dims,
);
const { prediction_outputs } = await model({ past_values });
console.log(prediction_outputs);

outputs match pytorch within 1e-5 atol.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@xenova
Copy link
Collaborator Author

xenova commented Nov 23, 2024

Example usage for PatchTSMixerForPrediction:

import { PatchTSMixerForPrediction, Tensor } from '@huggingface/transformers';

const model_id = "onnx-community/granite-timeseries-patchtsmixer";
const model = await PatchTSMixerForPrediction.from_pretrained(model_id, { dtype: "fp32" });

const dims = [64, 512, 7];
const prod = dims.reduce((a, b) => a * b, 1);
const past_values = new Tensor('float32',
    Float32Array.from({ length: prod }, (_, i) => i / prod),
    dims,
);
const { prediction_outputs } = await model({ past_values });
console.log(prediction_outputs);

@xenova xenova merged commit c9f12e5 into main Nov 25, 2024
4 checks passed
@xenova xenova deleted the add-patchtst branch November 25, 2024 20:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PatchTSTModel, PatchTSTConfig, & Trainer
2 participants