From f230cc2ca6614dd4eecf3af9f12c3ddbcf83036e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 31 Jul 2024 10:38:45 +0800 Subject: [PATCH] [Bugfix] Fix broadcasting logic for `multi_modal_kwargs` (#6836) --- .buildkite/test-pipeline.yaml | 5 +- .../dev/multimodal/multimodal_index.rst | 2 + .../distributed/test_multimodal_broadcast.py | 9 +- tests/distributed/test_parallel_state.py | 57 ----------- tests/models/test_llava_next.py | 96 ++++++++++++------- vllm/distributed/parallel_state.py | 46 +++------ vllm/multimodal/__init__.py | 6 +- vllm/multimodal/base.py | 62 ++++++++---- vllm/spec_decode/draft_model_runner.py | 4 +- vllm/utils.py | 51 +++++++++- vllm/worker/cpu_model_runner.py | 27 +++--- vllm/worker/embedding_model_runner.py | 16 +++- vllm/worker/model_runner.py | 14 +-- vllm/worker/neuron_model_runner.py | 17 ++-- vllm/worker/openvino_model_runner.py | 26 ++--- vllm/worker/xpu_model_runner.py | 27 +++--- 16 files changed, 254 insertions(+), 211 deletions(-) delete mode 100644 tests/distributed/test_parallel_state.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 91418e5ec1752..9ec9ec12bfcfe 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -56,7 +56,6 @@ steps: fast_check: true commands: - pytest -v -s core - - pytest -v -s distributed/test_parallel_state.py - label: Distributed Comm Ops Test #mirror_hardwares: [amd] @@ -90,13 +89,13 @@ steps: - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py + - TEST_DIST_MODEL=llava-hf/llava-v1.6-mistral-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py + - TEST_DIST_MODEL=llava-hf/llava-v1.6-mistral-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index 9784f4cc2e088..f70fd03e259ff 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -44,6 +44,8 @@ Base Classes .. autodata:: vllm.multimodal.BatchedTensors +.. autodata:: vllm.multimodal.BatchedTensorInputs + .. autoclass:: vllm.multimodal.MultiModalDataBuiltins :members: :show-inheritance: diff --git a/tests/distributed/test_multimodal_broadcast.py b/tests/distributed/test_multimodal_broadcast.py index 8e0e8ecd675eb..a99917f586949 100644 --- a/tests/distributed/test_multimodal_broadcast.py +++ b/tests/distributed/test_multimodal_broadcast.py @@ -19,10 +19,10 @@ model = os.environ["TEST_DIST_MODEL"] -if model.startswith("llava-hf/llava"): +if model.startswith("llava-hf/llava-1.5"): from ..models.test_llava import models, run_test -elif model.startswith("microsoft/Phi-3-vision"): - from ..models.test_phi3v import models, run_test +elif model.startswith("llava-hf/llava-v1.6"): + from ..models.test_llava_next import models, run_test else: raise NotImplementedError(f"Unsupported model: {model}") @@ -45,7 +45,8 @@ def test_models(hf_runner, vllm_runner, image_assets, vllm_runner, image_assets, model=models[0], - size_factors=[1.0], + # So that LLaVA-NeXT processor may return nested list + size_factors=[0.25, 0.5, 1.0], dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, diff --git a/tests/distributed/test_parallel_state.py b/tests/distributed/test_parallel_state.py deleted file mode 100644 index 3adcf6b61046d..0000000000000 --- a/tests/distributed/test_parallel_state.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Any, Dict - -import pytest -import torch - -from vllm.distributed.parallel_state import (_split_tensor_dict, - _update_nested_dict) - - -def test_split_tensor_dict(): - test_dict = { - "key_a": "a", - "key_b": torch.arange(8, dtype=torch.float32), - "key_c": { - "key_1": torch.arange(5, dtype=torch.float32), - "key_2": torch.tensor([], dtype=torch.float32), - "key_3": 123, - }, - "key_d": {}, - } - metadata_list, tensor_list = _split_tensor_dict(test_dict) - assert len(metadata_list) == 6 - assert torch.allclose(tensor_list[0], test_dict["key_b"]) - assert torch.allclose(tensor_list[1], test_dict["key_c"]["key_1"]) - assert torch.allclose(tensor_list[2], test_dict["key_c"]["key_2"]) - - -def test_split_tensor_dict_invalid_key(): - test_dict = { - "a%b": "a", - } - with pytest.raises(AssertionError): - _split_tensor_dict(test_dict) - - -def test_update_nested_dict(): - flattened_keys_values = [("key1%key2%key3", "value1"), - ("key1%key2%key4", "value2"), - ("key1%key5", "value3"), ("key6%key7", "value4"), - ("key8", "value5")] - res: Dict[str, Any] = {} - - for flat_key, value in flattened_keys_values: - _update_nested_dict(res, flat_key, value) - assert res == { - "key1": { - "key2": { - "key3": "value1", - "key4": "value2" - }, - "key5": "value3" - }, - "key6": { - "key7": "value4" - }, - "key8": "value5" - } diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index 2f200c13ea001..9c64f39eb6d08 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -1,14 +1,12 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Type import pytest from transformers import AutoConfig, AutoTokenizer -from vllm.model_executor.models.llava_next import ( - get_llava_next_image_feature_size) from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from ..conftest import IMAGE_ASSETS +from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from .utils import check_logprobs_close pytestmark = pytest.mark.vlm @@ -27,6 +25,8 @@ IMAGE_TOKEN_ID = 32000 +models = ["llava-hf/llava-v1.6-vicuna-7b-hf"] + def vllm_to_hf_output(vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], @@ -50,34 +50,19 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, return hf_output_ids, hf_output_str, out_logprobs -@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-vicuna-7b-hf"]) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype, max_tokens, num_logprobs) -> None: - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test is under tests/images. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding vision language config as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): images = [asset.pil_image for asset in image_assets] inputs_per_image = [( @@ -89,6 +74,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, with vllm_runner(model, dtype=dtype, max_model_len=4096, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, enforce_eager=True) as vllm_model: vllm_outputs_per_image = [ vllm_model.generate_greedy_logprobs(prompts, @@ -122,9 +109,54 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, ) +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, + dtype, max_tokens, num_logprobs) -> None: + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding vision language config as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors=size_factors, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + @pytest.mark.parametrize("height_and_width_and_result", [(1669, 2560, 2144), (183, 488, 776)]) def test_image_feature_size(height_and_width_and_result): + # Avoid initializing CUDA too early in distributed tests + from vllm.model_executor.models.llava_next import ( + get_llava_next_image_feature_size) + height, width, result = height_and_width_and_result config = AutoConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") assert get_llava_next_image_feature_size(config, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 4116b1729d188..bf7a7de0724af 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -45,22 +45,16 @@ class GraphCaptureContext: def _split_tensor_dict( - tensor_dict: Dict[str, Union[torch.Tensor, Any]], - prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + tensor_dict: Dict[str, Union[torch.Tensor, Any]] +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: """Split the tensor dictionary into two parts: 1. A list of (key, value) pairs. If the value is a tensor, it is replaced by its metadata. 2. A list of tensors. - - If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its - metadata will be "key1%key2". """ metadata_list: List[Tuple[str, Any]] = [] - tensor_list = [] + tensor_list: List[torch.Tensor] = [] for key, value in tensor_dict.items(): - assert "%" not in key, ( - "Avoid having '%' in key " - "as it is used as a separator for nested entries.") if isinstance(value, torch.Tensor): # Note: we cannot use `value.device` here, # because it contains not only the device type but also the device @@ -68,31 +62,13 @@ def _split_tensor_dict( # receiving side will set the device index. device = value.device.type metadata_list.append( - (prefix + key, TensorMetadata(device, value.dtype, - value.size()))) + (key, TensorMetadata(device, value.dtype, value.size()))) tensor_list.append(value) - elif isinstance(value, dict): - if len(value) == 0: - metadata_list.append((prefix + key, value)) - inner_metadata_list, inner_tensor_list = _split_tensor_dict( - value, prefix + key + "%") - metadata_list.extend(inner_metadata_list) - tensor_list.extend(inner_tensor_list) else: - metadata_list.append((prefix + key, value)) + metadata_list.append((key, value)) return metadata_list, tensor_list -def _update_nested_dict(nested_dict, flattened_key, value): - key_splits = flattened_key.split("%") - cur_dict = nested_dict - for k in key_splits[:-1]: - if k not in cur_dict: - cur_dict[k] = {} - cur_dict = cur_dict[k] - cur_dict[key_splits[-1]] = value - - class GroupCoordinator: """ PyTorch ProcessGroup wrapper for a group of processes. @@ -566,7 +542,7 @@ def broadcast_tensor_dict( device=value.device) if tensor.numel() == 0: # Skip broadcasting empty tensors. - _update_nested_dict(tensor_dict, key, tensor) + tensor_dict[key] = tensor continue if tensor.is_cpu: # use metadata_group for CPU tensors @@ -583,9 +559,9 @@ def broadcast_tensor_dict( group=group, async_op=True) async_handles.append(handle) - _update_nested_dict(tensor_dict, key, tensor) + tensor_dict[key] = tensor else: - _update_nested_dict(tensor_dict, key, value) + tensor_dict[key] = value for async_handle in async_handles: async_handle.wait() return tensor_dict @@ -661,7 +637,7 @@ def recv_tensor_dict( device=value.device) if tensor.numel() == 0: # Skip broadcasting empty tensors. - _update_nested_dict(tensor_dict, key, tensor) + tensor_dict[key] = tensor continue if tensor.is_cpu: # use metadata_group for CPU tensors @@ -673,9 +649,9 @@ def recv_tensor_dict( torch.distributed.recv(tensor, src=self.ranks[src], group=group) - _update_nested_dict(tensor_dict, key, tensor) + tensor_dict[key] = tensor else: - _update_nested_dict(tensor_dict, key, value) + tensor_dict[key] = value return tensor_dict def barrier(self): diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 0e3b35d425cb7..456e41ebfad03 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,5 +1,6 @@ -from .base import (BatchedTensors, MultiModalDataBuiltins, MultiModalDataDict, - MultiModalInputs, MultiModalPlugin, NestedTensors) +from .base import (BatchedTensorInputs, BatchedTensors, MultiModalDataBuiltins, + MultiModalDataDict, MultiModalInputs, MultiModalPlugin, + NestedTensors) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -12,6 +13,7 @@ """ __all__ = [ + "BatchedTensorInputs", "BatchedTensors", "MultiModalDataBuiltins", "MultiModalDataDict", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 5abd0ad61cdf9..f13885ef0dab0 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -9,10 +9,12 @@ import torch.types from PIL import Image from torch import nn +from typing_extensions import TypeAlias from vllm.config import ModelConfig from vllm.inputs import InputContext from vllm.logger import init_logger +from vllm.utils import JSONTree, json_map_leaves logger = init_logger(__name__) @@ -22,11 +24,16 @@ Currently only supports up to singly nested list of tensors. """ -BatchedTensors = Union[GenericSequence[NestedTensors], NestedTensors] +BatchedTensors: TypeAlias = JSONTree[torch.Tensor] """ -If each input tensor in the batch has the same size, this is a single batched -tensor; otherwise, this is a list of :class:`NestedTensors` with one element -per item in the batch. +A nested JSON structure of tensors which have been batched via +:meth:`MultiModalInputs.batch`. +""" + +BatchedTensorInputs: TypeAlias = Dict[str, JSONTree[torch.Tensor]] +""" +A dictionary containing nested tensors which have been batched via +:meth:`MultiModalInputs.batch`. """ if sys.version_info < (3, 9): @@ -46,14 +53,17 @@ class MultiModalInputs(_MultiModalInputsBase): """ @staticmethod - def try_concat( + def _try_concat( tensors: List[NestedTensors], - *, - device: torch.types.Device, - ) -> BatchedTensors: + ) -> Union[GenericSequence[NestedTensors], NestedTensors]: + """ + If each input tensor in the batch has the same shape, return a single + batched tensor; otherwise, return a list of :class:`NestedTensors` with + one element per item in the batch. + """ # may be list rather than tensors if isinstance(tensors[0], list): - return [[t.to(device=device) for t in tensor[0]] + return [[t for t in tensor[0]] for tensor in cast(List[List[torch.Tensor]], tensors)] tensors_ = cast(List[torch.Tensor], tensors) @@ -62,18 +72,21 @@ def try_concat( for tensor in tensors_: if tensor.shape[1:] != unbatched_shape: - return [ - tensor.squeeze(0).to(device=device) for tensor in tensors_ - ] + return [tensor.squeeze(0) for tensor in tensors_] - return torch.cat(tensors_, dim=0).to(device=device) + return torch.cat(tensors_, dim=0) @staticmethod - def batch( - inputs_list: List["MultiModalInputs"], - device: torch.types.Device, - ) -> Dict[str, BatchedTensors]: - """Batch multiple inputs together into a dictionary.""" + def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs: + """ + Batch multiple inputs together into a dictionary. + + The resulting dictionary has the same keys as the inputs. + If the corresponding value from each input is a tensor and they all + share the same shape, the output value is a single batched tensor; + otherwise, the output value is a list containing the original value + from each input. + """ if len(inputs_list) == 0: return {} @@ -90,9 +103,18 @@ def batch( item_lists[k].append(v) return { - k: MultiModalInputs.try_concat(item_list, device=device) + k: MultiModalInputs._try_concat(item_list) for k, item_list in item_lists.items() - } + } # type: ignore + + @staticmethod + def as_kwargs( + batched_inputs: BatchedTensorInputs, + *, + device: torch.types.Device, + ) -> BatchedTensorInputs: + return json_map_leaves(lambda x: x.to(device, non_blocking=True), + batched_inputs) class MultiModalDataBuiltins(TypedDict, total=False): diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 95071ecb6c8da..0b755600ae824 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -15,6 +15,7 @@ ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger +from vllm.multimodal import MultiModalInputs from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SamplerOutput) from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, @@ -323,7 +324,8 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, - **multi_modal_kwargs, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), ) # Compute the logits. diff --git a/vllm/utils.py b/vllm/utils.py index b7589ca50ba5b..38e1782a51ab9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -17,7 +17,7 @@ from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar, - Union) + Union, overload) import numpy as np import numpy.typing as npt @@ -53,6 +53,7 @@ P = ParamSpec('P') K = TypeVar("K") T = TypeVar("T") +U = TypeVar("U") class _Sentinel: @@ -712,6 +713,54 @@ def merge_dicts(dict1: Dict[K, List[T]], return dict(merged_dict) +JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"], + Tuple["JSONTree[T]", ...], T] +"""A nested JSON structure where the leaves need not be JSON-serializable.""" + + +@overload +def json_map_leaves( + func: Callable[[T], U], + value: Dict[str, JSONTree[T]], +) -> Dict[str, JSONTree[U]]: + ... + + +@overload +def json_map_leaves( + func: Callable[[T], U], + value: List[JSONTree[T]], +) -> List[JSONTree[U]]: + ... + + +@overload +def json_map_leaves( + func: Callable[[T], U], + value: Tuple[JSONTree[T], ...], +) -> Tuple[JSONTree[U], ...]: + ... + + +@overload +def json_map_leaves( + func: Callable[[T], U], + value: JSONTree[T], +) -> JSONTree[U]: + ... + + +def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]: + if isinstance(value, dict): + return {k: json_map_leaves(func, v) for k, v in value.items()} + elif isinstance(value, list): + return [json_map_leaves(func, v) for v in value] + elif isinstance(value, tuple): + return tuple(json_map_leaves(func, v) for v in value) + else: + return func(value) + + def flatten_2d_lists(lists: List[List[T]]) -> List[T]: """Flatten a list of lists to a single list.""" return [item for sublist in lists for item in sublist] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index c1dee444da512..e22e152a8a8ad 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,6 +1,5 @@ from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, - Type, Union) +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import torch from torch import nn @@ -12,7 +11,7 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) @@ -41,7 +40,7 @@ class CPUModelInput(ModelRunnerInputBase): input_positions: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None sampling_metadata: Optional["SamplingMetadata"] = None - multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None + multi_modal_kwargs: Optional[BatchedTensorInputs] = None virtual_engine: Optional[int] = None def as_broadcastable_tensor_dict( @@ -136,7 +135,7 @@ def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - Mapping[str, BatchedTensors]]: + BatchedTensorInputs]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] @@ -214,8 +213,7 @@ def _prepare_prompt( slot_mapping=slot_mapping, ) - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, - device=self.device) + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs) @@ -361,11 +359,16 @@ def execute_model( model_executable = self.model execute_model_kwargs = { - "input_ids": model_input.input_tokens, - "positions": model_input.input_positions, - "kv_caches": kv_caches, - "attn_metadata": model_input.attn_metadata, - **(model_input.multi_modal_kwargs or {}), + "input_ids": + model_input.input_tokens, + "positions": + model_input.input_positions, + "kv_caches": + kv_caches, + "attn_metadata": + model_input.attn_metadata, + **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), } hidden_states = model_executable(**execute_model_kwargs) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index e919dbd18d9df..72ab96cf3c2e1 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -8,6 +8,7 @@ PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.multimodal import MultiModalInputs from vllm.pooling_params import PoolingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, SequenceGroupMetadata) @@ -99,11 +100,16 @@ def execute_model( kv_caches = [None] * num_layers execute_model_kwargs = { - "input_ids": model_input.input_tokens, - "positions": model_input.input_positions, - "kv_caches": kv_caches, - "attn_metadata": model_input.attn_metadata, - **(model_input.multi_modal_kwargs or {}), + "input_ids": + model_input.input_tokens, + "positions": + model_input.input_positions, + "kv_caches": + kv_caches, + "attn_metadata": + model_input.attn_metadata, + **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), } hidden_states = model_executable(**execute_model_kwargs) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4010c45e10267..de999b11d91b5 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -4,8 +4,8 @@ import warnings import weakref from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, - Tuple, Type, TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, + TypeVar, Union) import numpy as np import torch @@ -40,7 +40,7 @@ from vllm.model_executor.models.interfaces import (supports_lora, supports_vision) from vllm.model_executor.models.utils import set_cpu_offload_max_bytes -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest @@ -94,7 +94,7 @@ class ModelInputForGPU(ModelRunnerInputBase): attn_metadata: Optional["AttentionMetadata"] = None prompt_adapter_mapping: Optional[PromptAdapterMapping] = None prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None - multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None + multi_modal_kwargs: Optional[BatchedTensorInputs] = None request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 @@ -608,8 +608,7 @@ def build(self) -> ModelInputForGPU: data.multi_modal_inputs for data in self.inter_data_list if data.multi_modal_inputs is not None ] - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, - device=self.runner.device) + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) return self.model_input_cls( input_tokens=input_tokens_tensor, @@ -1361,7 +1360,8 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, - **multi_modal_kwargs, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), **seqlen_agnostic_kwargs) # Compute the logits in the last pipeline stage. diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 243e2ece56fe5..6448e5ff4ac5e 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,6 +1,5 @@ from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, - Union) +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -10,7 +9,7 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) @@ -32,7 +31,7 @@ class ModelInputForNeuron(ModelRunnerInputBase): input_positions: Optional[torch.Tensor] = None input_block_ids: Optional[torch.Tensor] = None sampling_metadata: Optional["SamplingMetadata"] = None - multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None + multi_modal_kwargs: Optional[BatchedTensorInputs] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -84,8 +83,8 @@ def load_model(self) -> None: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], Mapping[ - str, BatchedTensors]]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], + BatchedTensorInputs]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] @@ -134,8 +133,7 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, - device=self.device) + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) return (input_tokens, input_positions, input_block_ids, seq_lens, multi_modal_kwargs) @@ -244,7 +242,8 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, - **(model_input.multi_modal_kwargs or {}), + **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), ) # Compute the logits. diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 6281cec09825f..a1d09a2f9e53e 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -1,4 +1,4 @@ -from typing import List, Mapping, NamedTuple, Optional, Tuple +from typing import List, NamedTuple, Optional, Tuple import openvino as ov import torch @@ -12,7 +12,7 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader.openvino import get_model -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -25,7 +25,7 @@ class ModelInput(NamedTuple): attn_metadata: Optional[OpenVINOAttentionMetadata] seq_lens: List[int] query_lens: List[int] - multi_modal_kwargs: Mapping[str, BatchedTensors] + multi_modal_kwargs: BatchedTensorInputs @classmethod def empty(cls, device): @@ -265,8 +265,7 @@ def _prepare_model_input( max_context_len=max_context_len_tensor, ) - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, - device=self.device) + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) return ModelInput( input_tokens, @@ -281,7 +280,7 @@ def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata, - SamplingMetadata, Mapping[str, BatchedTensors]]: + SamplingMetadata, BatchedTensorInputs]: # Prepare input tensors. ( input_tokens, @@ -324,11 +323,16 @@ def execute_model( model_executable = self.model execute_model_kwargs = { - "input_ids": input_tokens, - "positions": input_positions, - "kv_caches": kv_caches, - "attn_metadata": attn_metadata, - **(multi_modal_kwargs or {}), + "input_ids": + input_tokens, + "positions": + input_positions, + "kv_caches": + kv_caches, + "attn_metadata": + attn_metadata, + **MultiModalInputs.as_kwargs(multi_modal_kwargs or {}, + device=self.device), } hidden_states = model_executable(**execute_model_kwargs) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 98462f0f7f38e..112e494fadede 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -1,6 +1,5 @@ from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, - Type, Union) +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -14,7 +13,7 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.models.interfaces import supports_vision -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SamplerOutput, @@ -49,7 +48,7 @@ class ModelInputForXPU(ModelRunnerInputBase): input_positions: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None sampling_metadata: Optional["SamplingMetadata"] = None - multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None + multi_modal_kwargs: Optional[BatchedTensorInputs] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -376,11 +375,16 @@ def execute_model( model_executable = self.model execute_model_kwargs = { - "input_ids": model_input.input_tokens, - "positions": model_input.input_positions, - "kv_caches": kv_caches, - "attn_metadata": model_input.attn_metadata, - **(model_input.multi_modal_kwargs or {}), + "input_ids": + model_input.input_tokens, + "positions": + model_input.input_positions, + "kv_caches": + kv_caches, + "attn_metadata": + model_input.attn_metadata, + **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), } hidden_states = model_executable(**execute_model_kwargs) @@ -404,7 +408,7 @@ def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - Mapping[str, BatchedTensors]]: + BatchedTensorInputs]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] @@ -496,8 +500,7 @@ def _prepare_prompt( block_tables=torch.tensor([], device=self.device, dtype=torch.int), ) - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, - device=self.device) + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs)