-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
try to switch from numpy/scipy to jax in some part of basis interp #166
Conversation
2bb898b
to
92bac30
Compare
I just added a change on the solver. I found it was 2% faster to use Cholesky as a solver. Even if it's not a lot, I am taking it. It does not change the result when I tested it on HSC. |
I got a comment from @jmeyers314 yesterday on performance from JAX on CPU vs number of core used. I did some test and long story short JAX will be faster than numpy/scipy if the number of core used is greater than one. Here are some numbers for reference:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good Pierre-François. I just have a few comments/suggestions. Mostly about making the jax stuff optional for now. But I wouldn't be averse to making JAX required down the line, after we've had a chance to try it out in various contexts to see how the results compare (both on speed and fidelity).
8efa203
to
c76e3ac
Compare
Ok sounds good. I implemented your comments. Let me know if the change looks good or if it needs some additional changes. |
Just added the warning in the log. |
Thanks PF! Looks good. Now we can try it in a few different use cases and get some real work comparisons. My guess is the speed will depend a fair bit on the kind of hardware being used. |
This Pull Request switch some part of the linear algebra and tensor/vector reshaping from numpy/scipy to JAX in
BasisInterp
.While the math are unchanged it accelerate by a factor ~2 the whole computation time while using the current state of the arts in PSF modeling using Piff (
PixelGrid
+BasisInterp
). Here is one example run at S3DF on one visit of HSC:And here is a notebook to compare both (classic numpy/scipy vs jax) on a single ccd from HSC.
In order to make pass the test with this branch I had to disable tests in python 3.7 and python 3.8 because I was not able to make JAX work under those version. I found it was ok to disable both as python 3.7 looks to not be supported and python 3.8 will not be supported anymore in October by referring to this.