Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuXiaoxuanPKU committed Oct 12, 2024
1 parent 250e26a commit 3e23a4c
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions vllm/attention/backends/placeholder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,8 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
# Maximum query length in the batch.
max_query_len: Optional[int]

# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int]
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int]

# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
Expand Down Expand Up @@ -140,7 +137,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
slot_mapping=slot_mapping,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
decode_query_len=0,
max_decode_query_len=0,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
Expand Down Expand Up @@ -172,7 +169,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
slot_mapping=slot_mapping,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
decode_query_len=self.decode_query_len,
max_decode_query_len=self.max_decode_query_len,
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
Expand Down Expand Up @@ -256,9 +253,9 @@ def build(self, seq_lens: List[int], query_lens: List[int],
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
decode_query_len = max(decode_query_lens)
max_decode_query_len = max(decode_query_lens)
else:
decode_query_len = 1
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
Expand Down Expand Up @@ -304,7 +301,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
decode_query_len=decode_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
Expand Down

0 comments on commit 3e23a4c

Please sign in to comment.