Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate within-chain parallelization in numpyro #151

Closed
dylanhmorris opened this issue Jun 4, 2024 · 7 comments
Closed

Investigate within-chain parallelization in numpyro #151

dylanhmorris opened this issue Jun 4, 2024 · 7 comments
Assignees
Labels
pyrenew related to pyrenew internals request New feature or request
Milestone

Comments

@dylanhmorris
Copy link
Collaborator

dylanhmorris commented Jun 4, 2024

Goal

Understand what numpyro support exists for within-chain parallelization.

Details

  • The plate context can be used to mark conditionally independent sampling operations.
  • Are there ways then to ask the compiler to parallelize things within a single MCMC chain, as is possible in Stan?

See also

Related discussion in EpiAware CDCgov/Rt-without-renewal#254

@dylanhmorris dylanhmorris self-assigned this Jun 4, 2024
@gvegayon gvegayon added this to the Backlog milestone Jun 4, 2024
@dylanhmorris dylanhmorris modified the milestones: Backlog, 🦩 Kakuru Jun 4, 2024
@damonbayer damonbayer modified the milestones: 🦩 Kakuru, Backlog Jun 10, 2024
@damonbayer damonbayer added request New feature or request pyrenew related to pyrenew internals and removed development task labels Jul 12, 2024
@dylanhmorris
Copy link
Collaborator Author

@seabbs reports that it looks possible with JAX but not implemented in numpyro's bundled samplers. So if we want/need this future the solution may be to do sampling via blackjax or bayeux (see also #361).

In general, perhaps worth considering whether to separate PPL choice and sampler choice.

@seabbs
Copy link
Collaborator

seabbs commented Oct 1, 2024

*but that he does not think it is possible based on current evidence (I am heavily caveating as I really think it must be I just can't find any evidence).

@seabbs
Copy link
Collaborator

seabbs commented Oct 1, 2024

Okay some actual evidence: jax-ml/jax#1408

I think numpyro is using pmap to run multiple chains but that doesn't mean it couldn't be used within one chain (unless you can't stack them which I would hope you can but probably can't). That thread also shows a use of vmap in blackjax but its not clear it 1. would work in numpyro and 2. can be stacked with pmap.

So I am still leaning to no it can't right now but I think above gives some room for exploration that I didn't have a handle on before.

@seabbs
Copy link
Collaborator

seabbs commented Oct 1, 2024

The numpyro docs for setting the XLA options for this say they don't understand what it will do so that is encouraging.

@seabbs
Copy link
Collaborator

seabbs commented Oct 1, 2024

@damonbayer
Copy link
Collaborator

https://num.pyro.ai/en/latest/utilities.html#numpyro.util.set_host_devices

We already use this to enable between-chain parallelization.

@seabbs
Copy link
Collaborator

seabbs commented Oct 2, 2024

Its the warning that I was drawing attention to. The key question is it hard coded to only allow across chain parallelization or can you stack.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pyrenew related to pyrenew internals request New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants