From 5af6b7a5f73524738dfc2eee368ff462dd2c6361 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 10 Dec 2024 16:15:24 -0500 Subject: [PATCH] [Bugfix] Fix Mamba multistep Signed-off-by: Tyler Michael Smith --- vllm/attention/backends/placeholder_attn.py | 64 ++++++++++++++++++++- vllm/worker/multi_step_model_runner.py | 4 +- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 888adbffb8578..658039bfc3365 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -11,7 +11,8 @@ from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) # Placeholder attention backend for models like Mamba and embedding models that # lack attention. @@ -186,6 +187,67 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: ) return self._cached_decode_metadata + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + assert not turn_prefills_into_decodes, \ + ("Multi-Step + Chunked-Prefill is not supported for attention-free" + "models. turn_prefills_into_decodes is a " + "Multi-Step + Chunked-Prefill specific parameter.") + + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + # Update sequences, masking off entries greater than num_queries + device = self.seq_lens_tensor.device + mask = torch.arange(self.seq_lens_tensor.size(0), + device=device) < num_queries + self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype) + if sampled_token_ids is not None: + model_input.input_tokens.masked_scatter_( + mask, sampled_token_ids[:num_queries]) + class PlaceholderAttentionMetadataBuilder( AttentionMetadataBuilder[PlaceholderAttentionMetadata]): diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 3ca0d88a42183..e08a61e31fe42 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -29,7 +29,9 @@ logger = init_logger(__name__) -MULTI_STEP_ATTENTION_BACKENDS = ["FLASH_ATTN", "ROCM_FLASH", "FLASHINFER"] +MULTI_STEP_ATTENTION_BACKENDS = [ + "FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION" +] MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"] def _get_supported_attention_backends(chunked_prefill_enabled: bool) \