From 5a662f70eeb02dd4855deaf777cc5ea0ddb862bf Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 13 Nov 2024 20:39:03 +0800 Subject: [PATCH] [1/N] Initial prototype for multi-modal processor (#10044) Signed-off-by: DarkLight1337 Signed-off-by: OmerD --- .../models/enabling_multimodal_inputs.rst | 2 +- .../mm_processor_kwargs/test_qwen.py | 2 +- .../{test_base.py => test_inputs.py} | 2 +- tests/multimodal/test_processor_kwargs.py | 37 ++- tests/v1/core/test_prefix_caching.py | 4 +- vllm/config.py | 2 +- vllm/engine/async_llm_engine.py | 4 + vllm/engine/llm_engine.py | 16 +- vllm/engine/multiprocessing/client.py | 6 + vllm/engine/protocol.py | 16 +- vllm/entrypoints/openai/serving_chat.py | 1 - vllm/entrypoints/openai/serving_completion.py | 1 - vllm/inputs/__init__.py | 12 +- vllm/inputs/data.py | 99 ++++++- vllm/inputs/preprocess.py | 143 +++++++-- vllm/inputs/registry.py | 56 +++- vllm/model_executor/models/chatglm.py | 4 +- vllm/model_executor/models/fuyu.py | 3 +- vllm/model_executor/models/h2ovl.py | 3 +- vllm/model_executor/models/internvl.py | 3 +- vllm/model_executor/models/llava.py | 2 +- vllm/model_executor/models/minicpmv.py | 3 +- vllm/model_executor/models/phi3v.py | 2 +- vllm/model_executor/models/pixtral.py | 3 +- vllm/model_executor/models/qwen.py | 3 +- vllm/model_executor/models/qwen2_vl.py | 6 +- vllm/model_executor/models/utils.py | 2 +- vllm/multimodal/__init__.py | 10 +- vllm/multimodal/audio.py | 12 +- vllm/multimodal/base.py | 188 ++---------- vllm/multimodal/image.py | 10 +- vllm/multimodal/inputs.py | 225 +++++++++++++++ vllm/multimodal/processing.py | 273 ++++++++++++++++++ vllm/multimodal/registry.py | 84 +++++- vllm/multimodal/utils.py | 3 +- vllm/multimodal/video.py | 20 +- vllm/sequence.py | 68 ++--- vllm/v1/engine/async_llm.py | 4 + vllm/v1/engine/llm_engine.py | 4 +- vllm/v1/engine/processor.py | 73 +++-- vllm/v1/request.py | 26 +- vllm/v1/worker/gpu_model_runner.py | 2 +- vllm/worker/cpu_model_runner.py | 41 ++- vllm/worker/hpu_model_runner.py | 6 +- vllm/worker/model_runner.py | 25 +- vllm/worker/neuron_model_runner.py | 22 +- vllm/worker/openvino_model_runner.py | 21 +- vllm/worker/xpu_model_runner.py | 16 +- 48 files changed, 1133 insertions(+), 437 deletions(-) rename tests/multimodal/{test_base.py => test_inputs.py} (97%) create mode 100644 vllm/multimodal/inputs.py create mode 100644 vllm/multimodal/processing.py diff --git a/docs/source/models/enabling_multimodal_inputs.rst b/docs/source/models/enabling_multimodal_inputs.rst index 3d0d1aec69845..49b5285c45590 100644 --- a/docs/source/models/enabling_multimodal_inputs.rst +++ b/docs/source/models/enabling_multimodal_inputs.rst @@ -66,7 +66,7 @@ A default mapper is available for each modality in the core vLLM library. This i 3. Register maximum number of multi-modal tokens ------------------------------------------------ -For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data instance +For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data item and register it via :meth:`INPUT_REGISTRY.register_dummy_data `. .. code-block:: diff diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py index e6ed87fc8ea08..163220c91a27d 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py @@ -6,7 +6,7 @@ from PIL.Image import Image from vllm.inputs import InputContext, token_inputs -from vllm.multimodal.base import MultiModalKwargs +from vllm.multimodal import MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from .....conftest import IMAGE_ASSETS diff --git a/tests/multimodal/test_base.py b/tests/multimodal/test_inputs.py similarity index 97% rename from tests/multimodal/test_base.py rename to tests/multimodal/test_inputs.py index bfaf2cdeaa8d4..678bbb52b8c2f 100644 --- a/tests/multimodal/test_base.py +++ b/tests/multimodal/test_inputs.py @@ -1,6 +1,6 @@ import torch -from vllm.multimodal.base import MultiModalKwargs, NestedTensors +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors def assert_nested_tensors_equal(expected: NestedTensors, diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 4d3bbd805c152..e6c8793989e13 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -1,12 +1,12 @@ from array import array -from typing import Mapping +from typing import Callable, Dict, Mapping, Optional from unittest.mock import patch import pytest import torch from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext, - InputRegistry, token_inputs) + InputRegistry, ProcessorInputs, token_inputs) from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -34,10 +34,9 @@ def custom_processor(ctx: InputContext, inputs: DecoderOnlyInputs, *, num_crops=DEFAULT_NUM_CROPS): - # For testing purposes, we don't worry about the llm inputs / return - # type validation, and just return the value of the kwarg that we - # clobber. - return num_crops + # For testing purposes, we don't worry about the prompt + return token_inputs(prompt_token_ids=[], + mm_processor_kwargs={"num_crops": num_crops}) with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor", return_value=custom_processor): @@ -109,6 +108,21 @@ def _get_num_crops_info(init_num_crops: int, inference_num_crops: int): return init_kwargs, inference_kwargs, expected_seq_count +def _get_processed_num_crops( + processor: Callable[[ProcessorInputs], ProcessorInputs], + inference_kwargs: Optional[Dict[str, int]], +) -> int: + processed_inputs = processor( + token_inputs(prompt_token_ids=[], + prompt="", + mm_processor_kwargs=inference_kwargs)) + + assert "type" in processed_inputs + assert processed_inputs["type"] == "token" + assert "mm_processor_kwargs" in processed_inputs + return processed_inputs["mm_processor_kwargs"]["num_crops"] + + @pytest.mark.parametrize("init_num_crops,inference_num_crops", [ (None, None), (NUM_CROPS_OVERRIDE, None), @@ -124,10 +138,8 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops, ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) - num_crops_val = processor( - token_inputs(prompt_token_ids=[], - prompt="", - mm_processor_kwargs=inference_kwargs)) + num_crops_val = _get_processed_num_crops(processor, inference_kwargs) + assert num_crops_val == expected_seq_count @@ -153,10 +165,7 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, processor = dummy_registry.create_input_processor(ctx.model_config) # Should filter out the inference time kwargs - num_crops_val = processor( - token_inputs(prompt_token_ids=[], - prompt="", - mm_processor_kwargs=mm_processor_kwargs)) + num_crops_val = _get_processed_num_crops(processor, mm_processor_kwargs) assert num_crops_val == DEFAULT_NUM_CROPS diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index e5a3b62258dd8..d614d3e67460f 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1,5 +1,5 @@ """Compare the with and without prefix caching.""" -from vllm.inputs import DecoderOnlyInputs +from vllm.inputs import token_inputs from vllm.sampling_params import SamplingParams from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import hash_block_tokens @@ -8,7 +8,7 @@ def make_request(request_id, prompt_token_ids): return Request( request_id=request_id, - inputs=DecoderOnlyInputs(prompt_token_ids=prompt_token_ids), + inputs=token_inputs(prompt_token_ids=prompt_token_ids), sampling_params=SamplingParams(max_tokens=17), eos_token_id=100, arrival_time=0, diff --git a/vllm/config.py b/vllm/config.py index 8c0182dc5d634..3ae5a171434f5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -108,7 +108,7 @@ class ModelConfig: matches the model name exposed via the APIs. If multiple model names provided, the first name will be used. If not specified, the model name will be the same as `model`. - limit_mm_per_prompt: Maximum number of data instances per modality + limit_mm_per_prompt: Maximum number of data items per modality per prompt. Only applicable for multimodal models. override_neuron_config: Initialize non default neuron config or override default neuron config that are specific to Neuron devices, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 1a371b52bb64b..5a5388708b1c6 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -19,6 +19,7 @@ from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType +from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( @@ -729,6 +730,9 @@ def _error_callback(self, exc: Exception) -> None: self.set_errored(exc) self._request_tracker.propagate_exception(exc) + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.engine.input_preprocessor + async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 69ed6e6bd59d2..f5299746d845d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -30,7 +30,7 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, - PromptType) + PromptType, SingletonInputsAdapter) from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger @@ -39,6 +39,7 @@ from vllm.model_executor.guided_decoding import ( get_local_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams @@ -226,6 +227,7 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, ) -> None: @@ -335,7 +337,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: model_config) self.input_preprocessor = InputPreprocessor(model_config, - self.tokenizer) + self.tokenizer, + mm_registry) self.input_registry = input_registry self.input_processor = input_registry.create_input_processor( @@ -851,13 +854,6 @@ def add_request( ) processed_inputs = self.input_processor(preprocessed_inputs) - # This is a bit of a hack - copy the mm_processor_kwargs that were - # used in the input processor to the processed output, since these - # kwargs are presumed to be immutable and the values should be aligned - # between the input processor (here) and the input mapper. - processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get( - "mm_processor_kwargs") - self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, @@ -2019,7 +2015,7 @@ def _validate_model_inputs(self, inputs: ProcessorInputs, else: prompt_inputs = inputs - prompt_ids = prompt_inputs.get("prompt_token_ids") + prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 882742c2fc61b..fe21c58c775fe 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -31,6 +31,7 @@ # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT from vllm.inputs import PromptType +from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -94,6 +95,8 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig, parallel_config=engine_config.parallel_config, enable_lora=bool(engine_config.lora_config), ) + self.input_preprocessor = InputPreprocessor(self.model_config, + self.tokenizer) # Send RPCGenerateRequest to the MQLLMEngine. self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) @@ -345,6 +348,9 @@ async def _check_success(error_message: str, socket: Socket): or response != VLLM_RPC_SUCCESS_STR): raise ValueError(error_message) + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.input_preprocessor + async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): return await self.tokenizer.get_lora_tokenizer_async(lora_request) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index e0b59d94cfdc3..e15395d75c91f 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -62,7 +62,6 @@ def generate( async def beam_search( self, prompt: PromptType, - model_config: ModelConfig, request_id: str, params: BeamSearchParams, ) -> AsyncGenerator[RequestOutput, None]: @@ -74,13 +73,14 @@ async def beam_search( length_penalty = params.length_penalty include_stop_str_in_output = params.include_stop_str_in_output - tokenizer = await self.get_tokenizer() - input_preprocessor = InputPreprocessor(model_config, tokenizer) + preprocessor = await self.get_input_preprocessor() + tokenizer_group = preprocessor.get_tokenizer_group() + tokenizer = await tokenizer_group.get_lora_tokenizer_async() if is_explicit_encoder_decoder_prompt(prompt): raise NotImplementedError else: - processed_inputs = input_preprocessor._prompt_to_llm_inputs( + processed_inputs = preprocessor._prompt_to_llm_inputs( prompt, request_id=request_id, ) @@ -220,6 +220,7 @@ async def abort(self, request_id: str) -> None: Args: request_id: The unique id of the request. """ + ... @abstractmethod async def get_model_config(self) -> ModelConfig: @@ -228,8 +229,13 @@ async def get_model_config(self) -> ModelConfig: @abstractmethod async def get_decoding_config(self) -> DecodingConfig: - ... """Get the decoding configuration of the vLLM engine.""" + ... + + @abstractmethod + async def get_input_preprocessor(self) -> InputPreprocessor: + """Get the input processor of the vLLM engine.""" + ... @abstractmethod async def get_tokenizer( diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 74867d8de8843..09edaf98f7d17 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -190,7 +190,6 @@ async def create_chat_completion( if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( prompt=engine_prompt, - model_config=self.model_config, request_id=request_id, params=sampling_params, ) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index db31b1153d97e..936aae8f1c267 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -140,7 +140,6 @@ async def create_completion( if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( prompt=engine_prompt, - model_config=self.model_config, request_id=request_id, params=sampling_params, ) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 68ac50a2c5a16..54fbd7a321a6f 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,9 +1,11 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, - SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, - TokensPrompt, build_explicit_enc_dec_prompt, - to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) -from .registry import DummyData, InputContext, InputRegistry + SingletonInputs, SingletonInputsAdapter, SingletonPrompt, + TextPrompt, TokenInputs, TokensPrompt, + build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, + token_inputs, zip_enc_dec_prompts) +from .registry import (DummyData, InputContext, InputProcessingContext, + InputRegistry) INPUT_REGISTRY = InputRegistry() """ @@ -26,12 +28,14 @@ "EncoderDecoderInputs", "ProcessorInputs", "SingletonInputs", + "SingletonInputsAdapter", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", "INPUT_REGISTRY", "DummyData", "InputContext", + "InputProcessingContext", "InputRegistry", ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 46b41f431bec7..07ff9faa50f13 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,10 +1,14 @@ +from dataclasses import dataclass +from functools import cached_property from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal, Optional, Tuple, Union, cast) -from typing_extensions import NotRequired, TypedDict, TypeVar +import torch +from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never if TYPE_CHECKING: from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict + from vllm.multimodal.inputs import MultiModalInputsV2 class TextPrompt(TypedDict): @@ -36,13 +40,13 @@ class TokensPrompt(TypedDict): multi_modal_data: NotRequired["MultiModalDataDict"] """ - Optional multi-modal data to pass to the model, + DEPRECATED: Optional multi-modal data to pass to the model, if the model supports it. """ mm_processor_kwargs: NotRequired[Dict[str, Any]] """ - Optional multi-modal processor kwargs to be forwarded to the + DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the multimodal input mapper & processor. Note that if multiple modalities have registered mappers etc for the model being considered, we attempt to pass the mm_processor_kwargs to each of them. @@ -176,7 +180,7 @@ def token_inputs( return inputs -DecoderOnlyInputs = TokenInputs +DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputsV2"] """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. @@ -191,19 +195,91 @@ class EncoderDecoderInputs(TypedDict): This specifies the required data for encoder-decoder models. """ - encoder: TokenInputs + encoder: Union[TokenInputs, "MultiModalInputsV2"] """The inputs for the encoder portion.""" - decoder: TokenInputs + decoder: Union[TokenInputs, "MultiModalInputsV2"] """The inputs for the decoder portion.""" -SingletonInputs = TokenInputs +SingletonInputs = Union[TokenInputs, "MultiModalInputsV2"] """ A processed :class:`SingletonPrompt` which can be passed to :class:`vllm.sequence.Sequence`. """ + +@dataclass +class SingletonInputsAdapter: + """ + Unified interface to access the components of :class:`SingletonInputs`. + """ + inputs: SingletonInputs + + @cached_property + def prompt(self) -> Optional[str]: + inputs = self.inputs + + if inputs["type"] == "token" or inputs["type"] == "multimodal": + return inputs.get("prompt") + + assert_never(inputs) + + @cached_property + def prompt_token_ids(self) -> List[int]: + inputs = self.inputs + + if inputs["type"] == "token" or inputs["type"] == "multimodal": + return inputs.get("prompt_token_ids", []) + + assert_never(inputs) + + @cached_property + def prompt_embeds(self) -> Optional[torch.Tensor]: + inputs = self.inputs + + if inputs["type"] == "token" or inputs["type"] == "multimodal": + return None + + assert_never(inputs) + + @cached_property + def multi_modal_data(self) -> "MultiModalDataDict": + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("multi_modal_data", {}) + + if inputs["type"] == "multimodal": + return inputs.get("mm_kwargs", {}) + + assert_never(inputs) + + @cached_property + def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("multi_modal_placeholders", {}) + + if inputs["type"] == "multimodal": + return inputs.get("mm_placeholders", {}) + + assert_never(inputs) + + @cached_property + def mm_processor_kwargs(self) -> Dict[str, Any]: + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("mm_processor_kwargs", {}) + + if inputs["type"] == "multimodal": + return {} + + assert_never(inputs) + + ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] """ The inputs to :data:`vllm.inputs.InputProcessor`. @@ -234,10 +310,11 @@ def zip_enc_dec_prompts( ) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]: """ Zip encoder and decoder prompts together into a list of - :class:`ExplicitEncoderDecoderPrompt` instances. mm_processor_kwargs - may also be provided; if a dict is passed, the same dictionary will be - used for every encoder/decoder prompt. If an iterable is provided, it will - be zipped with the encoder/decoder prompts. + :class:`ExplicitEncoderDecoderPrompt` instances. + + ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same + dictionary will be used for every encoder/decoder prompt. If an iterable is + provided, it will be zipped with the encoder/decoder prompts. """ if mm_processor_kwargs is None: mm_processor_kwargs = cast(Dict[str, Any], {}) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 509b0448b9e51..fdf28615fda10 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,11 +1,13 @@ import asyncio -from typing import List, Optional +from typing import List, Mapping, Optional, Union from typing_extensions import assert_never from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.processing import MultiModalDataDict, MultiModalInputsV2 from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.utils import print_warning_once @@ -23,11 +25,13 @@ def __init__( self, model_config: ModelConfig, tokenizer: Optional[BaseTokenizerGroup], + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ) -> None: super().__init__() self.model_config = model_config self.tokenizer = tokenizer + self.mm_registry = mm_registry def get_tokenizer_group(self) -> BaseTokenizerGroup: if self.tokenizer is None: @@ -198,14 +202,79 @@ async def _tokenize_prompt_async( prompt=prompt, lora_request=lora_request) + def _can_process_multimodal(self) -> bool: + model_config = self.model_config + + if not model_config.is_multimodal_model: + raise ValueError("Your model does not support multi-modal inputs") + + # Interim measure so we can handle models that have yet to be + # updated to use the new multi-modal processor + can_process_multimodal = self.mm_registry.has_processor(model_config) + if not can_process_multimodal: + logger.info( + "Your model uses the legacy input pipeline instead of the new " + "multi-modal processor. Please note that the legacy pipeline " + "will be removed in a future release. For more details, see: " + "https://github.com/vllm-project/vllm/issues/10114") + + return can_process_multimodal + + def _process_multimodal( + self, + prompt: Union[str, List[int]], + mm_data: MultiModalDataDict, + mm_processor_kwargs: Optional[Mapping[str, object]], + lora_request: Optional[LoRARequest], + ) -> MultiModalInputsV2: + """ + Apply the model's multi-modal processor to a multi-modal prompt, + returning the corresponding token IDs and metadata. + """ + tokenizer_group = self.get_tokenizer_group() + tokenizer = tokenizer_group.get_lora_tokenizer(lora_request) + + mm_processor = self.mm_registry.create_processor( + self.model_config, tokenizer) + + if isinstance(prompt, list): + prompt = tokenizer.decode(prompt) + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + + return mm_processor.apply(prompt, mm_data, mm_processor_kwargs) + + async def _process_multimodal_async( + self, + prompt: Union[str, List[int]], + mm_data: MultiModalDataDict, + mm_processor_kwargs: Optional[Mapping[str, object]], + lora_request: Optional[LoRARequest], + ) -> MultiModalInputsV2: + """Async version of :meth:`_process_multimodal`.""" + tokenizer_group = self.get_tokenizer_group() + tokenizer = await tokenizer_group.get_lora_tokenizer_async(lora_request + ) + + mm_processor = self.mm_registry.create_processor( + self.model_config, tokenizer) + if isinstance(prompt, list): + logger.warning("Passing `multi_modal_data` in TokensPrompt is" + "deprecated and will be removed in a future update") + prompt = tokenizer.decode(prompt) + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + + return mm_processor.apply(prompt, mm_data, mm_processor_kwargs) + def _prompt_to_llm_inputs( self, prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> SingletonInputs: - ''' - Extract the components of any single encoder or decoder input prompt. + """ + Extract the singleton inputs from a prompt. Arguments: @@ -215,12 +284,8 @@ def _prompt_to_llm_inputs( Returns: - * prompt - * prompt_token_ids - * multi_modal_data - * mm_processor_kwargs (request-level input processor/mapper overrides) - ''' - + * :class:`SingletonInputs` instance + """ parsed = parse_singleton_prompt(prompt) if parsed["type"] == "str": @@ -243,6 +308,14 @@ def _prompt_to_llm_inputs( multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + if multi_modal_data is not None and self._can_process_multimodal(): + return self._process_multimodal( + prompt_token_ids, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + ) + return token_inputs( prompt_token_ids=prompt_token_ids, multi_modal_data=multi_modal_data, @@ -253,13 +326,22 @@ def _prompt_to_llm_inputs( text_content = parsed["content"] prompt_text = text_content["prompt"] + multi_modal_data = text_content.get("multi_modal_data") + mm_processor_kwargs = text_content.get("mm_processor_kwargs") + + if multi_modal_data is not None and self._can_process_multimodal(): + return self._process_multimodal( + prompt_text, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + ) + prompt_token_ids = self._tokenize_prompt( prompt_text, request_id=request_id, lora_request=lora_request, ) - multi_modal_data = text_content.get("multi_modal_data") - mm_processor_kwargs = text_content.get("mm_processor_kwargs") return token_inputs( prompt=prompt_text, @@ -299,6 +381,14 @@ async def _prompt_to_llm_inputs_async( multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + if multi_modal_data is not None and self._can_process_multimodal(): + return await self._process_multimodal_async( + prompt_token_ids, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + ) + return token_inputs( prompt_token_ids=prompt_token_ids, multi_modal_data=multi_modal_data, @@ -309,13 +399,22 @@ async def _prompt_to_llm_inputs_async( text_content = parsed["content"] prompt_text = text_content["prompt"] + multi_modal_data = text_content.get("multi_modal_data") + mm_processor_kwargs = text_content.get("mm_processor_kwargs") + + if multi_modal_data is not None and self._can_process_multimodal(): + return await self._process_multimodal_async( + prompt_text, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + ) + prompt_token_ids = await self._tokenize_prompt_async( prompt_text, request_id=request_id, lora_request=lora_request, ) - multi_modal_data = text_content.get("multi_modal_data") - mm_processor_kwargs = text_content.get("mm_processor_kwargs") return token_inputs( prompt=prompt_text, @@ -331,7 +430,8 @@ def _build_enc_dec_llm_inputs( encoder_inputs: SingletonInputs, decoder_inputs: Optional[SingletonInputs], ) -> EncoderDecoderInputs: - if encoder_inputs["type"] == "token": + if (encoder_inputs["type"] == "token" + or encoder_inputs["type"] == "multimodal"): pass else: assert_never(encoder_inputs) @@ -340,7 +440,8 @@ def _build_enc_dec_llm_inputs( dec_token_ids = self._prepare_decoder_input_ids_for_generation( None) decoder_inputs = token_inputs(dec_token_ids) - elif decoder_inputs["type"] == "token": + elif (decoder_inputs["type"] == "token" + or decoder_inputs["type"] == "multimodal"): dec_token_ids = self._prepare_decoder_input_ids_for_generation( decoder_inputs["prompt_token_ids"]) decoder_inputs["prompt_token_ids"] = dec_token_ids @@ -361,7 +462,7 @@ def _process_encoder_decoder_prompt( prompt: PromptType, request_id: str, ) -> EncoderDecoderInputs: - ''' + """ For encoder/decoder models only: Process an input prompt into an :class:`EncoderDecoderInputs` instance. @@ -391,8 +492,7 @@ def _process_encoder_decoder_prompt( Returns: * :class:`EncoderDecoderInputs` instance - ''' - + """ encoder_inputs: SingletonInputs decoder_inputs: Optional[SingletonInputs] @@ -460,7 +560,8 @@ def _build_decoder_only_llm_inputs( prompt_inputs: DecoderOnlyInputs, prompt_adapter_request: Optional[PromptAdapterRequest], ) -> DecoderOnlyInputs: - if prompt_inputs["type"] == "token": + if (prompt_inputs["type"] == "token" + or prompt_inputs["type"] == "multimodal"): prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( prompt_inputs["prompt_token_ids"], prompt_adapter_request=prompt_adapter_request, @@ -477,7 +578,7 @@ def _process_decoder_only_prompt( lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> DecoderOnlyInputs: - ''' + """ For decoder-only models: Process an input prompt into an :class:`DecoderOnlyInputs` instance. @@ -491,7 +592,7 @@ def _process_decoder_only_prompt( Returns: * :class:`DecoderOnlyInputs` instance - ''' + """ prompt_comps = self._prompt_to_llm_inputs( prompt, diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 7d7a797be4f60..68b4756331e6d 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -5,14 +5,17 @@ Optional, Protocol, Type, cast) from torch import nn -from transformers import PretrainedConfig -from typing_extensions import TypeVar +from transformers import PretrainedConfig, ProcessorMixin +from typing_extensions import TypeVar, assert_never from vllm.logger import init_logger +from vllm.transformers_utils.processor import cached_get_processor +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, resolve_mm_processor_kwargs) -from .data import ProcessorInputs +from .data import ProcessorInputs, SingletonInputs +from .parse import is_encoder_decoder_inputs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -61,6 +64,19 @@ def get_hf_image_processor_config(self) -> Dict[str, Any]: return self.model_config.hf_image_processor_config +@dataclass(frozen=True) +class InputProcessingContext(InputContext): + tokenizer: AnyTokenizer + """The tokenizer used to tokenize the inputs.""" + + def get_hf_processor(self) -> ProcessorMixin: + return cached_get_processor( + self.model_config.tokenizer, + tokenizer=self.tokenizer, # Override the tokenizer with ours + trust_remote_code=self.model_config.trust_remote_code, + ) + + N = TypeVar("N", bound=Type[nn.Module]) @@ -94,7 +110,7 @@ def __call__( ... -class _MultiModalCounts(UserDict): +class _MultiModalCounts(UserDict[str, int]): """ Wraps `mm_counts` for a more informative error message when attempting to access a plugin that does not exist. @@ -287,6 +303,21 @@ def _get_model_input_processor(self, model_cls: Type[nn.Module]): return self._input_processors_by_model_type \ .get(model_cls, self._default_input_processor) + def _ensure_mm_kwargs( + self, + inputs: SingletonInputs, + mm_processor_kwargs: Dict[str, Any], + ): + if inputs["type"] == "token": + # In case the input processor for that model fails to set it + if "mm_processor_kwargs" not in inputs: + inputs["mm_processor_kwargs"] = mm_processor_kwargs + elif inputs["type"] == "multimodal": + # Be more strict in V2 + assert "mm_kwargs" in inputs + else: + assert_never(inputs["type"]) + def process_input(self, model_config: "ModelConfig", inputs: ProcessorInputs) -> ProcessorInputs: """ @@ -312,8 +343,21 @@ def process_input(self, model_config: "ModelConfig", processor, ) - return processor(InputContext(model_config), inputs, - **mm_processor_kwargs) + processed_inputs = processor( + InputContext(model_config), + inputs, + **mm_processor_kwargs, + ) + + if is_encoder_decoder_inputs(processed_inputs): + self._ensure_mm_kwargs(processed_inputs["encoder"], + mm_processor_kwargs) + self._ensure_mm_kwargs(processed_inputs["decoder"], + mm_processor_kwargs) + else: + self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs) + + return processed_inputs def create_input_processor(self, model_config: "ModelConfig"): """ diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 08ed84aa9c71a..6ec2d5a2a3909 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -30,8 +30,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.base import MultiModalData +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 37f38d4d76671..b39dfe706e0df 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -32,8 +32,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges) diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 767171dad7c7b..df7e768fe14d3 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -15,8 +15,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, token_inputs) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from vllm.utils import is_list_of diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 77efc9a26ef7a..07165ea688f94 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -25,8 +25,7 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index af712bf8f9506..005ae5e03cfed 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -17,7 +17,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import NestedTensors +from vllm.multimodal.inputs import NestedTensors from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index aae534c0b5949..999739ccd98bf 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -51,8 +51,7 @@ from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.utils import LLMWrapper from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SequenceData diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index de03d28638cda..4db65edc174f1 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -39,7 +39,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import NestedTensors, PlaceholderRange +from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.utils import is_list_of diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 6bd5e119dd2dd..a3e30ea2dd299 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -29,8 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges) from vllm.sequence import IntermediateTensors, SequenceData diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 5acd87146c54e..3d26ede722dd1 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -42,8 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import is_list_of diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 9a19ccbca3f1e..2335baf459771 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -60,10 +60,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, - MultiModalKwargs) -from vllm.multimodal.base import MultiModalData +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, + MultiModalKwargs) from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData from vllm.transformers_utils.config import uses_mrope diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index ca4fc8ec952bf..1fc6c1be4b7bb 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -15,7 +15,7 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import ModelRegistry -from vllm.multimodal.base import MultiModalPlaceholderMap, NestedTensors +from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 14911853abc73..03a5f3a91f7a1 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,7 +1,8 @@ -from .base import (BatchedTensorInputs, MultiModalDataBuiltins, - MultiModalDataDict, MultiModalKwargs, - MultiModalPlaceholderDict, MultiModalPlaceholderMap, - MultiModalPlugin, NestedTensors) +from .base import MultiModalPlaceholderMap, MultiModalPlugin +from .inputs import (BatchedTensorInputs, MultiModalData, + MultiModalDataBuiltins, MultiModalDataDict, + MultiModalKwargs, MultiModalPlaceholderDict, + NestedTensors) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -15,6 +16,7 @@ __all__ = [ "BatchedTensorInputs", + "MultiModalData", "MultiModalDataBuiltins", "MultiModalDataDict", "MultiModalKwargs", diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index e71ae5feec1c6..1a230602966d4 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -1,5 +1,7 @@ from vllm.inputs.registry import InputContext -from vllm.multimodal.base import MultiModalKwargs, MultiModalPlugin + +from .base import MultiModalPlugin +from .inputs import AudioItem, MultiModalData, MultiModalKwargs class AudioPlugin(MultiModalPlugin): @@ -8,8 +10,12 @@ class AudioPlugin(MultiModalPlugin): def get_data_key(self) -> str: return "audio" - def _default_input_mapper(self, ctx: InputContext, data: object, - **mm_processor_kwargs) -> MultiModalKwargs: + def _default_input_mapper( + self, + ctx: InputContext, + data: MultiModalData[AudioItem], + **mm_processor_kwargs, + ) -> MultiModalKwargs: raise NotImplementedError("There is no default audio input mapper") def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index fa514d3fcb3b7..6eec660e42ac4 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,180 +1,23 @@ from abc import ABC, abstractmethod -from collections import UserDict, defaultdict -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, - NamedTuple, Optional, Tuple, Type, TypedDict, TypeVar, - Union, cast, final) - -import numpy as np -import torch -import torch.types -from PIL import Image +from collections import defaultdict +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, + Optional, Sequence, Tuple, Type, TypeVar, Union) + from torch import nn -from typing_extensions import TypeAlias from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, - json_map_leaves, resolve_mm_processor_kwargs) +from vllm.utils import (get_allowed_kwarg_only_overrides, + resolve_mm_processor_kwargs) if TYPE_CHECKING: from vllm.config import ModelConfig from vllm.sequence import SequenceGroupMetadata -logger = init_logger(__name__) - -NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] -""" -Uses a list instead of a tensor if the dimensions of each element do not match. -""" - -BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors] -""" -A dictionary containing nested tensors which have been batched via -:meth:`MultiModalKwargs.batch`. -""" - - -class _MultiModalKwargsBase(UserDict[str, NestedTensors]): - pass - - -class MultiModalKwargs(_MultiModalKwargsBase): - """ - A dictionary that represents the keyword arguments to - :meth:`~torch.nn.Module.forward`. - """ - - @staticmethod - def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: - """ - Recursively stacks lists of tensors when they all have the same shape. - """ - if isinstance(nested_tensors, torch.Tensor): - return nested_tensors - - if isinstance(nested_tensors, np.ndarray): - return torch.from_numpy(nested_tensors) - - if isinstance(nested_tensors, (int, float)): - return torch.tensor(nested_tensors) +from .inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs, + PlaceholderRange) - stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors] - if not is_list_of(stacked, torch.Tensor, check="all"): - # Only tensors (not lists) can be stacked. - return stacked - - tensors_ = cast(List[torch.Tensor], stacked) - if any(t.shape != tensors_[0].shape for t in tensors_): - # The tensors have incompatible shapes and can't be stacked. - return tensors_ - - return torch.stack(tensors_) - - @staticmethod - def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs: - """ - Batch multiple inputs together into a dictionary. - - The resulting dictionary has the same keys as the inputs. - If the corresponding value from each input is a tensor and they all - share the same shape, the output value is a single batched tensor; - otherwise, the output value is a list containing the original value - from each input. - """ - if len(inputs_list) == 0: - return {} - - item_lists: Dict[str, List[NestedTensors]] = defaultdict(list) - - for inputs in inputs_list: - # For models that supports multiple modalities (e.g. Qwen2-VL), - # different modalities will return different data keys, - # so batch() should skip the same key check. - - for k, v in inputs.items(): - item_lists[k].append(v) - - return { - k: MultiModalKwargs._try_stack(item_list) - for k, item_list in item_lists.items() - } - - @staticmethod - def as_kwargs( - batched_inputs: BatchedTensorInputs, - *, - device: torch.types.Device, - ) -> BatchedTensorInputs: - json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) - - json_mapped = json_map_leaves( - lambda x: x.to(device, non_blocking=True), - json_inputs, - ) - - return cast(BatchedTensorInputs, json_mapped) - - -_T = TypeVar("_T") - -MultiModalData: TypeAlias = Union[_T, List[_T]] -""" -Either a single data instance, or a list of data instances. - -The number of data instances allowed per modality is restricted by -`--limit-mm-per-prompt`. -""" - - -@final -class MultiModalDataBuiltins(TypedDict, total=False): - """Modality types that are predefined by vLLM.""" - - image: MultiModalData[Image.Image] - """The input image(s).""" - - audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]] - """The input audio item(s) and corresponding sampling rate(s).""" - - video: MultiModalData[Tuple[np.ndarray]] - """The input video(s).""" - - -MultiModalDataDict = Union[MultiModalDataBuiltins, - Mapping[str, MultiModalData[object]]] -""" -A dictionary containing an item for each modality type to input. - -Note: - This dictionary also accepts modality keys defined outside - :class:`MultiModalDataBuiltins` as long as a customized plugin is registered - through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. - Read more on that :ref:`here `. -""" - - -class PlaceholderRange(TypedDict): - """ - Placeholder location information for multi-modal data. - - For example: - Prompt: AAAA BBBB What is in these images? - Images A and B will have: - A: { "offset": 0, "length": 4 } - B: { "offset": 5, "length": 4 } - """ - - offset: int - """The start index of the placeholder in the prompt.""" - - length: int - """The length of the placeholder.""" - - -MultiModalPlaceholderDict = Mapping[str, List[PlaceholderRange]] -""" -A dictionary containing placeholder ranges. -""" +logger = init_logger(__name__) MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]], MultiModalKwargs] @@ -192,6 +35,7 @@ class PlaceholderRange(TypedDict): model. This does not include tokens that correspond to the input text. """ +_T = TypeVar("_T") N = TypeVar("N", bound=Type[nn.Module]) @@ -224,7 +68,7 @@ def get_data_key(self) -> str: def _default_input_mapper( self, ctx: InputContext, - data: MultiModalData[object], + data: MultiModalData[Any], **mm_processor_kwargs, ) -> MultiModalKwargs: """ @@ -273,8 +117,8 @@ def wrapper(model_cls: N) -> N: def map_input( self, model_config: "ModelConfig", - data: MultiModalData[object], - mm_processor_kwargs: Dict[str, Any], + data: MultiModalData[Any], + mm_processor_kwargs: Optional[Dict[str, Any]], ) -> MultiModalKwargs: """ Transform the data into a dictionary of model inputs using the @@ -289,6 +133,7 @@ def map_input( - :ref:`input_processing_pipeline` - :ref:`enabling_multimodal_inputs` """ + # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture @@ -300,6 +145,9 @@ def map_input( raise KeyError(f"No input mapper in {self} is registered for " f"model class {model_cls.__name__}.") + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + # In the case of the default mapper, we have to get resource # processor through its HuggingFace autoclass; since this goes # through **kwargs, we can't inspect it the same way, so we allow @@ -508,7 +356,7 @@ def append_items_from_seq_group( self, positions: range, multi_modal_items: List[_T], - multi_modal_placeholders: List[PlaceholderRange], + multi_modal_placeholders: Sequence[PlaceholderRange], ) -> List[_T]: """ Adds the multi-modal items that intersect ```positions`` to this diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 589b46266b08d..97bbce1ce1570 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -3,14 +3,14 @@ import torch from PIL import Image -from transformers.image_processing_base import BatchFeature from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.processor import get_image_processor from vllm.utils import is_list_of -from .base import MultiModalData, MultiModalKwargs, MultiModalPlugin +from .base import MultiModalPlugin +from .inputs import ImageItem, MultiModalData, MultiModalKwargs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -41,15 +41,11 @@ def _get_hf_image_processor( def _default_input_mapper( self, ctx: InputContext, - data: MultiModalData[object], + data: MultiModalData[ImageItem], **mm_processor_kwargs, ) -> MultiModalKwargs: model_config = ctx.model_config - # Processed by input processor - if isinstance(data, BatchFeature): - return MultiModalKwargs(data.data) - # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): image_processor = self._get_hf_image_processor( diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py new file mode 100644 index 0000000000000..64a4c58d5509c --- /dev/null +++ b/vllm/multimodal/inputs.py @@ -0,0 +1,225 @@ +from collections import UserDict, defaultdict +from typing import (Any, Dict, List, Literal, Mapping, Sequence, Tuple, + TypedDict, TypeVar, Union, cast, final) + +import numpy as np +import torch +import torch.types +from PIL.Image import Image +from typing_extensions import TypeAlias + +from vllm.utils import JSONTree, is_list_of, json_map_leaves + +_T = TypeVar("_T") + +# yapf: disable +ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor] +""" +A :class:`transformers.image_utils.ImageInput` representing a single image, +which can be passed to a HuggingFace :code:`ImageProcessor`. +""" + +VideoItem: TypeAlias = Union[ + List[Image], + np.ndarray, + torch.Tensor, + List[np.ndarray], + List[torch.Tensor], +] +""" + +A :class:`transformers.image_utils.VideoInput` representing a single video, +which can be passed to a HuggingFace :code:`VideoProcessor`. +""" + +AudioItem: TypeAlias = Union[ + np.ndarray, + List[float], + Tuple[np.ndarray, float], # DEPRECATED: Use mm_processor_kwargs instead +] +""" +Represents a single audio that can be inputted to a HuggingFace +:code:`AudioProcessor`. +""" +# yapf: enable + +MultiModalData: TypeAlias = Union[_T, List[_T]] +""" +Either a single data item, or a list of data items. + +The number of data items allowed per modality is restricted by +:code:`--limit-mm-per-prompt`. +""" + + +@final +class MultiModalDataBuiltins(TypedDict, total=False): + """Type annotations for modality types predefined by vLLM.""" + + image: MultiModalData[ImageItem] + """The input image(s).""" + + video: MultiModalData[VideoItem] + """The input video(s).""" + + audio: MultiModalData[AudioItem] + """The input audio(s).""" + + +MultiModalDataDict: TypeAlias = Mapping[str, MultiModalData[Any]] +""" +A dictionary containing an entry for each modality type to input. + +Note: + This dictionary also accepts modality keys defined outside + :class:`MultiModalDataBuiltins` as long as a customized plugin + is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. + Read more on that :ref:`here `. +""" + + +class PlaceholderRange(TypedDict): + """ + Placeholder location information for multi-modal data. + + For example: + Prompt: AAAA BBBB What is in these images? + Images A and B will have: + A: { "offset": 0, "length": 4 } + B: { "offset": 5, "length": 4 } + """ + + offset: int + """The start index of the placeholder in the prompt.""" + + length: int + """The length of the placeholder.""" + + +NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] +""" +Uses a list instead of a tensor if the dimensions of each element do not match. +""" + +BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors] +""" +A dictionary containing nested tensors which have been batched via +:meth:`MultiModalKwargs.batch`. +""" + + +class MultiModalKwargs(UserDict[str, NestedTensors]): + """ + A dictionary that represents the keyword arguments to + :meth:`~torch.nn.Module.forward`. + """ + + @staticmethod + def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: + """ + Stack the inner dimensions that have the same shape in + a nested list of tensors. + + Thus, a dimension represented by a list means that the inner + dimensions are different for each element along that dimension. + """ + if isinstance(nested_tensors, torch.Tensor): + return nested_tensors + + # TODO: Remove these once all models have been migrated + if isinstance(nested_tensors, np.ndarray): + return torch.from_numpy(nested_tensors) + if isinstance(nested_tensors, (int, float)): + return torch.tensor(nested_tensors) + + stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors] + if not is_list_of(stacked, torch.Tensor, check="all"): + # Only tensors (not lists) can be stacked. + return stacked + + tensors_ = cast(List[torch.Tensor], stacked) + if any(t.shape != tensors_[0].shape for t in tensors_): + # The tensors have incompatible shapes and can't be stacked. + return tensors_ + + return torch.stack(tensors_) + + @staticmethod + def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs: + """ + Batch multiple inputs together into a dictionary. + + The resulting dictionary has the same keys as the inputs. + If the corresponding value from each input is a tensor and they all + share the same shape, the output value is a single batched tensor; + otherwise, the output value is a list containing the original value + from each input. + """ + if len(inputs_list) == 0: + return {} + + # We need to consider the case where each item in the batch + # contains different modalities (i.e. different keys). + item_lists: Dict[str, List[NestedTensors]] = defaultdict(list) + + for inputs in inputs_list: + for k, v in inputs.items(): + item_lists[k].append(v) + + return { + k: MultiModalKwargs._try_stack(item_list) + for k, item_list in item_lists.items() + } + + @staticmethod + def as_kwargs( + batched_inputs: BatchedTensorInputs, + *, + device: torch.types.Device, + ) -> BatchedTensorInputs: + json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) + + json_mapped = json_map_leaves( + lambda x: x.to(device, non_blocking=True), + json_inputs, + ) + + return cast(BatchedTensorInputs, json_mapped) + + +MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]] +""" +A dictionary containing placeholder ranges. +""" + + +class MultiModalInputsV2(TypedDict): + """ + Represents the outputs of :class:`vllm.multimodal.MultiModalProcessor`, + ready to be passed to vLLM internals. + """ + + type: Literal["multimodal"] + """The type of inputs.""" + + prompt: str + """ + The original, unprocessed prompt text. + + Note: + Since prompt text is not required by vLLM internals, we leave this + unprocessed to save CPU computation. You can still call + :code:`tokenizer.decode(prompt_token_ids)` to get the processed text. + """ + + prompt_token_ids: List[int] + """The processed token IDs which includes placeholder tokens.""" + + mm_kwargs: MultiModalKwargs + """Keyword arguments to be directly passed to the model after batching.""" + + mm_placeholders: MultiModalPlaceholderDict + """ + For each modality, information about the placeholder tokens in + :code:`prompt_token_ids`. + """ diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py new file mode 100644 index 0000000000000..88a924da174a6 --- /dev/null +++ b/vllm/multimodal/processing.py @@ -0,0 +1,273 @@ +from dataclasses import dataclass +from functools import lru_cache, partial +from typing import (Any, Callable, Collection, Generic, List, Mapping, + Optional, TypedDict, TypeVar, final) + +from transformers import BatchFeature +from typing_extensions import TypeAlias + +from vllm.inputs import InputProcessingContext +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import is_list_of + +from .inputs import (AudioItem, ImageItem, MultiModalDataDict, + MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, + VideoItem) + +_T = TypeVar("_T") + +ReplacementFunc: TypeAlias = Callable[[_T, BatchFeature, int], List[int]] +""" +Given the original data item, HF-processed data, and index of the processed +item, output the replacement token IDs to be allocated in vLLM. +""" + + +@dataclass +class ModalityProcessingMetadata(Generic[_T]): + placeholder_replacements: Mapping[str, ReplacementFunc] + """ + A dictionary where each item represents the original placeholder in the + prompt text and the corresponding replacement. + """ + + +class MultiModalProcessingMetadataBuiltins(TypedDict, total=False): + """Type annotations for modality types predefined by vLLM.""" + + image: ModalityProcessingMetadata[ImageItem] + video: ModalityProcessingMetadata[VideoItem] + audio: ModalityProcessingMetadata[AudioItem] + + +MultiModalProcessingMetadata: TypeAlias = \ + Mapping[str, ModalityProcessingMetadata[Any]] +""" +A dictionary containing an entry for each modality type to process. + +Note: + This dictionary also accepts modality keys defined outside + :class:`MultiModalProcessingMetadataBuiltins` as long as a customized plugin + is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. + Read more on that :ref:`here `. +""" + +MultiModalMultiData: TypeAlias = List[_T] +""" +A list of data items, where the number of data items allowed +per modality is restricted by :code:`--limit-mm-per-prompt`. +""" + + +@final +class MultiModalMultiDataBuiltins(TypedDict, total=False): + """Type annotations for modality types predefined by vLLM.""" + + image: MultiModalMultiData[ImageItem] + """The input images.""" + + video: MultiModalMultiData[VideoItem] + """The input videos.""" + + audio: MultiModalMultiData[AudioItem] + """The input audios.""" + + +MultiModalMultiDataDict: TypeAlias = Mapping[str, MultiModalMultiData[Any]] +""" +A dictionary containing an entry for each modality type to input. + +Note: + This dictionary also accepts modality keys defined outside + :class:`MultiModalMultiDataBuiltins` as long as a customized plugin + is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. + Read more on that :ref:`here `. +""" + + +def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict: + """ + Convert a :class:`MultiModalDataDict` containing single data items + to a :class:`MultiModalMultiDataDict` containing multiple data items + per entry. + """ + multi_data: Mapping[str, MultiModalMultiData[Any]] = {} + + for k, v in data.items(): + # yapf: disable + if k == "video": + # Special case since even a single item can be a list + multi_data[k] = v if is_list_of(v, list) else [v] # type: ignore[index] + elif k in ("image", "audio"): + multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] + else: + multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] + # yapf: enable + + return multi_data + + +def encode_no_special_tokens( + tokenizer: AnyTokenizer, + text: str, +) -> List[int]: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.encode(text, add_special_tokens=False)`. + """ + if isinstance(tokenizer, MistralTokenizer): + return tokenizer.tokenizer.encode(text, bos=False, eos=False) + + return tokenizer.encode(text, add_special_tokens=False) + + +@lru_cache +def candidate_placeholders( + tokenizer: AnyTokenizer, + placeholder_text: str, +) -> Collection[List[int]]: + """Generate token ID sequences that may represent a placeholder text.""" + # When the placeholder text is not mapped to a special token ID, + # it may be tokenized differently based on whether it is at the start/end + # of the string. So, we go through each combination of whether the text + # is at the start and end boundaries of the string + + # Matches the placeholder when it is in the middle of the string + start_id, = encode_no_special_tokens(tokenizer, "a") + end_id, = encode_no_special_tokens(tokenizer, "b") + + candidate_basic = encode_no_special_tokens(tokenizer, placeholder_text) + + start_id_, *candidate_a = encode_no_special_tokens( + tokenizer, + f"a{placeholder_text}", + ) + assert start_id == start_id_ + + start_id_, *candidate_ab, end_id_ = encode_no_special_tokens( + tokenizer, + f"a{placeholder_text}b", + ) + assert start_id == start_id_ and end_id == end_id_ + + *candidate_b, end_id_ = encode_no_special_tokens( + tokenizer, + f"{placeholder_text}b", + ) + assert end_id == end_id_ + + # Remove duplicates (need to convert to tuple to be hashable) + unique_candidates = { + tuple(c) + for c in [candidate_basic, candidate_a, candidate_ab, candidate_b] + } + + # Convert back to list + return [list(c) for c in unique_candidates] + + +def apply_placeholders( + token_ids: List[int], + placeholder_ids: List[int], + get_replacement_ids: Callable[[], List[int]], +) -> Optional[PlaceholderRange]: + """ + Find the first occurrence of :code:`placeholder_ids`, + and replace it with the output of :code:`get_replacement_ids`. + + This function updates :code:`token_ids` in place. + """ + placeholder_length = len(placeholder_ids) + + for start_idx in range(len(token_ids) - placeholder_length + 1): + if token_ids[start_idx:placeholder_length] == placeholder_ids: + token_ids[start_idx:placeholder_length] = get_replacement_ids() + + return PlaceholderRange(offset=start_idx, + length=placeholder_length) + + return None + + +class MultiModalProcessor: + """ + Helper class to process multi-modal inputs to be used in vLLM. + """ + + def __init__( + self, + ctx: InputProcessingContext, + metadata: MultiModalProcessingMetadata, + ) -> None: + super().__init__() + + self.ctx = ctx + self.metadata = metadata + + def __call__( + self, + prompt: str, + mm_data: MultiModalDataDict, + mm_processor_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + return self.apply(prompt, mm_data, mm_processor_kwargs) + + def apply( + self, + prompt: str, + mm_data: MultiModalDataDict, + mm_processor_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + tokenizer = self.ctx.tokenizer + hf_processor = self.ctx.get_hf_processor() + + processed_inputs = hf_processor( + text=prompt, # type: ignore + **mm_data, + **mm_processor_kwargs, + ) + new_token_ids, = processed_inputs.pop("input_ids").tolist() + mm_kwargs = MultiModalKwargs(processed_inputs) + + mm_placeholders: Mapping[str, List[PlaceholderRange]] = {} + + for modality, orig_inputs in to_multi_format(mm_data).items(): + assert isinstance(orig_inputs, list) + + metadata = self.metadata[modality] + placeholder_replacements = metadata.placeholder_replacements + + modality_placeholders: List[PlaceholderRange] = [] + + for item_idx, orig_item in enumerate(orig_inputs): + for match_text, replace_fn in placeholder_replacements.items(): + candidates = candidate_placeholders(tokenizer, match_text) + get_replacement_ids = partial( + replace_fn, + orig_item, + processed_inputs, + item_idx, + ) + + for match_ids in candidates: + # TODO(youkaichao): Don't update new_token_ids + placeholders = apply_placeholders( + new_token_ids, + match_ids, + get_replacement_ids, + ) + + if placeholders is not None: + modality_placeholders.append(placeholders) + + # yapf: disable + mm_placeholders[modality] = modality_placeholders # type: ignore[index] + # yapf: enable + + return MultiModalInputsV2( + type="multimodal", + prompt=prompt, + prompt_token_ids=new_token_ids, + mm_kwargs=mm_kwargs, + mm_placeholders=mm_placeholders, + ) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index b844c9e1c2e89..b992442d3b314 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,13 +1,20 @@ import functools from collections import UserDict -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence +from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, + Sequence, Type, TypeVar) +import torch.nn as nn +from typing_extensions import TypeAlias + +from vllm.inputs import InputProcessingContext from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer from .audio import AudioPlugin -from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalKwargs, - MultiModalPlugin, MultiModalTokensCalc, NestedTensors) +from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin +from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors +from .processing import MultiModalProcessor from .video import VideoPlugin if TYPE_CHECKING: @@ -15,8 +22,18 @@ logger = init_logger(__name__) +N = TypeVar("N", bound=Type[nn.Module]) + +MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext], + MultiModalProcessor] +""" +Constructs a :class:`MultiModalProcessor` instance from the context. + +The processing metadata should be derived from the context. +""" + -class _MultiModalLimits(UserDict): +class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]): """ Wraps `_limits_by_model` for a more informative error message when attempting to access a model that does not exist. @@ -45,6 +62,9 @@ def __init__( plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: self._plugins = {p.get_data_key(): p for p in plugins} + self._processor_factories: Dict[Type[nn.Module], + MultiModalProcessorFactory] = {} + # This is used for non-multimodal models self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} @@ -243,3 +263,59 @@ def get_mm_limits_per_prompt( This should be called after :meth:`init_mm_limits_per_prompt`. """ return self._limits_by_model[model_config] + + def register_processor( + self, + factory: MultiModalProcessorFactory, + ): + """ + Register a multi-modal processor to a model class. + + When the model receives multi-modal data, the provided function is + invoked to transform the data into a dictionary of model inputs. + + See also: + - :ref:`input_processing_pipeline` + - :ref:`enabling_multimodal_inputs` + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._processor_factories: + logger.warning( + "Model class %s already has an input mapper " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._processor_factories[model_cls] = factory + + return model_cls + + return wrapper + + def has_processor(self, model_config: "ModelConfig") -> bool: + """ + Test whether a multi-modal processor is defined for a specific model. + """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + return model_cls in self._processor_factories + + def create_processor( + self, + model_config: "ModelConfig", + tokenizer: AnyTokenizer, + ) -> MultiModalProcessor: + """ + Create a multi-modal processor for a specific model and tokenizer. + """ + + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + processor_factory = self._processor_factories[model_cls] + + ctx = InputProcessingContext(model_config, tokenizer) + return processor_factory(ctx) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index bee3c25dbd8dd..40194716bbf94 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -11,9 +11,10 @@ import vllm.envs as envs from vllm.connections import global_http_connection from vllm.logger import init_logger -from vllm.multimodal.base import MultiModalDataDict, PlaceholderRange from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer +from .inputs import MultiModalDataDict, PlaceholderRange + logger = init_logger(__name__) cached_get_tokenizer = lru_cache(get_tokenizer) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index a518270974f92..ba9bf58a4a20c 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -1,5 +1,5 @@ from functools import lru_cache -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional import numpy as np @@ -9,8 +9,9 @@ from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import is_list_of -from .base import MultiModalData, MultiModalKwargs +from .base import MultiModalData from .image import ImagePlugin +from .inputs import MultiModalKwargs, VideoItem if TYPE_CHECKING: from vllm.config import ModelConfig @@ -20,17 +21,6 @@ cached_get_video_processor = lru_cache(get_video_processor) cached_get_tokenizer = lru_cache(get_tokenizer) -VideoInput = Union[ - "np.ndarray", # single video input - List["np.ndarray"], - # TODO: support more types - # List[Image.Image], List[List[Image.Image]], - # "torch.Tensor", - # List["torch.Tensor"], - # List[List["np.ndarrray"]], - # List[List["torch.Tensor"]], -] - class VideoPlugin(ImagePlugin): """Plugin for video data.""" @@ -53,13 +43,13 @@ def _get_hf_video_processor( def _default_input_mapper( self, ctx: InputContext, - data: MultiModalData[object], + data: MultiModalData[VideoItem], **mm_processor_kwargs, ) -> MultiModalKwargs: model_config = ctx.model_config if isinstance(data, list) and len(data) == 1: - data = data[0] + data = data[0] # type: ignore if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray): video_processor = self._get_hf_video_processor( diff --git a/vllm/sequence.py b/vllm/sequence.py index 1370cb5c4f9d2..3b41d25a2fe42 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,25 +5,21 @@ from array import array from collections import defaultdict from dataclasses import dataclass, field -from functools import cached_property, reduce -from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, - Mapping, Optional) +from functools import reduce +from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional from typing import Sequence as GenericSequence from typing import Set, Tuple, Union import msgspec import torch -from typing_extensions import assert_never +from vllm.inputs import SingletonInputs, SingletonInputsAdapter from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams -if TYPE_CHECKING: - from vllm.inputs import SingletonInputs - VLLM_TOKEN_ID_ARRAY_TYPE = "l" VLLM_INVALID_TOKEN_ID = -1 @@ -407,14 +403,14 @@ class Sequence: def __init__( self, seq_id: int, - inputs: "SingletonInputs", + inputs: SingletonInputs, block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.seq_id = seq_id - self.inputs = inputs + self.inputs = SingletonInputsAdapter(inputs) self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request @@ -441,59 +437,29 @@ def __init__( def n_blocks(self) -> int: return (self.get_len() + self.block_size - 1) // self.block_size - @cached_property + @property def prompt(self) -> Optional[str]: - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("prompt") + return self.inputs.prompt - assert_never(inputs) - - @cached_property + @property def prompt_token_ids(self) -> List[int]: - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("prompt_token_ids", []) + return self.inputs.prompt_token_ids - assert_never(inputs) - - @cached_property + @property def prompt_embeds(self) -> Optional[torch.Tensor]: - inputs = self.inputs - - if inputs["type"] == "token": - return None - - assert_never(inputs) + return self.inputs.prompt_embeds - @cached_property + @property def multi_modal_data(self) -> "MultiModalDataDict": - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("multi_modal_data", {}) - - assert_never(inputs) - - @cached_property - def mm_processor_kwargs(self) -> Dict[str, Any]: - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("mm_processor_kwargs", {}) - - assert_never(inputs) + return self.inputs.multi_modal_data @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("multi_modal_placeholders", {}) + return self.inputs.multi_modal_placeholders - assert_never(inputs) + @property + def mm_processor_kwargs(self) -> Dict[str, Any]: + return self.inputs.mm_processor_kwargs @property def lora_int_id(self) -> int: diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 2d7c58cfea13b..09bff9655a882 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -6,6 +6,7 @@ from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.protocol import EngineClient from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType +from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -321,6 +322,9 @@ async def get_model_config(self) -> ModelConfig: async def get_decoding_config(self): raise ValueError("Not Supported on V1 yet.") + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.processor.input_preprocessor + async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 5b45615a1b85b..4ebfff9584267 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -7,6 +7,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -32,6 +33,7 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, multiprocess_mode: bool = False, ) -> None: @@ -50,7 +52,7 @@ def __init__( # Processor (convert Inputs --> EngineCoreRequests) self.processor = Processor(vllm_config.model_config, vllm_config.lora_config, self.tokenizer, - input_registry) + input_registry, mm_registry) # Detokenizer (converts EngineCoreOutputs --> RequestOutput) self.detokenizer = Detokenizer( diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 5f13cbf2e4036..5c1577190c75a 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -2,15 +2,17 @@ from typing import Any, Dict, Mapping, Optional, Tuple, Union from vllm.config import LoRAConfig, ModelConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, - EncoderDecoderLLMInputs, InputRegistry, PromptType) +from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, + PromptType, SingletonInputsAdapter) +from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.lora.request import LoRARequest +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.config import try_get_generation_config -from vllm.transformers_utils.tokenizer_group import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest @@ -20,8 +22,9 @@ def __init__( self, model_config: ModelConfig, lora_config: Optional[LoRAConfig], - tokenizer: AnyTokenizer, + tokenizer: BaseTokenizerGroup, input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): self.model_config = model_config @@ -31,7 +34,8 @@ def __init__( self.generation_config_fields = _load_generation_config_dict( model_config) self.input_preprocessor = InputPreprocessor(model_config, - self.tokenizer) + self.tokenizer, + mm_registry) self.input_processor = input_registry.create_input_processor( model_config) @@ -73,6 +77,19 @@ def process_inputs( self._validate_model_inputs(processed_inputs) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + if is_encoder_decoder_inputs(processed_inputs): + decoder_inputs = SingletonInputsAdapter( + processed_inputs["decoder"]) + encoder_inputs = SingletonInputsAdapter( + processed_inputs["encoder"]) + else: + decoder_inputs = SingletonInputsAdapter(processed_inputs) + encoder_inputs = None + + # TODO: Impl encoder-decoder + if encoder_inputs is not None: + raise NotImplementedError + assert isinstance(params, SamplingParams) # TODO: can we avoid cloning here in multiproc case sampling_params = params.clone() @@ -81,27 +98,43 @@ def process_inputs( # Make Request for Detokenizer. detokenizer_request = DetokenizerRequest( - request_id, processed_inputs.get("prompt"), - processed_inputs.get("prompt_token_ids"), + request_id, + decoder_inputs.prompt, + decoder_inputs.prompt_token_ids, sampling_params.skip_special_tokens, sampling_params.spaces_between_special_tokens, - sampling_params.output_kind, sampling_params.stop, - sampling_params.include_stop_str_in_output) + sampling_params.output_kind, + sampling_params.stop, + sampling_params.include_stop_str_in_output, + ) # Make Request for EngineCore. engine_core_request = EngineCoreRequest( - request_id, processed_inputs.get("prompt"), - processed_inputs.get("prompt_token_ids"), - processed_inputs.get("multi_modal_data"), - processed_inputs.get("multi_modal_placeholders"), - processed_inputs.get("mm_processor_kwargs"), sampling_params, - eos_token_id, arrival_time, lora_request) + request_id, + decoder_inputs.prompt, + decoder_inputs.prompt_token_ids, + decoder_inputs.multi_modal_data, + decoder_inputs.multi_modal_placeholders, + decoder_inputs.mm_processor_kwargs, + sampling_params, + eos_token_id, + arrival_time, + lora_request, + ) return detokenizer_request, engine_core_request - def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, - EncoderDecoderLLMInputs]): - prompt_ids = inputs.get("prompt_token_ids") + def _validate_model_inputs(self, inputs: ProcessorInputs): + if is_encoder_decoder_inputs(inputs): + # For encoder-decoder multimodal models, the max_prompt_len + # restricts the decoder prompt length + prompt_inputs = inputs["decoder" if self.model_config. + is_multimodal_model else "encoder"] + else: + prompt_inputs = inputs + + prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids + if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") @@ -117,6 +150,10 @@ def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, "inputs, the number of image tokens depends on the number " "of images, and possibly their aspect ratios as well.") + # TODO: Find out how many placeholder tokens are there so we can + # check that chunked prefill does not truncate them + # max_batch_len = self.scheduler_config.max_num_batched_tokens + def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: config = try_get_generation_config( diff --git a/vllm/v1/request.py b/vllm/v1/request.py index f35cf738c89bf..51fb4003e5fe0 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,7 +1,7 @@ import enum -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Optional, Union -from vllm.inputs.data import DecoderOnlyInputs +from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams @@ -9,23 +9,20 @@ from vllm.v1.engine import EngineCoreRequest from vllm.v1.utils import ConstantList -if TYPE_CHECKING: - from vllm.inputs import DecoderOnlyInputs - class Request: def __init__( self, request_id: str, - inputs: "DecoderOnlyInputs", + inputs: DecoderOnlyInputs, sampling_params: SamplingParams, eos_token_id: Optional[int], arrival_time: float, lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id - self.inputs = inputs + self.inputs = SingletonInputsAdapter(inputs) self.sampling_params = sampling_params # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id @@ -41,17 +38,17 @@ def __init__( assert sampling_params.max_tokens is not None self.max_tokens = sampling_params.max_tokens - self.prompt = inputs.get("prompt") - self.prompt_token_ids = inputs["prompt_token_ids"] + self.prompt = self.inputs.prompt + self.prompt_token_ids = self.inputs.prompt_token_ids self.num_prompt_tokens = len(self.prompt_token_ids) self._output_token_ids: List[int] = [] self._all_token_ids: List[int] = self.prompt_token_ids.copy() self.num_computed_tokens = 0 # Raw multimodal data before the mm input mapper (e.g., PIL images). - self.mm_data = inputs.get("multi_modal_data") - self.mm_processor_kwargs = inputs.get("mm_processor_kwargs") - mm_positions = inputs.get("multi_modal_placeholders") + self.mm_data = self.inputs.multi_modal_data + self.mm_processor_kwargs = self.inputs.mm_processor_kwargs + mm_positions = self.inputs.multi_modal_placeholders if mm_positions: # FIXME(woosuk): Support other modalities. self.mm_positions = mm_positions.get("image", []) @@ -64,8 +61,7 @@ def __init__( def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( request_id=request.request_id, - inputs=DecoderOnlyInputs( - type="token", + inputs=token_inputs( prompt_token_ids=request.prompt_token_ids, prompt=request.prompt, multi_modal_data=request.mm_data, @@ -114,7 +110,7 @@ def get_finished_reason(self) -> Union[str, None]: return RequestStatus.get_finished_reason(self.status) def has_encoder_inputs(self) -> bool: - return self.mm_data is not None + return len(self.mm_data) > 0 @property def num_encoder_inputs(self) -> int: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 81480786a09e1..eebd1de96537f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -28,7 +28,7 @@ from vllm.v1.sample.metadata import SamplingMetadata if TYPE_CHECKING: - from vllm.multimodal.base import PlaceholderRange + from vllm.multimodal.inputs import PlaceholderRange from vllm.v1.core.scheduler import SchedulerOutput logger = init_logger(__name__) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 09c62fbb9875f..d3e1202c15e61 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -148,19 +148,29 @@ def build(self) -> ModelInputForCPU: query_lens=seq_lens, ) - def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata, - seq_data: SequenceData, computed_len: int, - mm_processor_kwargs: Dict[str, Any]): - + def _compute_multi_modal_input( + self, + seq_data: SequenceData, + computed_len: int, + seq_group_metadata: SequenceGroupMetadata, + ): # NOTE: mm_data only includes the subset of multi-modal items that # intersect with the current prefill positions. mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( - seq_group, range(computed_len, len(seq_data.get_token_ids()))) + seq_group_metadata, + range(computed_len, len(seq_data.get_token_ids())), + ) if not mm_data: - return + return None, None, None - mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs) + if self.runner.mm_registry.has_processor(self.runner.model_config): + mm_kwargs = mm_data + else: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + seq_group_metadata.mm_processor_kwargs, + ) # special processing for mrope position deltas. mrope_positions = None @@ -202,7 +212,7 @@ def _prepare_prompt( slot_mapping: List[int] = [] seq_lens: List[int] = [] - multi_model_kwargs_list: List[MultiModalKwargs] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] multi_modal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) @@ -223,11 +233,14 @@ def _prepare_prompt( mrope_positions = None if seq_group_metadata.multi_modal_data: - mm_kwargs, placeholder_maps, mrope_positions = self \ - ._compute_multi_modal_input( - seq_group_metadata, seq_data, computed_len, - seq_group_metadata.mm_processor_kwargs) - multi_model_kwargs_list.append(mm_kwargs) + ( + mm_kwargs, + placeholder_maps, + mrope_positions, + ) = self._compute_multi_modal_input(seq_data, computed_len, + seq_group_metadata) + + multi_modal_kwargs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items(): multi_modal_placeholder_maps[modality].extend( placeholder_map) @@ -302,7 +315,7 @@ def _prepare_prompt( multi_modal_placeholder_index_maps=placeholder_index_maps, ) - multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 92d6552b2f428..1ff30d685c6b1 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -716,7 +716,7 @@ def _prepare_prompt( context_lens: List[int] = [] query_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] - multi_model_kwargs_list: List[MultiModalKwargs] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] if len(seq_group_metadata_list) == 0: return PreparePromptMetadata.empty() @@ -777,7 +777,7 @@ def _prepare_prompt( mm_data = seq_group_metadata.multi_modal_data if mm_data: mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_model_kwargs_list.append(mm_kwargs) + multi_modal_kwargs_list.append(mm_kwargs) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized @@ -876,7 +876,7 @@ def _prepare_prompt( multi_modal_placeholder_index_maps= None # FIXME(kzawora): mutli-modality will not work here ) - multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return PreparePromptMetadata(input_tokens=input_tokens, input_positions=input_positions, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2da02f21f8342..042f9f07eace6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -252,7 +252,7 @@ def __init__( prompt_adapter_request: Optional[PromptAdapterRequest] = None, # Multi-modal inputs. - multi_model_kwargs: Optional[MultiModalKwargs] = None, + multi_modal_kwargs: Optional[MultiModalKwargs] = None, multi_modal_placeholder_maps: Optional[Dict[ str, MultiModalPlaceholderMap]] = None, @@ -373,7 +373,7 @@ def __init__( prompt_adapter_prompt_mapping or []) self.prompt_adapter_request = prompt_adapter_request - self.multi_model_kwargs = multi_model_kwargs + self.multi_modal_kwargs = multi_modal_kwargs self.multi_modal_placeholder_maps = multi_modal_placeholder_maps self.prefix_cache_hit = prefix_cache_hit @@ -661,10 +661,15 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, if not mm_data: return - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs) - inter_data.multi_model_kwargs = mm_kwargs + if self.runner.mm_registry.has_processor(self.runner.model_config): + mm_kwargs = mm_data + else: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + seq_group_metadata.mm_processor_kwargs, + ) + + inter_data.multi_modal_kwargs = mm_kwargs inter_data.multi_modal_placeholder_maps = placeholder_maps # special processing for mrope position deltas. @@ -938,11 +943,11 @@ def build(self) -> ModelInputForGPU: ) # Multi-modal data. - multi_model_kwargs_list = [ - data.multi_model_kwargs for data in self.inter_data_list - if data.multi_model_kwargs is not None + multi_modal_kwargs_list = [ + data.multi_modal_kwargs for data in self.inter_data_list + if data.multi_modal_kwargs is not None ] - multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return self.model_input_cls( input_tokens=input_tokens_tensor, diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 0ed33e435aa2f..ae4eb6ba6eaec 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -67,7 +67,8 @@ def __init__( self.pin_memory = is_pin_memory_available() # Multi-modal data support - self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + self.mm_registry = MULTIMODAL_REGISTRY + self.multi_modal_input_mapper = self.mm_registry \ .create_input_mapper(self.model_config) # Lazy initialization. @@ -122,7 +123,7 @@ def _prepare_prompt( input_block_ids: List[int] = [] seq_lens: List[int] = [] - multi_model_kwargs_list: List[MultiModalKwargs] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -144,12 +145,15 @@ def _prepare_prompt( mm_data = seq_group_metadata.multi_modal_data if mm_data: - # Process multi-modal data - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs, - ) - multi_model_kwargs_list.append(mm_kwargs) + if self.mm_registry.has_processor(self.model_config): + mm_kwargs = mm_data + else: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + seq_group_metadata.mm_processor_kwargs, + ) + + multi_modal_kwargs_list.append(mm_kwargs) max_seq_len = max(seq_lens) assert max_seq_len > 0 @@ -167,7 +171,7 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return (input_tokens, input_positions, input_block_ids, seq_lens, multi_modal_kwargs) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 378e1e06039b2..6000e5dfe4e30 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -70,7 +70,8 @@ def __init__( ) # Multi-modal data support - self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + self.mm_registry = MULTIMODAL_REGISTRY + self.multi_modal_input_mapper = self.mm_registry \ .create_input_mapper(self.model_config) # Lazy initialization. @@ -102,7 +103,7 @@ def _prepare_model_input( seq_lens: List[int] = [] past_lens: List[int] = [] query_lens: List[int] = [] - multi_model_kwargs_list: List[MultiModalKwargs] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] multi_modal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) @@ -222,11 +223,15 @@ def _prepare_model_input( mm_data, placeholder_maps = MultiModalPlaceholderMap \ .from_seq_group(seq_group_metadata, positions_range) - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - mm_processor_kwargs=seq_group_metadata. - mm_processor_kwargs) - multi_model_kwargs_list.append(mm_kwargs) + if self.mm_registry.has_processor(self.model_config): + mm_kwargs = mm_data + else: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + seq_group_metadata.mm_processor_kwargs, + ) + + multi_modal_kwargs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items(): multi_modal_placeholder_maps[modality].extend( @@ -275,7 +280,7 @@ def _prepare_model_input( multi_modal_placeholder_index_maps=placeholder_index_maps, ) - multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return ModelInput( input_tokens, diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index c9e637c057979..e6322e095bbb9 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -160,7 +160,7 @@ def _prepare_prompt( input_positions: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] - multi_model_kwargs_list: List[MultiModalKwargs] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] multi_modal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) @@ -191,8 +191,16 @@ def _prepare_prompt( mm_data, placeholder_maps = MultiModalPlaceholderMap \ .from_seq_group(seq_group_metadata, positions_range) - mm_kwargs = self.runner.multi_modal_input_mapper(mm_data) - multi_model_kwargs_list.append(mm_kwargs) + if self.runner.mm_registry.has_processor( + self.runner.model_config): + mm_kwargs = mm_data + else: + mm_kwargs = self.runner.multi_modal_input_mapper( + mm_data, + seq_group_metadata.mm_processor_kwargs, + ) + + multi_modal_kwargs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items(): multi_modal_placeholder_maps[modality].extend( @@ -264,7 +272,7 @@ def _prepare_prompt( block_tables=torch.tensor([], device=self.device, dtype=torch.int), ) - multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs)