From 2e2747b07ab4c83f967a6fae9f528acc39b318aa Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Fri, 11 Oct 2024 14:40:28 -0700 Subject: [PATCH 1/4] v1 --- tests/spec_decode/test_scorer.py | 52 ++++++++++++++++++------ vllm/attention/backends/flash_attn.py | 55 ++++++++++++++++---------- vllm/spec_decode/mqa_scorer.py | 42 ++++++++++++++++---- vllm/spec_decode/spec_decode_worker.py | 6 --- 4 files changed, 108 insertions(+), 47 deletions(-) diff --git a/tests/spec_decode/test_scorer.py b/tests/spec_decode/test_scorer.py index 5f703b03ab7fe..e579c8b38db91 100644 --- a/tests/spec_decode/test_scorer.py +++ b/tests/spec_decode/test_scorer.py @@ -1,3 +1,6 @@ +import random +from typing import List + import pytest import torch @@ -10,31 +13,45 @@ from .utils import create_batch, create_worker -def create_proposal(batch_size: int, propose_len: int, vocab_size: int, +def create_proposal(propose_lens: List[int], vocab_size: int, device: str) -> SpeculativeProposals: - proposal_probs = torch.rand((batch_size, propose_len, vocab_size), + batch_size = len(propose_lens) + max_propose_len = max(propose_lens) + proposal_probs = torch.rand((batch_size, max_propose_len, vocab_size), device=device) - proposal_token_ids = torch.argmax(proposal_probs, dim=-1) - proposal_lens = torch.tensor([propose_len] * batch_size, device=device) + + proposal_token_ids = torch.full((batch_size, max_propose_len), + fill_value=-1, + device=device) + for i in range(batch_size): + proposal_token_ids[i][:propose_lens[i]] = torch.argmax( + proposal_probs[i][:propose_lens[i]], dim=-1) + + propose_lens = torch.tensor(propose_lens, device=device) return SpeculativeProposals(proposal_token_ids, proposal_probs, - proposal_lens) + propose_lens) def assert_score_equal(score1: SpeculativeScores, score2: SpeculativeScores) -> None: assert torch.allclose(score1.probs, score2.probs) assert torch.allclose(score1.logprobs, score2.logprobs) - assert torch.equal(score1.token_ids, score2.token_ids) + assert torch.equal( + score1.token_ids, + score2.token_ids), f"{score1.token_ids}, {score2.token_ids}" @pytest.mark.parametrize('model_name', ['facebook/opt-125m']) @pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16]) -@pytest.mark.parametrize('propose_len', [1, 3, 5]) +@pytest.mark.parametrize('max_propose_len', [1, 3, 5]) +@pytest.mark.parametrize('mixed_propose_len', [True]) @pytest.mark.parametrize('device', ['cuda']) -def test_scoroer(model_name: str, batch_size: int, propose_len: int, - device: str) -> None: +def test_scorer(model_name: str, batch_size: int, max_propose_len: int, + mixed_propose_len: bool, device: str) -> None: """ - Compare the batch expansion scorer and mqa scorer return the same score + Compare the batch expansion scorer and mqa scorer return the same score. + We test for both queries with the same propose length and different + propose length. """ seed = 0 block_size = 32 @@ -46,13 +63,22 @@ def test_scoroer(model_name: str, batch_size: int, propose_len: int, should_modify_greedy_probs_inplace = True vocab_size = scorer_worker.vocab_size - proposals = create_proposal(batch_size, propose_len, vocab_size, device) + + if not mixed_propose_len: + propose_lens = [max_propose_len] * batch_size + else: + non_zero_cnt = random.randint(0, batch_size) + propose_lens = [max_propose_len + ] * non_zero_cnt + [0] * (batch_size - non_zero_cnt) + random.shuffle(propose_lens) + + proposals = create_proposal(propose_lens, vocab_size, device) seq_group_metadatalist, _, _ = create_batch(batch_size, - propose_len, + max_propose_len, block_size=block_size, num_gpu_blocks=num_gpu_blocks) requests = ExecuteModelRequest(seq_group_metadatalist, - num_lookahead_slots=propose_len) + num_lookahead_slots=max_propose_len) batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device, vocab_size) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index bba80262e52d3..98f65f9844894 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -115,7 +115,7 @@ class FlashAttentionMetadata(AttentionMetadata): # Currently, we require that all requests have the same number of query # tokens during the decoding phase. When speculavie decoding is enabled, # decode_query_len might be greater than 1. In all other cases, it is 1. - decode_query_len: Optional[int] + decode_query_len: int # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. @@ -126,11 +126,11 @@ class FlashAttentionMetadata(AttentionMetadata): # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] + query_start_loc: torch.Tensor # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] + seq_start_loc: torch.Tensor # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] @@ -206,8 +206,8 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: max_query_len=self.max_query_len, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, + query_start_loc=self.query_start_loc[self.num_prefills:], + seq_start_loc=self.seq_start_loc[self.num_prefills:], context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, @@ -714,20 +714,36 @@ def unified_flash_attention( if decode_meta := attn_metadata.decode_metadata: # Decoding run. - _, num_head, head_dim = decode_query.shape - decode_query = decode_query.reshape(-1, decode_meta.decode_query_len, - num_head, head_dim) - decode_output = flash_attn_with_kvcache( - q=decode_query, - k_cache=key_cache, - v_cache=value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=softmax_scale, - causal=True, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - ).squeeze(1) + # Use flash_attn_varlen_func kernel for speculative decoding + # because different queries might have different lengths. + if decode_meta.decode_query_len > 1: + decode_output = flash_attn_varlen_func( + q=decode_query, + k=key_cache, + v=value_cache, + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_q=decode_meta.decode_query_len, + cu_seqlens_k=decode_meta.seq_start_loc, + max_seqlen_k=decode_meta.max_decode_seq_len, + softmax_scale=softmax_scale, + causal=True, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + block_table=decode_meta.block_tables, + ) + else: + # Use flash_attn_with_kvcache for normal decoding. + decode_output = flash_attn_with_kvcache( + q=decode_query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=softmax_scale, + causal=True, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + ).squeeze(1) if prefill_output is None: assert decode_output is not None @@ -739,7 +755,6 @@ def unified_flash_attention( # Chunked prefill does not work with speculative decoding. # Therefore, the query length for decode should be 1 in chunked prefill. assert decode_meta is not None - assert decode_meta.decode_query_len == 1 decode_output = decode_output.squeeze(1) output = torch.cat([prefill_output, decode_output], dim=0) return output.view(num_tokens, hidden_size) diff --git a/vllm/spec_decode/mqa_scorer.py b/vllm/spec_decode/mqa_scorer.py index 59f2a4191a8b2..f35a8a0ab8be3 100644 --- a/vllm/spec_decode/mqa_scorer.py +++ b/vllm/spec_decode/mqa_scorer.py @@ -18,6 +18,7 @@ def score_proposals( target_seq_id_start = max( get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1 all_proposal_tokens = proposals.proposal_token_ids.tolist() + all_proposal_lengths = proposals.proposal_lens.tolist() for i, seq_group_metadata in enumerate( execute_model_req.seq_group_metadata_list): seq_data_dict = seq_group_metadata.seq_data @@ -27,7 +28,8 @@ def score_proposals( seq_data: SequenceData = seq_data_dict[seq_id] prompt_token_ids = seq_data.get_prompt_token_ids() output_token_ids = seq_data.get_output_token_ids() - proposal_token_ids = all_proposal_tokens[i] + proposal_token_ids = all_proposal_tokens[ + i][:all_proposal_lengths[i]] new_output_token_ids = [*output_token_ids, *proposal_token_ids] target_seq_id = target_seq_id_start + i @@ -62,18 +64,42 @@ def score_proposals( target_sampler_output = target_sampler_output[0] - bs, k = proposals.proposal_token_ids.shape - all_tokens = target_sampler_output.sampled_token_ids.reshape(bs, k + 1) - - all_probs = target_sampler_output.sampled_token_probs.reshape( - bs, k + 1, self._vocab_size) - all_logprobs = target_sampler_output.logprobs.reshape( - bs, k + 1, self._vocab_size) + k = execute_model_req.num_lookahead_slots + bs = len(execute_model_req.seq_group_metadata_list) + target_token_ids = target_sampler_output.sampled_token_ids + target_probs = target_sampler_output.sampled_token_probs + target_logprobs = target_sampler_output.logprobs + # If all requests have the same number of query tokens, we can avoid + # the for loop to build output for better performance. + if min(all_proposal_lengths) == k: + bs, _ = proposals.proposal_token_ids.shape + all_tokens = target_token_ids.reshape(bs, k + 1) + all_probs = target_probs.reshape(bs, k + 1, self._vocab_size) + all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size) + else: + all_tokens = target_token_ids.new_full(size=(bs, k + 1), + fill_value=-1) + all_probs = target_probs.new_zeros(*all_tokens.shape, + self._vocab_size) + all_logprobs = target_logprobs.new_full(size=all_probs.shape, + fill_value=-float("inf")) + target_token_ids = target_token_ids.flatten() + start_loc = 0 + for i, proposed_len in enumerate(all_proposal_lengths): + output_len = proposed_len + 1 + end_loc = start_loc + output_len + all_tokens[ + i, :output_len] = target_token_ids[start_loc:end_loc] + all_probs[i, :output_len] = target_probs[start_loc:end_loc] + all_logprobs[ + i, :output_len] = target_logprobs[start_loc:end_loc] + start_loc = end_loc hidden_states = None if target_sampler_output.hidden_states is not None: hidden_states = target_sampler_output.hidden_states.reshape( bs, (k + 1), -1) + return SpeculativeScores(probs=all_probs, token_ids=all_tokens, logprobs=all_logprobs, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index a67715290a515..0a43fd9091f49 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -188,12 +188,6 @@ def create_worker( "[Speculative Decoding] Disabling MQA scorer as the " "MQA is only available with flash attn backend.") - if ngram_prompt_lookup_max > 0: - disable_mqa_scorer = True - logger.info( - "[Speculative Decoding] Disabling MQA scorer as the " - "NGramWorker does not support MQA scorer.") - if "model_config" in draft_worker_kwargs and \ draft_worker_kwargs["model_config"].max_model_len < \ scorer_worker.model_config.max_model_len: From 02a48498497b6e843a01efba332179dfd5510591 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Fri, 11 Oct 2024 16:27:43 -0700 Subject: [PATCH 2/4] fix tests --- vllm/attention/backends/flash_attn.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 98f65f9844894..0ef183faac0e3 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -206,8 +206,10 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: max_query_len=self.max_query_len, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=self.query_start_loc[self.num_prefills:], - seq_start_loc=self.seq_start_loc[self.num_prefills:], + query_start_loc=self.query_start_loc[self.num_prefills:] + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, From 9a582c67a5ef5f7a61bea593b42418ddfd163754 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Fri, 11 Oct 2024 16:45:10 -0700 Subject: [PATCH 3/4] style --- vllm/attention/backends/blocksparse_attn.py | 5 +---- vllm/attention/backends/flash_attn.py | 12 +++++------- vllm/attention/backends/rocm_flash_attn.py | 5 +---- vllm/attention/backends/xformers.py | 5 +---- 4 files changed, 8 insertions(+), 19 deletions(-) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 57ac152d9edb6..d6955ac394f05 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -186,10 +186,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - # Number of query tokens for each request in the batch. - # Currently, we require that all requests have the same number of query - # tokens during the decoding phase. When speculavie decoding is enabled, - # decode_query_len might be greater than 1. In all other cases, it is 1. + # Max number of query tokens for among request in the batch. decode_query_len: Optional[int] = None _cached_prefill_metadata: Optional[ diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 0ef183faac0e3..1a35a54ef40e4 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -111,11 +111,8 @@ class FlashAttentionMetadata(AttentionMetadata): # Maximum query length in the batch. max_query_len: Optional[int] - # Number of query tokens for each request in the batch. - # Currently, we require that all requests have the same number of query - # tokens during the decoding phase. When speculavie decoding is enabled, - # decode_query_len might be greater than 1. In all other cases, it is 1. - decode_query_len: int + # Max number of query tokens among request in the batch. + decode_query_len: Optional[int] # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. @@ -126,11 +123,11 @@ class FlashAttentionMetadata(AttentionMetadata): # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - query_start_loc: torch.Tensor + query_start_loc: Optional[torch.Tensor] # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. - seq_start_loc: torch.Tensor + seq_start_loc: Optional[torch.Tensor] # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] @@ -718,6 +715,7 @@ def unified_flash_attention( # Decoding run. # Use flash_attn_varlen_func kernel for speculative decoding # because different queries might have different lengths. + assert decode_meta.decode_query_len is not None if decode_meta.decode_query_len > 1: decode_output = flash_attn_varlen_func( q=decode_query, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 7456aab8b8d2a..b85682fbb3c3b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -121,10 +121,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): # so far). context_lens_tensor: Optional[torch.Tensor] - # Number of query tokens for each request in the batch. - # Currently, we require that all requests have the same number of query - # tokens during the decoding phase. When speculavie decoding is enabled, - # decode_query_len might be greater than 1. In all other cases, it is 1. + # Max number of query tokens among request in the batch. decode_query_len: Optional[int] = None _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index a3f9ff64f8b8b..5da509f2699b1 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -118,10 +118,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None - # Number of query tokens for each request in the batch. - # Currently, we require that all requests have the same number of query - # tokens during the decoding phase. When speculavie decoding is enabled, - # decode_query_len might be greater than 1. In all other cases, it is 1. + # Max number of query tokens among request in the batch. decode_query_len: Optional[int] = None # (batch_size + 1,). The cumulative subquery lengths of the sequences in From 3cc9d84d43171b8319838cf9a0ba8f1c6f8299d3 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Fri, 11 Oct 2024 20:02:17 -0700 Subject: [PATCH 4/4] rename --- vllm/attention/backends/blocksparse_attn.py | 2 +- vllm/attention/backends/flash_attn.py | 18 +++++++++--------- vllm/attention/backends/rocm_flash_attn.py | 2 +- vllm/attention/backends/utils.py | 2 +- vllm/attention/backends/xformers.py | 2 +- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index d6955ac394f05..c216d195c9e7e 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -187,7 +187,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): use_cuda_graph: bool # Max number of query tokens for among request in the batch. - decode_query_len: Optional[int] = None + max_decode_query_len: Optional[int] = None _cached_prefill_metadata: Optional[ "BlocksparseFlashAttentionMetadata"] = None diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 1a35a54ef40e4..8457bde066eb7 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -112,7 +112,7 @@ class FlashAttentionMetadata(AttentionMetadata): max_query_len: Optional[int] # Max number of query tokens among request in the batch. - decode_query_len: Optional[int] + max_decode_query_len: Optional[int] # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. @@ -170,9 +170,9 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: slot_mapping=self.slot_mapping[:self.num_prefill_tokens], seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - decode_query_len=0, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, max_decode_seq_len=0, query_start_loc=self.query_start_loc[:self.num_prefills + 1], seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], @@ -199,7 +199,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - decode_query_len=self.decode_query_len, + max_decode_query_len=self.max_decode_query_len, max_query_len=self.max_query_len, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, @@ -412,9 +412,9 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_query_len = max(query_lens) decode_query_lens = query_lens[self.num_prefills:] if len(decode_query_lens) > 0: - decode_query_len = max(decode_query_lens) + max_decode_query_len = max(decode_query_lens) else: - decode_query_len = 1 + max_decode_query_len = 1 max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens @@ -467,7 +467,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, - decode_query_len=decode_query_len, + max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, query_start_loc=query_start_loc, @@ -715,14 +715,14 @@ def unified_flash_attention( # Decoding run. # Use flash_attn_varlen_func kernel for speculative decoding # because different queries might have different lengths. - assert decode_meta.decode_query_len is not None - if decode_meta.decode_query_len > 1: + assert decode_meta.max_decode_query_len is not None + if decode_meta.max_decode_query_len > 1: decode_output = flash_attn_varlen_func( q=decode_query, k=key_cache, v=value_cache, cu_seqlens_q=decode_meta.query_start_loc, - max_seqlen_q=decode_meta.decode_query_len, + max_seqlen_q=decode_meta.max_decode_query_len, cu_seqlens_k=decode_meta.seq_start_loc, max_seqlen_k=decode_meta.max_decode_seq_len, softmax_scale=softmax_scale, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b85682fbb3c3b..fe55a82833944 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -122,7 +122,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): context_lens_tensor: Optional[torch.Tensor] # Max number of query tokens among request in the batch. - decode_query_len: Optional[int] = None + max_decode_query_len: Optional[int] = None _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 2b8c373178ab3..53e3a53badeae 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -313,7 +313,7 @@ def graph_capture_get_metadata_for_batch( seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], max_query_len=1, - decode_query_len=1, + max_decode_query_len=1, max_prefill_seq_len=0, max_decode_seq_len=self.runner.max_seq_len_to_capture, query_start_loc=None, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 5da509f2699b1..9ad7c41e48b68 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -119,7 +119,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): max_query_len: Optional[int] = None # Max number of query tokens among request in the batch. - decode_query_len: Optional[int] = None + max_decode_query_len: Optional[int] = None # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length