Skip to content

Commit 9a93973

Browse files
authored
[Bugfix] Fix Mamba multistep (#11071)
Signed-off-by: Tyler Michael Smith <[email protected]>
1 parent 134810b commit 9a93973

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

vllm/attention/backends/placeholder_attn.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from vllm.multimodal import MultiModalPlaceholderMap
1212

1313
if TYPE_CHECKING:
14-
from vllm.worker.model_runner import ModelInputForGPUBuilder
14+
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
15+
ModelInputForGPUWithSamplingMetadata)
1516

1617
# Placeholder attention backend for models like Mamba and embedding models that
1718
# lack attention.
@@ -186,6 +187,67 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
186187
)
187188
return self._cached_decode_metadata
188189

190+
def advance_step(self,
191+
model_input: "ModelInputForGPUWithSamplingMetadata",
192+
sampled_token_ids: Optional[torch.Tensor],
193+
block_size: int,
194+
num_seqs: int,
195+
num_queries: int,
196+
turn_prefills_into_decodes: bool = False):
197+
"""
198+
Update metadata in-place to advance one decode step.
199+
"""
200+
# When using cudagraph, the num_seqs is padded to the next captured
201+
# batch sized, but num_queries tracks the actual number of requests in
202+
# the batch. For --enforce-eager mode, num_seqs == num_queries
203+
if num_seqs != num_queries:
204+
assert num_seqs > num_queries
205+
assert self.use_cuda_graph
206+
207+
assert not turn_prefills_into_decodes, \
208+
("Multi-Step + Chunked-Prefill is not supported for attention-free"
209+
"models. turn_prefills_into_decodes is a "
210+
"Multi-Step + Chunked-Prefill specific parameter.")
211+
212+
assert self.seq_lens is not None
213+
assert self.max_decode_seq_len == max(self.seq_lens)
214+
215+
assert self.num_prefills == 0
216+
assert self.num_prefill_tokens == 0
217+
assert self.num_decode_tokens == num_seqs
218+
219+
assert self.seq_lens is not None
220+
assert len(self.seq_lens) == num_seqs
221+
assert self.seq_lens_tensor is not None
222+
assert self.seq_lens_tensor.shape == (num_seqs, )
223+
assert self.max_query_len == 1
224+
assert self.max_prefill_seq_len == 0
225+
226+
assert self.query_start_loc is not None
227+
assert self.query_start_loc.shape == (num_queries + 1, )
228+
assert self.seq_start_loc is not None
229+
assert self.seq_start_loc.shape == (num_seqs + 1, )
230+
231+
assert self.context_lens_tensor is not None
232+
assert self.context_lens_tensor.shape == (num_queries, )
233+
234+
assert self.block_tables is not None
235+
236+
# Update query lengths. Note that we update only queries and not seqs,
237+
# since tensors may be padded due to captured cuda graph batch size
238+
for i in range(num_queries):
239+
self.seq_lens[i] += 1
240+
self.max_decode_seq_len = max(self.seq_lens)
241+
242+
# Update sequences, masking off entries greater than num_queries
243+
device = self.seq_lens_tensor.device
244+
mask = torch.arange(self.seq_lens_tensor.size(0),
245+
device=device) < num_queries
246+
self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype)
247+
if sampled_token_ids is not None:
248+
model_input.input_tokens.masked_scatter_(
249+
mask, sampled_token_ids[:num_queries])
250+
189251

190252
class PlaceholderAttentionMetadataBuilder(
191253
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):

vllm/worker/multi_step_model_runner.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929

3030
logger = init_logger(__name__)
3131

32-
MULTI_STEP_ATTENTION_BACKENDS = ["FLASH_ATTN", "ROCM_FLASH", "FLASHINFER"]
32+
MULTI_STEP_ATTENTION_BACKENDS = [
33+
"FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION"
34+
]
3335
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"]
3436

3537
def _get_supported_attention_backends(chunked_prefill_enabled: bool) \

0 commit comments

Comments
 (0)