From 98f47f2a4032f8c395268de80858c64ffcfc60fa Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 28 Nov 2024 09:01:02 -0800 Subject: [PATCH] [V1] Optimize the CPU overheads in FlashAttention custom op (#10733) Signed-off-by: Woosuk Kwon --- vllm/v1/attention/backends/flash_attn.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 5f8535eaa303f..e618edf7d35bf 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -135,6 +135,13 @@ def forward( assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the CPU + # overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + output = torch.empty_like(query) torch.ops.vllm.unified_v1_flash_attention( output, @@ -153,7 +160,7 @@ def forward( self.alibi_slopes, self.logits_soft_cap, ) - return output + return output.view(-1, self.num_heads * self.head_size) def unified_v1_flash_attention( @@ -184,11 +191,6 @@ def unified_v1_flash_attention( attn_metadata: FlashAttentionMetadata = current_metadata num_actual_tokens = attn_metadata.num_actual_tokens - # Reshape the query, key, and value tensors. - query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - # Reshape the input keys and values and store them in the cache. key_cache = kv_cache[0] value_cache = kv_cache[1] @@ -218,8 +220,7 @@ def unified_v1_flash_attention( block_table=attn_metadata.block_table, softcap=logits_soft_cap, ) - attn_output = attn_output.view(num_actual_tokens, -1) - # TODO(woosuk): Optimize this. + # TODO(woosuk): Remove this unnecessary copy. output[:num_actual_tokens].copy_(attn_output)