diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 591bedfa3a6f1..01aaeb3e08822 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -18,6 +18,15 @@ if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata +try: + from flash_attn import (flash_attn_func, # noqa: F401 + flash_attn_varlen_func) + flash_attn_available = True +except ModuleNotFoundError: + flash_attn_func = None + flash_attn_varlen_func = None + flash_attn_available = False + logger = init_logger(__name__) _PARTITION_SIZE_ROCM = 256 @@ -498,16 +507,18 @@ def __init__( "FA backend instead by setting the env var " "`VLLM_USE_TRITON_FLASH_ATTN=0`") else: - # if not using triton, navi3x/navi21/navi10 do not use flash-attn + # if not using triton, navi21/navi10 do not use flash-attn # either if not current_platform.has_device_capability(90): self.use_naive_attn = True else: - try: - from flash_attn import flash_attn_varlen_func # noqa: F401 - self.attn_func = flash_attn_varlen_func + if flash_attn_available: + if current_platform.has_device_capability(110): + self.attn_func = _ck_attention + else: + self.attn_func = flash_attn_varlen_func logger.debug("Using CK FA in ROCmBackend") - except ModuleNotFoundError: + else: self.use_naive_attn = True if self.use_naive_attn: @@ -704,19 +715,38 @@ def forward( attn_masks, ) else: - out = self.attn_func( - q=query, - k=key, - v=value, - cu_seqlens_q=query_seq_start_loc, - cu_seqlens_k=key_seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=key_max_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) + if _ON_NAVI: + query = query.view( + (num_prefill_tokens, self.num_heads, -1)) + key = key.view( + (num_prefill_tokens, self.num_kv_heads, -1)) + value = value.view( + (num_prefill_tokens, self.num_kv_heads, -1)) + out = self.attn_func( + query, + key, + value, + prefill_meta.seq_lens, + num_prefill_tokens, + self.num_heads, + self.head_size, + self.scale, + attn_masks, + ) + else: + out = self.attn_func( + q=query, + k=key, + v=value, + cu_seqlens_q=query_seq_start_loc, + cu_seqlens_k=key_seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=key_max_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) # common code for prefill assert output[:num_prefill_tokens].shape == out.shape @@ -872,6 +902,36 @@ def _sdpa_attention( return output +def _ck_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + seq_lens: List[int], + num_tokens: int, + num_heads: int, + head_size: int, + scale: float, + attn_masks: Optional[List[torch.Tensor]] = None, +) -> torch.Tensor: + start = 0 + output = torch.empty((num_tokens, num_heads, head_size), + dtype=query.dtype, + device=query.device) + + for i, seq_len in enumerate(seq_lens): + end = start + seq_len + sub_out = flash_attn_func(query[start:end, :, :].unsqueeze(0), + key[start:end, :, :].unsqueeze(0), + value[start:end, :, :].unsqueeze(0), + dropout_p=0.0, + softmax_scale=scale, + causal=attn_masks is None) + output[start:end, :, :] = sub_out.squeeze(0) + start = end + + return output + + def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, block_size: int, gqa_ratio: int, max_seq_len: int) -> bool: