From 313900a35e176634728da8162ae845bad242c10e Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Fri, 15 Nov 2024 09:02:48 +0000 Subject: [PATCH] fix multi modal chunked prefill Signed-off-by: jiang1.li --- vllm/attention/backends/torch_sdpa.py | 2 +- vllm/worker/cpu_model_runner.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 910418c6dbc6b..b3981edf3de01 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -404,7 +404,7 @@ def forward( output = torch.empty_like(query) if prefill_meta := attn_metadata.prefill_metadata: assert attn_metadata.seq_lens is not None - if not prefill_meta.prefill_metadata: + if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore self._run_sdpa_forward(output, query, key, diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 0204061ce1f85..635377dbe0fc8 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -263,12 +263,12 @@ def build(self) -> ModelInputForCPU: def _build_input_data(self): for seq_group_metadata in self.seq_group_metadata_list: for seq_id, seq_data in seq_group_metadata.seq_data.items(): + self._compute_input_tokens(self.input_data, seq_group_metadata, + seq_data, seq_id) if (seq_group_metadata.is_prompt and seq_group_metadata.multi_modal_data): self._compute_multi_modal_input(seq_group_metadata, seq_data) - self._compute_input_tokens(self.input_data, seq_group_metadata, - seq_data, seq_id) def _compute_input_tokens(self, data: ModelInputData, seq_group_metadata: SequenceGroupMetadata, @@ -377,10 +377,8 @@ def _compute_input_tokens(self, data: ModelInputData, def _compute_multi_modal_input(self, seq_group_metadata: SequenceGroupMetadata, seq_data: SequenceData): - assert not self.chunked_prefill, \ - "multi-model on CPU does not support chunked-prefill." computed_len = seq_data.get_num_computed_tokens() - seq_len = seq_data.get_len() + seq_len = self.input_data.seq_lens[-1] # NOTE: mm_data only includes the subset of multi-modal items that # intersect with the current prefill positions. @@ -400,6 +398,9 @@ def _compute_multi_modal_input(self, # special processing for mrope position deltas. if self.runner.model_config.uses_mrope: + assert not self.chunked_prefill, \ + "MROPE on CPU does not support chunked-prefill." + image_grid_thw = mm_kwargs.get("image_grid_thw", None) video_grid_thw = mm_kwargs.get("video_grid_thw", None) assert image_grid_thw is not None or video_grid_thw is not None, (