From 09704007a0875c593f2eac42a89e29efc8936ffe Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 5 Nov 2024 10:07:31 +0800 Subject: [PATCH] [Core] Make encoder-decoder inputs a nested structure to be more composable (#9604) Signed-off-by: DarkLight1337 Signed-off-by: Sumit Dubey --- tests/core/utils.py | 57 ++-- .../output_processor/test_stop_checker.py | 3 +- tests/test_cache_block_hashing.py | 7 +- tests/tokenization/test_detokenize.py | 6 +- vllm/engine/llm_engine.py | 51 ++-- vllm/engine/protocol.py | 23 +- vllm/inputs/__init__.py | 11 +- vllm/inputs/data.py | 51 ++-- vllm/inputs/parse.py | 15 +- vllm/inputs/preprocess.py | 269 +++++++++--------- vllm/inputs/registry.py | 14 +- vllm/model_executor/models/mllama.py | 96 +++++-- vllm/model_executor/models/registry.py | 5 + vllm/sequence.py | 113 +++----- 14 files changed, 372 insertions(+), 349 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index a95a573db7cd3..cd0caa4704e11 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -4,6 +4,7 @@ from typing import Tuple from vllm import SamplingParams +from vllm.inputs import EncoderDecoderInputs, token_inputs from vllm.lora.request import LoRARequest from vllm.sequence import Logprob, Sequence, SequenceGroup @@ -27,10 +28,7 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt = Sequence(int(request_id), - inputs={ - "prompt": prompt_str, - "prompt_token_ids": prompt_tokens, - }, + inputs=token_inputs(prompt_tokens, prompt=prompt_str), block_size=block_size) seq_group = SequenceGroup(request_id=request_id, seqs=[prompt], @@ -63,23 +61,21 @@ def create_dummy_prompt_encoder_decoder( encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) - inputs = { - "prompt": decoder_prompt_str, - "prompt_token_ids": decoder_prompt_tokens, - "encoder_prompt": encoder_prompt_str, - "encoder_prompt_token_ids": encoder_prompt_tokens, - "multi_modal_data": None, + inputs: EncoderDecoderInputs = { + "decoder": token_inputs(decoder_prompt_tokens, + prompt=decoder_prompt_str), + "encoder": token_inputs(encoder_prompt_tokens, + prompt=encoder_prompt_str), } decoder_prompt = Sequence(int(request_id), - inputs=inputs, - block_size=block_size, - from_decoder_prompt=True) + inputs=inputs["decoder"], + block_size=block_size) encoder_prompt = Sequence(int(request_id), - inputs=inputs, - block_size=block_size, - from_decoder_prompt=False) + inputs=inputs["encoder"], + block_size=block_size) + seq_group = SequenceGroup(request_id=request_id, seqs=[decoder_prompt], sampling_params=SamplingParams(best_of=best_of), @@ -108,7 +104,7 @@ def create_seq_group( for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, - inputs={"prompt_token_ids": prompt_token_ids}, + inputs=token_inputs(prompt_token_ids), block_size=16, ) @@ -143,21 +139,19 @@ def create_seq_group_encoder_decoder( prompt_token_ids = [0] * seq_prompt_len - inputs = { - "prompt": "", - "prompt_token_ids": prompt_token_ids, - "encoder_prompt": "", - "encoder_prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, + inputs: EncoderDecoderInputs = { + "decoder": token_inputs(prompt_token_ids), + "encoder": token_inputs(prompt_token_ids), } seqs = [] for seq_id_offset, output_len in enumerate(seq_output_lens): # Construct decoder input sequences - seq = Sequence(seq_id=seq_id_start + seq_id_offset, - inputs=inputs, - block_size=16, - from_decoder_prompt=True) + seq = Sequence( + seq_id=seq_id_start + seq_id_offset, + inputs=inputs["decoder"], + block_size=16, + ) for i in range(output_len): seq.append_token_id( @@ -167,10 +161,11 @@ def create_seq_group_encoder_decoder( seqs.append(seq) # Encoder input sequence - encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens), - inputs=inputs, - block_size=16, - from_decoder_prompt=False) + encoder_seq = Sequence( + seq_id=seq_id_start + len(seq_output_lens), + inputs=inputs["encoder"], + block_size=16, + ) return SequenceGroup(request_id=request_id, seqs=seqs, diff --git a/tests/engine/output_processor/test_stop_checker.py b/tests/engine/output_processor/test_stop_checker.py index 0d84443c51f99..cc14e8cbf75df 100644 --- a/tests/engine/output_processor/test_stop_checker.py +++ b/tests/engine/output_processor/test_stop_checker.py @@ -4,6 +4,7 @@ from transformers import PreTrainedTokenizer from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.inputs import token_inputs from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob, Sequence, SequenceStatus @@ -15,7 +16,7 @@ def sequence_with_eos(text: str, eos_token: str, """ seq = Sequence( seq_id=0, - inputs={"prompt_token_ids": []}, + inputs=token_inputs([]), block_size=16, eos_token_id=eos_token_id, ) diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index 3576a4834ebc3..e8f8499aa88ca 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -6,6 +6,7 @@ import pytest +from vllm.inputs import token_inputs from vllm.lora.request import LoRARequest from vllm.sequence import Sequence from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -70,10 +71,8 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, hashes[-1].append([]) prompt_token_ids = tokenizer.encode(prompt) seq = Sequence(seq_id, - inputs={ - "prompt": prompt, - "prompt_token_ids": prompt_token_ids, - }, + inputs=token_inputs(prompt_token_ids, + prompt=prompt), block_size=block_size, eos_token_id=tokenizer.tokenizer.eos_token_id, lora_request=lora_request) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 1d07885349409..a3e70a40db979 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -3,6 +3,7 @@ import pytest from transformers import AutoTokenizer +from vllm.inputs import token_inputs from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup from vllm.transformers_utils.detokenizer import (Detokenizer, detokenize_incrementally) @@ -169,10 +170,7 @@ def create_sequence(prompt_token_ids=None): prompt_token_ids = prompt_token_ids or [1] return Sequence( seq_id=0, - inputs={ - "prompt": "", - "prompt_token_ids": prompt_token_ids, - }, + inputs=token_inputs(prompt_token_ids, prompt=""), block_size=16, ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2c584218485c8..a1809b1a9dd26 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -10,7 +10,7 @@ from typing import Set, Type, Union, cast, overload import torch -from typing_extensions import TypeIs, TypeVar +from typing_extensions import TypeVar import vllm.envs as envs from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, @@ -29,9 +29,9 @@ from vllm.executor.executor_base import ExecutorBase from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, - EncoderDecoderInputs, InputRegistry, PromptType, - TokensPrompt) +from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, + PromptType) +from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.logits_process import get_bad_words_logits_processors @@ -638,7 +638,7 @@ def _verify_args(self) -> None: def _add_processed_request( self, request_id: str, - processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], + processed_inputs: ProcessorInputs, params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], @@ -669,18 +669,19 @@ def _add_processed_request( seq_id = next(self.seq_counter) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) - seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, + if is_encoder_decoder_inputs(processed_inputs): + decoder_inputs = processed_inputs["decoder"] + encoder_inputs = processed_inputs["encoder"] + else: + decoder_inputs = processed_inputs + encoder_inputs = None + + seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, lora_request, prompt_adapter_request) - encoder_seq = None - if 'encoder_prompt_token_ids' in processed_inputs: - encoder_seq = Sequence(seq_id, - processed_inputs, - block_size, - eos_token_id, - lora_request, - prompt_adapter_request, - from_decoder_prompt=False) + encoder_seq = (None if encoder_inputs is None else Sequence( + seq_id, encoder_inputs, block_size, eos_token_id, lora_request, + prompt_adapter_request)) # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): @@ -874,7 +875,7 @@ def _validate_token_prompt(self, prompt: PromptType, # This needs to happen before multimodal input pre-processing, which # may add dummy tokens that aren't part of the tokenizer's # vocabulary. - if self._is_token_prompt(prompt): + if is_token_prompt(prompt): prompt_ids = prompt["prompt_token_ids"] if len(prompt_ids) == 0: # Empty prompt check is handled later @@ -884,10 +885,6 @@ def _validate_token_prompt(self, prompt: PromptType, raise ValueError( "Token id {} is out of vocabulary".format(max_input_id)) - @staticmethod - def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]: - return isinstance(prompt, dict) and "prompt_token_ids" in prompt - def _create_sequence_group_with_sampling( self, request_id: str, @@ -1978,17 +1975,17 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: def is_encoder_decoder_model(self): return self.input_preprocessor.is_encoder_decoder_model() - def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, - EncoderDecoderInputs], + def _validate_model_inputs(self, inputs: ProcessorInputs, lora_request: Optional[LoRARequest]): - if self.model_config.is_multimodal_model: + if is_encoder_decoder_inputs(inputs): # For encoder-decoder multimodal models, the max_prompt_len # restricts the decoder prompt length - prompt_ids = inputs.get("prompt_token_ids") - elif self.is_encoder_decoder_model(): - prompt_ids = inputs.get("encoder_prompt_token_ids") + prompt_inputs = inputs["decoder" if self.model_config. + is_multimodal_model else "encoder"] else: - prompt_ids = inputs.get("prompt_token_ids") + prompt_inputs = inputs + + prompt_ids = prompt_inputs.get("prompt_token_ids") if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 6a09361c56865..e0b59d94cfdc3 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -1,11 +1,12 @@ import asyncio from abc import ABC, abstractmethod -from typing import AsyncGenerator, List, Mapping, Optional, Union +from typing import AsyncGenerator, List, Mapping, Optional from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -60,7 +61,7 @@ def generate( async def beam_search( self, - prompt: Union[PromptType, List[int]], + prompt: PromptType, model_config: ModelConfig, request_id: str, params: BeamSearchParams, @@ -76,11 +77,19 @@ async def beam_search( tokenizer = await self.get_tokenizer() input_preprocessor = InputPreprocessor(model_config, tokenizer) - (prompt_text, prompt_token_ids, multi_modal_data, - mm_processor_kwargs) = input_preprocessor._extract_prompt_components( - prompt, - request_id=request_id, - ) + if is_explicit_encoder_decoder_prompt(prompt): + raise NotImplementedError + else: + processed_inputs = input_preprocessor._prompt_to_llm_inputs( + prompt, + request_id=request_id, + ) + + prompt_token_ids = processed_inputs["prompt_token_ids"] + prompt_text = processed_inputs.get("prompt") + multi_modal_data = processed_inputs.get("multi_modal_data") + mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs") + tokenized_length = len(prompt_token_ids) sort_beams_key = create_sort_beams_key_function( diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index ac7b3ca28b406..68ac50a2c5a16 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,8 +1,8 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs, - ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs, - SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, - build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, - token_inputs, zip_enc_dec_prompts) + ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, + SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, + TokensPrompt, build_explicit_enc_dec_prompt, + to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) from .registry import DummyData, InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -22,9 +22,10 @@ "ExplicitEncoderDecoderPrompt", "TokenInputs", "token_inputs", - "SingletonInputs", "DecoderOnlyInputs", "EncoderDecoderInputs", + "ProcessorInputs", + "SingletonInputs", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index ba393cbcce4eb..46b41f431bec7 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,4 +1,4 @@ -from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, +from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal, Optional, Tuple, Union, cast) from typing_extensions import NotRequired, TypedDict, TypeVar @@ -122,27 +122,30 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): class TokenInputs(TypedDict): """Represents token-based inputs.""" + + type: Literal["token"] + """The type of inputs.""" + prompt_token_ids: List[int] """The token IDs of the prompt.""" - prompt: NotRequired[Optional[str]] + prompt: NotRequired[str] """ The original prompt text corresponding to the token IDs, if available. """ - multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] + multi_modal_data: NotRequired["MultiModalDataDict"] """ Optional multi-modal data to pass to the model, if the model supports it. """ - multi_modal_placeholders: NotRequired[ - Optional["MultiModalPlaceholderDict"]] + multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"] """ Placeholder ranges for the multi-modal data. """ - mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]] + mm_processor_kwargs: NotRequired[Dict[str, Any]] """ Optional multi-modal processor kwargs to be forwarded to the multimodal input mapper & processor. Note that if multiple modalities @@ -159,7 +162,7 @@ def token_inputs( mm_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> TokenInputs: """Construct :class:`TokenInputs` from optional values.""" - inputs = TokenInputs(prompt_token_ids=prompt_token_ids) + inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) if prompt is not None: inputs["prompt"] = prompt @@ -173,12 +176,6 @@ def token_inputs( return inputs -SingletonInputs = TokenInputs -""" -A processed :class:`SingletonPrompt` which can be passed to -:class:`vllm.sequence.Sequence`. -""" - DecoderOnlyInputs = TokenInputs """ The inputs in :class:`~vllm.LLMEngine` before they are @@ -187,28 +184,30 @@ def token_inputs( """ -class EncoderDecoderInputs(TokenInputs): +class EncoderDecoderInputs(TypedDict): """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. This specifies the required data for encoder-decoder models. """ - encoder_prompt_token_ids: List[int] - """The token IDs of the encoder prompt.""" + encoder: TokenInputs + """The inputs for the encoder portion.""" - encoder_prompt: NotRequired[Optional[str]] - """ - The original encoder prompt text corresponding to the token IDs, if - available. - """ + decoder: TokenInputs + """The inputs for the decoder portion.""" - encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] - """ - Optional multi-modal data to pass to the encoder model, - if the model supports it. - """ +SingletonInputs = TokenInputs +""" +A processed :class:`SingletonPrompt` which can be passed to +:class:`vllm.sequence.Sequence`. +""" + +ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] +""" +The inputs to :data:`vllm.inputs.InputProcessor`. +""" _T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) _T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index e79d2c813bb4f..09f1ff2cb42e9 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -4,9 +4,9 @@ from vllm.utils import is_list_of -from .data import (DecoderOnlyInputs, EncoderDecoderInputs, - ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt, - TextPrompt, TokensPrompt) +from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, + ProcessorInputs, PromptType, SingletonPrompt, TextPrompt, + TokensPrompt) class ParsedText(TypedDict): @@ -98,12 +98,15 @@ def parse_singleton_prompt( raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") +def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]: + return isinstance(prompt, dict) and "prompt_token_ids" in prompt + + def is_explicit_encoder_decoder_prompt( prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: return isinstance(prompt, dict) and "encoder_prompt" in prompt def is_encoder_decoder_inputs( - inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], -) -> TypeIs[EncoderDecoderInputs]: - return "encoder_prompt_token_ids" in inputs + inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]: + return "encoder" in inputs and "decoder" in inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 82ce7d392b719..a5c787a56b5a9 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional from typing_extensions import assert_never @@ -10,22 +10,12 @@ from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.utils import print_warning_once -from .data import (DecoderOnlyInputs, EncoderDecoderInputs, PromptType, - SingletonPrompt) +from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, + PromptType, SingletonInputs, SingletonPrompt, token_inputs) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt -if TYPE_CHECKING: - from vllm.multimodal import MultiModalDataDict - logger = init_logger(__name__) -PromptComponents = Tuple[Optional[str], List[int], - Optional["MultiModalDataDict"], Optional[Dict[str, - Any]]] -DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], - Optional["MultiModalDataDict"], - Optional[Dict[str, Any]]] - class InputPreprocessor: @@ -115,7 +105,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: "default" decoder prompt be . However, it is possible that in the future - other models may have different or more + other models may have different or more complex logic for the default decoder prompt. This motivates having a special helper method for default decoder prompts. @@ -132,7 +122,6 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: def _prepare_decoder_input_ids_for_generation( self, decoder_input_ids: Optional[List[int]], - force_bos: bool = True, ) -> List[int]: """ Prepares `decoder_input_ids` for generation with encoder-decoder models. @@ -162,8 +151,8 @@ def _prepare_decoder_input_ids_for_generation( # use decoder_start_token_id as decoder_input_ids decoder_input_ids = self._get_default_enc_dec_decoder_prompt() - if force_bos and (len(decoder_input_ids) == 0 - or decoder_input_ids[0] != decoder_start_token_id): + if (len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id): decoder_input_ids = [decoder_start_token_id] + decoder_input_ids return decoder_input_ids @@ -209,12 +198,12 @@ async def _tokenize_prompt_async( prompt=prompt, lora_request=lora_request) - def _extract_prompt_components( + def _prompt_to_llm_inputs( self, prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, - ) -> PromptComponents: + ) -> SingletonInputs: ''' Extract the components of any single encoder or decoder input prompt. @@ -241,34 +230,52 @@ def _extract_prompt_components( request_id=request_id, lora_request=lora_request, ) - multi_modal_data = None - mm_processor_kwargs = None - elif parsed["type"] == "tokens": - prompt_text = None - prompt_token_ids = parsed["content"]["prompt_token_ids"] - multi_modal_data = parsed["content"].get("multi_modal_data") - mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") - elif parsed["type"] == "text": - prompt_text = parsed["content"]["prompt"] + + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + ) + + if parsed["type"] == "tokens": + tokens_content = parsed["content"] + + prompt_token_ids = tokens_content["prompt_token_ids"] + multi_modal_data = tokens_content.get("multi_modal_data") + mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + + return token_inputs( + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + + if parsed["type"] == "text": + text_content = parsed["content"] + + prompt_text = text_content["prompt"] prompt_token_ids = self._tokenize_prompt( prompt_text, request_id=request_id, lora_request=lora_request, ) - multi_modal_data = parsed["content"].get("multi_modal_data") - mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") - else: - assert_never(parsed) + multi_modal_data = text_content.get("multi_modal_data") + mm_processor_kwargs = text_content.get("mm_processor_kwargs") + + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) - return (prompt_text, prompt_token_ids, multi_modal_data, - mm_processor_kwargs) + assert_never(parsed) - async def _extract_prompt_components_async( + async def _prompt_to_llm_inputs_async( self, prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, - ) -> PromptComponents: + ) -> SingletonInputs: """Async version of :meth:`_extract_prompt_components`.""" parsed = parse_singleton_prompt(prompt) @@ -279,59 +286,74 @@ async def _extract_prompt_components_async( request_id=request_id, lora_request=lora_request, ) - multi_modal_data = None - mm_processor_kwargs = None - elif parsed["type"] == "tokens": - prompt_text = None - prompt_token_ids = parsed["content"]["prompt_token_ids"] - multi_modal_data = parsed["content"].get("multi_modal_data") - mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") - elif parsed["type"] == "text": - prompt_text = parsed["content"]["prompt"] + + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + ) + + if parsed["type"] == "tokens": + tokens_content = parsed["content"] + + prompt_token_ids = tokens_content["prompt_token_ids"] + multi_modal_data = tokens_content.get("multi_modal_data") + mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + + return token_inputs( + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + + if parsed["type"] == "text": + text_content = parsed["content"] + + prompt_text = text_content["prompt"] prompt_token_ids = await self._tokenize_prompt_async( prompt_text, request_id=request_id, lora_request=lora_request, ) - multi_modal_data = parsed["content"].get("multi_modal_data") - mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") - else: - assert_never(parsed) + multi_modal_data = text_content.get("multi_modal_data") + mm_processor_kwargs = text_content.get("mm_processor_kwargs") + + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) - return (prompt_text, prompt_token_ids, multi_modal_data, - mm_processor_kwargs) + assert_never(parsed) def _build_enc_dec_llm_inputs( self, - encoder_comps: PromptComponents, - decoder_comps: DecoderPromptComponents, - mm_processor_kwargs: Dict[str, Any], + encoder_inputs: SingletonInputs, + decoder_inputs: Optional[SingletonInputs], ) -> EncoderDecoderInputs: - encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps - decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps - - # Reminder: Please update docs/source/serving/compatibility_matrix.rst - # If the feature combo become valid - if decoder_mm_data is not None: - raise ValueError( - "Multi-modality decoder inputs of encoder-decoder models are " - "not supported yet") - - # For Multi-Modal models (e.g., mllama), the text input can be - # <|image|><|begin_of_text|>hello world. And we should not add - # another <|begin_of_text|> to the beginning. - decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation( - decoder_prompt_ids, - force_bos=(encoder_mm_data is None and decoder_mm_data is None))) + if encoder_inputs["type"] == "token": + pass + else: + assert_never(encoder_inputs) + + if decoder_inputs is None: + dec_token_ids = self._prepare_decoder_input_ids_for_generation( + None) + decoder_inputs = token_inputs(dec_token_ids) + elif decoder_inputs["type"] == "token": + dec_token_ids = self._prepare_decoder_input_ids_for_generation( + decoder_inputs["prompt_token_ids"]) + decoder_inputs["prompt_token_ids"] = dec_token_ids + + if "multi_modal_data" in decoder_inputs: + raise ValueError("Multi-modal decoder inputs of encoder-" + "decoder models are not supported yet") + else: + assert_never(encoder_inputs) return EncoderDecoderInputs( - prompt_token_ids=decoder_prompt_ids, - prompt=decoder_prompt, - multi_modal_data=decoder_mm_data, - mm_processor_kwargs=mm_processor_kwargs, - encoder_prompt_token_ids=encoder_prompt_ids, - encoder_prompt=encoder_prompt, - encoder_multi_modal_data=encoder_mm_data, + encoder=encoder_inputs, + decoder=decoder_inputs, ) def _process_encoder_decoder_prompt( @@ -341,8 +363,7 @@ def _process_encoder_decoder_prompt( ) -> EncoderDecoderInputs: ''' For encoder/decoder models only: - Process an input prompt into an - :class:`EncoderDecoderInputs` instance. + Process an input prompt into an :class:`EncoderDecoderInputs` instance. There are two types of input prompts: singleton prompts which carry only the @@ -361,7 +382,7 @@ def _process_encoder_decoder_prompt( have any possible singleton type; thus this method relies on helper functions to obtain token ids for the sub-prompts. - + Arguments: * prompt: an input prompt @@ -372,40 +393,31 @@ def _process_encoder_decoder_prompt( * :class:`EncoderDecoderInputs` instance ''' - encoder_comps: PromptComponents - decoder_comps: DecoderPromptComponents + encoder_inputs: SingletonInputs + decoder_inputs: Optional[SingletonInputs] if is_explicit_encoder_decoder_prompt(prompt): - encoder_comps = self._extract_prompt_components( + encoder_inputs = self._prompt_to_llm_inputs( prompt["encoder_prompt"], request_id=request_id, ) if (decoder_input := prompt["decoder_prompt"]) is None: - decoder_comps = None, None, None, None + decoder_inputs = None else: - decoder_comps = self._extract_prompt_components( + decoder_inputs = self._prompt_to_llm_inputs( decoder_input, request_id=request_id, ) - # Handle this carefully in case it was directly initialized by user - mm_processor_kwargs = prompt.get("mm_processor_kwargs", {}) else: - encoder_comps = self._extract_prompt_components( + encoder_inputs = self._prompt_to_llm_inputs( prompt, request_id=request_id, ) - # If there are no decoder components, we assume the - # mm_processor_kwargs are in the encoder prompt - mm_processor_kwargs = encoder_comps[-1] if encoder_comps[ - -1] is not None else {} - decoder_comps = None, None, None, None - - return self._build_enc_dec_llm_inputs( - encoder_comps, - decoder_comps, - mm_processor_kwargs, - ) + + decoder_inputs = None + + return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) async def _process_encoder_decoder_prompt_async( self, @@ -413,59 +425,50 @@ async def _process_encoder_decoder_prompt_async( request_id: str, ) -> EncoderDecoderInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" - encoder_comps: PromptComponents - decoder_comps: DecoderPromptComponents + encoder_inputs: SingletonInputs + decoder_inputs: Optional[SingletonInputs] if is_explicit_encoder_decoder_prompt(prompt): - encoder_task = self._extract_prompt_components_async( + encoder_task = self._prompt_to_llm_inputs_async( prompt["encoder_prompt"], request_id=request_id, ) if (decoder_input := prompt["decoder_prompt"]) is None: - encoder_comps = await encoder_task - decoder_comps = None, None, None, None + encoder_inputs = await encoder_task + decoder_inputs = None else: - decoder_task = self._extract_prompt_components_async( + decoder_task = self._prompt_to_llm_inputs_async( decoder_input, request_id=request_id, ) - encoder_comps, decoder_comps = await asyncio.gather( + encoder_inputs, decoder_inputs = await asyncio.gather( encoder_task, decoder_task) - mm_processor_kwargs = prompt["mm_processor_kwargs"] else: - encoder_comps = await self._extract_prompt_components_async( + encoder_inputs = await self._prompt_to_llm_inputs_async( prompt, request_id=request_id, ) - # If there are no decoder components, we assume the - # mm_processor_kwargs are in the encoder prompt - mm_processor_kwargs = encoder_comps[-1] if encoder_comps[ - -1] is not None else {} - decoder_comps = None, None, None, None - - return self._build_enc_dec_llm_inputs( - encoder_comps, - decoder_comps, - mm_processor_kwargs, - ) + + decoder_inputs = None + + return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) def _build_decoder_only_llm_inputs( self, - prompt_comps: PromptComponents, + prompt_inputs: DecoderOnlyInputs, prompt_adapter_request: Optional[PromptAdapterRequest], ) -> DecoderOnlyInputs: - (prompt, prompt_token_ids, multi_modal_data, - mm_processor_kwargs) = prompt_comps - - prompt_token_ids = self._apply_prompt_adapter( - prompt_token_ids, prompt_adapter_request=prompt_adapter_request) + if prompt_inputs["type"] == "token": + prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( + prompt_inputs["prompt_token_ids"], + prompt_adapter_request=prompt_adapter_request, + ) + else: + assert_never(prompt_inputs) - return DecoderOnlyInputs(prompt_token_ids=prompt_token_ids, - prompt=prompt, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs) + return prompt_inputs def _process_decoder_only_prompt( self, @@ -490,7 +493,7 @@ def _process_decoder_only_prompt( * :class:`DecoderOnlyInputs` instance ''' - prompt_comps = self._extract_prompt_components( + prompt_comps = self._prompt_to_llm_inputs( prompt, request_id=request_id, lora_request=lora_request, @@ -509,7 +512,7 @@ async def _process_decoder_only_prompt_async( prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> DecoderOnlyInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" - prompt_comps = await self._extract_prompt_components_async( + prompt_comps = await self._prompt_to_llm_inputs_async( prompt, request_id=request_id, lora_request=lora_request, @@ -526,7 +529,7 @@ def preprocess( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: + ) -> ProcessorInputs: """Preprocess the input prompt.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of @@ -554,7 +557,7 @@ async def preprocess_async( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: + ) -> ProcessorInputs: """Async version of :meth:`preprocess`.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index fbf912a212568..7d7a797be4f60 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -2,7 +2,7 @@ from collections import UserDict from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple, - Optional, Protocol, Type) + Optional, Protocol, Type, cast) from torch import nn from transformers import PretrainedConfig @@ -12,7 +12,7 @@ from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, resolve_mm_processor_kwargs) -from .data import DecoderOnlyInputs +from .data import ProcessorInputs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -109,7 +109,7 @@ def __getitem__(self, key: str) -> int: raise KeyError(msg) from exc -InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs] +InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs] """Preprocess the inputs to the model.""" @@ -254,8 +254,8 @@ def dummy_data_for_profiling( def _default_input_processor( self, ctx: InputContext, - inputs: DecoderOnlyInputs, - ) -> DecoderOnlyInputs: + inputs: ProcessorInputs, + ) -> ProcessorInputs: """The default input processor is a no-op.""" return inputs @@ -288,7 +288,7 @@ def _get_model_input_processor(self, model_cls: Type[nn.Module]): .get(model_cls, self._default_input_processor) def process_input(self, model_config: "ModelConfig", - inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: + inputs: ProcessorInputs) -> ProcessorInputs: """ Apply an input processor to an instance of model inputs. @@ -308,7 +308,7 @@ def process_input(self, model_config: "ModelConfig", # If it's empty, it'll fall back to the default kwarg values mm_processor_kwargs = resolve_mm_processor_kwargs( model_config.mm_processor_kwargs, - inputs.get("mm_processor_kwargs"), + cast(Dict[str, Any], inputs.get("mm_processor_kwargs")), processor, ) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index d30b9addd09f1..251bfc079684e 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -36,8 +36,8 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - EncoderDecoderInputs, InputContext) +from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs, + InputContext, TokenInputs, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -52,6 +52,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SequenceData +from vllm.utils import is_list_of from .clip import CLIPMLP from .interfaces import SupportsMultiModal @@ -86,41 +87,58 @@ def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int: return num_images -def input_processor_for_mllama(ctx: InputContext, - inputs: Union[DecoderOnlyInputs, - EncoderDecoderInputs]): - # move encoder_prompt to prompt - if inputs.get("prompt") is None: - inputs["prompt"] = inputs["encoder_prompt"] - inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"] +def input_processor_for_mllama( + ctx: InputContext, + inputs: EncoderDecoderInputs, +) -> EncoderDecoderInputs: + # Example input to processor: + # { + # 'encoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 + # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # 'decoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000], + # }, + # } + + # move encoder prompt to decoder + dec_inputs = TokenInputs(**inputs["encoder"]) + + multi_modal_data = dec_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + # text-only + return EncoderDecoderInputs( + encoder=token_inputs([]), + decoder=dec_inputs, + ) - # process multi-modal data - multi_modal_data = inputs.get("encoder_multi_modal_data") + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_data = [image_data] - if multi_modal_data is None or "image" not in multi_modal_data \ - or multi_modal_data["image"] is None: - # text-only - inputs["encoder_prompt"] = "" - inputs["encoder_prompt_token_ids"] = [] - inputs["encoder_multi_modal_data"] = {} - return inputs + assert is_list_of(image_data, Image.Image) - if isinstance(multi_modal_data['image'], Image.Image): - multi_modal_data['image'] = [multi_modal_data['image']] # Since only the last group of consecutive images # are attended by the decoded tokens, we only need to # get the number of tiles for those images. num_decode_images = _get_num_image_in_last_group( - inputs["prompt_token_ids"]) + dec_inputs["prompt_token_ids"]) + hf_config = ctx.model_config.hf_config + vision_config = hf_config.vision_config + num_tiles = 0 - for image in multi_modal_data["image"][::-1]: + for image in image_data[::-1]: width, height = image.size - tile_size = hf_config.vision_config.image_size + tile_size = vision_config.image_size canvas_height, canvas_width = get_optimal_tiled_canvas( image_height=height, image_width=width, - max_image_tiles=hf_config.vision_config.max_num_tiles, + max_image_tiles=vision_config.max_num_tiles, tile_size=tile_size, ) num_tiles_height = canvas_height // tile_size @@ -133,14 +151,34 @@ def input_processor_for_mllama(ctx: InputContext, # Set encoder prompt length based on the number of tiles. # This tells the block manager to allocate correct number # of slots for encoder tokens. - assert hf_config.vision_config.image_size % 14 == 0, \ + assert vision_config.image_size % 14 == 0, \ "chunk size should be multiple of 14" - token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 + token_per_chunk = (vision_config.image_size // 14)**2 + 1 num_tokens = num_tiles * token_per_chunk - inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens - inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_tokens - return inputs + # Example output from processor: + # { + # 'encoder': { + # 'type': 'token', + # 'prompt_token_ids': [128256, 128256, ..., 128256], + # 'prompt': '<|image|><|image|>...<|image|>', + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # 'decoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 + # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # } + return EncoderDecoderInputs( + encoder=token_inputs( + prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_tokens, + prompt=MLLAMA_IMAGE_TOKEN * num_tokens, + multi_modal_data=multi_modal_data, + ), + decoder=dec_inputs, + ) def get_max_mllama_image_tokens(ctx: InputContext) -> int: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 3a929f5cb5195..af52fbffba19e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -343,6 +343,11 @@ def register_model( def _raise_for_unsupported(self, architectures: List[str]): all_supported_archs = self.get_supported_archs() + if any(arch in all_supported_archs for arch in architectures): + raise ValueError( + f"Model architectures {architectures} failed " + "to be inspected. Please check the logs for more details.") + raise ValueError( f"Model architectures {architectures} are not supported for now. " f"Supported architectures: {all_supported_archs}") diff --git a/vllm/sequence.py b/vllm/sequence.py index 44a9257c9a4c1..7d7ddc7ec4447 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -9,12 +9,12 @@ from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, Mapping, Optional) from typing import Sequence as GenericSequence -from typing import Set, Tuple, Union, cast +from typing import Set, Tuple, Union import msgspec import torch +from typing_extensions import assert_never -from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams @@ -379,15 +379,10 @@ def __repr__(self) -> str: class Sequence: """Stores the data, status, and block information of a sequence. - - The sequence is constructed from the :code:`SingletonInputs` instance - passed in through the :code:`inputs` constructor argument. - - For encoder/decoder models, SingletonInputs encapsulates both a - decoder and encoder prompt, creating an ambiguity about which - prompt to construct the sequence from. The `from_decoder_prompt` - constructor argument signals whether to construct the Sequence - from the SingletonInputs decoder prompt, or encoder prompt. + + The sequence is constructed from the :data:`DecoderOnlyInputs` + (for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder) + instance passed in through the :code:`inputs` constructor argument. Args: seq_id: The ID of the sequence. @@ -397,10 +392,6 @@ class Sequence: eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. lora_request: LoRA request. prompt_adapter_request: Prompt Adapter request. - from_decoder_prompt: Construct Sequence from SingletonInputs decoder - prompt (True) or encoder prompt (False.) Must be - True for decoder-only model. - """ def __init__( @@ -411,7 +402,6 @@ def __init__( eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - from_decoder_prompt: bool = True, ) -> None: self.seq_id = seq_id self.inputs = inputs @@ -419,33 +409,6 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request - self.from_decoder_prompt = from_decoder_prompt - - # For decoder-only models, a Sequence is constructed - # from an DecoderOnlyInputs instance (the `inputs` arg.) - # - # For encoder/decoder models the same `inputs` - # instance could be utilized to construct either an - # encoder sequence or a decoder sequence, because - # `DecoderOnlyInputs` has both decoder- and encoder-oriented - # member variables (i.e. it encapsulates both an encoder - # and a decoder prompt.) The decision of which type of sequence - # to generate is determined by the `from_decoder_prompt` argument. - # - # When constructing a encoder sequence - # (`from_decoder_prompt` False) it matters that - # the `DecoderOnlyInputs` instance stored in `inputs` is valid - # in the sense that its encoder-related member variables are - # populated; below, an exception is raised if this is - # not the case. - # - # When constructing a decoder sequence (`from_decoder_prompt` True) - # it does not matter whether `inputs` has its encoder-related - # member variables populated. - if not (from_decoder_prompt or is_encoder_decoder_inputs(inputs)): - raise ValueError("Cannot extract encoder input prompt from " - f"invalid input {inputs}; did you forget the " - "encoder input prompt fields?") self.data = SequenceData.from_seqs(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] @@ -470,45 +433,57 @@ def n_blocks(self) -> int: @cached_property def prompt(self) -> Optional[str]: - # Select decoder or encoder input prompt str, as appropriate - prompt_key: str = ("prompt" - if self.from_decoder_prompt else "encoder_prompt") + inputs = self.inputs - return cast(Optional[str], self.inputs.get(prompt_key)) + if inputs["type"] == "token": + return inputs.get("prompt") + + assert_never(inputs) @cached_property def prompt_token_ids(self) -> List[int]: - # Select decoder or encoder input prompt token ids, as appropriate - prompt_token_ids_key: str = ("prompt_token_ids" - if self.from_decoder_prompt else - "encoder_prompt_token_ids") + inputs = self.inputs - # Cache computed prompt token ids - return cast(List[int], self.inputs.get(prompt_token_ids_key)) + if inputs["type"] == "token": + return inputs.get("prompt_token_ids", []) - @property - def multi_modal_data(self) -> MultiModalDataDict: + assert_never(inputs) + + @cached_property + def prompt_embeds(self) -> Optional[torch.Tensor]: inputs = self.inputs - if (inputs.get("multi_modal_data") - and inputs.get("encoder_multi_modal_data")): - raise ValueError( - "Multi-modal data in both encoder and decoder is not supported." - ) + if inputs["type"] == "token": + return None - return cast( - MultiModalDataDict, - (inputs.get("multi_modal_data") - or inputs.get("encoder_multi_modal_data") or {}), - ) + assert_never(inputs) + + @cached_property + def multi_modal_data(self) -> "MultiModalDataDict": + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("multi_modal_data", {}) + + assert_never(inputs) + + @cached_property + def mm_processor_kwargs(self) -> Dict[str, Any]: + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("mm_processor_kwargs", {}) + + assert_never(inputs) @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - return self.inputs.get("multi_modal_placeholders") or {} + inputs = self.inputs - @property - def mm_processor_kwargs(self) -> Dict[str, Any]: - return self.inputs.get("mm_processor_kwargs") or {} + if inputs["type"] == "token": + return inputs.get("multi_modal_placeholders", {}) + + assert_never(inputs) @property def lora_int_id(self) -> int: