Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

try to switch from numpy/scipy to jax in some part of basis interp #166

Merged
merged 15 commits into from
Jul 30, 2024

Conversation

PFLeget
Copy link
Collaborator

@PFLeget PFLeget commented Jun 28, 2024

This Pull Request switch some part of the linear algebra and tensor/vector reshaping from numpy/scipy to JAX in BasisInterp.

While the math are unchanged it accelerate by a factor ~2 the whole computation time while using the current state of the arts in PSF modeling using Piff (PixelGrid + BasisInterp). Here is one example run at S3DF on one visit of HSC:

acceleration_jax

And here is a notebook to compare both (classic numpy/scipy vs jax) on a single ccd from HSC.

In order to make pass the test with this branch I had to disable tests in python 3.7 and python 3.8 because I was not able to make JAX work under those version. I found it was ok to disable both as python 3.7 looks to not be supported and python 3.8 will not be supported anymore in October by referring to this.

@PFLeget PFLeget force-pushed the dev/pleget/add_jax_in_BasisInterp branch 2 times, most recently from 2bb898b to 92bac30 Compare July 8, 2024 12:55
@PFLeget PFLeget marked this pull request as ready for review July 8, 2024 13:47
@PFLeget
Copy link
Collaborator Author

PFLeget commented Jul 15, 2024

I just added a change on the solver. I found it was 2% faster to use Cholesky as a solver. Even if it's not a lot, I am taking it. It does not change the result when I tested it on HSC.

@PFLeget
Copy link
Collaborator Author

PFLeget commented Jul 19, 2024

I got a comment from @jmeyers314 yesterday on performance from JAX on CPU vs number of core used. I did some test and long story short JAX will be faster than numpy/scipy if the number of core used is greater than one. Here are some numbers for reference:

# core = 1
# N_stars = 40
# numpy --> 7.35 sec
# jax --> 9.82 sec

# core = 2
# N_stars = 40
# numpy --> 6.72 sec
# jax --> 4.44 sec

# core = 2
# N_stars = 80
# numpy --> 15.56 sec
# jax --> 9.19 sec

# core = 4 
# N_stars = 40
# numpy --> 6.23 sec
# jax --> 3.04 sec 

# core = 4 
# N_stars = 80
# numpy --> 14.49 sec
# jax --> 6.33 sec

Copy link
Owner

@rmjarvis rmjarvis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good Pierre-François. I just have a few comments/suggestions. Mostly about making the jax stuff optional for now. But I wouldn't be averse to making JAX required down the line, after we've had a chance to try it out in various contexts to see how the results compare (both on speed and fidelity).

piff/basis_interp.py Outdated Show resolved Hide resolved
piff/basis_interp.py Outdated Show resolved Hide resolved
piff/basis_interp.py Outdated Show resolved Hide resolved
piff/basis_interp.py Outdated Show resolved Hide resolved
piff/basis_interp.py Outdated Show resolved Hide resolved
piff/basis_interp.py Outdated Show resolved Hide resolved
piff/basis_interp.py Outdated Show resolved Hide resolved
requirements.txt Outdated Show resolved Hide resolved
@PFLeget PFLeget force-pushed the dev/pleget/add_jax_in_BasisInterp branch from 8efa203 to c76e3ac Compare July 30, 2024 01:42
@PFLeget
Copy link
Collaborator Author

PFLeget commented Jul 30, 2024

Ok sounds good. I implemented your comments. Let me know if the change looks good or if it needs some additional changes.

piff/basis_interp.py Outdated Show resolved Hide resolved
@PFLeget
Copy link
Collaborator Author

PFLeget commented Jul 30, 2024

Just added the warning in the log.

@rmjarvis
Copy link
Owner

Thanks PF! Looks good. Now we can try it in a few different use cases and get some real work comparisons. My guess is the speed will depend a fair bit on the kind of hardware being used.

@rmjarvis rmjarvis merged commit bcf1d13 into main Jul 30, 2024
9 checks passed
@rmjarvis rmjarvis deleted the dev/pleget/add_jax_in_BasisInterp branch July 30, 2024 13:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants