Skip to content

Commit

Permalink
Update README with installation steps.
Browse files Browse the repository at this point in the history
  • Loading branch information
nshepperd committed May 4, 2024
1 parent d43cbca commit 565ee42
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 43 deletions.
66 changes: 33 additions & 33 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,36 +109,36 @@ jobs:
name: ${{env.wheel_name}}
path: ./wheelhouse/${{env.wheel_name}}

publish_package:
name: Publish package
needs: [build_wheels]

runs-on: ubuntu-latest
permissions:
id-token: write

steps:
- uses: actions/checkout@v3

- uses: actions/setup-python@v4
with:
python-version: '3.10'

- name: Install dependencies
run: |
pip install setuptools==68.0.0
pip install git+https://github.com/nshepperd/setuptools-cuda-cpp
pip install ninja packaging wheel pybind11
- name: Build core package
run: |
CUDA_HOME=/ python setup.py sdist --dist-dir=dist
- name: Retrieve release distributions
uses: actions/download-artifact@v4
with:
path: dist/
merge-multiple: true

- name: Publish release distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
# publish_package:
# name: Publish package
# needs: [build_wheels]

# runs-on: ubuntu-latest
# permissions:
# id-token: write

# steps:
# - uses: actions/checkout@v3

# - uses: actions/setup-python@v4
# with:
# python-version: '3.10'

# - name: Install dependencies
# run: |
# pip install setuptools==68.0.0
# pip install git+https://github.com/nshepperd/setuptools-cuda-cpp
# pip install ninja packaging wheel pybind11

# - name: Build core package
# run: |
# CUDA_HOME=/ python setup.py sdist --dist-dir=dist

# - name: Retrieve release distributions
# uses: actions/download-artifact@v4
# with:
# path: dist/
# merge-multiple: true

# - name: Publish release distributions to PyPI
# uses: pypa/gh-action-pypi-publish@release/v1
35 changes: 26 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,60 @@ This repository provides a jax binding to <https://github.com/Dao-AILab/flash-at

Please see [Tri Dao's repo](https://github.com/Dao-AILab/flash-attention) for more information about flash attention.

## Usage

FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
Please cite (see below) and credit FlashAttention if you use it.

## Installation and features
## Installation

Requirements:
- CUDA 11.8 and above.
- Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
- JAX >=`0.4.24`. The custom sharding used for ring attention requires some somewhat advanced features.

To install: For now, download the appropriate release from the releases page and install it with pip.
To install: `pip install flash-attn-jax` will get the latest release from pypi. This gives you the cuda 12.3 build. If you want to use the cuda 11.8 build, you can install from the releases page (but according to jax's documentation, 11.8 will stop being supported for newer versions of jax).

### Installing from source

Flash attention takes a long time to compile unless you have a powerful machine. But if you want to compile from source, I use `cibuildwheel` to compile the releases. You could do the same. Something like (for python 3.12):

```sh
git clone https://github.com/nshepperd/flash-attn-jax
cd flash-attn-jax
cibuildwheel --only cp312-manylinux_x86_64 # I think cibuildwheel needs superuser privileges on some systems because of docker reasons?
```

This will create a wheel in the `wheelhouse` directory. You can then install it with `pip install wheelhouse/flash_attn_jax_0.2.0-cp312-cp312-manylinux_x86_64.whl`. Or you could use setup.py to build the wheel and install it. You need cuda toolkit installed in that case.

## Usage

Interface: `src/flash_attn_jax/flash.py`

```py
from flash_attn_jax import flash_mha

# flash_mha : [n, l, h, d] x [n, lk, hk, d] x [n, lk, hk, d] -> [n, l, h, d]
flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1))
```

Accepts q,k,v with shape `[n, l, h, d]`, and returns `[n, l, h, d]`. `softmax_scale` is the
multiplier for the softmax, defaulting to `1/sqrt(d)`. Set window_size
to positive values for sliding window attention.
This supports multi-query and grouped-query attention (when hk != h). The `softmax_scale` is the multiplier for the softmax, defaulting to `1/sqrt(d)`. Set `window_size` to positive values for sliding window attention.

### Now Supports Ring Attention

Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm:
Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm (forward and backward).

```py
os.environ["XLA_FLAGS"] = '--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_collectives=true'
#...
with Mesh(devices, axis_names=('len',)) as mesh:
sharding = NamedSharding(mesh, P(None,'len',None)) # n l d
sharding = NamedSharding(mesh, P(None,'len')) # n l
tokens = jax.device_put(tokens, sharding)
# invoke your jax.jit'd transformer.forward
```

It's not entirely reliable at hiding the communication latency though, depending on the whims of the xla optimizer. I'm waiting https://github.com/google/jax/issues/20864 to be fixed, then I can make it better.

### GPU support

FlashAttention-2 currently supports:
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
[build-system]
requires = ["setuptools", "wheel", "setuptools-cuda-cpp @ git+https://github.com/nshepperd/setuptools-cuda-cpp", "packaging", "pybind11"]
requires = ["setuptools", "wheel", "setuptools-cuda-cpp @ git+https://github.com/nshepperd/setuptools-cuda-cpp", "packaging", "pybind11"]

[tool.cibuildwheel]
manylinux-x86_64-image = "sameli/manylinux_2_28_x86_64_cuda_12.3"

0 comments on commit 565ee42

Please sign in to comment.