Skip to content

Commit

Permalink
Enable CK Attention for Navi31
Browse files Browse the repository at this point in the history
- Enables CK Attention for Navi31
- Requires this branch of Flash Attention:
- https://github.com/ROCm/flash-attention/tree/howiejay/navi_support
  • Loading branch information
hyoon1 committed Nov 19, 2024
1 parent 48726bf commit a0c1e1c
Showing 1 changed file with 78 additions and 18 deletions.
96 changes: 78 additions & 18 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

try:
from flash_attn import flash_attn_func # noqa: F401
from flash_attn import flash_attn_varlen_func # noqa: F401
flash_attn_available = True
except ModuleNotFoundError:
flash_attn_func = None
flash_attn_varlen_func = None
flash_attn_available = False

logger = init_logger(__name__)

_PARTITION_SIZE_ROCM = 256
Expand Down Expand Up @@ -498,16 +507,18 @@ def __init__(
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# if not using triton, navi21/navi10 do not use flash-attn
# either
if not current_platform.has_device_capability(90):
self.use_naive_attn = True
else:
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func
if flash_attn_available:
if current_platform.has_device_capability(110):
self.attn_func = _ck_attention
else:
self.attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend")
except ModuleNotFoundError:
else:
self.use_naive_attn = True

if self.use_naive_attn:
Expand Down Expand Up @@ -704,19 +715,38 @@ def forward(
attn_masks,
)
else:
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=query_seq_start_loc,
cu_seqlens_k=key_seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=key_max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
if _ON_NAVI:
query = query.view(
(num_prefill_tokens, self.num_heads, -1))
key = key.view(
(num_prefill_tokens, self.num_kv_heads, -1))
value = value.view(
(num_prefill_tokens, self.num_kv_heads, -1))
out = self.attn_func(
query,
key,
value,
prefill_meta.seq_lens,
num_prefill_tokens,
self.num_heads,
self.head_size,
self.scale,
attn_masks,
)
else:
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=query_seq_start_loc,
cu_seqlens_k=key_seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=key_max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)

# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
Expand Down Expand Up @@ -872,6 +902,36 @@ def _sdpa_attention(
return output


def _ck_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
seq_lens: List[int],
num_tokens: int,
num_heads: int,
head_size: int,
scale: float,
attn_masks: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
start = 0
output = torch.empty((num_tokens, num_heads, head_size),
dtype=query.dtype,
device=query.device)

for i, seq_len in enumerate(seq_lens):
end = start + seq_len
sub_out = flash_attn_func(query[start:end, :, :].unsqueeze(0),
key[start:end, :, :].unsqueeze(0),
value[start:end, :, :].unsqueeze(0),
dropout_p=0.0,
softmax_scale=scale,
causal=attn_masks is None)
output[start:end, :, :] = sub_out.squeeze(0)
start = end

return output


def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int) -> bool:
Expand Down

0 comments on commit a0c1e1c

Please sign in to comment.