-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add JAX Support to NPBench and Implement JAX Benchmarks #31
base: main
Are you sure you want to change the base?
Conversation
Define the JaxFramework class for implementing block_until_ready() calls after kernel computation for correctness of profiling, convert np.ndarray to jax Array in the copy_func()
Define the JaxFramework class for implementing block_until_ready() calls after kernel computation for correctness of profiling, convert np.ndarray to jax Array in the copy_func()
Previously >10x slower, now around same runtime as numpy.
Previously was up to 70x slower, now it's faster for smaller sizes and comparable to numpy for bigger ones.
Previously was up to 90x slower, now it's up to 40x faster than numpy.
Previously was up to 3x slower, now comparable to numpy.
Add JAX as a supported framework
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
@@ -0,0 +1,106 @@ | |||
# Copyright 2021 ETH Zurich and the NPBench authors. All rights reserved. | |||
import pathlib | |||
import jax.numpy as jnp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have to put this in a try/except block, otherwise JAX becomes a mandatory dependency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One change/clarification on my end, otherwise good to go!
npbench/benchmarks/spmv/spmv_jax.py
Outdated
def spmv(A_row, A_col, A_val, x): | ||
dim = A_row.size - 1 # needed because for the "paper" test size, scipy auto-infers the dims wrong | ||
matrix_in_csr_format = scipy.sparse.csr_matrix((A_val, A_col, A_row), shape=(dim, dim)) | ||
matrix_in_bcoo_format = jax_sparse.BCOO.from_scipy_sparse(matrix_in_csr_format) | ||
|
||
return matrix_in_bcoo_format @ x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please clarify this implementation?
- I am not sure the use of scipy or library calls would be allowed here without the
_lib.py
suffix (like you added for other applications) - This implementation changes the sparsity representation, which is BCOO instead of CSR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We observed a very significant speedup (40x compared to numpy) with the current implementation that uses JAX's internal BCOO format and hence used that as the final implementation.
But you are right, having it in the lib
implementation is better for fair comparison. Done.
The BCOO implementation has been added to spmv_jax_lib.py
and spmv_jax.py
has been reverted to a previous version that mimicked the numpy implementation in JAX.
Overview:
This PR introduces JAX as a supported framework in NPBench, providing JAX implementations for all existing benchmarks. This enhancement allows NPBench to compare the performance of JAX against other frameworks, for all implemented benchmarks in NPBench.
Changes Introduced:
framework_info/jax.json
to define framework-specific configuration settings for JAX.npbench/benchmarks/<bench_parent>/<bench_name>_jax.py
.npbench/benchmarks/<bench_parent>/<bench_name>_jax_lib.py
npbench/infrastructure/__init__.py
to include JAX framework initialization.npbench/infrastructure/jax_framework.py
to handle:block_until_ready()
andjnp.array()
).jax_framework.py
related to incorrect naming of the current framework during benchmarking.Motivation:
The addition of JAX support enhances NPBench by:
Testing and Validation:
Contributors: