-
Notifications
You must be signed in to change notification settings - Fork 315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch implementation of TimesFM #129
Comments
nice! thanks for this! Did you have a chance to compare the performance between the original and the new implementation performance? |
This is awesome, @tangh18! 🎉 When running Checking if there are any ideas about how to resolve this? JAX version is |
Thanks for your attention! I conducted tests on an A100 server, running the process 45 times with each run consisting of a batch size of 32 and a context length of 2880. The processing time with PyTorch was 3.89 seconds, while JAX completed the task in 1.61 seconds. However, this comparison may not be entirely fair. The PyTorch version did not handle paddings, and its implementation was also not optimized. |
The code in |
Thanks for your super cool contribution! I think what @sebastianpinedaar means is performance in terms of accuracy of the model. Have you been able to reproduce the accuracy metrics reported in the original paper? |
Thanks for the info @tangh18! As @melopeo pointed out, I was rather curious about the MAE and scaled MAE performance comparison between the original jax TimesFM and the pytorch version, at least in a couple of datasets. I mean this as a sanity check. Although there are some difference in the preprocessing, hopefully the difference is not too big. |
Although using the official PyTorch implementation, I took the same weight conversion processing and was able to reproduce the results like this. Hopefully this helps! 'mse': np.float32(0.4324413),
'smape': np.float32(0.7251805),
'mae': np.float32(0.40476117),
'wape': np.float32(11.708796),
'nrmse': np.float32(19.022907),
'num_elements': 20160,
'abs_sum': np.float32(696.91077),
'dataset': 'etth1',
'freq': 'h',
'pred_len': 96,
'context_len': 512 |
Well, that is my misunderstanding. I may try etth1 when available as @TeddyHuang-00 did with the official pytorch implementation. For sanity check, I have checked several times that when there's no padding and use the same normalizing method, the output of jax and my implementation are almost same with the same input. Thanks for point out that. @sebastianpinedaar @melopeo |
I haven't learned how to build models with JAX, and I would like to ask how to convert the weights from the JAX checkpoints into a format that can be loaded into a PyTorch model, if I'm using the PyTorch version of the model provided on the official pytorch models. |
https://gist.github.com/TeddyHuang-00/fc2238f6f5956a9906c8c206edef2603 You are welcome 😄 |
I implement a PyTorch version of TimesFM here. It includes the essential components required to operate the model effectively. Hope it helps. :D
The text was updated successfully, but these errors were encountered: