From ce2c9effde7b39dafc3dc08d9bb4905cf9b205dc Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 5 Nov 2024 17:11:33 +0000 Subject: [PATCH] Initial prototype for multi-modal processor Signed-off-by: DarkLight1337 --- .../dev/multimodal/multimodal_index.rst | 2 +- .../models/enabling_multimodal_inputs.rst | 2 +- .../mm_processor_kwargs/test_qwen.py | 4 +- tests/multimodal/test_base.py | 22 +- vllm/config.py | 2 +- vllm/engine/async_llm_engine.py | 4 + vllm/engine/llm_engine.py | 5 +- vllm/engine/multiprocessing/client.py | 6 + vllm/engine/protocol.py | 16 +- vllm/entrypoints/openai/serving_chat.py | 1 - vllm/entrypoints/openai/serving_completion.py | 1 - vllm/inputs/__init__.py | 4 +- vllm/inputs/data.py | 13 +- vllm/inputs/preprocess.py | 130 +++++++-- vllm/inputs/registry.py | 17 +- vllm/model_executor/models/chatglm.py | 4 +- vllm/model_executor/models/fuyu.py | 4 +- vllm/model_executor/models/h2ovl.py | 10 +- vllm/model_executor/models/internvl.py | 6 +- vllm/model_executor/models/minicpmv.py | 4 +- vllm/model_executor/models/mllama.py | 2 +- vllm/model_executor/models/molmo.py | 4 +- vllm/model_executor/models/pixtral.py | 10 +- vllm/model_executor/models/qwen.py | 12 +- vllm/model_executor/models/qwen2_audio.py | 8 +- vllm/model_executor/models/qwen2_vl.py | 8 +- vllm/model_executor/models/ultravox.py | 8 +- vllm/multimodal/__init__.py | 10 +- vllm/multimodal/audio.py | 4 +- vllm/multimodal/base.py | 187 +----------- vllm/multimodal/image.py | 10 +- vllm/multimodal/inputs.py | 229 +++++++++++++++ vllm/multimodal/processing.py | 275 ++++++++++++++++++ vllm/multimodal/registry.py | 84 +++++- vllm/multimodal/video.py | 6 +- vllm/sequence.py | 37 ++- vllm/spec_decode/draft_model_runner.py | 4 +- vllm/v1/engine/llm_engine.py | 15 +- vllm/worker/cpu_enc_dec_model_runner.py | 4 +- vllm/worker/cpu_model_runner.py | 16 +- vllm/worker/embedding_model_runner.py | 4 +- vllm/worker/enc_dec_model_runner.py | 4 +- vllm/worker/model_runner.py | 27 +- vllm/worker/neuron_model_runner.py | 24 +- vllm/worker/openvino_model_runner.py | 23 +- vllm/worker/xpu_model_runner.py | 20 +- 46 files changed, 942 insertions(+), 350 deletions(-) create mode 100644 vllm/multimodal/inputs.py create mode 100644 vllm/multimodal/processing.py diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index e112b43aade5e..30f543abc20c7 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -53,7 +53,7 @@ Base Classes .. autodata:: vllm.multimodal.MultiModalDataDict -.. autoclass:: vllm.multimodal.MultiModalInputs +.. autoclass:: vllm.multimodal.MultiModalKwargs :members: :show-inheritance: diff --git a/docs/source/models/enabling_multimodal_inputs.rst b/docs/source/models/enabling_multimodal_inputs.rst index 3d0d1aec69845..49b5285c45590 100644 --- a/docs/source/models/enabling_multimodal_inputs.rst +++ b/docs/source/models/enabling_multimodal_inputs.rst @@ -66,7 +66,7 @@ A default mapper is available for each modality in the core vLLM library. This i 3. Register maximum number of multi-modal tokens ------------------------------------------------ -For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data instance +For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data item and register it via :meth:`INPUT_REGISTRY.register_dummy_data `. .. code-block:: diff diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py index a01651b171d60..9f2daa4c7273f 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py @@ -6,7 +6,7 @@ from PIL.Image import Image from vllm.inputs import InputContext, token_inputs -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal import MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from .....conftest import IMAGE_ASSETS @@ -96,7 +96,7 @@ def test_input_mapper_valid_mm_data(input_mapper_for_qwen, mapped_img_data = input_mapper_for_qwen(qwen_vl_context, img_data) # Ensure that we get the appropriately shaped pixel_values # for images and image embeddings, respectively. - assert isinstance(mapped_img_data, MultiModalInputs) + assert isinstance(mapped_img_data, MultiModalKwargs) assert "pixel_values" in mapped_img_data assert mapped_img_data["pixel_values"].shape == expected_shape diff --git a/tests/multimodal/test_base.py b/tests/multimodal/test_base.py index 68d05de904ba8..bfaf2cdeaa8d4 100644 --- a/tests/multimodal/test_base.py +++ b/tests/multimodal/test_base.py @@ -1,6 +1,6 @@ import torch -from vllm.multimodal.base import MultiModalInputs, NestedTensors +from vllm.multimodal.base import MultiModalKwargs, NestedTensors def assert_nested_tensors_equal(expected: NestedTensors, @@ -13,8 +13,8 @@ def assert_nested_tensors_equal(expected: NestedTensors, assert_nested_tensors_equal(expected_item, actual_item) -def assert_multimodal_inputs_equal(expected: MultiModalInputs, - actual: MultiModalInputs): +def assert_multimodal_inputs_equal(expected: MultiModalKwargs, + actual: MultiModalKwargs): assert set(expected.keys()) == set(actual.keys()) for key in expected: assert_nested_tensors_equal(expected[key], actual[key]) @@ -22,7 +22,7 @@ def assert_multimodal_inputs_equal(expected: MultiModalInputs, def test_multimodal_input_batch_single_tensor(): t = torch.rand([1, 2]) - result = MultiModalInputs.batch([{"image": t}]) + result = MultiModalKwargs.batch([{"image": t}]) assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)}) @@ -30,7 +30,7 @@ def test_multimodal_input_batch_multiple_tensors(): a = torch.rand([1, 1, 2]) b = torch.rand([1, 1, 2]) c = torch.rand([1, 1, 2]) - result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}]) + result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}]) assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])}) @@ -38,7 +38,7 @@ def test_multimodal_input_batch_multiple_heterogeneous_tensors(): a = torch.rand([1, 2, 2]) b = torch.rand([1, 3, 2]) c = torch.rand([1, 4, 2]) - result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}]) + result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}]) assert_multimodal_inputs_equal(result, {"image": [a, b, c]}) @@ -46,7 +46,7 @@ def test_multimodal_input_batch_nested_tensors(): a = torch.rand([2, 3]) b = torch.rand([2, 3]) c = torch.rand([2, 3]) - result = MultiModalInputs.batch([{ + result = MultiModalKwargs.batch([{ "image": [a] }, { "image": [b] @@ -65,7 +65,7 @@ def test_multimodal_input_batch_heterogeneous_lists(): a = torch.rand([1, 2, 3]) b = torch.rand([1, 2, 3]) c = torch.rand([1, 2, 3]) - result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}]) + result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}]) assert_multimodal_inputs_equal( result, {"image": [torch.stack([a, b]), c.unsqueeze(0)]}) @@ -76,7 +76,7 @@ def test_multimodal_input_batch_multiple_batchable_lists(): b = torch.rand([1, 2, 3]) c = torch.rand([1, 2, 3]) d = torch.rand([1, 2, 3]) - result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}]) + result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c, d]}]) assert_multimodal_inputs_equal( result, {"image": torch.stack([torch.stack([a, b]), @@ -88,8 +88,8 @@ def test_multimodal_input_batch_mixed_stacking_depths(): b = torch.rand([1, 3, 3]) c = torch.rand([1, 4, 3]) - result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}]) + result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}]) assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]}) - result = MultiModalInputs.batch([{"image": [a]}, {"image": [b, c]}]) + result = MultiModalKwargs.batch([{"image": [a]}, {"image": [b, c]}]) assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]}) diff --git a/vllm/config.py b/vllm/config.py index 814e00c8785f0..86127e23e2d3c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -107,7 +107,7 @@ class ModelConfig: matches the model name exposed via the APIs. If multiple model names provided, the first name will be used. If not specified, the model name will be the same as `model`. - limit_mm_per_prompt: Maximum number of data instances per modality + limit_mm_per_prompt: Maximum number of data items per modality per prompt. Only applicable for multimodal models. override_neuron_config: Initialize non default neuron config or override default neuron config that are specific to Neuron devices, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index b0fdc67776bbd..d0ad253489827 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -19,6 +19,7 @@ from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType +from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( @@ -721,6 +722,9 @@ def _error_callback(self, exc: Exception) -> None: self.set_errored(exc) self._request_tracker.propagate_exception(exc) + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.engine.input_preprocessor + async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a1809b1a9dd26..3cf4e582a245a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -39,6 +39,7 @@ from vllm.model_executor.guided_decoding import ( get_local_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams @@ -226,6 +227,7 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, ) -> None: @@ -338,7 +340,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: model_config) self.input_preprocessor = InputPreprocessor(model_config, - self.tokenizer) + self.tokenizer, + mm_registry) self.input_registry = input_registry self.input_processor = input_registry.create_input_processor( diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 882742c2fc61b..fe21c58c775fe 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -31,6 +31,7 @@ # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT from vllm.inputs import PromptType +from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -94,6 +95,8 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig, parallel_config=engine_config.parallel_config, enable_lora=bool(engine_config.lora_config), ) + self.input_preprocessor = InputPreprocessor(self.model_config, + self.tokenizer) # Send RPCGenerateRequest to the MQLLMEngine. self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) @@ -345,6 +348,9 @@ async def _check_success(error_message: str, socket: Socket): or response != VLLM_RPC_SUCCESS_STR): raise ValueError(error_message) + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.input_preprocessor + async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): return await self.tokenizer.get_lora_tokenizer_async(lora_request) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index e0b59d94cfdc3..e15395d75c91f 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -62,7 +62,6 @@ def generate( async def beam_search( self, prompt: PromptType, - model_config: ModelConfig, request_id: str, params: BeamSearchParams, ) -> AsyncGenerator[RequestOutput, None]: @@ -74,13 +73,14 @@ async def beam_search( length_penalty = params.length_penalty include_stop_str_in_output = params.include_stop_str_in_output - tokenizer = await self.get_tokenizer() - input_preprocessor = InputPreprocessor(model_config, tokenizer) + preprocessor = await self.get_input_preprocessor() + tokenizer_group = preprocessor.get_tokenizer_group() + tokenizer = await tokenizer_group.get_lora_tokenizer_async() if is_explicit_encoder_decoder_prompt(prompt): raise NotImplementedError else: - processed_inputs = input_preprocessor._prompt_to_llm_inputs( + processed_inputs = preprocessor._prompt_to_llm_inputs( prompt, request_id=request_id, ) @@ -220,6 +220,7 @@ async def abort(self, request_id: str) -> None: Args: request_id: The unique id of the request. """ + ... @abstractmethod async def get_model_config(self) -> ModelConfig: @@ -228,8 +229,13 @@ async def get_model_config(self) -> ModelConfig: @abstractmethod async def get_decoding_config(self) -> DecodingConfig: - ... """Get the decoding configuration of the vLLM engine.""" + ... + + @abstractmethod + async def get_input_preprocessor(self) -> InputPreprocessor: + """Get the input processor of the vLLM engine.""" + ... @abstractmethod async def get_tokenizer( diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9551b4f2091dd..359578fd9f7e4 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -187,7 +187,6 @@ async def create_chat_completion( if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( prompt=engine_prompt, - model_config=self.model_config, request_id=request_id, params=sampling_params, ) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 570232be38379..5a9da5e0c47b4 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -140,7 +140,6 @@ async def create_completion( if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( prompt=engine_prompt, - model_config=self.model_config, request_id=request_id, params=sampling_params, ) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 68ac50a2c5a16..338589ed04dc4 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -3,7 +3,8 @@ SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) -from .registry import DummyData, InputContext, InputRegistry +from .registry import (DummyData, InputContext, InputProcessingContext, + InputRegistry) INPUT_REGISTRY = InputRegistry() """ @@ -32,6 +33,7 @@ "INPUT_REGISTRY", "DummyData", "InputContext", + "InputProcessingContext", "InputRegistry", ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 46b41f431bec7..86db955d7baa2 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -5,6 +5,7 @@ if TYPE_CHECKING: from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict + from vllm.multimodal.inputs import MultiModalKwargsV2 class TextPrompt(TypedDict): @@ -36,13 +37,13 @@ class TokensPrompt(TypedDict): multi_modal_data: NotRequired["MultiModalDataDict"] """ - Optional multi-modal data to pass to the model, + DEPRECATED: Optional multi-modal data to pass to the model, if the model supports it. """ mm_processor_kwargs: NotRequired[Dict[str, Any]] """ - Optional multi-modal processor kwargs to be forwarded to the + DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the multimodal input mapper & processor. Note that if multiple modalities have registered mappers etc for the model being considered, we attempt to pass the mm_processor_kwargs to each of them. @@ -176,7 +177,7 @@ def token_inputs( return inputs -DecoderOnlyInputs = TokenInputs +DecoderOnlyInputs = Union[TokenInputs, MultiModalKwargsV2] """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. @@ -191,14 +192,14 @@ class EncoderDecoderInputs(TypedDict): This specifies the required data for encoder-decoder models. """ - encoder: TokenInputs + encoder: Union[TokenInputs, MultiModalKwargsV2] """The inputs for the encoder portion.""" - decoder: TokenInputs + decoder: Union[TokenInputs, MultiModalKwargsV2] """The inputs for the decoder portion.""" -SingletonInputs = TokenInputs +SingletonInputs = Union[TokenInputs, MultiModalKwargsV2] """ A processed :class:`SingletonPrompt` which can be passed to :class:`vllm.sequence.Sequence`. diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index a5c787a56b5a9..42f7281bc37dc 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,11 +1,13 @@ import asyncio -from typing import List, Optional +from typing import List, Mapping, Optional, Union from typing_extensions import assert_never from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.processing import MultiModalDataDict, MultiModalKwargsV2 from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.utils import print_warning_once @@ -23,11 +25,13 @@ def __init__( self, model_config: ModelConfig, tokenizer: Optional[BaseTokenizerGroup], + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ) -> None: super().__init__() self.model_config = model_config self.tokenizer = tokenizer + self.mm_registry = mm_registry def get_tokenizer_group(self) -> BaseTokenizerGroup: if self.tokenizer is None: @@ -198,14 +202,66 @@ async def _tokenize_prompt_async( prompt=prompt, lora_request=lora_request) + def _can_process_multimodal(self) -> bool: + # Interim measure so we can handle models that have yet to be + # updated to use the new multi-modal processor + return self.mm_registry.has_processor(self.model_config) + + def _process_multimodal( + self, + prompt: Union[str, List[int]], + mm_data: MultiModalDataDict, + mm_processor_kwargs: Optional[Mapping[str, object]], + lora_request: Optional[LoRARequest], + ) -> MultiModalKwargsV2: + """ + Apply the model's multi-modal processor to a multi-modal prompt, + returning the corresponding token IDs and metadata. + """ + tokenizer_group = self.get_tokenizer_group() + tokenizer = tokenizer_group.get_lora_tokenizer(lora_request) + + mm_processor = self.mm_registry.create_processor( + self.model_config, tokenizer) + + if isinstance(prompt, list): + prompt = tokenizer.decode(prompt) + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + + return mm_processor.apply(prompt, mm_data, mm_processor_kwargs) + + async def _process_multimodal_async( + self, + prompt: Union[str, List[int]], + mm_data: MultiModalDataDict, + mm_processor_kwargs: Optional[Mapping[str, object]], + lora_request: Optional[LoRARequest], + ) -> MultiModalKwargsV2: + """Async version of :meth:`_process_multimodal`.""" + tokenizer_group = self.get_tokenizer_group() + tokenizer = await tokenizer_group.get_lora_tokenizer_async(lora_request + ) + + mm_processor = self.mm_registry.create_processor( + self.model_config, tokenizer) + if isinstance(prompt, list): + logger.warning("Passing `multi_modal_data` in TokensPrompt is" + "deprecated and will be removed in a future update") + prompt = tokenizer.decode(prompt) + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + + return mm_processor.apply(prompt, mm_data, mm_processor_kwargs) + def _prompt_to_llm_inputs( self, prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> SingletonInputs: - ''' - Extract the components of any single encoder or decoder input prompt. + """ + Extract the singleton inputs from a prompt. Arguments: @@ -215,12 +271,8 @@ def _prompt_to_llm_inputs( Returns: - * prompt - * prompt_token_ids - * multi_modal_data - * mm_processor_kwargs (request-level input processor/mapper overrides) - ''' - + * :class:`SingletonInputs` instance + """ parsed = parse_singleton_prompt(prompt) if parsed["type"] == "str": @@ -243,6 +295,14 @@ def _prompt_to_llm_inputs( multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + if multi_modal_data is not None and self._can_process_multimodal(): + return self._process_multimodal( + prompt_token_ids, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + ) + return token_inputs( prompt_token_ids=prompt_token_ids, multi_modal_data=multi_modal_data, @@ -253,13 +313,22 @@ def _prompt_to_llm_inputs( text_content = parsed["content"] prompt_text = text_content["prompt"] + multi_modal_data = text_content.get("multi_modal_data") + mm_processor_kwargs = text_content.get("mm_processor_kwargs") + + if multi_modal_data is not None and self._can_process_multimodal(): + return self._process_multimodal( + prompt_text, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + ) + prompt_token_ids = self._tokenize_prompt( prompt_text, request_id=request_id, lora_request=lora_request, ) - multi_modal_data = text_content.get("multi_modal_data") - mm_processor_kwargs = text_content.get("mm_processor_kwargs") return token_inputs( prompt=prompt_text, @@ -299,6 +368,14 @@ async def _prompt_to_llm_inputs_async( multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + if multi_modal_data is not None and self._can_process_multimodal(): + return await self._process_multimodal_async( + prompt_token_ids, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + ) + return token_inputs( prompt_token_ids=prompt_token_ids, multi_modal_data=multi_modal_data, @@ -309,13 +386,22 @@ async def _prompt_to_llm_inputs_async( text_content = parsed["content"] prompt_text = text_content["prompt"] + multi_modal_data = text_content.get("multi_modal_data") + mm_processor_kwargs = text_content.get("mm_processor_kwargs") + + if multi_modal_data is not None and self._can_process_multimodal(): + return await self._process_multimodal_async( + prompt_text, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + ) + prompt_token_ids = await self._tokenize_prompt_async( prompt_text, request_id=request_id, lora_request=lora_request, ) - multi_modal_data = text_content.get("multi_modal_data") - mm_processor_kwargs = text_content.get("mm_processor_kwargs") return token_inputs( prompt=prompt_text, @@ -331,7 +417,8 @@ def _build_enc_dec_llm_inputs( encoder_inputs: SingletonInputs, decoder_inputs: Optional[SingletonInputs], ) -> EncoderDecoderInputs: - if encoder_inputs["type"] == "token": + if (encoder_inputs["type"] == "token" + or encoder_inputs["type"] == "multimodal"): pass else: assert_never(encoder_inputs) @@ -340,7 +427,8 @@ def _build_enc_dec_llm_inputs( dec_token_ids = self._prepare_decoder_input_ids_for_generation( None) decoder_inputs = token_inputs(dec_token_ids) - elif decoder_inputs["type"] == "token": + elif (decoder_inputs["type"] == "token" + or decoder_inputs["type"] == "multimodal"): dec_token_ids = self._prepare_decoder_input_ids_for_generation( decoder_inputs["prompt_token_ids"]) decoder_inputs["prompt_token_ids"] = dec_token_ids @@ -361,7 +449,7 @@ def _process_encoder_decoder_prompt( prompt: PromptType, request_id: str, ) -> EncoderDecoderInputs: - ''' + """ For encoder/decoder models only: Process an input prompt into an :class:`EncoderDecoderInputs` instance. @@ -391,8 +479,7 @@ def _process_encoder_decoder_prompt( Returns: * :class:`EncoderDecoderInputs` instance - ''' - + """ encoder_inputs: SingletonInputs decoder_inputs: Optional[SingletonInputs] @@ -460,7 +547,8 @@ def _build_decoder_only_llm_inputs( prompt_inputs: DecoderOnlyInputs, prompt_adapter_request: Optional[PromptAdapterRequest], ) -> DecoderOnlyInputs: - if prompt_inputs["type"] == "token": + if (prompt_inputs["type"] == "token" + or prompt_inputs["type"] == "multimodal"): prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( prompt_inputs["prompt_token_ids"], prompt_adapter_request=prompt_adapter_request, @@ -477,7 +565,7 @@ def _process_decoder_only_prompt( lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> DecoderOnlyInputs: - ''' + """ For decoder-only models: Process an input prompt into an :class:`DecoderOnlyInputs` instance. @@ -491,7 +579,7 @@ def _process_decoder_only_prompt( Returns: * :class:`DecoderOnlyInputs` instance - ''' + """ prompt_comps = self._prompt_to_llm_inputs( prompt, diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 7d7a797be4f60..47edcfd8ed1c8 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -5,10 +5,12 @@ Optional, Protocol, Type, cast) from torch import nn -from transformers import PretrainedConfig +from transformers import PretrainedConfig, ProcessorMixin from typing_extensions import TypeVar from vllm.logger import init_logger +from vllm.transformers_utils.processor import cached_get_processor +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, resolve_mm_processor_kwargs) @@ -61,6 +63,19 @@ def get_hf_image_processor_config(self) -> Dict[str, Any]: return self.model_config.hf_image_processor_config +@dataclass(frozen=True) +class InputProcessingContext(InputContext): + tokenizer: AnyTokenizer + """The tokenizer used to tokenize the inputs.""" + + def get_hf_processor(self) -> ProcessorMixin: + return cached_get_processor( + self.model_config.tokenizer, + tokenizer=self.tokenizer, # Override the tokenizer with ours + trust_remote_code=self.model_config.trust_remote_code, + ) + + N = TypeVar("N", bound=Type[nn.Module]) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index c3c9ec703c1e6..d0c5486ca5ac5 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -31,7 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.base import MultiModalData from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, @@ -74,7 +74,7 @@ def mm_input_mapper_for_glmv( raise pixel_values = raw_batch_data['images'] - return MultiModalInputs({'pixel_values': pixel_values}) + return MultiModalKwargs({'pixel_values': pixel_values}) def merge_glm_vision_embeddings( diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 0de590d1d8372..1975d8a570f9a 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -35,7 +35,7 @@ from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.base import MultiModalKwargs from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges) @@ -219,7 +219,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object): ]) # image has been processed with prompt in input processor - return MultiModalInputs({"pixel_values": data}) + return MultiModalKwargs({"pixel_values": data}) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu) diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 43242fe370ba2..767171dad7c7b 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -16,7 +16,7 @@ token_inputs) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.base import MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from vllm.utils import is_list_of @@ -324,12 +324,12 @@ def input_mapper( data: object, *, max_dynamic_patch: Optional[int] = None, - ) -> MultiModalInputs: + ) -> MultiModalKwargs: # NOTE: Preprocessing for the image data is done in the # 'input_processor' function during actual inference. if isinstance(data, dict): - return MultiModalInputs(data) + return MultiModalKwargs(data) # The section below is only used with dummy data during # memory profiling. @@ -347,7 +347,7 @@ def input_mapper( pixel_values = [image_pixel_values_mapper(img) for img in data] else: - return MultiModalInputs({"image_embeds": data}) + return MultiModalKwargs({"image_embeds": data}) model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, @@ -359,7 +359,7 @@ def input_mapper( return_tensors="pt", )[0] - return MultiModalInputs({ + return MultiModalKwargs({ "pixel_values": pixel_values, "image_token_id": image_token_id }) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index d2ec0ff6e74c6..dfa869c1f5531 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -26,7 +26,7 @@ InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.base import MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -346,7 +346,7 @@ def input_mapper( # we can't stack here because images may have different num_patches data = [image_pixel_values_mapper(img) for img in data] else: - return MultiModalInputs({"image_embeds": data}) + return MultiModalKwargs({"image_embeds": data}) model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, @@ -355,7 +355,7 @@ def input_mapper( add_special_tokens=False, return_tensors="pt")[0] - return MultiModalInputs({ + return MultiModalKwargs({ "pixel_values": data, "image_token_id": image_token_id }) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index f90df6b7df036..6ad04631f208e 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -53,7 +53,7 @@ from vllm.model_executor.models.utils import LLMWrapper from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.base import MultiModalKwargs from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SequenceData @@ -375,7 +375,7 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object): batch_data["slice_start_id"] = data[0]["slice_start_id"] batch_data["slice_end_id"] = data[0]["slice_end_id"] - return MultiModalInputs(batch_data) + return MultiModalKwargs(batch_data) class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 251bfc079684e..f7d2889a6fa5b 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1163,7 +1163,7 @@ def sample( def _parse_and_validate_image_input(self, **kwargs: object): # tensor with the same shape will be batched together by - # MultiModalInputs.batch, so pixel_values here can be: + # MultiModalKwargs.batch, so pixel_values here can be: # - List[List[torch.Tensor]]: # with shape (num_tiles, 3, image_res, image_res) # - List[torch.Tensor]: diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index ba798833e26a9..3aa68893de2d1 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -37,7 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) @@ -865,7 +865,7 @@ def image_input_mapper_for_molmo( ctx: InputContext, data: object, ): - return MultiModalInputs(data) + return MultiModalKwargs(data) def dummy_data_for_molmo(ctx: InputContext, seq_len: int, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 051454c49bff8..2e2683646ffdf 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -27,7 +27,7 @@ from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.base import MultiModalKwargs from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges) from vllm.sequence import IntermediateTensors, SequenceData @@ -91,8 +91,8 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, def input_mapper_for_pixtral(ctx: InputContext, - data: object) -> MultiModalInputs: - """Maps the input data to its MultiModalInputs (if any). + data: object) -> MultiModalKwargs: + """Maps the input data to its MultiModalKwargs (if any). Args: ctx: Context of the loaded model. @@ -100,7 +100,7 @@ def input_mapper_for_pixtral(ctx: InputContext, to pixel_values in .forward() for a visual QWenLMHeadModel model. Returns: - MultiModalInputs containing the stacked normalized images tensor or + MultiModalKwargs containing the stacked normalized images tensor or image embeddings. """ # Early exit if we have provided an image to a language only Qwen model @@ -118,7 +118,7 @@ def input_mapper_for_pixtral(ctx: InputContext, dtype=torch.float16) images.append(image) - return MultiModalInputs({"images": images}) + return MultiModalKwargs({"images": images}) def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index b2b5c70182135..a53ddf12c8f5d 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -44,7 +44,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.base import MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import is_list_of @@ -723,8 +723,8 @@ def input_processor_for_qwen(ctx: InputContext, multi_modal_data=multi_modal_data) -def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: - """Maps the input data to its MultiModalInputs (if any). +def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalKwargs: + """Maps the input data to its MultiModalKwargs (if any). Args: ctx: Context of the loaded model. @@ -732,7 +732,7 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: to pixel_values in .forward() for a visual QWenLMHeadModel model. Returns: - MultiModalInputs containing the stacked normalized images tensor or + MultiModalKwargs containing the stacked normalized images tensor or image embeddings. """ # Early exit if we have provided an image to a language only Qwen model @@ -741,7 +741,7 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: logger.warning( "Images were provided but this model has no visual config; " "multimodal inputs will not be forwarded to the model.") - return MultiModalInputs() + return MultiModalKwargs() model_config = ctx.model_config tokenizer = cached_get_tokenizer( @@ -785,7 +785,7 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: data = [data] transformed_images = [transform(datum) for datum in data] pixel_values = torch.stack(transformed_images, dim=0) - return MultiModalInputs({"pixel_values": pixel_values}) + return MultiModalKwargs({"pixel_values": pixel_values}) def build_normalization_transform(image_size: int) -> transforms.Compose: diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 6114548bda42c..c6edcf0071118 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -43,7 +43,7 @@ default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import IntermediateTensors, SequenceData @@ -222,13 +222,13 @@ def input_processor_for_qwen2_audio( def input_mapper_for_qwen2_audio( ctx: InputContext, multi_modal_data: Union[np.ndarray, List[np.ndarray]], -) -> MultiModalInputs: +) -> MultiModalKwargs: """Input mapper for Qwen2-Audio.""" if not isinstance(multi_modal_data, list): multi_modal_data = [multi_modal_data] if len(multi_modal_data) == 0: - return MultiModalInputs() + return MultiModalKwargs() processor = cached_get_processor(ctx.model_config.model) audio_feature_extractor = processor.feature_extractor @@ -255,7 +255,7 @@ def input_mapper_for_qwen2_audio( logger.error("Failed to process audio (%s)", multi_modal_data) raise - return MultiModalInputs(batch_data) + return MultiModalKwargs(batch_data) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_audio) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index d801903f8f9fe..5ed57ce2f2f9d 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -58,7 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, - MultiModalInputs) + MultiModalKwargs) from vllm.multimodal.base import MultiModalData from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer @@ -567,10 +567,10 @@ def mm_input_mapper_for_qwen2_vl( *, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, -) -> MultiModalInputs: +) -> MultiModalKwargs: """Input mapper for Qwen2-VL.""" if data_type_key == "image" and isinstance(data, dict): - return MultiModalInputs({ + return MultiModalKwargs({ "image_embeds": data.get("image_embeds"), "image_grid_thw": data.get("image_grid_thw"), }) @@ -608,7 +608,7 @@ def mm_input_mapper_for_qwen2_vl( logger.error("Failed to process image (%s)", data) raise - return MultiModalInputs(batch_data) + return MultiModalKwargs(batch_data) image_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 749750fc9c16e..13dfb21352529 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, NestedTensors) from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges, @@ -116,11 +116,11 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): data = [data] if len(data) == 0: - return MultiModalInputs() + return MultiModalKwargs() # If the audio inputs are embeddings, no need for preprocessing if is_list_of(data, torch.Tensor, check="all"): - return MultiModalInputs({"audio_embeds": data}) + return MultiModalKwargs({"audio_embeds": data}) audio_features = [] for audio_input in data: @@ -154,7 +154,7 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): # Remove the batch dimension because we're wrapping it in a list. audio_features.append(single_audio_features.squeeze(0)) - return MultiModalInputs({"audio_features": audio_features}) + return MultiModalKwargs({"audio_features": audio_features}) def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 53da2badb9b98..ae60e0d60bc79 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,7 +1,7 @@ -from .base import (BatchedTensorInputs, MultiModalDataBuiltins, - MultiModalDataDict, MultiModalInputs, - MultiModalPlaceholderDict, MultiModalPlaceholderMap, - MultiModalPlugin, NestedTensors) +from .base import MultiModalPlaceholderMap, MultiModalPlugin +from .inputs import (BatchedTensorInputs, MultiModalDataBuiltins, + MultiModalDataDict, MultiModalKwargs, + MultiModalPlaceholderDict, NestedTensors) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -17,7 +17,7 @@ "BatchedTensorInputs", "MultiModalDataBuiltins", "MultiModalDataDict", - "MultiModalInputs", + "MultiModalKwargs", "MultiModalPlaceholderDict", "MultiModalPlaceholderMap", "MultiModalPlugin", diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index 04d71826f29fa..e71ae5feec1c6 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -1,5 +1,5 @@ from vllm.inputs.registry import InputContext -from vllm.multimodal.base import MultiModalInputs, MultiModalPlugin +from vllm.multimodal.base import MultiModalKwargs, MultiModalPlugin class AudioPlugin(MultiModalPlugin): @@ -9,7 +9,7 @@ def get_data_key(self) -> str: return "audio" def _default_input_mapper(self, ctx: InputContext, data: object, - **mm_processor_kwargs) -> MultiModalInputs: + **mm_processor_kwargs) -> MultiModalKwargs: raise NotImplementedError("There is no default audio input mapper") def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 6b10d0c609f13..a94ed1dc68013 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,186 +1,26 @@ -import sys from abc import ABC, abstractmethod -from collections import UserDict, defaultdict -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, - NamedTuple, Optional, Tuple, Type, TypedDict, TypeVar, - Union, cast, final) - -import numpy as np -import torch -import torch.types -from PIL import Image +from collections import defaultdict +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, + Optional, Sequence, Tuple, Type, TypeVar, Union) + from torch import nn -from typing_extensions import TypeAlias from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, - json_map_leaves, resolve_mm_processor_kwargs) +from vllm.utils import (get_allowed_kwarg_only_overrides, + resolve_mm_processor_kwargs) if TYPE_CHECKING: from vllm.config import ModelConfig from vllm.sequence import SequenceGroupMetadata -logger = init_logger(__name__) - -NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] -""" -Uses a list instead of a tensor if the dimensions of each element do not match. -""" - -BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors] -""" -A dictionary containing nested tensors which have been batched via -:meth:`MultiModalInputs.batch`. -""" - -if sys.version_info < (3, 9): - # UserDict cannot be subscripted - class _MultiModalInputsBase(UserDict): - pass -else: - - class _MultiModalInputsBase(UserDict[str, NestedTensors]): - pass - - -class MultiModalInputs(_MultiModalInputsBase): - """ - A dictionary that represents the keyword arguments to - :meth:`~torch.nn.Module.forward`. - """ - - @staticmethod - def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: - """ - Recursively stacks lists of tensors when they all have the same shape. - """ - if isinstance(nested_tensors, torch.Tensor): - return nested_tensors - - if isinstance(nested_tensors, np.ndarray): - return torch.from_numpy(nested_tensors) - - if isinstance(nested_tensors, (int, float)): - return torch.tensor(nested_tensors) - - stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] - if not is_list_of(stacked, torch.Tensor, check="all"): - # Only tensors (not lists) can be stacked. - return stacked - - tensors_ = cast(List[torch.Tensor], stacked) - if any(t.shape != tensors_[0].shape for t in tensors_): - # The tensors have incompatible shapes and can't be stacked. - return tensors_ - - return torch.stack(tensors_) - - @staticmethod - def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs: - """ - Batch multiple inputs together into a dictionary. - - The resulting dictionary has the same keys as the inputs. - If the corresponding value from each input is a tensor and they all - share the same shape, the output value is a single batched tensor; - otherwise, the output value is a list containing the original value - from each input. - """ - if len(inputs_list) == 0: - return {} - - item_lists: Dict[str, List[NestedTensors]] = defaultdict(list) - - for inputs in inputs_list: - # For models that supports multiple modalities (e.g. Qwen2-VL), - # different modalities will return different data keys, - # so batch() should skip the same key check. - - for k, v in inputs.items(): - item_lists[k].append(v) - - return { - k: MultiModalInputs._try_stack(item_list) - for k, item_list in item_lists.items() - } - - @staticmethod - def as_kwargs( - batched_inputs: BatchedTensorInputs, - *, - device: torch.types.Device, - ) -> BatchedTensorInputs: - json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) - - json_mapped = json_map_leaves( - lambda x: x.to(device, non_blocking=True), - json_inputs, - ) - - return cast(BatchedTensorInputs, json_mapped) +from .inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs, + PlaceholderRange) - -_T = TypeVar("_T") - -MultiModalData: TypeAlias = Union[_T, List[_T]] -""" -Either a single data instance, or a list of data instances. - -The number of data instances allowed per modality is restricted by -`--limit-mm-per-prompt`. -""" - - -@final -class MultiModalDataBuiltins(TypedDict, total=False): - """Modality types that are predefined by vLLM.""" - - image: MultiModalData[Image.Image] - """The input image(s).""" - - audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]] - """The input audio item(s) and corresponding sampling rate(s).""" - - -MultiModalDataDict = Union[MultiModalDataBuiltins, - Mapping[str, MultiModalData[object]]] -""" -A dictionary containing an item for each modality type to input. - -Note: - This dictionary also accepts modality keys defined outside - :class:`MultiModalDataBuiltins` as long as a customized plugin is registered - through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. - Read more on that :ref:`here `. -""" - - -class PlaceholderRange(TypedDict): - """ - Placeholder location information for multi-modal data. - - For example: - Prompt: AAAA BBBB What is in these images? - Images A and B will have: - A: { "offset": 0, "length": 4 } - B: { "offset": 5, "length": 4 } - """ - - offset: int - """The start index of the placeholder in the prompt.""" - - length: int - """The length of the placeholder.""" - - -MultiModalPlaceholderDict = Mapping[str, List[PlaceholderRange]] -""" -A dictionary containing placeholder ranges. -""" +logger = init_logger(__name__) MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]], - MultiModalInputs] + MultiModalKwargs] """ Return a dictionary to be passed as keyword arguments to :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers @@ -195,6 +35,7 @@ class PlaceholderRange(TypedDict): model. This does not include tokens that correspond to the input text. """ +_T = TypeVar("_T") N = TypeVar("N", bound=Type[nn.Module]) @@ -229,7 +70,7 @@ def _default_input_mapper( ctx: InputContext, data: MultiModalData[object], **mm_processor_kwargs, - ) -> MultiModalInputs: + ) -> MultiModalKwargs: """ Return a dictionary to be passed as keyword arguments to :meth:`~torch.nn.Module.forward`. This is similar in concept to @@ -273,7 +114,7 @@ def wrapper(model_cls: N) -> N: def map_input(self, model_config: "ModelConfig", data: MultiModalData[object], - mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs: + mm_processor_kwargs: Dict[str, Any]) -> MultiModalKwargs: """ Transform the data into a dictionary of model inputs using the input mapper registered for that model. @@ -500,7 +341,7 @@ def from_seq_group( def append_items_from_seq_group( self, positions: range, multi_modal_items: List[_T], - multi_modal_placeholders: List[PlaceholderRange]) -> List[_T]: + multi_modal_placeholders: Sequence[PlaceholderRange]) -> List[_T]: """ Adds the multi-modal items that intersect ```positions`` to this placeholder map and returns the intersecting items. diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 3f6bb6c8338d2..589b46266b08d 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -10,7 +10,7 @@ from vllm.transformers_utils.processor import get_image_processor from vllm.utils import is_list_of -from .base import MultiModalData, MultiModalInputs, MultiModalPlugin +from .base import MultiModalData, MultiModalKwargs, MultiModalPlugin if TYPE_CHECKING: from vllm.config import ModelConfig @@ -43,12 +43,12 @@ def _default_input_mapper( ctx: InputContext, data: MultiModalData[object], **mm_processor_kwargs, - ) -> MultiModalInputs: + ) -> MultiModalKwargs: model_config = ctx.model_config # Processed by input processor if isinstance(data, BatchFeature): - return MultiModalInputs(data.data) + return MultiModalKwargs(data.data) # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): @@ -78,11 +78,11 @@ def _default_input_mapper( type(image_processor).__name__) raise - return MultiModalInputs(batch_data) + return MultiModalKwargs(batch_data) # Image embedding elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor): - return MultiModalInputs({"image_embeds": data}) + return MultiModalKwargs({"image_embeds": data}) raise TypeError(f"Invalid image type: {type(data)}") diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py new file mode 100644 index 0000000000000..c19befa856ada --- /dev/null +++ b/vllm/multimodal/inputs.py @@ -0,0 +1,229 @@ +import sys +from collections import UserDict, defaultdict +from typing import (Any, Dict, List, Literal, Mapping, Sequence, Tuple, + TypedDict, TypeVar, Union, cast, final) + +import numpy as np +import torch +import torch.types +from PIL.Image import Image +from typing_extensions import TypeAlias + +from vllm.utils import JSONTree, is_list_of, json_map_leaves + +_T = TypeVar("_T") + +# yapf: disable +ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor] +""" +A :class:`transformers.image_utils.ImageInput` representing a single image, +which can be passed to a HuggingFace :code:`ImageProcessor`. +""" + +VideoItem: TypeAlias = Union[ + List[Image], + np.ndarray, + torch.Tensor, + List[np.ndarray], + List[torch.Tensor], +] +""" + +A :class:`transformers.image_utils.VideoInput` representing a single video, +which can be passed to a HuggingFace :code:`VideoProcessor`. +""" + +AudioItem: TypeAlias = Union[ + np.ndarray, + List[float], + Tuple[np.ndarray, float], # DEPRECATED: Use mm_processor_kwargs instead +] +""" +Represents a single audio that can be inputted to a HuggingFace +:code:`AudioProcessor`. +""" +# yapf: enable + +MultiModalData: TypeAlias = Union[_T, List[_T]] +""" +Either a single data item, or a list of data items. + +The number of data items allowed per modality is restricted by +:code:`--limit-mm-per-prompt`. +""" + + +@final +class MultiModalDataBuiltins(TypedDict, total=False): + """Type annotations for modality types predefined by vLLM.""" + + image: MultiModalData[ImageItem] + """The input image(s).""" + + video: MultiModalData[VideoItem] + """The input video(s).""" + + audio: MultiModalData[AudioItem] + """The input audio(s).""" + + +MultiModalDataDict: TypeAlias = Mapping[str, MultiModalData[Any]] +""" +A dictionary containing an entry for each modality type to input. + +Note: + This dictionary also accepts modality keys defined outside + :class:`MultiModalDataBuiltins` as long as a customized plugin + is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. + Read more on that :ref:`here `. +""" + + +class PlaceholderRange(TypedDict): + """ + Placeholder location information for multi-modal data. + + For example: + Prompt: AAAA BBBB What is in these images? + Images A and B will have: + A: { "offset": 0, "length": 4 } + B: { "offset": 5, "length": 4 } + """ + + offset: int + """The start index of the placeholder in the prompt.""" + + length: int + """The length of the placeholder.""" + + +NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] +""" +Uses a list instead of a tensor if the dimensions of each element do not match. +""" + +BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors] +""" +A dictionary containing nested tensors which have been batched via +:meth:`MultiModalKwargs.batch`. +""" + +if sys.version_info < (3, 9): + # UserDict cannot be subscripted + class _MultiModalKwargsBase(UserDict): + pass +else: + + class _MultiModalKwargsBase(UserDict[str, NestedTensors]): + pass + + +class MultiModalKwargs(_MultiModalKwargsBase): + """ + A dictionary that represents the keyword arguments to + :meth:`~torch.nn.Module.forward`. + """ + + @staticmethod + def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: + """ + Stack the inner dimensions that have the same shape in + a nested list of tensors. + + Thus, a dimension represented by a list means that the inner + dimensions are different for each element along that dimension. + """ + if isinstance(nested_tensors, torch.Tensor): + return nested_tensors + + stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors] + if not is_list_of(stacked, torch.Tensor, check="all"): + # Only tensors (not lists) can be stacked. + return stacked + + tensors_ = cast(List[torch.Tensor], stacked) + if any(t.shape != tensors_[0].shape for t in tensors_): + # The tensors have incompatible shapes and can't be stacked. + return tensors_ + + return torch.stack(tensors_) + + @staticmethod + def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs: + """ + Batch multiple inputs together into a dictionary. + + The resulting dictionary has the same keys as the inputs. + If the corresponding value from each input is a tensor and they all + share the same shape, the output value is a single batched tensor; + otherwise, the output value is a list containing the original value + from each input. + """ + if len(inputs_list) == 0: + return {} + + # We need to consider the case where each item in the batch + # contains different modalities (i.e. different keys). + item_lists: Dict[str, List[NestedTensors]] = defaultdict(list) + + for inputs in inputs_list: + for k, v in inputs.items(): + item_lists[k].append(v) + + return { + k: MultiModalKwargs._try_stack(item_list) + for k, item_list in item_lists.items() + } + + @staticmethod + def as_kwargs( + batched_inputs: BatchedTensorInputs, + *, + device: torch.types.Device, + ) -> BatchedTensorInputs: + json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) + + json_mapped = json_map_leaves( + lambda x: x.to(device, non_blocking=True), + json_inputs, + ) + + return cast(BatchedTensorInputs, json_mapped) + + +MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]] +""" +A dictionary containing placeholder ranges. +""" + + +class MultiModalKwargsV2(TypedDict): + """ + Represents the outputs of :class:`vllm.multimodal.MultiModalProcessor`, + ready to be passed to vLLM internals. + """ + + type: Literal["multimodal"] + """The type of inputs.""" + + prompt: str + """ + The original, unprocessed prompt text. + + Note: + Since prompt text is not required by vLLM internals, we leave this + unprocessed to save CPU computation. You can still call + :code:`tokenizer.decode(prompt_token_ids)` to get the processed text. + """ + + prompt_token_ids: List[int] + """The processed token IDs which includes placeholder tokens.""" + + mm_kwargs: MultiModalKwargs + """Keyword arguments to be directly passed to the model after batching.""" + + mm_placeholders: MultiModalPlaceholderDict + """ + For each modality, information about the placeholder tokens in + :code:`prompt_token_ids`. + """ diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py new file mode 100644 index 0000000000000..a0de8e9469194 --- /dev/null +++ b/vllm/multimodal/processing.py @@ -0,0 +1,275 @@ +from dataclasses import dataclass +from functools import lru_cache, partial +from typing import (Any, Callable, Collection, Generic, List, Mapping, + Optional, TypedDict, TypeVar, final) + +from transformers import BatchFeature +from typing_extensions import TypeAlias + +from vllm.inputs import InputProcessingContext +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import is_list_of + +from .inputs import (AudioItem, ImageItem, MultiModalDataDict, + MultiModalKwargsV2, MultiModalKwargs, PlaceholderRange, + VideoItem) + +_T = TypeVar("_T") + +ReplacementFunc: TypeAlias = Callable[[_T, BatchFeature, int], List[int]] +""" +Given the original data item, HF-processed data, and index of the processed +item, output the replacement token IDs to be allocated in vLLM. +""" + + +@dataclass +class ModalityProcessingMetadata(Generic[_T]): + placeholder_replacements: Mapping[str, ReplacementFunc] + """ + A dictionary where each item represents the original placeholder in the + prompt text and the corresponding replacement. + """ + + +class MultiModalProcessingMetadataBuiltins(TypedDict, total=False): + """Type annotations for modality types predefined by vLLM.""" + + image: ModalityProcessingMetadata[ImageItem] + video: ModalityProcessingMetadata[VideoItem] + audio: ModalityProcessingMetadata[AudioItem] + + +MultiModalProcessingMetadata: TypeAlias = \ + Mapping[str, ModalityProcessingMetadata[Any]] +""" +A dictionary containing an entry for each modality type to process. + +Note: + This dictionary also accepts modality keys defined outside + :class:`MultiModalProcessingMetadataBuiltins` as long as a customized plugin + is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. + Read more on that :ref:`here `. +""" + +MultiModalMultiData: TypeAlias = List[_T] +""" +A list of data items, where the number of data items allowed +per modality is restricted by :code:`--limit-mm-per-prompt`. +""" + + +@final +class MultiModalMultiDataBuiltins(TypedDict, total=False): + """Type annotations for modality types predefined by vLLM.""" + + image: MultiModalMultiData[ImageItem] + """The input images.""" + + video: MultiModalMultiData[VideoItem] + """The input videos.""" + + audio: MultiModalMultiData[AudioItem] + """The input audios.""" + + +MultiModalMultiDataDict: TypeAlias = Mapping[str, MultiModalMultiData[Any]] +""" +A dictionary containing an entry for each modality type to input. + +Note: + This dictionary also accepts modality keys defined outside + :class:`MultiModalMultiDataBuiltins` as long as a customized plugin + is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. + Read more on that :ref:`here `. +""" + + +def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict: + """ + Convert a :class:`MultiModalDataDict` containing single data items + to a :class:`MultiModalMultiDataDict` containing multiple data items + per entry. + """ + multi_data: Mapping[str, MultiModalMultiData[Any]] = {} + + for k, v in data.items(): + # yapf: disable + if k == "image": + multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] + elif k == "video": + # Special case since even a single item can be a list + multi_data[k] = v if is_list_of(v, list) else [v] # type: ignore[index] + elif k == "audio": + multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] + else: + multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] + # yapf: enable + + return multi_data + + +def encode_no_special_tokens( + tokenizer: AnyTokenizer, + text: str, +) -> List[int]: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.encode(text, add_special_tokens=False)`. + """ + if isinstance(tokenizer, MistralTokenizer): + return tokenizer.tokenizer.encode(text, bos=False, eos=False) + + return tokenizer.encode(text, add_special_tokens=False) + + +@lru_cache +def candidate_placeholders( + tokenizer: AnyTokenizer, + placeholder_text: str, +) -> Collection[List[int]]: + """Generate token ID sequences that may represent a placeholder text.""" + # When the placeholder text is not mapped to a special token ID, + # it may be tokenized differently based on whether it is at the start/end + # of the string. So, we go through each combination of whether the text + # is at the start and end boundaries of the string + + # Matches the placeholder when it is in the middle of the string + start_id, = encode_no_special_tokens(tokenizer, "a") + end_id, = encode_no_special_tokens(tokenizer, "b") + + candidate_basic = encode_no_special_tokens(tokenizer, placeholder_text) + + start_id_, *candidate_a = encode_no_special_tokens( + tokenizer, + f"a{placeholder_text}", + ) + assert start_id == start_id_ + + start_id_, *candidate_ab, end_id_ = encode_no_special_tokens( + tokenizer, + f"a{placeholder_text}b", + ) + assert start_id == start_id_ and end_id == end_id_ + + *candidate_b, end_id_ = encode_no_special_tokens( + tokenizer, + f"{placeholder_text}b", + ) + assert end_id == end_id_ + + # Remove duplicates (need to convert to tuple to be hashable) + unique_candidates = { + tuple(c) + for c in [candidate_basic, candidate_a, candidate_ab, candidate_b] + } + + # Convert back to list + return [list(c) for c in unique_candidates] + + +def apply_placeholders( + token_ids: List[int], + placeholder_ids: List[int], + get_replacement_ids: Callable[[], List[int]], +) -> Optional[PlaceholderRange]: + """ + Find the first occurrence of :code:`placeholder_ids`, + and replace it with the output of :code:`get_replacement_ids`. + + This function updates :code:`token_ids` in place. + """ + placeholder_length = len(placeholder_ids) + + for start_idx in range(len(token_ids) - placeholder_length + 1): + if token_ids[start_idx:placeholder_length] == placeholder_ids: + token_ids[start_idx:placeholder_length] = get_replacement_ids() + + return PlaceholderRange(offset=start_idx, + length=placeholder_length) + + return None + + +class MultiModalProcessor: + """ + Helper class to process multi-modal inputs to be used in vLLM. + """ + + def __init__( + self, + ctx: InputProcessingContext, + metadata: MultiModalProcessingMetadata, + ) -> None: + super().__init__() + + self.ctx = ctx + self.metadata = metadata + + def __call__( + self, + prompt: str, + mm_data: MultiModalDataDict, + mm_processor_kwargs: Mapping[str, object], + ) -> MultiModalKwargsV2: + return self.apply(prompt, mm_data, mm_processor_kwargs) + + def apply( + self, + prompt: str, + mm_data: MultiModalDataDict, + mm_processor_kwargs: Mapping[str, object], + ) -> MultiModalKwargsV2: + tokenizer = self.ctx.tokenizer + hf_processor = self.ctx.get_hf_processor() + + processed_inputs = hf_processor( + text=prompt, # type: ignore + **mm_data, + **mm_processor_kwargs, + ) + new_token_ids, = processed_inputs.pop("input_ids").tolist() + mm_kwargs = MultiModalKwargs(processed_inputs) + + mm_placeholders: Mapping[str, List[PlaceholderRange]] = {} + + for modality, orig_inputs in to_multi_format(mm_data).items(): + assert isinstance(orig_inputs, list) + + metadata = self.metadata[modality] + placeholder_replacements = metadata.placeholder_replacements + + modality_placeholders: List[PlaceholderRange] = [] + + for item_idx, orig_item in enumerate(orig_inputs): + for match_text, replace_fn in placeholder_replacements.items(): + candidates = candidate_placeholders(tokenizer, match_text) + get_replacement_ids = partial( + replace_fn, + orig_item, + processed_inputs, + item_idx, + ) + + for match_ids in candidates: + # TODO(youkaichao): Don't update new_token_ids + placeholders = apply_placeholders( + new_token_ids, + match_ids, + get_replacement_ids, + ) + + if placeholders is not None: + modality_placeholders.append(placeholders) + + # yapf: disable + mm_placeholders[modality] = modality_placeholders # type: ignore[index] + # yapf: enable + + return MultiModalKwargsV2( + type="multimodal", + prompt=prompt, + prompt_token_ids=new_token_ids, + mm_kwargs=mm_kwargs, + mm_placeholders=mm_placeholders, + ) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index bce2f4c6abe5b..219075d08ce7b 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,13 +1,20 @@ import functools from collections import UserDict -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence +from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, + Sequence, Type, TypeVar) +import torch.nn as nn +from typing_extensions import TypeAlias + +from vllm.inputs import InputProcessingContext from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer from .audio import AudioPlugin -from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs, +from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalKwargs, MultiModalPlugin, MultiModalTokensCalc, NestedTensors) from .image import ImagePlugin +from .processing import MultiModalProcessor from .video import VideoPlugin if TYPE_CHECKING: @@ -15,6 +22,16 @@ logger = init_logger(__name__) +N = TypeVar("N", bound=Type[nn.Module]) + +MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext], + MultiModalProcessor] +""" +Constructs a :class:`MultiModalProcessor` instance from the context. + +The processing metadata should be derived from the context. +""" + class _MultiModalLimits(UserDict): """ @@ -45,6 +62,9 @@ def __init__( plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: self._plugins = {p.get_data_key(): p for p in plugins} + self._processor_factories: Dict[Type[nn.Module], + MultiModalProcessorFactory] = {} + # This is used for non-multimodal models self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} @@ -103,7 +123,7 @@ def map_input( model_config: "ModelConfig", data: MultiModalDataDict, mm_processor_kwargs: Optional[Dict[str, Any]] = None, - ) -> MultiModalInputs: + ) -> MultiModalKwargs: """ Apply an input mapper to the data passed to the model. @@ -139,7 +159,7 @@ def map_input( merged_dict[input_key] = input_tensor - return MultiModalInputs(merged_dict) + return MultiModalKwargs(merged_dict) def create_input_mapper(self, model_config: "ModelConfig"): """ @@ -243,3 +263,59 @@ def get_mm_limits_per_prompt( This should be called after :meth:`init_mm_limits_per_prompt`. """ return self._limits_by_model[model_config] + + def register_processor( + self, + factory: MultiModalProcessorFactory, + ): + """ + Register a multi-modal processor to a model class. + + When the model receives multi-modal data, the provided function is + invoked to transform the data into a dictionary of model inputs. + + See also: + - :ref:`input_processing_pipeline` + - :ref:`enabling_multimodal_inputs` + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._processor_factories: + logger.warning( + "Model class %s already has an input mapper " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._processor_factories[model_cls] = factory + + return model_cls + + return wrapper + + def has_processor(self, model_config: "ModelConfig") -> bool: + """ + Test whether a multi-modal processor is defined for a specific model. + """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + return model_cls in self._processor_factories + + def create_processor( + self, + model_config: "ModelConfig", + tokenizer: AnyTokenizer, + ) -> MultiModalProcessor: + """ + Create a multi-modal processor for a specific model and tokenizer. + """ + + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + processor_factory = self._processor_factories[model_cls] + + ctx = InputProcessingContext(model_config, tokenizer) + return processor_factory(ctx) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 6c2c6720f4276..27487d9d5946e 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -8,7 +8,7 @@ from vllm.transformers_utils.processor import get_video_processor from vllm.transformers_utils.tokenizer import get_tokenizer -from .base import MultiModalData, MultiModalInputs +from .base import MultiModalData, MultiModalKwargs from .image import ImagePlugin if TYPE_CHECKING: @@ -54,7 +54,7 @@ def _default_input_mapper( ctx: InputContext, data: MultiModalData[object], **mm_processor_kwargs, - ) -> MultiModalInputs: + ) -> MultiModalKwargs: model_config = ctx.model_config if isinstance(data, list) and len(data) == 1: @@ -78,7 +78,7 @@ def _default_input_mapper( logger.error("Failed to process video (%s)", data) raise - return MultiModalInputs(batch_data) + return MultiModalKwargs(batch_data) raise TypeError(f"Invalid video type: {type(data)}") diff --git a/vllm/sequence.py b/vllm/sequence.py index 7d7ddc7ec4447..2d35e525ec5cf 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -435,7 +435,7 @@ def n_blocks(self) -> int: def prompt(self) -> Optional[str]: inputs = self.inputs - if inputs["type"] == "token": + if inputs["type"] == "token" or inputs["type"] == "multimodal": return inputs.get("prompt") assert_never(inputs) @@ -444,7 +444,7 @@ def prompt(self) -> Optional[str]: def prompt_token_ids(self) -> List[int]: inputs = self.inputs - if inputs["type"] == "token": + if inputs["type"] == "token" or inputs["type"] == "multimodal": return inputs.get("prompt_token_ids", []) assert_never(inputs) @@ -453,7 +453,7 @@ def prompt_token_ids(self) -> List[int]: def prompt_embeds(self) -> Optional[torch.Tensor]: inputs = self.inputs - if inputs["type"] == "token": + if inputs["type"] == "token" or inputs["type"] == "multimodal": return None assert_never(inputs) @@ -465,23 +465,35 @@ def multi_modal_data(self) -> "MultiModalDataDict": if inputs["type"] == "token": return inputs.get("multi_modal_data", {}) + if inputs["type"] == "multimodal": + return inputs.get("mm_kwargs", {}) + assert_never(inputs) - @cached_property - def mm_processor_kwargs(self) -> Dict[str, Any]: + @property + def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: inputs = self.inputs if inputs["type"] == "token": - return inputs.get("mm_processor_kwargs", {}) + return inputs.get("multi_modal_placeholders", {}) + + if inputs["type"] == "multimodal": + return inputs.get("mm_placeholders", {}) assert_never(inputs) - @property - def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: + @cached_property + def mm_processor_kwargs(self) -> Dict[str, Any]: inputs = self.inputs if inputs["type"] == "token": - return inputs.get("multi_modal_placeholders", {}) + return { + "needs_mm_mapper": True, + **inputs.get("mm_processor_kwargs", {}), + } + + if inputs["type"] == "multimodal": + return {} assert_never(inputs) @@ -953,6 +965,13 @@ def __post_init__(self): else: self.token_chunk_size = 1 + @property + def needs_mm_mapper(self): + # Interim measure so we can handle models that have yet to be + # updated to use the new multi-modal processor + return (self.mm_processor_kwargs is not None + and self.mm_processor_kwargs.get("needs_mm_mapper", False)) + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 17cc0ad1a4a3a..13d65d8304d19 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -18,7 +18,7 @@ "CUDA and ROCm flash attention backend.") from err from vllm.logger import init_logger -from vllm.multimodal import MultiModalInputs +from vllm.multimodal import MultiModalKwargs from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) @@ -274,7 +274,7 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), **kwargs, ) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 64cc18149d6c5..9b1bc841d905a 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -8,10 +8,11 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, - EncoderDecoderLLMInputs, InputRegistry, PromptType) + EncoderDecoderInputs, InputRegistry, PromptType) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import CompletionOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -39,6 +40,7 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, ) -> None: @@ -128,8 +130,11 @@ def __init__( self.generation_config_fields = _load_generation_config_dict( model_config) - self.input_preprocessor = InputPreprocessor(model_config, - self.tokenizer) + self.input_preprocessor = InputPreprocessor( + model_config, + self.tokenizer, + mm_registry, + ) self.input_registry = input_registry self.input_processor = input_registry.create_input_processor( model_config) @@ -213,7 +218,7 @@ def _verify_args(self) -> None: def _add_processed_request( self, request_id: str, - processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderLLMInputs], + processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], @@ -436,7 +441,7 @@ def check_health(self) -> None: self.model_executor.check_health() def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, - EncoderDecoderLLMInputs]): + EncoderDecoderInputs]): prompt_ids = inputs.get("prompt_token_ids") if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index 8ebbf6db939bc..994af7c5a455f 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -5,7 +5,7 @@ from vllm.attention import AttentionMetadata from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MultiModalInputs +from vllm.multimodal import MultiModalKwargs from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad from vllm.worker.cpu_model_runner import (CPUModelRunner, @@ -287,7 +287,7 @@ def execute_model( kv_caches, "attn_metadata": model_input.attn_metadata, - **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), "intermediate_tensors": intermediate_tensors, diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index fdd72a452f2ad..c040f9fa191e2 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs, MultiModalPlaceholderMap) + MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) from vllm.transformers_utils.config import uses_mrope @@ -159,7 +159,11 @@ def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata, if not mm_data: return - mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs) + if "needs_mm_mapper" in mm_processor_kwargs: + mm_kwargs = self.multi_modal_input_mapper(mm_data, + mm_processor_kwargs) + else: + mm_kwargs = mm_data # special processing for mrope position deltas. mrope_positions = None @@ -201,7 +205,7 @@ def _prepare_prompt( slot_mapping: List[int] = [] seq_lens: List[int] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] multi_modal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) @@ -226,7 +230,7 @@ def _prepare_prompt( ._compute_multi_modal_input( seq_group_metadata, seq_data, computed_len, seq_group_metadata.mm_processor_kwargs) - multi_modal_inputs_list.append(mm_kwargs) + multi_modal_kwargs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items(): multi_modal_placeholder_maps[modality].extend( placeholder_map) @@ -298,7 +302,7 @@ def _prepare_prompt( multi_modal_placeholder_index_maps=placeholder_index_maps, ) - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs) @@ -527,7 +531,7 @@ def execute_model( kv_caches, "attn_metadata": model_input.attn_metadata, - **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), "intermediate_tensors": intermediate_tensors, diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index ff288d5ca1512..37cfcbf13d7a3 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -8,7 +8,7 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.multimodal import MultiModalInputs +from vllm.multimodal import MultiModalKwargs from vllm.pooling_params import PoolingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, SequenceGroupMetadata) @@ -104,7 +104,7 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device)) if (self.observability_config is not None diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 90a43196084ea..008e0c9745994 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -18,7 +18,7 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.utils import get_architecture_class_name -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, MultiModalRegistry) from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, @@ -206,7 +206,7 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), **seqlen_agnostic_kwargs) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 328dab598f8ef..384d94c46876f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -38,7 +38,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs, MultiModalPlaceholderMap, + MultiModalKwargs, MultiModalPlaceholderMap, MultiModalRegistry) from vllm.platforms import current_platform from vllm.prompt_adapter.layers import PromptAdapterMapping @@ -240,7 +240,7 @@ def __init__( prompt_adapter_request: Optional[PromptAdapterRequest] = None, # Multi-modal inputs. - multi_modal_inputs: Optional[MultiModalInputs] = None, + multi_modal_kwargs: Optional[MultiModalKwargs] = None, multi_modal_placeholder_maps: Optional[Dict[ str, MultiModalPlaceholderMap]] = None, @@ -361,7 +361,7 @@ def __init__( prompt_adapter_prompt_mapping or []) self.prompt_adapter_request = prompt_adapter_request - self.multi_modal_inputs = multi_modal_inputs + self.multi_modal_kwargs = multi_modal_kwargs self.multi_modal_placeholder_maps = multi_modal_placeholder_maps self.prefix_cache_hit = prefix_cache_hit @@ -646,10 +646,13 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, if not mm_data: return - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs) - inter_data.multi_modal_inputs = mm_kwargs + if seq_group_metadata.needs_mm_mapper: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, seq_group_metadata.mm_processor_kwargs) + else: + mm_kwargs = mm_data + + inter_data.multi_modal_kwargs = mm_kwargs inter_data.multi_modal_placeholder_maps = placeholder_maps # special processing for mrope position deltas. @@ -923,11 +926,11 @@ def build(self) -> ModelInputForGPU: ) # Multi-modal data. - multi_modal_inputs_list = [ - data.multi_modal_inputs for data in self.inter_data_list - if data.multi_modal_inputs is not None + multi_modal_kwargs_list = [ + data.multi_modal_kwargs for data in self.inter_data_list + if data.multi_modal_kwargs is not None ] - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return self.model_input_cls( input_tokens=input_tokens_tensor, @@ -1640,7 +1643,7 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), **seqlen_agnostic_kwargs) diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 2da22cbfc7cb5..b9fc41bfe175a 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs) + MultiModalKwargs) from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase @@ -122,7 +122,7 @@ def _prepare_prompt( input_block_ids: List[int] = [] seq_lens: List[int] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -144,12 +144,16 @@ def _prepare_prompt( mm_data = seq_group_metadata.multi_modal_data if mm_data: - # Process multi-modal data - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs, - ) - multi_modal_inputs_list.append(mm_kwargs) + if seq_group_metadata.needs_mm_mapper: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata. + mm_processor_kwargs, + ) + else: + mm_kwargs = mm_data + + multi_modal_kwargs_list.append(mm_kwargs) max_seq_len = max(seq_lens) assert max_seq_len > 0 @@ -167,7 +171,7 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return (input_tokens, input_positions, input_block_ids, seq_lens, multi_modal_kwargs) @@ -314,7 +318,7 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, - **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), ) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index c9c87ea748081..2b69d579c17ce 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.openvino import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs, MultiModalPlaceholderMap) + MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import SequenceGroupMetadata from vllm.worker.model_runner_base import ModelRunnerBase @@ -102,7 +102,7 @@ def _prepare_model_input( seq_lens: List[int] = [] past_lens: List[int] = [] query_lens: List[int] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] multi_modal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) @@ -222,11 +222,16 @@ def _prepare_model_input( mm_data, placeholder_maps = MultiModalPlaceholderMap \ .from_seq_group(seq_group_metadata, positions_range) - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - mm_processor_kwargs=seq_group_metadata. - mm_processor_kwargs) - multi_modal_inputs_list.append(mm_kwargs) + if seq_group_metadata.needs_mm_mapper: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata. + mm_processor_kwargs, + ) + else: + mm_kwargs = mm_data + + multi_modal_kwargs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items(): multi_modal_placeholder_maps[modality].extend( @@ -275,7 +280,7 @@ def _prepare_model_input( multi_modal_placeholder_index_maps=placeholder_index_maps, ) - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return ModelInput( input_tokens, @@ -341,7 +346,7 @@ def execute_model( kv_caches, "attn_metadata": attn_metadata, - **MultiModalInputs.as_kwargs(multi_modal_kwargs or {}, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs or {}, device=self.device), } diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index bae8b469767b2..b34c4a413156d 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -18,7 +18,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs, MultiModalPlaceholderMap, + MultiModalKwargs, MultiModalPlaceholderMap, MultiModalRegistry) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata @@ -160,7 +160,7 @@ def _prepare_prompt( input_positions: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] multi_modal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) @@ -191,8 +191,16 @@ def _prepare_prompt( mm_data, placeholder_maps = MultiModalPlaceholderMap \ .from_seq_group(seq_group_metadata, positions_range) - mm_kwargs = self.runner.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) + if seq_group_metadata.needs_mm_mapper: + mm_kwargs = self.runner.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata. + mm_processor_kwargs, + ) + else: + mm_kwargs = mm_data + + multi_modal_kwargs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items(): multi_modal_placeholder_maps[modality].extend( @@ -264,7 +272,7 @@ def _prepare_prompt( block_tables=torch.tensor([], device=self.device, dtype=torch.int), ) - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs) @@ -565,7 +573,7 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device)) # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: