Skip to content

Commit

Permalink
Support multimodal data for neuron and tpu
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Jul 3, 2024
1 parent aaa0f1f commit cc540c3
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 11 deletions.
37 changes: 31 additions & 6 deletions vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Union)

import torch
from torch import nn
Expand All @@ -9,6 +10,8 @@
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
Expand All @@ -29,6 +32,7 @@ class ModelInputForNeuron(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None
input_block_ids: Optional[torch.Tensor] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None

def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
Expand Down Expand Up @@ -65,6 +69,10 @@ def __init__(
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()

# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)

# Lazy initialization.
self.model: nn.Module # initialize after load_model.

Expand All @@ -76,13 +84,15 @@ def load_model(self) -> None:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], Mapping[
str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
input_block_ids: List[int] = []

seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
Expand All @@ -102,6 +112,12 @@ def _prepare_prompt(
assert len(block_table) == 1
input_block_ids.append(block_table[0])

mm_data = seq_group_metadata.multi_modal_data
if mm_data:
# Process multi-modal data
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)

max_seq_len = max(seq_lens)
assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens,
Expand All @@ -118,7 +134,11 @@ def _prepare_prompt(
dtype=torch.long,
device=self.device)

return input_tokens, input_positions, input_block_ids, seq_lens
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)

return (input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs)

def _prepare_decode(
self,
Expand Down Expand Up @@ -184,8 +204,9 @@ def prepare_model_input(
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_block_ids,
seq_lens) = self._prepare_prompt(seq_group_metadata_list)
(input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
Expand All @@ -203,7 +224,8 @@ def prepare_model_input(
return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions,
input_block_ids=input_block_ids,
sampling_metadata=sampling_metadata)
sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs)

@torch.inference_mode()
def execute_model(
Expand All @@ -217,10 +239,13 @@ def execute_model(
raise ValueError(
"NeuronModelRunner does not support multi-step execution.")

multi_modal_kwargs = model_input.multi_modal_kwargs or {}

hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**multi_modal_kwargs,
)

# Compute the logits.
Expand Down
45 changes: 40 additions & 5 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import List, Optional, Tuple
from typing import List, Mapping, Optional, Tuple

import numpy as np
import torch
Expand All @@ -12,6 +12,8 @@
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata,
SequenceOutput)
Expand Down Expand Up @@ -66,6 +68,10 @@ def __init__(
False,
)

# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)

def load_model(self) -> None:
self.device = self.device_config.device

Expand Down Expand Up @@ -193,12 +199,14 @@ def warmup_model(
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
):
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
Mapping[str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
prompt_lens: List[int] = []
slot_mapping: List[List[int]] = []
multi_modal_inputs_list: List[MultiModalInputs] = []

for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
Expand All @@ -224,6 +232,11 @@ def _prepare_prompt(
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)

mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)

assert len(prompt_lens) > 0
num_prefills = len(prompt_lens)
num_prefill_tokens = sum(prompt_lens)
Expand Down Expand Up @@ -261,17 +274,24 @@ def _prepare_prompt(
block_tables=None,
context_lens=None,
)
return input_tokens, input_positions, attn_metadata, prompt_lens

multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)

return (input_tokens, input_positions, attn_metadata, prompt_lens,
multi_modal_kwargs)

def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
):
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
Mapping[str, BatchedTensors]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []

batch_idx = 0
for seq_group_metadata in seq_group_metadata_list:
Expand All @@ -297,6 +317,11 @@ def _prepare_decode(
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])

mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)

batch_size = _get_padded_batch_size(batch_idx)
num_paddings = batch_size - batch_idx
input_tokens = input_tokens + [[0]] * num_paddings
Expand Down Expand Up @@ -330,7 +355,12 @@ def _prepare_decode(
block_tables=block_tables,
context_lens=context_lens,
)
return input_tokens, input_positions, attn_metadata, input_lens

multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)

return (input_tokens, input_positions, attn_metadata, input_lens,
multi_modal_kwargs)

def _prepare_sample(
self,
Expand Down Expand Up @@ -483,6 +513,7 @@ def forward(
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
attn_metadata: AttentionMetadata,
input_lens: torch.Tensor,
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]],
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
Expand All @@ -496,6 +527,8 @@ def forward(
memory profiling at initialization.
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
multi_modal_kwargs: Keyword arguments from multi-modal data to
pass to the model.
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
"""
Expand Down Expand Up @@ -535,11 +568,13 @@ def forward(
slot_mapping = slot_mapping.flatten()
attn_metadata.slot_mapping = slot_mapping

multi_modal_kwargs = multi_modal_kwargs or {}
hidden_states = self.model(
token_ids,
position_ids,
kv_caches,
attn_metadata,
**multi_modal_kwargs,
)
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata)
Expand Down

0 comments on commit cc540c3

Please sign in to comment.