-
Notifications
You must be signed in to change notification settings - Fork 332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Full pytorch support #145
Full pytorch support #145
Conversation
"""TimesFM init file.""" | ||
|
||
from .timesfm import TimesFm, freq_map |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering if we should print / log the API change here, just in case people forget to check the new README.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh thats a good idea. What should be log here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about
print("TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, will leave the version upload to you.
"""TimesFM init file.""" | ||
|
||
from .timesfm import TimesFm, freq_map | ||
from timesfm.timesfm_base import freq_map, TimesFmCheckpoint, TimesFmHparams, TimesFmBase | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am thinking what if one has both jax and torch version installed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in that case the JAX version will be initialized as TimesFM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sg. Let's add
print("Loaded Jax TimesFM.")
and
print("Loaded PyTorch TimesFM.")
after the imports respectively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
LGTM. |
No description provided.