From cc540c3e2b6069587f2afe70c0e1c55f3a2a8fd6 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 3 Jul 2024 01:49:01 +0000 Subject: [PATCH] Support multimodal data for neuron and tpu --- vllm/worker/neuron_model_runner.py | 37 ++++++++++++++++++++---- vllm/worker/tpu_model_runner.py | 45 ++++++++++++++++++++++++++---- 2 files changed, 71 insertions(+), 11 deletions(-) diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 8b96966be4704..a954681101845 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, + Union) import torch from torch import nn @@ -9,6 +10,8 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader.neuron import get_neuron_model +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, + MultiModalInputs) from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) from vllm.utils import is_pin_memory_available, make_tensor_with_pad @@ -29,6 +32,7 @@ class ModelInputForNeuron(ModelRunnerInputBase): input_positions: Optional[torch.Tensor] = None input_block_ids: Optional[torch.Tensor] = None sampling_metadata: Optional["SamplingMetadata"] = None + multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -65,6 +69,10 @@ def __init__( self.device = self.device_config.device self.pin_memory = is_pin_memory_available() + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) + # Lazy initialization. self.model: nn.Module # initialize after load_model. @@ -76,13 +84,15 @@ def load_model(self) -> None: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], Mapping[ + str, BatchedTensors]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] input_block_ids: List[int] = [] seq_lens: List[int] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -102,6 +112,12 @@ def _prepare_prompt( assert len(block_table) == 1 input_block_ids.append(block_table[0]) + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + # Process multi-modal data + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) + max_seq_len = max(seq_lens) assert max_seq_len > 0 input_tokens = make_tensor_with_pad(input_tokens, @@ -118,7 +134,11 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - return input_tokens, input_positions, input_block_ids, seq_lens + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, + device=self.device) + + return (input_tokens, input_positions, input_block_ids, seq_lens, + multi_modal_kwargs) def _prepare_decode( self, @@ -184,8 +204,9 @@ def prepare_model_input( is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, input_block_ids, - seq_lens) = self._prepare_prompt(seq_group_metadata_list) + (input_tokens, input_positions, input_block_ids, seq_lens, + multi_modal_kwargs + ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) @@ -203,7 +224,8 @@ def prepare_model_input( return ModelInputForNeuron(input_tokens=input_tokens, input_positions=input_positions, input_block_ids=input_block_ids, - sampling_metadata=sampling_metadata) + sampling_metadata=sampling_metadata, + multi_modal_kwargs=multi_modal_kwargs) @torch.inference_mode() def execute_model( @@ -217,10 +239,13 @@ def execute_model( raise ValueError( "NeuronModelRunner does not support multi-step execution.") + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + hidden_states = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, + **multi_modal_kwargs, ) # Compute the logits. diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index dd08536efc5fb..29928e3a8da08 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,5 +1,5 @@ import time -from typing import List, Optional, Tuple +from typing import List, Mapping, Optional, Tuple import numpy as np import torch @@ -12,6 +12,8 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, + MultiModalInputs) from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, SamplerOutput, SequenceGroupMetadata, SequenceOutput) @@ -66,6 +68,10 @@ def __init__( False, ) + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) + def load_model(self) -> None: self.device = self.device_config.device @@ -193,12 +199,14 @@ def warmup_model( def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ): + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor, + Mapping[str, BatchedTensors]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] prompt_lens: List[int] = [] slot_mapping: List[List[int]] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -224,6 +232,11 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping[-1].append(slot) + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) + assert len(prompt_lens) > 0 num_prefills = len(prompt_lens) num_prefill_tokens = sum(prompt_lens) @@ -261,17 +274,24 @@ def _prepare_prompt( block_tables=None, context_lens=None, ) - return input_tokens, input_positions, attn_metadata, prompt_lens + + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, + device=self.device) + + return (input_tokens, input_positions, attn_metadata, prompt_lens, + multi_modal_kwargs) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ): + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor, + Mapping[str, BatchedTensors]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] context_lens: List[int] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] batch_idx = 0 for seq_group_metadata in seq_group_metadata_list: @@ -297,6 +317,11 @@ def _prepare_decode( slot = block_number * self.block_size + block_offset slot_mapping.append([slot]) + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) + batch_size = _get_padded_batch_size(batch_idx) num_paddings = batch_size - batch_idx input_tokens = input_tokens + [[0]] * num_paddings @@ -330,7 +355,12 @@ def _prepare_decode( block_tables=block_tables, context_lens=context_lens, ) - return input_tokens, input_positions, attn_metadata, input_lens + + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, + device=self.device) + + return (input_tokens, input_positions, attn_metadata, input_lens, + multi_modal_kwargs) def _prepare_sample( self, @@ -483,6 +513,7 @@ def forward( kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], attn_metadata: AttentionMetadata, input_lens: torch.Tensor, + multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]], t: torch.Tensor, p: torch.Tensor, num_samples: int, @@ -496,6 +527,8 @@ def forward( memory profiling at initialization. attn_metadata: The Pallas attention metadata. input_lens: The actual input lengths of shape [batch_size]. + multi_modal_kwargs: Keyword arguments from multi-modal data to + pass to the model. t: The sampling temperature of shape [batch_size]. p: The top-p probability of shape [batch_size]. """ @@ -535,11 +568,13 @@ def forward( slot_mapping = slot_mapping.flatten() attn_metadata.slot_mapping = slot_mapping + multi_modal_kwargs = multi_modal_kwargs or {} hidden_states = self.model( token_ids, position_ids, kv_caches, attn_metadata, + **multi_modal_kwargs, ) hidden_states = hidden_states.flatten(0, 1) logits = self.model.compute_logits(hidden_states, sampling_metadata)