Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX-based sampler #517

Open
JohnGoertz opened this issue Nov 22, 2021 · 3 comments
Open

JAX-based sampler #517

JohnGoertz opened this issue Nov 22, 2021 · 3 comments
Assignees

Comments

@JohnGoertz
Copy link

Feature description
It would be great to have a sampler that's both compatible with JAX-jitted functions and leverages JAX's parallelization tools.

Motivation/Application
I have a slow objective function that is made significantly (orders of magnitude) faster using JAX tools, in particular jit and vmap. However, JAX's multithreading clashes with pyABC's multithreaded samplers, and pickling the jitted function doesn't behave either. Oddly, this isn't an issue with relatively simple versions of my objective function, those can use pyABC's default samplers, but more complex versions only work with pyABC's SingleCoreSampler.

I'd like to write an extension of the SingleCoreSampler that relies solely on JAX for vectorization/parallelization/multithreading. I have some ideas on how to get started but I'd like some pointers. This would work best as a batch-sampling system, where an array of samples are submitted and the model function is mapped across the array using vmap or pmap. This evaluation could itself be jitted as well. My questions are:

  • How could I get the sampler to create a batch of samples?
  • Would mapping submit_one across the batch work?
  • Do you know if there's anything that happens to the model function when it's assigned to submit_one that JAX might not like? (namely numpy operations).
  • How to return the samples after evaluation?
@yannikschaelte
Copy link
Member

yannikschaelte commented Nov 22, 2021

Hi, this sure sounds useful. We had plans of implementing batch submission, essentially generalizing the simulate_one to simulate_n, which would require some internal rewirings. What is your timeline on this? We would probably have time in the next couple weeks, however not immediately. In principle, simulate_many would be stand-alone and return a List of model simulations Dict[str, Union[np.ndarray, pd.DataFrame]] as well as distance values etc. These would be in List format as the data storage handles single simulations as different entities, without vectorization. Would that be an issue? Conceptually, rewriting storage to provide a single simulation matrix per generation would clearly also be possible, however would require further changes, as pyABC was not designed for batching but rather simulation-heavy single cpu simulators. Definition of batch size n would require some tuning, e.g. based on prediction of acceptance rate. See also the discussion in #351

@JohnGoertz
Copy link
Author

Hi Yannik, that's great to hear. It's not urgent, something in a few weeks would be fine. Going to/from List/Dict and JAX-numpy arrays shouldn't be too much of an issue; the jitted evaluator would have to be inside of that, but that shouldn't introduce too much overhead. I can definitely see how batch size would have to be tuned to each application, but I have at least some intuition for what should work. For instance, if I am requesting a population size of 1000, a batch size of 100 should significantly speed things up without leading to too much excessive computation. Initially at least, this would be sequential sampling from pyABC's perspective making 100 proposals at a time, so you don't have to worry about the distributed case of having 999 accepted samples but then ten workers each submit 100 proposals just to get that last sample.

It's interesting that they had better luck with pymc3's ABC-SMC, I tried that first and I think yours was faster. Also, re-implementing simulations to be JAX-friendly isn't trivial, but it's a lot more approachable than Theano...

@yannikschaelte
Copy link
Member

Sounds good! I will let you know when I get to work on this, hopefully in the next few weeks. If you need an urgent solution, a simple implementation sidestepping the simulate_one calls in

sample = self.sampler.sample_until_n_accepted(
with a dedicated sampler (which implements batch size and result merging) should be straightforward, a sustainable solution however take a little longer. Agreed, it should not be too much work, only simultaneously speeding up the storage format, which at the moment can be the bottleneck for fast simulators, can add a bit of complexity.

Yes, what is faster depends on the problem at hand (as well as the implemented algorithms). For most of our problems, Theano/Aesara//Jax are no options, as simulators are dedicated C++/R routines, however it will be good if pyABC also efficiently handles those, as there have by now been a few applications already.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants