JAX is a library for array-oriented numerical computation (à la NumPy), with automatic differentiation and JIT compilation to enable high-performance machine learning research.
Building JAX involves two steps:
-
Building or installing jaxlib, the C++ support library for jax.
-
Installing the jax Python package.
git clone https://github.com/jax-ml/jax
cd jax
- A C++ compiler: Clang 18.
bash -c "$(wget -O - https://apt.llvm.org/llvm.sh)"
wget https://apt.llvm.org/llvm.sh
chmod +x llvm.sh
sudo ./llvm.sh 18 # installs LLVM 18(includes clang-18)
- Python: for running the build helper script.
"Building" refers to the process of compiling source code and packaging it into an executable or installable software product (like a wheel file).
-
Compilation: Convert source code into intermediate object files (.o).
-
Linking: Combine object files into executables and libraries (.so).
-
Packaging: Bundle executables, libraries, configuration files, and dependencies into a distributable format (.whl).
# JAX uses a Python script (build/build.py) to build automatically.
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt
pip install dist/*.whl
jaxlib: the core C++ support library for jax.jax-cuda-plugin: a plugin to extend jax's capabilities for CUDA devices.jax-cuda-pjrt: an implementation of Google's Portable Runtime (PjRt) interface specifically designed for CUDA devices.
jaxlib-0.7.0.dev20250628-cp311-cp311-manylinux2014_x86_64.whl
├── jaxlib
│ ├── cpu
│ │ ├── _lapack.so
│ │ └── _sparse.so
│ ├── cpu_feature_guard.so
│ ├── cpu_sparse.py
│ ├── gpu_common_utils.py
│ ├── gpu_linalg.py
│ ├── gpu_prng.py
│ ├── gpu_rnn.py
│ ├── gpu_solver.py
│ ├── gpu_sparse.py
│ ├── gpu_triton.py
│ ├── hlo_helpers.py
│ ├── include
│ │ └── xla
│ │ └── ffi
│ │ └── api
│ │ ├── api.h
│ │ ├── c_api.h
│ │ └── ffi.h
│ ├── __init__.py
│ ├── _jax.so
│ ├── lapack.py
│ ├── libjax_common.so
│ ├── mlir
│ │ ├── dialects
│ │ │ ├── _arith_enum_gen.py
│ │ │ ├── _arith_ops_gen.py
│ │ │ ├── arith.py
│ │ │ ├── _builtin_ops_gen.py
│ │ │ ├── builtin.py
│ │ │ ├── _cf_ops_gen.py
│ │ │ ├── cf.py
│ │ │ ├── _chlo_ops_gen.py
│ │ │ ├── chlo.py
│ │ │ ├── _func_ops_gen.py
│ │ │ ├── func.py
│ │ │ ├── gpu
│ │ │ │ ├── __init__.py
│ │ │ │ └── passes
│ │ │ │ └── __init__.py
│ │ │ ├── _gpu_enum_gen.py
│ │ │ ├── _gpu_ops_gen.py
│ │ │ ├── _llvm_enum_gen.py
│ │ │ ├── _llvm_ops_gen.py
│ │ │ ├── llvm.py
│ │ │ ├── _math_ops_gen.py
│ │ │ ├── math.py
│ │ │ ├── _memref_ops_gen.py
│ │ │ ├── memref.py
│ │ │ ├── _mhlo_ops_gen.py
│ │ │ ├── mhlo.py
│ │ │ ├── _nvgpu_enum_gen.py
│ │ │ ├── _nvgpu_ops_gen.py
│ │ │ ├── nvgpu.py
│ │ │ ├── _nvvm_enum_gen.py
│ │ │ ├── _nvvm_ops_gen.py
│ │ │ ├── nvvm.py
│ │ │ ├── _ods_common.py
│ │ │ ├── _scf_ops_gen.py
│ │ │ ├── scf.py
│ │ │ ├── _sdy_enums_gen.py
│ │ │ ├── _sdy_ops_gen.py
│ │ │ ├── sdy.py
│ │ │ ├── _sparse_tensor_enum_gen.py
│ │ │ ├── _sparse_tensor_ops_gen.py
│ │ │ ├── sparse_tensor.py
│ │ │ ├── _stablehlo_ops_gen.py
│ │ │ ├── stablehlo.py
│ │ │ ├── _vector_enum_gen.py
│ │ │ ├── _vector_ops_gen.py
│ │ │ └── vector.py
│ │ ├── extras
│ │ │ └── meta.py
│ │ ├── ir.py
│ │ ├── ir.pyi
│ │ ├── _mlir_libs
│ │ │ ├── _chlo.so
│ │ │ ├── __init__.py
│ │ │ ├── _mlirDialectsGPU.so
│ │ │ ├── _mlirDialectsLLVM.so
│ │ │ ├── _mlirDialectsNVGPU.so
│ │ │ ├── _mlirDialectsSparseTensor.so
│ │ │ ├── _mlirGPUPasses.so
│ │ │ ├── _mlirHlo.so
│ │ │ ├── _mlir.so
│ │ │ ├── _mlirSparseTensorPasses.so
│ │ │ ├── _mosaic_gpu_ext.so
│ │ │ ├── register_jax_dialects.so
│ │ │ ├── _sdy.so
│ │ │ ├── _stablehlo.so
│ │ │ ├── _tpu_ext.so
│ │ │ ├── _triton_ext.pyi
│ │ │ └── _triton_ext.so
│ │ ├── passmanager.py
│ │ └── passmanager.pyi
│ ├── mosaic
│ │ ├── dialect
│ │ │ └── gpu
│ │ │ ├── _mosaic_gpu_gen_enums.py
│ │ │ └── _mosaic_gpu_gen_ops.py
│ │ └── python
│ │ ├── layout_defs.py
│ │ ├── mosaic_gpu.py
│ │ ├── _tpu_gen.py
│ │ └── tpu.py
│ ├── plugin_support.py
│ ├── _pretty_printer.so
│ ├── _profiler.so
│ ├── py.typed
│ ├── triton
│ │ ├── dialect.py
│ │ ├── __init__.py
│ │ ├── _triton_enum_gen.py
│ │ └── _triton_ops_gen.py
│ ├── utils.so
│ ├── version.py
│ ├── weakref_lru_cache.pyi
│ ├── weakref_lru_cache.so
│ └── xla_client.py
└── jaxlib-0.7.0.dev20250628.dist-info
├── LICENSE.txt
├── METADATA
├── RECORD
├── top_level.txt
└── WHEEL
jax_cuda12_plugin-0.7.0.dev20250628-cp311-cp311-manylinux2014_x86_64.whl
├── jax_cuda12_plugin
│ ├── cuda_plugin_extension.so
│ ├── _hybrid.so
│ ├── libmosaic_gpu_runtime.so
│ ├── _linalg.so
│ ├── _mosaic_gpu_ext.so
│ ├── _prng.so
│ ├── _rnn.so
│ ├── _solver.so
│ ├── _sparse.so
│ ├── _triton.so
│ ├── version.py
│ └── _versions.so
└── jax_cuda12_plugin-0.7.0.dev20250628.dist-info
├── METADATA
├── RECORD
├── top_level.txt
└── WHEEL
jax_cuda12_pjrt-0.7.0.dev20250628-py3-none-manylinux2014_x86_64.whl
├── jax_cuda12_pjrt-0.7.0.dev20250628.dist-info
│ ├── entry_points.txt
│ ├── METADATA
│ ├── RECORD
│ ├── top_level.txt
│ └── WHEEL
└── jax_plugins
└── xla_cuda12
├── __init__.py
├── version.py
└── xla_cuda_plugin.so
pip install -e .