Skip to content

Commit

Permalink
Fix: MI100 Support By Bypassing Custom Paged Attention (vllm-project#…
Browse files Browse the repository at this point in the history
…9560)

Signed-off-by: qishuai <[email protected]>
  • Loading branch information
MErkinSag authored and FerdinandZhong committed Oct 29, 2024
1 parent 56711d5 commit 5a4894e
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
logger = init_logger(__name__)

_PARTITION_SIZE_ROCM = 512
_ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
_ON_NAVI = "gfx1" in _GPU_ARCH
_ON_MI250_MI300 = any(arch in _GPU_ARCH
for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"])


class ROCmFlashAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -662,7 +665,8 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int) -> bool:
# rocm custom page attention not support on navi (gfx1*)
return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)
return (_ON_MI250_MI300 and not _ON_NAVI
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)

0 comments on commit 5a4894e

Please sign in to comment.