Skip to content

cschenjunlin/JAX

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 

Repository files navigation

JAX

JAX is a library for array-oriented numerical computation (à la NumPy), with automatic differentiation and JIT compilation to enable high-performance machine learning research.

Building JAX from source

Building JAX involves two steps:

  • Building or installing jaxlib, the C++ support library for jax.

  • Installing the jax Python package.

1. Obtain the JAX source code

git clone https://github.com/jax-ml/jax
cd jax

2. Building jaxlib (for CUDA)

prerequisites

  • A C++ compiler: Clang 18.
bash -c "$(wget -O - https://apt.llvm.org/llvm.sh)"
wget https://apt.llvm.org/llvm.sh
chmod +x llvm.sh
sudo ./llvm.sh 18  # installs LLVM 18(includes clang-18)
  • Python: for running the build helper script.

building process

​​"Building" refers to the process of compiling source code and packaging it into an executable or installable software product (like a wheel file).​

  • Compilation: Convert source code into intermediate object files (.o).

  • Linking: Combine object files into executables and libraries (.so).

  • Packaging: Bundle executables, libraries, configuration files, and dependencies into a distributable format (.whl).

# JAX uses a Python script (build/build.py) to build automatically.
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt

installation

pip install dist/*.whl

components to build

  • jaxlib: the core C++ support library for jax.
  • jax-cuda-plugin​: a plugin to extend jax's capabilities for CUDA devices.
  • jax-cuda-pjrt: an implementation of Google's Portable Runtime (PjRt) interface specifically designed for CUDA devices.

outputs of the building

  • jaxlib-0.7.0.dev20250628-cp311-cp311-manylinux2014_x86_64.whl
├── jaxlib
│   ├── cpu
│   │   ├── _lapack.so
│   │   └── _sparse.so
│   ├── cpu_feature_guard.so
│   ├── cpu_sparse.py
│   ├── gpu_common_utils.py
│   ├── gpu_linalg.py
│   ├── gpu_prng.py
│   ├── gpu_rnn.py
│   ├── gpu_solver.py
│   ├── gpu_sparse.py
│   ├── gpu_triton.py
│   ├── hlo_helpers.py
│   ├── include
│   │   └── xla
│   │       └── ffi
│   │           └── api
│   │               ├── api.h
│   │               ├── c_api.h
│   │               └── ffi.h
│   ├── __init__.py
│   ├── _jax.so
│   ├── lapack.py
│   ├── libjax_common.so
│   ├── mlir
│   │   ├── dialects
│   │   │   ├── _arith_enum_gen.py
│   │   │   ├── _arith_ops_gen.py
│   │   │   ├── arith.py
│   │   │   ├── _builtin_ops_gen.py
│   │   │   ├── builtin.py
│   │   │   ├── _cf_ops_gen.py
│   │   │   ├── cf.py
│   │   │   ├── _chlo_ops_gen.py
│   │   │   ├── chlo.py
│   │   │   ├── _func_ops_gen.py
│   │   │   ├── func.py
│   │   │   ├── gpu
│   │   │   │   ├── __init__.py
│   │   │   │   └── passes
│   │   │   │       └── __init__.py
│   │   │   ├── _gpu_enum_gen.py
│   │   │   ├── _gpu_ops_gen.py
│   │   │   ├── _llvm_enum_gen.py
│   │   │   ├── _llvm_ops_gen.py
│   │   │   ├── llvm.py
│   │   │   ├── _math_ops_gen.py
│   │   │   ├── math.py
│   │   │   ├── _memref_ops_gen.py
│   │   │   ├── memref.py
│   │   │   ├── _mhlo_ops_gen.py
│   │   │   ├── mhlo.py
│   │   │   ├── _nvgpu_enum_gen.py
│   │   │   ├── _nvgpu_ops_gen.py
│   │   │   ├── nvgpu.py
│   │   │   ├── _nvvm_enum_gen.py
│   │   │   ├── _nvvm_ops_gen.py
│   │   │   ├── nvvm.py
│   │   │   ├── _ods_common.py
│   │   │   ├── _scf_ops_gen.py
│   │   │   ├── scf.py
│   │   │   ├── _sdy_enums_gen.py
│   │   │   ├── _sdy_ops_gen.py
│   │   │   ├── sdy.py
│   │   │   ├── _sparse_tensor_enum_gen.py
│   │   │   ├── _sparse_tensor_ops_gen.py
│   │   │   ├── sparse_tensor.py
│   │   │   ├── _stablehlo_ops_gen.py
│   │   │   ├── stablehlo.py
│   │   │   ├── _vector_enum_gen.py
│   │   │   ├── _vector_ops_gen.py
│   │   │   └── vector.py
│   │   ├── extras
│   │   │   └── meta.py
│   │   ├── ir.py
│   │   ├── ir.pyi
│   │   ├── _mlir_libs
│   │   │   ├── _chlo.so
│   │   │   ├── __init__.py
│   │   │   ├── _mlirDialectsGPU.so
│   │   │   ├── _mlirDialectsLLVM.so
│   │   │   ├── _mlirDialectsNVGPU.so
│   │   │   ├── _mlirDialectsSparseTensor.so
│   │   │   ├── _mlirGPUPasses.so
│   │   │   ├── _mlirHlo.so
│   │   │   ├── _mlir.so
│   │   │   ├── _mlirSparseTensorPasses.so
│   │   │   ├── _mosaic_gpu_ext.so
│   │   │   ├── register_jax_dialects.so
│   │   │   ├── _sdy.so
│   │   │   ├── _stablehlo.so
│   │   │   ├── _tpu_ext.so
│   │   │   ├── _triton_ext.pyi
│   │   │   └── _triton_ext.so
│   │   ├── passmanager.py
│   │   └── passmanager.pyi
│   ├── mosaic
│   │   ├── dialect
│   │   │   └── gpu
│   │   │       ├── _mosaic_gpu_gen_enums.py
│   │   │       └── _mosaic_gpu_gen_ops.py
│   │   └── python
│   │       ├── layout_defs.py
│   │       ├── mosaic_gpu.py
│   │       ├── _tpu_gen.py
│   │       └── tpu.py
│   ├── plugin_support.py
│   ├── _pretty_printer.so
│   ├── _profiler.so
│   ├── py.typed
│   ├── triton
│   │   ├── dialect.py
│   │   ├── __init__.py
│   │   ├── _triton_enum_gen.py
│   │   └── _triton_ops_gen.py
│   ├── utils.so
│   ├── version.py
│   ├── weakref_lru_cache.pyi
│   ├── weakref_lru_cache.so
│   └── xla_client.py
└── jaxlib-0.7.0.dev20250628.dist-info
    ├── LICENSE.txt
    ├── METADATA
    ├── RECORD
    ├── top_level.txt
    └── WHEEL
  • jax_cuda12_plugin-0.7.0.dev20250628-cp311-cp311-manylinux2014_x86_64.whl
├── jax_cuda12_plugin
│   ├── cuda_plugin_extension.so
│   ├── _hybrid.so
│   ├── libmosaic_gpu_runtime.so
│   ├── _linalg.so
│   ├── _mosaic_gpu_ext.so
│   ├── _prng.so
│   ├── _rnn.so
│   ├── _solver.so
│   ├── _sparse.so
│   ├── _triton.so
│   ├── version.py
│   └── _versions.so
└── jax_cuda12_plugin-0.7.0.dev20250628.dist-info
    ├── METADATA
    ├── RECORD
    ├── top_level.txt
    └── WHEEL
  • jax_cuda12_pjrt-0.7.0.dev20250628-py3-none-manylinux2014_x86_64.whl
├── jax_cuda12_pjrt-0.7.0.dev20250628.dist-info
│   ├── entry_points.txt
│   ├── METADATA
│   ├── RECORD
│   ├── top_level.txt
│   └── WHEEL
└── jax_plugins
    └── xla_cuda12
        ├── __init__.py
        ├── version.py
        └── xla_cuda_plugin.so

3. Installing jax

pip install -e .

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published