From e620599c101940fb8f1103556796e211b97a3614 Mon Sep 17 00:00:00 2001
From: Alex Brooks <alex.brooks@ibm.com>
Date: Wed, 23 Oct 2024 11:28:57 -0600
Subject: [PATCH] [Frontend] Enable Online Multi-image Support for MLlama
 (#9393)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
---
 tests/entrypoints/test_chat_utils.py | 176 +++++++++++++++++++++++++++
 vllm/entrypoints/chat_utils.py       |  91 ++++++++------
 2 files changed, 230 insertions(+), 37 deletions(-)

diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py
index 1d8c328b73259..f64743e065fc8 100644
--- a/tests/entrypoints/test_chat_utils.py
+++ b/tests/entrypoints/test_chat_utils.py
@@ -8,11 +8,13 @@
 from vllm.config import ModelConfig
 from vllm.entrypoints.chat_utils import (parse_chat_messages,
                                          parse_chat_messages_futures)
+from vllm.entrypoints.llm import apply_hf_chat_template
 from vllm.multimodal import MultiModalDataDict
 from vllm.multimodal.utils import encode_image_base64
 from vllm.transformers_utils.tokenizer_group import TokenizerGroup
 
 PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
+MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
 
 
 @pytest.fixture(scope="module")
@@ -39,6 +41,30 @@ def phi3v_tokenizer():
     )
 
 
+@pytest.fixture(scope="module")
+def mllama_model_config():
+    return ModelConfig(MLLAMA_MODEL_ID,
+                       task="generate",
+                       tokenizer=MLLAMA_MODEL_ID,
+                       tokenizer_mode="auto",
+                       trust_remote_code=True,
+                       dtype="bfloat16",
+                       seed=0,
+                       limit_mm_per_prompt={
+                           "image": 2,
+                       })
+
+
+@pytest.fixture(scope="module")
+def mllama_tokenizer():
+    return TokenizerGroup(
+        MLLAMA_MODEL_ID,
+        enable_lora=False,
+        max_num_seqs=5,
+        max_input_length=None,
+    )
+
+
 @pytest.fixture(scope="module")
 def image_url():
     image = ImageAsset('cherry_blossom')
@@ -414,3 +440,153 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
         "<|image_1|>\n<|image_2|>\nWhat's in these images?"
     }]
     _assert_mm_data_is_image_input(mm_data, 2)
+
+
+### Mllama currently wraps images / texts as interleaved dictionaries
+def test_mllama_single_image(
+    mllama_model_config,
+    mllama_tokenizer,
+    image_url,
+):
+    """Ensures that a single image is parsed correctly mllama."""
+    conversation, mm_data = parse_chat_messages([{
+        "role":
+        "user",
+        "content": [{
+            'type': 'text',
+            'text': 'The content of this image is:'
+        }, {
+            "image_url": image_url
+        }]
+    }], mllama_model_config, mllama_tokenizer)
+    _assert_mm_data_is_image_input(mm_data, 1)
+    assert conversation == [{
+        'role':
+        'user',
+        'content': [{
+            'type': 'text',
+            'text': 'The content of this image is:'
+        }, {
+            'type': 'image'
+        }]
+    }]
+
+
+def test_mllama_interleaved_images(
+    mllama_model_config,
+    mllama_tokenizer,
+    image_url,
+):
+    """Ensures that multiple image are parsed as interleaved dicts."""
+    conversation, mm_data = parse_chat_messages([{
+        "role":
+        "user",
+        "content": [
+            {
+                'type': 'text',
+                'text': 'The content of the first image is:'
+            },
+            {
+                "image_url": image_url
+            },
+            {
+                'type': 'text',
+                'text': 'The content of the second image is:'
+            },
+            {
+                "image_url": image_url
+            },
+        ]
+    }], mllama_model_config, mllama_tokenizer)
+    _assert_mm_data_is_image_input(mm_data, 2)
+    assert conversation == [{
+        'role':
+        'user',
+        'content': [{
+            'type': 'text',
+            'text': 'The content of the first image is:'
+        }, {
+            'type': 'image'
+        }, {
+            'type': 'text',
+            'text': 'The content of the second image is:'
+        }, {
+            'type': 'image'
+        }]
+    }]
+
+
+@pytest.mark.parametrize("model", [MLLAMA_MODEL_ID])
+def test_multimodal_image_parsing_matches_hf(model, image_url):
+    """Checks end to end hf alignment for multimodal [image] parsing."""
+
+    def get_conversation(is_hf: bool):
+        img_part = {"type": "image_url", "image_url": {"url": image_url}}
+        if is_hf:
+            img_part = {'type': 'image'}
+        return [{
+            'role':
+            'user',
+            'content': [
+                {
+                    'type': 'text',
+                    'text': 'The content of the first image is:'
+                },
+                img_part,
+                {
+                    'type': 'text',
+                    'text': 'The content of the second image is:'
+                },
+                img_part,
+                {
+                    'type': 'text',
+                    'text': 'What animal is in the first image?'
+                },
+            ]
+        }]
+
+    # Build a config for the model
+    model_config = ModelConfig(model,
+                               task="generate",
+                               tokenizer=MLLAMA_MODEL_ID,
+                               tokenizer_mode="auto",
+                               trust_remote_code=True,
+                               dtype="bfloat16",
+                               seed=0,
+                               limit_mm_per_prompt={
+                                   "image": 2,
+                               })
+
+    # Build the tokenizer group and grab the underlying tokenizer
+    tokenizer_group = TokenizerGroup(
+        MLLAMA_MODEL_ID,
+        enable_lora=False,
+        max_num_seqs=5,
+        max_input_length=None,
+    )
+    tokenizer = tokenizer_group.tokenizer
+
+    # Build and parse a conversation with {"type": "image"} using the tokenizer
+    hf_conversation = get_conversation(is_hf=True)
+    hf_result = tokenizer.apply_chat_template(
+        hf_conversation,
+        tokenize=False,
+        add_generation_prompt=True,
+    )
+
+    # Now parse with vLLMs chat utils & apply the template
+    vllm_conversation = get_conversation(is_hf=False)
+    conversation, _ = parse_chat_messages(
+        vllm_conversation,
+        model_config,
+        tokenizer_group,
+    )
+
+    vllm_result = apply_hf_chat_template(
+        tokenizer,
+        conversation=conversation,
+        chat_template=None,
+        add_generation_prompt=True,
+    )
+
+    assert hf_result == vllm_result
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index f64af27a957be..ddc5e0b90e858 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -483,53 +483,70 @@ def _parse_chat_message_content_parts(
     parts: Iterable[ChatCompletionContentPartParam],
     mm_tracker: BaseMultiModalItemTracker,
 ) -> List[ConversationMessage]:
