Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch implementation of TimesFM #129

Open
tangh18 opened this issue Aug 27, 2024 · 10 comments
Open

PyTorch implementation of TimesFM #129

tangh18 opened this issue Aug 27, 2024 · 10 comments

Comments

@tangh18
Copy link

tangh18 commented Aug 27, 2024

I implement a PyTorch version of TimesFM here. It includes the essential components required to operate the model effectively. Hope it helps. :D

@sebastianpinedaar
Copy link

nice! thanks for this! Did you have a chance to compare the performance between the original and the new implementation performance?

@agnikumar
Copy link

This is awesome, @tangh18! 🎉

When running python convert_ckpt.py, I'm seeing
jaxlib.xla_extension.XlaRuntimeError: FAILED PRECONDITION: DNN library initialization failed.

Checking if there are any ideas about how to resolve this? JAX version is 0.4.26, CUDA version is 12.4, and cuDNN version is 8.9.7.29, which should be compatible.
image

@tangh18
Copy link
Author

tangh18 commented Sep 2, 2024

nice! thanks for this! Did you have a chance to compare the performance between the original and the new implementation performance?

Thanks for your attention! I conducted tests on an A100 server, running the process 45 times with each run consisting of a batch size of 32 and a context length of 2880. The processing time with PyTorch was 3.89 seconds, while JAX completed the task in 1.61 seconds. However, this comparison may not be entirely fair. The PyTorch version did not handle paddings, and its implementation was also not optimized.

@tangh18
Copy link
Author

tangh18 commented Sep 2, 2024

This is awesome, @tangh18! 🎉

When running python convert_ckpt.py, I'm seeing jaxlib.xla_extension.XlaRuntimeError: FAILED PRECONDITION: DNN library initialization failed.

Checking if there are any ideas about how to resolve this? JAX version is 0.4.26, CUDA version is 12.4, and cuDNN version is 8.9.7.29, which should be compatible. image

The code in convert_ckpt.py about JAX only involves creating the model and loading the checkpoint. Have you successfully run timesfm in JAX? If so, the issue you're encountering should not occur.

@melopeo
Copy link

melopeo commented Sep 2, 2024

This is awesome, @tangh18! 🎉
When running python convert_ckpt.py, I'm seeing jaxlib.xla_extension.XlaRuntimeError: FAILED PRECONDITION: DNN library initialization failed.
Checking if there are any ideas about how to resolve this? JAX version is 0.4.26, CUDA version is 12.4, and cuDNN version is 8.9.7.29, which should be compatible. image

The code in convert_ckpt.py about JAX only involves creating the model and loading the checkpoint. Have you successfully run timesfm in JAX? If so, the issue you're encountering should not occur.

Thanks for your super cool contribution! I think what @sebastianpinedaar means is performance in terms of accuracy of the model. Have you been able to reproduce the accuracy metrics reported in the original paper?

@sebastianpinedaar
Copy link

nice! thanks for this! Did you have a chance to compare the performance between the original and the new implementation performance?

Thanks for your attention! I conducted tests on an A100 server, running the process 45 times with each run consisting of a batch size of 32 and a context length of 2880. The processing time with PyTorch was 3.89 seconds, while JAX completed the task in 1.61 seconds. However, this comparison may not be entirely fair. The PyTorch version did not handle paddings, and its implementation was also not optimized.

Thanks for the info @tangh18! As @melopeo pointed out, I was rather curious about the MAE and scaled MAE performance comparison between the original jax TimesFM and the pytorch version, at least in a couple of datasets. I mean this as a sanity check. Although there are some difference in the preprocessing, hopefully the difference is not too big.

@TeddyHuang-00
Copy link
Contributor

nice! thanks for this! Did you have a chance to compare the performance between the original and the new implementation performance?

Thanks for your attention! I conducted tests on an A100 server, running the process 45 times with each run consisting of a batch size of 32 and a context length of 2880. The processing time with PyTorch was 3.89 seconds, while JAX completed the task in 1.61 seconds. However, this comparison may not be entirely fair. The PyTorch version did not handle paddings, and its implementation was also not optimized.

