Skip to content

Commit

Permalink
fix multi modal chunked prefill
Browse files Browse the repository at this point in the history
Signed-off-by: jiang1.li <[email protected]>
  • Loading branch information
bigPYJ1151 committed Nov 19, 2024
1 parent ec15a2e commit 313900a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def forward(
output = torch.empty_like(query)
if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None
if not prefill_meta.prefill_metadata:
if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore
self._run_sdpa_forward(output,
query,
key,
Expand Down
11 changes: 6 additions & 5 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,12 @@ def build(self) -> ModelInputForCPU:
def _build_input_data(self):
for seq_group_metadata in self.seq_group_metadata_list:
for seq_id, seq_data in seq_group_metadata.seq_data.items():
self._compute_input_tokens(self.input_data, seq_group_metadata,
seq_data, seq_id)
if (seq_group_metadata.is_prompt
and seq_group_metadata.multi_modal_data):
self._compute_multi_modal_input(seq_group_metadata,
seq_data)
self._compute_input_tokens(self.input_data, seq_group_metadata,
seq_data, seq_id)

def _compute_input_tokens(self, data: ModelInputData,
seq_group_metadata: SequenceGroupMetadata,
Expand Down Expand Up @@ -377,10 +377,8 @@ def _compute_input_tokens(self, data: ModelInputData,
def _compute_multi_modal_input(self,
seq_group_metadata: SequenceGroupMetadata,
seq_data: SequenceData):
assert not self.chunked_prefill, \
"multi-model on CPU does not support chunked-prefill."
computed_len = seq_data.get_num_computed_tokens()
seq_len = seq_data.get_len()
seq_len = self.input_data.seq_lens[-1]

# NOTE: mm_data only includes the subset of multi-modal items that
# intersect with the current prefill positions.
Expand All @@ -400,6 +398,9 @@ def _compute_multi_modal_input(self,

# special processing for mrope position deltas.
if self.runner.model_config.uses_mrope:
assert not self.chunked_prefill, \
"MROPE on CPU does not support chunked-prefill."

image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
assert image_grid_thw is not None or video_grid_thw is not None, (
Expand Down

0 comments on commit 313900a

Please sign in to comment.