Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX Support to NPBench and Implement JAX Benchmarks #31

Open
wants to merge 118 commits into
base: main
Choose a base branch
from

Conversation

hardik01shah
Copy link

Overview:

This PR introduces JAX as a supported framework in NPBench, providing JAX implementations for all existing benchmarks. This enhancement allows NPBench to compare the performance of JAX against other frameworks, for all implemented benchmarks in NPBench.


Changes Introduced:

  1. Framework Configuration:
    • Added framework_info/jax.json to define framework-specific configuration settings for JAX.
  2. Benchmark Implementations:
    • Added JAX implementations for all benchmarks in the directory structure:
      • npbench/benchmarks/<bench_parent>/<bench_name>_jax.py.
    • Support for benchmarks with existing JAX library implementations.
      • npbench/benchmarks/<bench_parent>/<bench_name>_jax_lib.py
  3. Infrastructure Updates:
    • Updated npbench/infrastructure/__init__.py to include JAX framework initialization.
    • Added npbench/infrastructure/jax_framework.py to handle:
      • Pre- and post-processing for JAX array arguments (e.g., using block_until_ready() and jnp.array()).
  4. Bug Fix:
    • Resolved a minor bug in jax_framework.py related to incorrect naming of the current framework during benchmarking.

Motivation:

The addition of JAX support enhances NPBench by:

  • Allowing direct comparison of JAX’s performance with other frameworks across diverse scientific Python benchmarks.
  • Extending the utility of NPBench to the JAX community for performance evaluation and optimization.

Testing and Validation:

  • Verified the correctness of JAX implementations by validating the implementations with the Numpy impplementations.
  • Ensured smooth integration of JAX into the benchmarking pipeline.

Contributors:

hardik01shah and others added 30 commits October 29, 2024 10:24
Define the JaxFramework class for implementing block_until_ready() calls after kernel computation for correctness of profiling, convert np.ndarray to jax Array in the copy_func()
Define the JaxFramework class for implementing block_until_ready() calls after kernel computation for correctness of profiling, convert np.ndarray to jax Array in the copy_func()
Copy link
Contributor

@alexnick83 alexnick83 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

@@ -0,0 +1,106 @@
# Copyright 2021 ETH Zurich and the NPBench authors. All rights reserved.
import pathlib
import jax.numpy as jnp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to put this in a try/except block, otherwise JAX becomes a mandatory dependency.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Copy link
Collaborator

@tbennun tbennun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One change/clarification on my end, otherwise good to go!

Comment on lines 7 to 12
def spmv(A_row, A_col, A_val, x):
dim = A_row.size - 1 # needed because for the "paper" test size, scipy auto-infers the dims wrong
matrix_in_csr_format = scipy.sparse.csr_matrix((A_val, A_col, A_row), shape=(dim, dim))
matrix_in_bcoo_format = jax_sparse.BCOO.from_scipy_sparse(matrix_in_csr_format)

return matrix_in_bcoo_format @ x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please clarify this implementation?

  1. I am not sure the use of scipy or library calls would be allowed here without the _lib.py suffix (like you added for other applications)
  2. This implementation changes the sparsity representation, which is BCOO instead of CSR

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We observed a very significant speedup (40x compared to numpy) with the current implementation that uses JAX's internal BCOO format and hence used that as the final implementation.

But you are right, having it in the lib implementation is better for fair comparison. Done.
The BCOO implementation has been added to spmv_jax_lib.py and spmv_jax.py has been reverted to a previous version that mimicked the numpy implementation in JAX.

@hardik01shah hardik01shah requested a review from tbennun February 4, 2025 09:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants