Skip to content

Commit

Permalink
Update CK FA to 2.5.9post1+ck_tile which includes sliding window support
Browse files Browse the repository at this point in the history
  • Loading branch information
mawong-amd committed Jul 19, 2024
1 parent 678b929 commit 5dcf77c
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
# Triton FA is used by default on ROCm now so this is unnecessary.
ARG BUILD_FA="1"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
ARG FA_BRANCH="ae7928c"
ARG FA_BRANCH="23a2b1c2"

# Whether to build triton on rocm
ARG BUILD_TRITON="1"
Expand Down
6 changes: 3 additions & 3 deletions docs/source/getting_started/amd-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,11 @@ Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTor

Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton <https://github.com/ROCm/triton/blob/triton-mlir/README.md>`_

2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/flash_attention_for_rocm>`_
2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/ck_tile>`_

Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
Install ROCm's flash attention (v2.5.9.post1) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support>`_

.. note::
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)

3. Build vLLM.
Expand All @@ -110,5 +109,6 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl
.. tip::

- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
- Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support.
- To use CK flash-attention or PyTorch naive attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
- The ROCm version of PyTorch, ideally, should match the ROCm driver version.
8 changes: 8 additions & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,12 @@ def __init__(
triton_attention)
self.attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support "
"sliding window attention. If using half "
"precision, please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either
Expand Down Expand Up @@ -434,6 +440,8 @@ def forward(
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)

# common code for prefill
Expand Down
13 changes: 7 additions & 6 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,14 @@

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"Qwen2ForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"MistralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"MixtralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"Qwen2ForCausalLM": _ROCM_SWA_REASON,
"MistralForCausalLM": _ROCM_SWA_REASON,
"MixtralForCausalLM": _ROCM_SWA_REASON,
}


Expand Down

0 comments on commit 5dcf77c

Please sign in to comment.