Thanks for the info @tangh18! As @melopeo pointed out, I was rather curious about the MAE and scaled MAE performance comparison between the original jax TimesFM and the pytorch version, at least in a couple of datasets. I mean this as a sanity check. Although there are some difference in the preprocessing, hopefully the difference is not too big.

Although using the official PyTorch implementation, I took the same weight conversion processing and was able to reproduce the results like this. Hopefully this helps!

'mse': np.float32(0.4324413),
'smape': np.float32(0.7251805),
'mae': np.float32(0.40476117),
'wape': np.float32(11.708796),
'nrmse': np.float32(19.022907),
'num_elements': 20160,
'abs_sum': np.float32(696.91077),
'dataset': 'etth1',
'freq': 'h',
'pred_len': 96,
'context_len': 512

@tangh18
Copy link
Author

tangh18 commented Sep 10, 2024

Well, that is my misunderstanding. I may try etth1 when available as @TeddyHuang-00 did with the official pytorch implementation. For sanity check, I have checked several times that when there's no padding and use the same normalizing method, the output of jax and my implementation are almost same with the same input. Thanks for point out that. @sebastianpinedaar @melopeo

@guiyang882
Copy link

nice! thanks for this! Did you have a chance to compare the performance between the original and the new implementation performance?

Thanks for your attention! I conducted tests on an A100 server, running the process 45 times with each run consisting of a batch size of 32 and a context length of 2880. The processing time with PyTorch was 3.89 seconds, while JAX completed the task in 1.61 seconds. However, this comparison may not be entirely fair. The PyTorch version did not handle paddings, and its implementation was also not optimized.

Thanks for the info @tangh18! As @melopeo pointed out, I was rather curious about the MAE and scaled MAE performance comparison between the original jax TimesFM and the pytorch version, at least in a couple of datasets. I mean this as a sanity check. Although there are some difference in the preprocessing, hopefully the difference is not too big.

Although using the official PyTorch implementation, I took the same weight conversion processing and was able to reproduce the results like this. Hopefully this helps!

'mse': np.float32(0.4324413),
'smape': np.float32(0.7251805),
'mae': np.float32(0.40476117),
'wape': np.float32(11.708796),
'nrmse': np.float32(19.022907),
'num_elements': 20160,
'abs_sum': np.float32(696.91077),
'dataset': 'etth1',
'freq': 'h',
'pred_len': 96,
'context_len': 512

I haven't learned how to build models with JAX, and I would like to ask how to convert the weights from the JAX checkpoints into a format that can be loaded into a PyTorch model, if I'm using the PyTorch version of the model provided on the official pytorch models.

@TeddyHuang-00
Copy link
Contributor

nice! thanks for this! Did you have a chance to compare the performance between the original and the new implementation performance?

Thanks for your attention! I conducted tests on an A100 server, running the process 45 times with each run consisting of a batch size of 32 and a context length of 2880. The processing time with PyTorch was 3.89 seconds, while JAX completed the task in 1.61 seconds. However, this comparison may not be entirely fair. The PyTorch version did not handle paddings, and its implementation was also not optimized.

Thanks for the info @tangh18! As @melopeo pointed out, I was rather curious about the MAE and scaled MAE performance comparison between the original jax TimesFM and the pytorch version, at least in a couple of datasets. I mean this as a sanity check. Although there are some difference in the preprocessing, hopefully the difference is not too big.

Although using the official PyTorch implementation, I took the same weight conversion processing and was able to reproduce the results like this. Hopefully this helps!

'mse': np.float32(0.4324413),
'smape': np.float32(0.7251805),
'mae': np.float32(0.40476117),
'wape': np.float32(11.708796),
'nrmse': np.float32(19.022907),
'num_elements': 20160,
'abs_sum': np.float32(696.91077),
'dataset': 'etth1',
'freq': 'h',
'pred_len': 96,
'context_len': 512

I haven't learned how to build models with JAX, and I would like to ask how to convert the weights from the JAX checkpoints into a format that can be loaded into a PyTorch model, if I'm using the PyTorch version of the model provided on the official pytorch models.

https://gist.github.com/TeddyHuang-00/fc2238f6f5956a9906c8c206edef2603

You are welcome 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants