Skip to content

Commit

Permalink
Merge branch 'main' into upstream_merge_24_10_21
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras authored Oct 23, 2024
2 parents 87e3970 + 69d5e1d commit be448fb
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,18 +607,17 @@ def forward(
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens

output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]

# QKV for prefill.
query = query[:num_prefill_tokens]

if key is not None and value is not None:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]

if prefill_meta := attn_metadata.prefill_metadata:
output = torch.empty_like(query)

# Prompt run.
# normal attention and DECODER
if attn_type == AttentionType.DECODER and (
Expand Down Expand Up @@ -735,7 +734,6 @@ def forward(
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
# Whether to use rocm custom paged attention or not
output = torch.empty_like(decode_query)
num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads
Expand Down

0 comments on commit be448fb

Please sign in to comment.