-    texts: List[str] = []
+    content: List[Union[str, Dict[str, str]]] = []
 
     mm_parser = mm_tracker.create_parser()
     keep_multimodal_content = \
         mm_tracker._model_config.hf_config.model_type in \
             MODEL_KEEP_MULTI_MODAL_CONTENT
 
-    has_image = False
     for part in parts:
-        if isinstance(part, str):  # Handle plain text parts
-            text = _TextParser(part)
-            texts.append(text)
-        else:  # Handle structured dictionary parts
-            part_type, content = _parse_chat_message_content_mm_part(part)
-
-            # if part_type is text/refusal/image_url/audio_url but
-            # content is empty, logg a warning and skip
-            if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
-                logger.warning("Skipping multimodal part "
-                               "with empty / unparsable content.")
-                continue
-
-            if part_type in ("text", "refusal"):
-                texts.append(content)
-            elif part_type == "image_url":
-                mm_parser.parse_image(content)
-                has_image = True
-            elif part_type == "audio_url":
-                mm_parser.parse_audio(content)
-            else:
-                raise NotImplementedError(f"Unknown part type: {part_type}")
+        parse_res = _parse_chat_message_content_part(
+            part, mm_parser, wrap_dicts=keep_multimodal_content)
+        if parse_res:
+            content.append(parse_res)
 
-    text_prompt = "\n".join(texts)
     if keep_multimodal_content:
-        text_prompt = "\n".join(texts)
-        role_content = [{'type': 'text', 'text': text_prompt}]
-
-        if has_image:
-            role_content = [{'type': 'image'}] + role_content
+        # Parsing wraps images and texts as interleaved dictionaries
         return [ConversationMessage(role=role,
-                                    content=role_content)]  # type: ignore
-    else:
-        mm_placeholder_counts = mm_parser.mm_placeholder_counts()
-        if mm_placeholder_counts:
-            text_prompt = _get_full_multimodal_text_prompt(
-                mm_placeholder_counts, text_prompt)
-        return [ConversationMessage(role=role, content=text_prompt)]
+                                    content=content)]  # type: ignore
+    texts = cast(List[str], content)
+    text_prompt = "\n".join(texts)
+    mm_placeholder_counts = mm_parser.mm_placeholder_counts()
+    if mm_placeholder_counts:
+        text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
+                                                       text_prompt)
+    return [ConversationMessage(role=role, content=text_prompt)]
+
+
+def _parse_chat_message_content_part(
+        part: ChatCompletionContentPartParam,
+        mm_parser: BaseMultiModalContentParser,
+        wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]:
+    """Parses a single part of a conversation. If wrap_dicts is True,
+    structured dictionary pieces for texts and images will be
+    wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
+    {"type": "image"}, respectively. Otherwise multimodal data will be
+    handled by mm_parser, and texts will be returned as strings to be joined
+    with multimodal placeholders.
+    """
+    if isinstance(part, str):  # Handle plain text parts
+        text = _TextParser(part)
+        return text
+
+    # Handle structured dictionary parts
+    part_type, content = _parse_chat_message_content_mm_part(part)
+
+    # if part_type is text/refusal/image_url/audio_url but
+    # content is empty, log a warning and skip
+    if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
+        logger.warning(
+            "Skipping multimodal part (type: '%s')"
+            "with empty / unparsable content.", part_type)
+        return None
+
+    if part_type in ("text", "refusal"):
+        return {'type': 'text', 'text': content} if wrap_dicts else content
+
+    if part_type == "image_url":
+        mm_parser.parse_image(content)
+        return {'type': 'image'} if wrap_dicts else None
+
+    if part_type == "audio_url":
+        mm_parser.parse_audio(content)
+        return {'type': 'audio'} if wrap_dicts else None
+
+    raise NotImplementedError(f"Unknown part type: {part_type}")
 
 
 # No need to validate using Pydantic again