Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable CK Attention for Navi31 #285

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 79 additions & 19 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

try:
from flash_attn import flash_attn_func # noqa: F401
from flash_attn import flash_attn_varlen_func # noqa: F401
flash_attn_available = True
except ModuleNotFoundError:
flash_attn_func = None
flash_attn_varlen_func = None
flash_attn_available = False

logger = init_logger(__name__)

_PARTITION_SIZE_ROCM = 256
Expand Down Expand Up @@ -509,16 +518,18 @@ def __init__(
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# if not using triton, navi21/navi10 do not use flash-attn
# either
if not current_platform.has_device_capability(90):
self.use_naive_attn = True
else:
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func
if flash_attn_available:
if current_platform.has_device_capability(110):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check isn't equivalent to is_navi
On Cuda device_capability is meant to increase with each new architecture, and is meant to differentiate by new features support, such as FP8, etc.
On ROCm it is 1st digit of gfx * 10 + 2nd digit, which doesn't mean much, especially for any future architectures.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

self.attn_func = _ck_attention
else:
self.attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend")
except ModuleNotFoundError:
else:
self.use_naive_attn = True

if self.use_naive_attn:
Expand Down Expand Up @@ -722,20 +733,39 @@ def forward(
attn_masks,
)
else:
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=query_seq_start_loc,
cu_seqlens_k=key_seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=key_max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
)
if _ON_NAVI:
query = query.view(
(num_prefill_tokens, self.num_heads, -1))
key = key.view(
(num_prefill_tokens, self.num_kv_heads, -1))
value = value.view(
(num_prefill_tokens, self.num_kv_heads, -1))
out = self.attn_func(
query,
key,
value,
prefill_meta.seq_lens,
num_prefill_tokens,
self.num_heads,
self.head_size,
self.scale,
attn_masks,
)
else:
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=query_seq_start_loc,
cu_seqlens_k=key_seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=key_max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
)

# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
Expand Down Expand Up @@ -891,6 +921,36 @@ def _sdpa_attention(
return output


def _ck_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
seq_lens: List[int],
num_tokens: int,
num_heads: int,
head_size: int,
scale: float,
attn_masks: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
start = 0
output = torch.empty((num_tokens, num_heads, head_size),
dtype=query.dtype,
device=query.device)

for i, seq_len in enumerate(seq_lens):
end = start + seq_len
sub_out = flash_attn_func(query[start:end, :, :].unsqueeze(0),
key[start:end, :, :].unsqueeze(0),
value[start:end, :, :].unsqueeze(0),
dropout_p=0.0,
softmax_scale=scale,
causal=attn_masks is None)
output[start:end, :, :] = sub_out.squeeze(0)
start = end

return output


def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int) -> bool:
Expand Down
Loading