-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX-based sampler #517
Comments
Hi, this sure sounds useful. We had plans of implementing batch submission, essentially generalizing the |
Hi Yannik, that's great to hear. It's not urgent, something in a few weeks would be fine. Going to/from It's interesting that they had better luck with pymc3's ABC-SMC, I tried that first and I think yours was faster. Also, re-implementing simulations to be JAX-friendly isn't trivial, but it's a lot more approachable than Theano... |
Sounds good! I will let you know when I get to work on this, hopefully in the next few weeks. If you need an urgent solution, a simple implementation sidestepping the Line 799 in e4fdc78
Yes, what is faster depends on the problem at hand (as well as the implemented algorithms). For most of our problems, Theano/Aesara//Jax are no options, as simulators are dedicated C++/R routines, however it will be good if pyABC also efficiently handles those, as there have by now been a few applications already. |
Feature description
It would be great to have a sampler that's both compatible with JAX-jitted functions and leverages JAX's parallelization tools.
Motivation/Application
I have a slow objective function that is made significantly (orders of magnitude) faster using JAX tools, in particular
jit
andvmap
. However, JAX's multithreading clashes with pyABC's multithreaded samplers, and pickling the jitted function doesn't behave either. Oddly, this isn't an issue with relatively simple versions of my objective function, those can use pyABC's default samplers, but more complex versions only work with pyABC'sSingleCoreSampler
.I'd like to write an extension of the
SingleCoreSampler
that relies solely on JAX for vectorization/parallelization/multithreading. I have some ideas on how to get started but I'd like some pointers. This would work best as a batch-sampling system, where an array of samples are submitted and the model function is mapped across the array usingvmap
orpmap
. This evaluation could itself be jitted as well. My questions are:submit_one
across the batch work?submit_one
that JAX might not like? (namely numpy operations).The text was updated successfully, but these errors were encountered: