From 770ec6024fc00cd696899f5c6fdc53b7148876e6 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 25 Sep 2024 13:29:32 -0700 Subject: [PATCH] [Model] Add support for the multi-modal Llama 3.2 model (#8811) Co-authored-by: simon-mo Co-authored-by: Chang Su Co-authored-by: Simon Mo Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: Roger Wang --- docs/source/models/supported_models.rst | 5 + examples/offline_inference_vision_language.py | 24 + examples/openai_vision_api_client.py | 4 +- requirements-common.txt | 2 +- .../vision_language/__init__.py | 0 .../vision_language/test_mllama.py | 283 ++++ vllm/config.py | 4 +- vllm/engine/llm_engine.py | 6 +- vllm/entrypoints/chat_utils.py | 28 +- vllm/entrypoints/openai/serving_chat.py | 2 + vllm/inputs/data.py | 6 + vllm/inputs/preprocess.py | 22 +- vllm/inputs/registry.py | 54 +- vllm/model_executor/models/__init__.py | 2 + vllm/model_executor/models/mllama.py | 1135 +++++++++++++++++ vllm/multimodal/base.py | 6 + vllm/multimodal/image.py | 5 + vllm/sequence.py | 12 +- vllm/transformers_utils/config.py | 17 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/mllama.py | 28 + vllm/transformers_utils/tokenizer.py | 1 - vllm/worker/enc_dec_model_runner.py | 40 +- vllm/worker/utils.py | 4 - 24 files changed, 1647 insertions(+), 45 deletions(-) create mode 100644 tests/models/encoder_decoder/vision_language/__init__.py create mode 100644 tests/models/encoder_decoder/vision_language/test_mllama.py create mode 100644 vllm/model_executor/models/mllama.py create mode 100644 vllm/transformers_utils/configs/mllama.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index d86d0860f7f29..bf690726a637b 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -254,6 +254,11 @@ Multimodal Language Models - Image\ :sup:`+` - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - + * - :code:`MllamaForConditionalGeneration` + - Llama 3.2 + - Image + - :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc. + - * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - Image\ :sup:`E` diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 6675aa0109a68..6d34621a8a9bc 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -242,6 +242,29 @@ def run_qwen2_vl(question, modality): return llm, prompt, stop_token_ids +# LLama +def run_mllama(question, modality): + assert modality == "image" + + model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" + + # Note: The default setting of max_num_seqs (256) and + # max_model_len (131072) for this model may cause OOM. + # You may lower either to run this example on lower-end GPUs. + + # The configuration below has been confirmed to launch on a + # single H100 GPU. + llm = LLM( + model=model_name, + max_num_seqs=16, + enforce_eager=True, + ) + + prompt = f"<|image|><|begin_of_text|>{question}" + stop_token_ids = None + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -256,6 +279,7 @@ def run_qwen2_vl(question, modality): "internvl_chat": run_internvl, "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, + "mllama": run_mllama, } diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index 1ba702ef019e4..71ae03e4d148b 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -38,7 +38,7 @@ "content": [ { "type": "text", - "text": "What’s in this image?" + "text": "What's in this image?" }, { "type": "image_url", @@ -75,7 +75,7 @@ def encode_image_base64_from_url(image_url: str) -> str: "content": [ { "type": "text", - "text": "What’s in this image?" + "text": "What's in this image?" }, { "type": "image_url", diff --git a/requirements-common.txt b/requirements-common.txt index c113ff3630425..2fc89c026901b 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -4,7 +4,7 @@ numpy < 2.0.0 requests tqdm py-cpuinfo -transformers >= 4.43.2 # Required for Chameleon and Llama 3.1 hotfox. +transformers >= 4.45.0 # Required for Llama 3.2. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. fastapi < 0.113.0; python_version < '3.9' diff --git a/tests/models/encoder_decoder/vision_language/__init__.py b/tests/models/encoder_decoder/vision_language/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py new file mode 100644 index 0000000000000..cda0926d0baf9 --- /dev/null +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -0,0 +1,283 @@ +from typing import List, Optional, Tuple, Type, overload + +import pytest +from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, + BatchEncoding) + +from vllm.multimodal.utils import rescale_image_size +from vllm.sequence import SampleLogprobs + +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) +from ....utils import multi_gpu_test +from ...utils import check_logprobs_close + +_LIMIT_IMAGE_PER_PROMPT = 1 + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "<|image|><|begin_of_text|>The meaning of the image is", + "cherry_blossom": + "<|image|><|begin_of_text|>The city is", +}) + +text_only_prompts = [ + "The color of the sky is blue but sometimes it can also be", +] + +models = [ + "meta-llama/Llama-3.2-11B-Vision-Instruct", +] + + +def vllm_to_hf_output(vllm_output: Tuple[List[int], str, + Optional[SampleLogprobs]], + model: str): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_index + + tokenizer = AutoTokenizer.from_pretrained(model) + eos_token_id = tokenizer.eos_token_id + + hf_output_ids = [ + token_id for idx, token_id in enumerate(output_ids) + if token_id != image_token_id or output_ids[idx - 1] != image_token_id + ] + + assert output_str[0] == " " + hf_output_str = output_str[1:] + if hf_output_ids[-1] == eos_token_id: + hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) + + return hf_output_ids, hf_output_str, out_logprobs + + +@overload +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + ... + + +@overload +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + sizes: List[Tuple[int, int]], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + ... + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: Optional[List[float]] = None, + sizes: Optional[List[Tuple[int, int]]] = None, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + images = [asset.pil_image for asset in image_assets] + + if size_factors is not None: + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + elif sizes is not None: + inputs_per_image = [( + [ + prompt if size is not None else text_only_prompts[0] + for size in sizes + ], + [ + image.resize(size) if size is not None else None + for size in sizes + ], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + if len(sizes) == 0: + inputs_per_image.append( + (text_only_prompts, [None] * len(text_only_prompts))) + else: + raise ValueError("You must provide either `size_factors` or `sizes`") + + _run_test(hf_runner, + vllm_runner, + inputs_per_image, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend) + + +def _run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test are from IMAGE_ASSETS. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding MultiModalConfig as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + dtype=dtype, + max_num_seqs=16, + max_model_len=4096, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT + }) as vllm_model: + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs + ] + + def process(hf_inputs: BatchEncoding): + return hf_inputs + + from transformers import AutoConfig + from transformers.models.mllama import MllamaConfig as MllamaConfigHf + + # use transformer's MllamaConfig for hf_runner + # and vllm's MllamaConfig for vllm_runner + AutoConfig.register("mllama", MllamaConfigHf, exist_ok=True) + with hf_runner(model, + dtype=dtype, + postprocess_inputs=process, + auto_cls=AutoModelForVision2Seq) as hf_model: + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs + ] + + from vllm.transformers_utils.configs.mllama import MllamaConfig + AutoConfig.register("mllama", MllamaConfig, exist_ok=True) + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, model) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "sizes", + [ + # Text only + [], + # Single-size + [(512, 512)], + # Single-size, batched + [(512, 512), (512, 512), (512, 512)], + # Multi-size, batched + [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), + (1024, 1024), (512, 1536), (512, 2028)], + # Multi-size, batched, including text only + [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), + (1024, 1024), (512, 1536), (512, 2028), None], + # mllama has 8 possible aspect ratios, carefully set the sizes + # to cover all of them + ], +) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype, + max_tokens, num_logprobs) -> None: + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + sizes=sizes, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "sizes", + [ + [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), + (1024, 1024), (512, 1536), (512, 2028), None], + ], +) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_distributed(hf_runner, vllm_runner, image_assets, model, sizes, + dtype, max_tokens, num_logprobs) -> None: + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + sizes=sizes, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=2, + ) diff --git a/vllm/config.py b/vllm/config.py index 308f29a3dc371..108badf150c86 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -576,7 +576,9 @@ def get_multimodal_config(self) -> "MultiModalConfig": @property def is_encoder_decoder_model(self) -> bool: """Extract the HF encoder/decoder model flag.""" - return getattr(self.hf_config, "is_encoder_decoder", False) + return getattr(self.hf_config, "is_encoder_decoder", False) or ( + (hasattr(self.hf_config, "text_config") and getattr( + self.hf_config.text_config, "is_encoder_decoder", False))) @property def is_embedding_model(self) -> bool: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c341b236003a3..768ac69c3692d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1734,7 +1734,11 @@ def is_embedding_model(self): def _validate_model_inputs(self, inputs: Union[LLMInputs, EncoderDecoderLLMInputs]): - if self.is_encoder_decoder_model(): + if self.model_config.is_multimodal_model: + # For encoder-decoder multimodal models, the max_prompt_len + # restricts the decoder prompt length + prompt_ids = inputs.get("prompt_token_ids") + elif self.is_encoder_decoder_model(): prompt_ids = inputs.get("encoder_prompt_token_ids") else: prompt_ids = inputs.get("prompt_token_ids") diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index f1ce2c36fcceb..4a575ae8f8537 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -159,6 +159,8 @@ def _placeholder_str(self, modality: ModalityStr, hf_config.image_token_index) if model_type in ("chameleon", "internvl_chat"): return "" + if model_type == "mllama": + return "<|image|>" if model_type == "qwen2_vl": return "<|vision_start|><|image_pad|><|vision_end|>" @@ -358,6 +360,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], _ImageParser = partial(cast, ChatCompletionContentPartImageParam) _AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) +MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'} def _parse_chat_message_content_parts( @@ -368,7 +371,11 @@ def _parse_chat_message_content_parts( texts: List[str] = [] mm_parser = mm_tracker.create_parser() + keep_multimodal_content = \ + mm_tracker._model_config.hf_config.model_type in \ + MODEL_KEEP_MULTI_MODAL_CONTENT + has_image = False for part in parts: part_type = part["type"] if part_type == "text": @@ -383,6 +390,7 @@ def _parse_chat_message_content_parts( "will be ignored.") mm_parser.parse_image(image_url["url"]) + has_image = True elif part_type == "audio_url": audio_url = _AudioParser(part)["audio_url"] @@ -394,12 +402,20 @@ def _parse_chat_message_content_parts( raise NotImplementedError(f"Unknown part type: {part_type}") text_prompt = "\n".join(texts) - mm_placeholder_counts = mm_parser.mm_placeholder_counts() - if mm_placeholder_counts: - text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, - text_prompt) - - return [ConversationMessage(role=role, content=text_prompt)] + if keep_multimodal_content: + text_prompt = "\n".join(texts) + role_content = [{'type': 'text', 'text': text_prompt}] + + if has_image: + role_content = [{'type': 'image'}] + role_content + return [ConversationMessage(role=role, + content=role_content)] # type: ignore + else: + mm_placeholder_counts = mm_parser.mm_placeholder_counts() + if mm_placeholder_counts: + text_prompt = _get_full_multimodal_text_prompt( + mm_placeholder_counts, text_prompt) + return [ConversationMessage(role=role, content=text_prompt)] # No need to validate using Pydantic again diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0321ea98ec742..94076ea3a51db 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -309,6 +309,8 @@ async def chat_completion_stream_generator( async for res in result_generator: if res.prompt_token_ids is not None: num_prompt_tokens = len(res.prompt_token_ids) + if res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(res.encoder_prompt_token_ids) # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 75ab0c770155b..a71e9a7b5db66 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -139,6 +139,12 @@ class EncoderDecoderLLMInputs(LLMInputs): available. """ + encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] + """ + Optional multi-modal data to pass to the encoder model, + if the model supports it. + """ + _T1 = TypeVar("_T1", bound=SingletonPromptInputs, diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index be2aa5f8cb7d0..bee3d1ed75cbb 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -128,6 +128,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: def _prepare_decoder_input_ids_for_generation( self, decoder_input_ids: Optional[List[int]], + force_bos: bool = True, ) -> List[int]: """ Prepares `decoder_input_ids` for generation with encoder-decoder models. @@ -157,8 +158,8 @@ def _prepare_decoder_input_ids_for_generation( # use decoder_start_token_id as decoder_input_ids decoder_input_ids = self._get_default_enc_dec_decoder_prompt() - if (len(decoder_input_ids) == 0 - or decoder_input_ids[0] != decoder_start_token_id): + if force_bos and (len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id): decoder_input_ids = [decoder_start_token_id] + decoder_input_ids return decoder_input_ids @@ -295,18 +296,25 @@ def _build_enc_dec_llm_inputs( encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps - if encoder_mm_data is not None or decoder_mm_data is not None: - raise ValueError("Multi-modal encoder-decoder models are " - "not supported yet") + if decoder_mm_data is not None: + raise ValueError( + "Multi-modality decoder inputs of encoder-decoder models are " + "not supported yet") - decoder_prompt_ids = ( - self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) + # For Multi-Modal models (e.g., mllama), the text input can be + # <|image|><|begin_of_text|>hello world. And we should not add + # another <|begin_of_text|> to the beginning. + decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation( + decoder_prompt_ids, + force_bos=(encoder_mm_data is None and decoder_mm_data is None))) return EncoderDecoderLLMInputs( prompt_token_ids=decoder_prompt_ids, prompt=decoder_prompt, + multi_modal_data=decoder_mm_data, encoder_prompt_token_ids=encoder_prompt_ids, encoder_prompt=encoder_prompt, + encoder_multi_modal_data=encoder_mm_data, ) def _process_encoder_decoder_prompt( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 6ab23d1c4b769..159d958ebf671 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -112,6 +112,8 @@ class InputRegistry: def __init__(self) -> None: self._dummy_factories_by_model_type: Dict[Type[nn.Module], DummyDataFactory] = {} + self._dummy_encoder_factories_by_model_type: Dict[ + Type[nn.Module], DummyDataFactory] = {} self._input_processors_by_model_type: Dict[Type[nn.Module], InputProcessor] = {} @@ -162,11 +164,44 @@ def _get_dummy_data_factory(self, model_cls: Type[nn.Module]): return self._dummy_factories_by_model_type \ .get(model_cls, self._default_dummy_data_factory) + def register_dummy_encoder_data(self, factory: DummyDataFactory): + """ + Register a dummy encoder data factory to a model class + + This is similar to :meth:`~register_dummy_data`, but for encoder input. + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._dummy_encoder_factories_by_model_type: + logger.warning( + "Model class %s already has dummy encoder data " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._dummy_encoder_factories_by_model_type[model_cls] = factory + + return model_cls + + return wrapper + + def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]): + if model_cls in self._dummy_encoder_factories_by_model_type: + dummy_factory = self._dummy_encoder_factories_by_model_type[ + model_cls] + else: + logger.warning( + "No dummy encoder data factory registered to %s. " + "Using the dummy data factory for the model instead.", + model_cls) + dummy_factory = self._get_dummy_data_factory(model_cls) + return dummy_factory + def dummy_data_for_profiling( self, model_config: "ModelConfig", seq_len: int, mm_registry: "MultiModalRegistry", + is_encoder_data: bool = False, ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ Create dummy data for profiling the memory usage of a model. @@ -184,8 +219,10 @@ def dummy_data_for_profiling( from vllm.model_executor.model_loader import get_model_architecture model_cls, _ = get_model_architecture(model_config) - dummy_factory = self._get_dummy_data_factory(model_cls) - + if is_encoder_data: + dummy_factory = self._get_dummy_encoder_data_factory(model_cls) + else: + dummy_factory = self._get_dummy_data_factory(model_cls) mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) mm_processor_kwargs = get_allowed_kwarg_only_overrides( dummy_factory, overrides=model_config.mm_processor_kwargs) @@ -196,10 +233,15 @@ def dummy_data_for_profiling( # Having more tokens is over-conservative but otherwise fine num_tokens = seq_data.prompt_token_ids - assert len(num_tokens) >= seq_len, ( - f"Expected at least {seq_len} dummy tokens for profiling, " - f"but found {len(num_tokens)} tokens instead.") - + if len(num_tokens) < seq_len: + if is_encoder_data: + logger.warning( + "Expected at least %d dummy encoder tokens for profiling, " + "but found %d tokens instead.", seq_len, len(num_tokens)) + else: + raise AssertionError( + f"Expected at least {seq_len} dummy tokens for profiling, " + f"but found {len(num_tokens)} tokens instead.") if mm_data is not None: for k, v in mm_data.items(): num_items = len(v) if isinstance(v, list) else 1 diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 3f52eb44edfff..3a6fa9e26ff4b 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -101,6 +101,8 @@ "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), "UltravoxModel": ("ultravox", "UltravoxModel"), + "MllamaForConditionalGeneration": ("mllama", + "MllamaForConditionalGeneration"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py new file mode 100644 index 0000000000000..aa868a3b8da28 --- /dev/null +++ b/vllm/model_executor/models/mllama.py @@ -0,0 +1,1135 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mllama model.""" +import math +from array import array +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers.models.mllama.configuration_mllama as config_mllama +from PIL import Image +from torch import nn +from transformers.modeling_outputs import (BaseModelOutput, + CausalLMOutputWithPast) +from transformers.models.mllama.image_processing_mllama import ( + get_optimal_tiled_canvas) + +import vllm.distributed.parallel_state as ps +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.config import CacheConfig, MultiModalConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData + +from .clip import CLIPMLP +from .interfaces import SupportsMultiModal +from .llama import LlamaDecoderLayer, LlamaMLP + +logger = init_logger(__name__) +MLLAMA_IMAGE_TOKEN_ID = 128256 +MLLAMA_IMAGE_TOKEN = "<|image|>" + + +class MllamaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: """ + """(batch_size, max_num_image, max_num_chunk, num_channel, height, width)""" + aspect_ratio_ids: torch.Tensor + """Shape: `(batch_size, max_num_image)`""" + aspect_ratio_mask: torch.Tensor + """Shape: `(batch_size, max_num_image, max_num_tiles)`""" + + +# TODO: support LlamaImageEmbeddingInputs + + +def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): + # move encoder_prompt to prompt + if llm_inputs.get("prompt") is None: + llm_inputs["prompt"] = llm_inputs["encoder_prompt"] + llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] + + # process multi-modal data + assert "decoder_multi_modal_data" not in llm_inputs, \ + "multi-modal data should be put in encoder message of mllama" + multi_modal_data = llm_inputs.get("encoder_multi_modal_data") + + if multi_modal_data is None or "image" not in multi_modal_data \ + or multi_modal_data["image"] is None: + # text-only + llm_inputs["encoder_prompt"] = "" + llm_inputs["encoder_prompt_token_ids"] = [] + llm_inputs["encoder_multi_modal_data"] = {} + return llm_inputs + + # get num_tiles + if isinstance(multi_modal_data['image'], Image.Image): + multi_modal_data['image'] = [multi_modal_data['image']] + hf_config = ctx.model_config.hf_config + num_tiles = 0 + for image in multi_modal_data["image"]: + width, height = image.size + tile_size = hf_config.vision_config.image_size + canvas_height, canvas_width = get_optimal_tiled_canvas( + image_height=height, + image_width=width, + max_image_tiles=hf_config.vision_config.max_num_tiles, + tile_size=tile_size, + ) + num_tiles_height = canvas_height // tile_size + num_tiles_width = canvas_width // tile_size + num_tiles += num_tiles_height * num_tiles_width + + # set encoder prompt based on num_tiles + assert hf_config.vision_config.image_size % 14 == 0, \ + "chunk size should be multiple of 14" + token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 + num_tokens = num_tiles * token_per_chunk + llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens + llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID + ] * num_tokens + + return llm_inputs + + +def get_max_mllama_image_tokens(ctx: InputContext) -> int: + hf_config = ctx.model_config.hf_config + token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 + return hf_config.vision_config.max_num_tiles * token_per_chunk + + +def dummy_decoder_seq_data(seq_len: int, num_images: int): + # <|image|> * num_images + 0 * (seq_len - num_images) + assert seq_len >= num_images, \ + "seq_len should be greater than or equal to num_images" + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [MLLAMA_IMAGE_TOKEN_ID]) * num_images + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - num_images) + return SequenceData(token_ids) + + +def dummy_encoder_seq_data(ctx: InputContext, num_images: int): + num_tokens = get_max_mllama_image_tokens(ctx) * num_images + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [MLLAMA_IMAGE_TOKEN_ID]) * num_tokens + return SequenceData(token_ids) + + +def dummy_image(num_images: int, ): + width = height = 1024 + image = Image.new("RGB", (width, height), color=0) + return {"image": image if num_images == 1 else [image] * num_images} + + +def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] + return dummy_decoder_seq_data(seq_len, num_images), None + + +def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] + return dummy_encoder_seq_data(ctx, num_images), dummy_image(num_images) + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, + 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length) + attention_mask = attention_mask.reshape(batch_size, + max_num_tiles * target_length, 1) + attention_mask = attention_mask @ attention_mask.transpose( + -1, -2) * torch.finfo(dtype).min + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +class ColumnParallelConv2dPatch(torch.nn.Module): + """Conv2D Patching layer with model parallelism. + Column parallel over unfolded input. + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + Input: (bsz, in_channels, width, height) + Output: (bsz, num_tokens, out_channels) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + bias: bool = False, + ) -> None: + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) + self._linear = ColumnParallelLinear( + in_channels * kernel_size[0] * kernel_size[1], + out_channels, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._unfold(x) + x = x.permute(0, 2, 1) + x, _ = self._linear(x) + return x + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + + def __init__(self, + config: config_mllama.MllamaVisionConfig, + is_gated: bool = True): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.is_gated = is_gated + + self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.hidden_size) + if is_gated: + self.gate = nn.Parameter(torch.zeros(1)) + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, + self.hidden_size) + + if self.is_gated: + embeddings = embeddings * self.gate.tanh() + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaPrecomputedPositionEmbedding(nn.Module): + + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.num_patches = (config.image_size // config.patch_size)**2 + 1 + self.hidden_size = config.hidden_size + self.scale = config.hidden_size**-0.5 + + self.gate = nn.Parameter(torch.zeros(1)) + + # position embedding + position_embedding = torch.randn(self.num_patches, self.hidden_size) + self.embedding = nn.Parameter(self.scale * position_embedding) + + # tile position embedding + self.tile_embedding = nn.Embedding( + self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.num_patches * self.hidden_size) + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + # position embeddings + gated_position_embedding = (1 - self.gate.tanh()) * self.embedding + hidden_state = hidden_state + gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size) + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size) + gated_tile_position_embedding = self.gate.tanh( + ) * tile_position_embedding + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +# TODO: support other attention backends for attention in vision model +class MllamaVisionSdpaAttention(nn.Module): + + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + + model_parallel_size = get_tensor_model_parallel_world_size() + self.embed_dim = config.hidden_size + self.num_heads = config.attention_heads + self.head_dim = config.hidden_size // config.attention_heads + self.num_local_heads = self.num_heads // model_parallel_size + self.q_size = self.num_local_heads * self.head_dim + self.kv_size = self.num_local_heads * self.head_dim + + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=False, + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.embed_dim, + bias=False, + input_is_parallel=True, + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_state) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(q.shape[0], q.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + k = k.view(k.shape[0], k.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + v = v.view(v.shape[0], v.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + + # TODO: remove padding in image encoder + attn_output = F.scaled_dot_product_attention(q, + k, + v, + attn_mask=attention_mask, + dropout_p=0.0) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(attn_output.shape[0], + attn_output.shape[1], -1) + output, _ = self.o_proj(attn_output) + return output + + +class MllamaVisionEncoderLayer(nn.Module): + + def __init__(self, + config: config_mllama.MllamaVisionConfig, + is_gated: bool = False): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.attention_heads + self.is_gated = is_gated + self.intermediate_size = config.intermediate_size + + self.self_attn = MllamaVisionSdpaAttention(config) + self.mlp = CLIPMLP(config) + + self.input_layernorm = nn.LayerNorm(self.hidden_size, + eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, + eps=config.norm_eps) + + # there used to be an if else here, no code path + if is_gated: + self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4) + self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, + attention_mask=attention_mask) + gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() + hidden_state = residual + gate_attn * hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() + hidden_state = residual + gate_ffn * hidden_state + + return hidden_state + + +class MllamaVisionEncoder(nn.Module): + + def __init__(self, + config: config_mllama.MllamaVisionConfig, + num_layers=32, + is_gated=False, + output_hidden_states=None): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + MllamaVisionEncoderLayer(config, is_gated) + for _ in range(num_layers) + ]) + self.output_hidden_states = output_hidden_states or [] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Union[Tuple, BaseModelOutput]: + encoder_states = () + + for i, encoder_layer in enumerate(self.layers): + if i in self.output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) + + if len(self.layers) - 1 in self.output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + + return hidden_states, encoder_states + + +class MllamaVisionModel(nn.Module): + + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + self.image_size = config.image_size + self.patch_size = config.patch_size + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.in_channels = config.num_channels + self.intermediate_layers_indices = config.intermediate_layers_indices + + self.num_patches = (self.image_size // self.patch_size)**2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = ColumnParallelConv2dPatch( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.class_embedding = nn.Parameter(self.scale * + torch.randn(self.hidden_size)) + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( + config) + + self.pre_tile_positional_embedding = \ + MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) + self.post_tile_positional_embedding = \ + MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) + + # layer norms + self.layernorm_pre = nn.LayerNorm(self.hidden_size) + self.layernorm_post = nn.LayerNorm(self.hidden_size) + + # encoders + self.transformer = MllamaVisionEncoder( + config, + config.num_hidden_layers, + is_gated=False, + output_hidden_states=config.intermediate_layers_indices) + self.global_transformer = MllamaVisionEncoder(config, + config.num_global_layers, + is_gated=True) + + def apply_class_embedding(self, + hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, + hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward(self, pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor) -> torch.Tensor: + batch_size, num_concurrent_media, num_tiles, num_channels, \ + height, width = pixel_values.shape + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, + height, width) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1) + + # patch embedding + patch_embeds = self.patch_embedding( + pixel_values.to(self.layernorm_pre.weight.dtype)) + hidden_state = patch_embeds + hidden_state = ps.get_tp_group().all_gather(hidden_state) + + # tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, -1, dim) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + + # apply cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # apply position embeddings + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, num_patches, dim) + hidden_state = self.gated_positional_embedding(hidden_state, + aspect_ratio_ids) + + # apply encoder + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, 0, 0, num_padding_patches + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + attention_mask = aspect_ratio_mask.reshape( + batch_size * num_concurrent_media, -1) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.layernorm_pre.weight.dtype, + ) + + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, + dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + ) + hidden_state, intermediate_hidden_states = output[0], output[1] + intermediate_hidden_states = torch.stack(intermediate_hidden_states, + dim=-1) + + # apply global encoder + hidden_state = self.layernorm_post(hidden_state) + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), dim) + hidden_state = self.global_transformer( + hidden_state, attention_mask=attention_mask)[0] + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = hidden_state[:, :, :slice_index] + + # adding intermediate layer outputs + hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, + num_tiles, num_patches, dim) + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, num_tiles, + num_patches + num_padding_patches, -1) + intermediate_hidden_states = intermediate_hidden_states[:, :, : + slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1) + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], + dim=-1) + return hidden_state + + +class MllamaTextRMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + """ + MllamaTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MllamaTextCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Optional[config_mllama.MllamaTextConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.config = config + self.model_parallel_size = get_tensor_model_parallel_world_size() + self.num_heads = self.config.num_attention_heads + self.num_local_heads = self.num_heads // self.model_parallel_size + self.num_key_value_heads = self.config.num_key_value_heads + self.num_local_key_value_heads = \ + self.num_key_value_heads // self.model_parallel_size + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // self.num_heads + self.layer_idx = layer_idx + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.q_local_size = self.num_local_heads * self.head_dim + self.kv_local_size = self.num_local_key_value_heads * self.head_dim + + # TODO: change to Q/KV separate linear after #7448 is merged + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_key_value_heads, + bias=False, + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + ) + # vllm.model_executor.layers.layernorm.RMSNorm has precision issue, + # use huggingface's instead + self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.scaling = self.head_dim**-0.5 + + self.attn = Attention( + self.num_local_heads, + self.head_dim, + self.scaling, + self.num_local_key_value_heads, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cross_attention_states: Optional[torch.Tensor], + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv_dec, _ = self.qkv_proj(hidden_states) + q, _, _ = qkv_dec.split( + [self.q_local_size, self.kv_local_size, self.kv_local_size], + dim=-1) + if cross_attention_states is None: + k = None + v = None + else: + qkv_enc, _ = self.qkv_proj(cross_attention_states) + _, k, v = qkv_enc.split( + [self.q_local_size, self.kv_local_size, self.kv_local_size], + dim=-1) + k = k.view(-1, self.num_local_key_value_heads, self.head_dim) + v = v.view(-1, self.num_local_key_value_heads, self.head_dim) + k = self.k_norm(k) + q = q.view(-1, self.num_local_heads, self.head_dim) + q = self.q_norm(q) + + output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER_DECODER) + out, _ = self.o_proj(output) + return out + + +class MllamaCrossAttentionDecoderLayer(torch.nn.Module): + """Cross-attention transformer block with tanh-gated attention + and feedforward.""" + + def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \ + -> None: + super().__init__() + self.layer_idx = layer_idx + self.cross_attn = MllamaTextCrossAttention( + config=config, + layer_idx=layer_idx, + ) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) + + self.mlp = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + cross_attention_mask: torch.Tensor, + full_text_row_masked_out_mask: torch.Tensor, + kv_cache: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.cross_attn( + hidden_states=hidden_states, + attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = full_text_row_masked_out_mask * hidden_states + hidden_states = residual + self.cross_attn_attn_gate.tanh( + ) * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = full_text_row_masked_out_mask * hidden_states + hidden_states = residual + self.cross_attn_mlp_gate.tanh( + ) * hidden_states + return hidden_states + + +class MllamaTextModel(nn.Module): + config_class = config_mllama.MllamaTextConfig + base_model_prefix = "model" + + def __init__(self, config: config_mllama.MllamaTextConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig]): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, + config.hidden_size) + self.cross_attention_layers = config.cross_attention_layers + + layers = [] + for layer_idx in range(config.num_hidden_layers): + if layer_idx in self.cross_attention_layers: + layers.append( + MllamaCrossAttentionDecoderLayer(config, layer_idx)) + else: + # TODO: force LlamaDecoderLayer to config.attention_bias=False + layers.append( + LlamaDecoderLayer(config, + cache_config=cache_config, + quant_config=quant_config)) + + self.layers = nn.ModuleList(layers) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, + torch.Tensor]], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + skip_cross_attention: bool, + ) -> torch.Tensor: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): + if not skip_cross_attention: + hidden_states = decoder_layer( + hidden_states=hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask= + full_text_row_masked_out_mask, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + ) + elif isinstance(decoder_layer, LlamaDecoderLayer): + hidden_states, residual = decoder_layer( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + residual=None, + ) + hidden_states = hidden_states + residual + else: + raise ValueError( + f"Unknown decoder layer type {type(decoder_layer)}") + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MllamaForCausalLM(nn.Module): + config_class = config_mllama.MllamaTextConfig + base_model_prefix = "language_model" + _no_split_modules = [ + "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" + ] + + def __init__(self, config: config_mllama.MllamaTextConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig]): + super().__init__() + self.vocab_size = config.vocab_size + self.model = MllamaTextModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, + torch.Tensor]], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + skip_cross_attention: bool, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + skip_cross_attention=skip_cross_attention, + ) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama) +@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama) +@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) +class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): + + def __init__(self, + config: config_mllama.MllamaConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.vocab_size = config.text_config.vocab_size + self.hidden_size = config.text_config.hidden_size + self.max_num_tiles = config.vision_config.max_num_tiles + self.vision_output_dim = config.vision_config.vision_output_dim + self.pad_token_id = \ + config.pad_token_id if config.pad_token_id is not None else -1 + self.image_size = config.vision_config.image_size + + self.vision_model = MllamaVisionModel(config.vision_config) + self.language_model = MllamaForCausalLM( + config.text_config, + cache_config=cache_config, + quant_config=quant_config, + ) + self.multi_modal_projector = nn.Linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=True, + ) + self.logits_processor = LogitsProcessor(config.output_hidden_states, + config.text_config.vocab_size) + self.sampler = Sampler() + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.language_model.lm_head, + hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def _parse_and_validate_image_input(self, **kwargs: object): + # tensor with the same shape will be batched together by + # MultiModalInputs.batch, so pixel_values here can be: + # - List[List[torch.Tensor]]: + # with shape (num_tiles, 3, image_res, image_res) + # - List[torch.Tensor]: + # with shape (num_image, num_tiles, 3, image_res, image_res) + # - torch.Tensor: + # with shape (bs, num_image, num_tiles, 3, image_res, image_res) + pixel_values: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "aspect_ratio_ids", None) + aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "aspect_ratio_mask", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + assert aspect_ratio_ids is not None + assert aspect_ratio_mask is not None + max_num_images = max([len(x[0]) for x in pixel_values]) + if max_num_images == 0: + raise ValueError("No images provided.") + max_num_tiles = max( + max([len(x) for x in y[0]]) for y in pixel_values) + device = self.multi_modal_projector.weight.device + bsz = len(pixel_values) + out_num_tiles = [] + out_images = torch.zeros( + bsz, + max_num_images, + max_num_tiles, + 3, + self.image_size, + self.image_size, + dtype=torch.float32, + device=device, + ) + out_ar_ids = torch.ones(bsz, + max_num_images, + dtype=torch.int64, + device=device) + out_ar_mask = torch.zeros(bsz, + max_num_images, + max_num_tiles, + dtype=torch.int64, + device=device) + for b in range(len(pixel_values)): + _num_tiles = [] + for i in range(len(pixel_values[b][0])): + img = pixel_values[b][0][i] + out_images[b, i, :img.shape[0]] = img + out_ar_ids[b, i] = aspect_ratio_ids[b][0][i] + out_ar_mask[b, i] = aspect_ratio_mask[b][0][i] + _num_tiles.append(img.shape[0]) + out_num_tiles.append(_num_tiles) + + return MllamaImagePixelInputs( + type="pixel_values", + data=out_images, + aspect_ratio_ids=out_ar_ids, + aspect_ratio_mask=out_ar_mask, + ) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def flat_encoder_result(self, cross_attention_states: torch.Tensor, + attn_metadata: AttentionMetadata): + + cross_attention_states_flat = torch.zeros( + sum(attn_metadata.encoder_seq_lens), + cross_attention_states.shape[-1], + device=cross_attention_states.device, + dtype=cross_attention_states.dtype) + start_pos = 0 + for seq_len, vision_token_in_batch in zip( + attn_metadata.encoder_seq_lens, cross_attention_states): + end_pos = start_pos + seq_len + cross_attention_states_flat[ + start_pos:end_pos] = vision_token_in_batch[:seq_len] + start_pos = end_pos + cross_attention_states = cross_attention_states_flat + + full_text_row_masked_out_mask = torch.ones( + (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) + start_pos = 0 + for seq_len, encoder_seq_len in zip( + attn_metadata.seq_lens_tensor.cpu(), + attn_metadata.encoder_seq_lens): + if encoder_seq_len == 0: + full_text_row_masked_out_mask[start_pos:start_pos + + seq_len] = False + start_pos += seq_len + full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( + cross_attention_states.device) + + return cross_attention_states, full_text_row_masked_out_mask + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs: object, + ) -> Union[Tuple, CausalLMOutputWithPast]: + if attn_metadata.num_prefill_tokens > 0 and \ + attn_metadata.num_decode_tokens > 0: + raise ValueError("Chunk prefill not supported") + image_inputs = self._parse_and_validate_image_input(**kwargs) + if image_inputs is None: + cross_attention_mask = None + full_text_row_masked_out_mask = ( + attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to( + input_ids.device) + cross_attention_states = None + skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0 + else: + # NOTE: llama's reference implementation runs vision model on CPU + pixel_values = image_inputs['data'] + aspect_ratio_ids = image_inputs['aspect_ratio_ids'] + aspect_ratio_mask = image_inputs['aspect_ratio_mask'] + cross_attention_states = self.vision_model(pixel_values, + aspect_ratio_ids, + aspect_ratio_mask) + cross_attention_states = self.multi_modal_projector( + cross_attention_states) + + bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape) + cross_attention_states = cross_attention_states.view( + bsz, -1, image_token_dim) + + cross_attention_states, full_text_row_masked_out_mask = \ + self.flat_encoder_result(cross_attention_states, attn_metadata) + skip_cross_attention = False + # TODO: support multi-image by this mask + cross_attention_mask = None + + outputs = self.language_model( + input_ids=input_ids, + positions=positions, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + skip_cross_attention=skip_cross_attention, + ) + + return outputs + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + updated_params = set() + for name, loaded_weight in weights: + if 'patch_embedding.weight' in name: + name = name.replace('patch_embedding.weight', + 'patch_embedding._linear.weight') + loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + updated_params.add(name) + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict.pop(name) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 87d3a4576f332..8bcb38ef241ed 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -54,6 +54,12 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: if isinstance(nested_tensors, torch.Tensor): return nested_tensors + if isinstance(nested_tensors, np.ndarray): + return torch.from_numpy(nested_tensors) + + if isinstance(nested_tensors, (int, float)): + return torch.tensor(nested_tensors) + stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] if not is_list_of(stacked, torch.Tensor, check="all"): # Only tensors (not lists) can be stacked. diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 31b1c3f93411a..d3a230e40477e 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -2,6 +2,7 @@ import torch from PIL import Image +from transformers.image_processing_base import BatchFeature from vllm.config import ModelConfig from vllm.inputs.registry import InputContext @@ -39,6 +40,10 @@ def _default_input_mapper( ) -> MultiModalInputs: model_config = ctx.model_config + # Processed by input processor + if isinstance(data, BatchFeature): + return MultiModalInputs(data.data) + # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): image_processor = self._get_hf_image_processor(model_config) diff --git a/vllm/sequence.py b/vllm/sequence.py index fda7ef87749a1..49a198df045bd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -13,6 +13,7 @@ import msgspec import torch +from vllm.inputs import EncoderDecoderLLMInputs, LLMInputs from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams @@ -21,7 +22,6 @@ from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if TYPE_CHECKING: - from vllm.inputs import LLMInputs from vllm.multimodal.base import MultiModalDataDict VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -471,7 +471,15 @@ def prompt_token_ids(self) -> List[int]: @property def multi_modal_data(self) -> "MultiModalDataDict": - return self.inputs.get("multi_modal_data") or {} + if self.inputs.get("multi_modal_data") and self.inputs.get( + "encoder_multi_modal_data"): + raise ValueError( + "Multi-modal data in both encoder and decoder is not supported." + ) + inputs = self.inputs + return self.inputs.get("multi_modal_data") or (cast( + EncoderDecoderLLMInputs, + inputs).get("encoder_multi_modal_data")) or {} @property def lora_int_id(self) -> int: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1744935d624fb..3871c0cb8b819 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -22,9 +22,10 @@ EAGLEConfig, ExaoneConfig, GraniteConfig, InternVLChatConfig, JAISConfig, MedusaConfig, - MLPSpeculatorConfig, MPTConfig, - NemotronConfig, RWConfig, - SolarConfig, UltravoxConfig) + MllamaConfig, MLPSpeculatorConfig, + MPTConfig, NemotronConfig, + RWConfig, SolarConfig, + UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file @@ -37,6 +38,10 @@ logger = init_logger(__name__) +_CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = { + "mllama": MllamaConfig +} + _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { "chatglm": ChatGLMConfig, "dbrx": DbrxConfig, @@ -55,11 +60,15 @@ # Granite can be removed from here once we have upgraded to # transformers 4.45+ "granite": GraniteConfig, + **_CONFIG_REGISTRY_OVERRIDE_HF } for name, cls in _CONFIG_REGISTRY.items(): with contextlib.suppress(ValueError): - AutoConfig.register(name, cls) + if name in _CONFIG_REGISTRY_OVERRIDE_HF: + AutoConfig.register(name, cls, exist_ok=True) + else: + AutoConfig.register(name, cls) class ConfigFormat(str, enum.Enum): diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index ea4fc8ad21f35..d5b13adb58a0b 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -10,6 +10,7 @@ from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig +from vllm.transformers_utils.configs.mllama import MllamaConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig @@ -26,6 +27,7 @@ "MedusaConfig", "EAGLEConfig", "ExaoneConfig", + "MllamaConfig", "MLPSpeculatorConfig", "NemotronConfig", "SolarConfig", diff --git a/vllm/transformers_utils/configs/mllama.py b/vllm/transformers_utils/configs/mllama.py new file mode 100644 index 0000000000000..49e766d7fa1f4 --- /dev/null +++ b/vllm/transformers_utils/configs/mllama.py @@ -0,0 +1,28 @@ +from transformers.models.mllama import configuration_mllama as mllama_hf_config + + +class MllamaTextConfig(mllama_hf_config.MllamaTextConfig): + ''' + Use this class to override is_encoder_decoder: + - transformers regards mllama as is_encoder_decoder=False + - vllm needs is_encoder_decoder=True to enable cross-attention + ''' + + def __init__( + self, + **kwargs, + ): + super().__init__(**kwargs) + self.is_encoder_decoder = True + + +class MllamaConfig(mllama_hf_config.MllamaConfig): + + def __init__( + self, + text_config=None, + **kwargs, + ): + if isinstance(text_config, dict): + text_config = MllamaTextConfig(**text_config) + super().__init__(text_config=text_config, **kwargs) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index f9fb8d1e103b7..2a2d74382e37a 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -111,7 +111,6 @@ def get_tokenizer( 'encoding and decoding.', FutureWarning, stacklevel=2) - if tokenizer_mode == "mistral": tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), revision=revision) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 709efdc8b9d57..bd716ac3e7ec3 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -18,7 +18,8 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, + MultiModalRegistry) from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata) @@ -52,6 +53,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, "finished_requests_ids": self.finished_requests_ids, + "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -194,6 +196,8 @@ def execute_model( "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_seqlen_agnostic else {} + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -202,6 +206,8 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), **seqlen_agnostic_kwargs) logits = self.model.compute_logits(hidden_or_intermediate_states, @@ -288,8 +294,7 @@ def profile_run(self) -> None: max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( self.model_config) if max_mm_tokens > 0: - raise NotImplementedError( - "Multi-modal encoder-decoder models are not supported yet") + logger.info("Starting profile run for multi-modal models.") batch_size = 0 for group_id in range(max_num_seqs): @@ -297,24 +302,39 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - seq_data, _ = self.input_registry \ - .dummy_data_for_profiling(self.model_config, + decoder_seq_data, decoder_dummy_multi_modal_data \ + = self.input_registry.dummy_data_for_profiling( + self.model_config, seq_len, - self.mm_registry) + self.mm_registry, + is_encoder_data=False) + encoder_seq_data, encoder_dummy_multi_modal_data \ + = self.input_registry.dummy_data_for_profiling( + self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=True) # Having more tokens is over-conservative but otherwise fine - assert len(seq_data.prompt_token_ids) >= seq_len, ( + assert len(decoder_seq_data.prompt_token_ids) >= seq_len, ( f"Expected at least {seq_len} dummy tokens for profiling, " - f"but got: {len(seq_data.prompt_token_ids)}") + f"but got: {len(decoder_seq_data.prompt_token_ids)}") + + assert decoder_dummy_multi_modal_data is None or \ + encoder_dummy_multi_modal_data is None, ( + "Multi-modal data can't be provided in both encoder and decoder" + ) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - seq_data={group_id: seq_data}, + seq_data={group_id: decoder_seq_data}, sampling_params=sampling_params, block_tables=None, - encoder_seq_data=seq_data, + encoder_seq_data=encoder_seq_data, cross_block_table=None, + multi_modal_data=decoder_dummy_multi_modal_data + or encoder_dummy_multi_modal_data, ) seqs.append(seq) diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index a58b80e4f2adb..a07395dfc61d8 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -39,10 +39,6 @@ def assert_enc_dec_mr_supported_scenario( raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) - if enc_dec_mr.model_config.is_multimodal_model: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM']) - if enc_dec_mr.scheduler_config.num_lookahead_slots > 0: raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])