Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port hierarchical multivariate forecasting model from Pyro to NumPyro #2006

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

juanitorduz
Copy link
Contributor

@juanitorduz juanitorduz commented Mar 11, 2025

I ported https://pyro.ai/examples/forecasting_iii.html from Pyro to NumPyro in https://juanitorduz.github.io/numpyro_hierarchical_forecasting_2/ and I suggest adding them to the official docs.

TODO:

  • Improve text and story
  • Code clean up
  • Add path to notebook and image in the docs
  • Port data loader to NumPyro

Note that I need to use my own branch to load the data because of pyro-ppl/pyro#3425

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@juanitorduz juanitorduz marked this pull request as draft March 11, 2025 22:37
@juanitorduz juanitorduz marked this pull request as ready for review March 12, 2025 12:01
@juanitorduz
Copy link
Contributor Author

This one is ready for review :)

Copy link

review-notebook-app bot commented Mar 12, 2025

View / edit / reply to this conversation on ReviewNB

fehiepsi commented on 2025-03-12T14:01:06Z
----------------------------------------------------------------

Line #4.    from pyro.contrib.examples.bart import load_bart_od

Could you port the loader to https://github.com/pyro-ppl/numpyro/blob/master/numpyro/examples/datasets.py ?


juanitorduz commented on 2025-03-12T18:11:02Z
----------------------------------------------------------------

I gave it a shot in b4d73cb

Copy link

review-notebook-app bot commented Mar 12, 2025

View / edit / reply to this conversation on ReviewNB

fehiepsi commented on 2025-03-12T14:01:07Z
----------------------------------------------------------------

Line #63.        _, pred_levels = scan(

It seems that we don't need to use numpyro scan here. jax.lax.scan is enough.


juanitorduz commented on 2025-03-12T18:11:57Z
----------------------------------------------------------------

Wow! I did not know that! Changed in e01b54a :)

(is there any difference in performance?)

fehiepsi commented on 2025-03-13T23:40:36Z
----------------------------------------------------------------

numpyro's scan works for transition function which involves sample statements - it just does some extra work but i don't think there is any difference in performance here

dst.write(src.read())

if os.path.exists(pkl_file):
dataset_dict = torch.load(pkl_file, weights_only=False)
Copy link
Contributor Author

@juanitorduz juanitorduz Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to load this pkl file with torch and I think it would be fine to be able to get the data from https://github.com/pyro-ppl/datasets/ . I think that parsing the source files

https://github.com/pyro-ppl/pyro/blob/50af09d284311d33421c0757468a1e53abf19408/pyro/contrib/examples/bart.py#L24-#L34

is unnecessary ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think that we can save the file as a npz file? If so, we can upload the data to pyro datasets and load from there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried and it works np.savez_compressed("bart.npz", **dataset). The problem is that the file is 120MB and the limit for GitHub is 100MB. When I tried to push I got

> git push origin add_bart_npz
Enumerating objects: 4, done.
Counting objects: 100% (4/4), done.
Delta compression using up to 12 threads
Compressing objects: 100% (3/3), done.
Writing objects: 100% (3/3), 114.63 MiB | 5.69 MiB/s, done.
Total 3 (delta 1), reused 0 (delta 0), pack-reused 0
remote: Resolving deltas: 100% (1/1), completed with 1 local object.
remote: error: Trace: ac60d558e2a8e83193ce7b70a11ad5f4c7c18f53c73be387c139e67c741474c3
remote: error: See https://gh.io/lfs for more information.
remote: error: File bart.npz is 115.35 MB; this exceeds GitHub's file size limit of 100.00 MB
remote: error: GH001: Large files detected. You may want to try Git Large File Storage - https://git-lfs.github.com.
To https://github.com/juanitorduz/datasets.git
 ! [remote rejected] add_bart_npz -> add_bart_npz (pre-receive hook declined)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can save in lower precision (int8, int16, float16 etc.) I did the same for covertype dataset.

Copy link
Contributor Author

@juanitorduz juanitorduz Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With int8 I was able to bring it down to 78.7MB and the limit via command line is 50MB. How did you push the original bart_full.pkl.bz2 file? I see the commit pyro-ppl/datasets@5550dbe . For me this 50MB limit seems impossible to get rid of ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an alternative, I could put make it publicly available via Dropbox or S3 🤗

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or instead of bart_full, we store bart_2011.npz, bart_2012.npz,... (in int16)

Copy link
Contributor Author

juanitorduz commented Mar 12, 2025

I gave it a shot in b4d73cb


View entire conversation on ReviewNB

Copy link
Contributor Author

Wow! I did not know that! Changed in e01b54a :)

(is there any difference in performance?)


View entire conversation on ReviewNB

Copy link
Member

numpyro's scan works for transition function which involves sample statements - it just does some extra work but i don't think there is any difference in performance here


View entire conversation on ReviewNB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants