From f580225406fc4f55055f1469f947a16439097a8a Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 23 Oct 2024 11:28:57 -0600 Subject: [PATCH] [Frontend] Enable Online Multi-image Support for MLlama (#9393) Signed-off-by: Alex-Brooks Co-authored-by: Cyrus Leung Signed-off-by: qishuai --- tests/entrypoints/test_chat_utils.py | 176 +++++++++++++++++++++++++++ vllm/entrypoints/chat_utils.py | 91 ++++++++------ 2 files changed, 230 insertions(+), 37 deletions(-) diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 1d8c328b73259..f64743e065fc8 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -8,11 +8,13 @@ from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import (parse_chat_messages, parse_chat_messages_futures) +from vllm.entrypoints.llm import apply_hf_chat_template from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import encode_image_base64 from vllm.transformers_utils.tokenizer_group import TokenizerGroup PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" +MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" @pytest.fixture(scope="module") @@ -39,6 +41,30 @@ def phi3v_tokenizer(): ) +@pytest.fixture(scope="module") +def mllama_model_config(): + return ModelConfig(MLLAMA_MODEL_ID, + task="generate", + tokenizer=MLLAMA_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="bfloat16", + seed=0, + limit_mm_per_prompt={ + "image": 2, + }) + + +@pytest.fixture(scope="module") +def mllama_tokenizer(): + return TokenizerGroup( + MLLAMA_MODEL_ID, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + + @pytest.fixture(scope="module") def image_url(): image = ImageAsset('cherry_blossom') @@ -414,3 +440,153 @@ def test_parse_chat_messages_multiple_images_uncommon_input( "<|image_1|>\n<|image_2|>\nWhat's in these images?" }] _assert_mm_data_is_image_input(mm_data, 2) + + +### Mllama currently wraps images / texts as interleaved dictionaries +def test_mllama_single_image( + mllama_model_config, + mllama_tokenizer, + image_url, +): + """Ensures that a single image is parsed correctly mllama.""" + conversation, mm_data = parse_chat_messages([{ + "role": + "user", + "content": [{ + 'type': 'text', + 'text': 'The content of this image is:' + }, { + "image_url": image_url + }] + }], mllama_model_config, mllama_tokenizer) + _assert_mm_data_is_image_input(mm_data, 1) + assert conversation == [{ + 'role': + 'user', + 'content': [{ + 'type': 'text', + 'text': 'The content of this image is:' + }, { + 'type': 'image' + }] + }] + + +def test_mllama_interleaved_images( + mllama_model_config, + mllama_tokenizer, + image_url, +): + """Ensures that multiple image are parsed as interleaved dicts.""" + conversation, mm_data = parse_chat_messages([{ + "role": + "user", + "content": [ + { + 'type': 'text', + 'text': 'The content of the first image is:' + }, + { + "image_url": image_url + }, + { + 'type': 'text', + 'text': 'The content of the second image is:' + }, + { + "image_url": image_url + }, + ] + }], mllama_model_config, mllama_tokenizer) + _assert_mm_data_is_image_input(mm_data, 2) + assert conversation == [{ + 'role': + 'user', + 'content': [{ + 'type': 'text', + 'text': 'The content of the first image is:' + }, { + 'type': 'image' + }, { + 'type': 'text', + 'text': 'The content of the second image is:' + }, { + 'type': 'image' + }] + }] + + +@pytest.mark.parametrize("model", [MLLAMA_MODEL_ID]) +def test_multimodal_image_parsing_matches_hf(model, image_url): + """Checks end to end hf alignment for multimodal [image] parsing.""" + + def get_conversation(is_hf: bool): + img_part = {"type": "image_url", "image_url": {"url": image_url}} + if is_hf: + img_part = {'type': 'image'} + return [{ + 'role': + 'user', + 'content': [ + { + 'type': 'text', + 'text': 'The content of the first image is:' + }, + img_part, + { + 'type': 'text', + 'text': 'The content of the second image is:' + }, + img_part, + { + 'type': 'text', + 'text': 'What animal is in the first image?' + }, + ] + }] + + # Build a config for the model + model_config = ModelConfig(model, + task="generate", + tokenizer=MLLAMA_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="bfloat16", + seed=0, + limit_mm_per_prompt={ + "image": 2, + }) + + # Build the tokenizer group and grab the underlying tokenizer + tokenizer_group = TokenizerGroup( + MLLAMA_MODEL_ID, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + tokenizer = tokenizer_group.tokenizer + + # Build and parse a conversation with {"type": "image"} using the tokenizer + hf_conversation = get_conversation(is_hf=True) + hf_result = tokenizer.apply_chat_template( + hf_conversation, + tokenize=False, + add_generation_prompt=True, + ) + + # Now parse with vLLMs chat utils & apply the template + vllm_conversation = get_conversation(is_hf=False) + conversation, _ = parse_chat_messages( + vllm_conversation, + model_config, + tokenizer_group, + ) + + vllm_result = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=None, + add_generation_prompt=True, + ) + + assert hf_result == vllm_result diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index f64af27a957be..ddc5e0b90e858 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -483,53 +483,70 @@ def _parse_chat_message_content_parts( parts: Iterable[ChatCompletionContentPartParam], mm_tracker: BaseMultiModalItemTracker, ) -> List[ConversationMessage]: - texts: List[str] = [] + content: List[Union[str, Dict[str, str]]] = [] mm_parser = mm_tracker.create_parser() keep_multimodal_content = \ mm_tracker._model_config.hf_config.model_type in \ MODEL_KEEP_MULTI_MODAL_CONTENT - has_image = False for part in parts: - if isinstance(part, str): # Handle plain text parts - text = _TextParser(part) - texts.append(text) - else: # Handle structured dictionary parts - part_type, content = _parse_chat_message_content_mm_part(part) - - # if part_type is text/refusal/image_url/audio_url but - # content is empty, logg a warning and skip - if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content: - logger.warning("Skipping multimodal part " - "with empty / unparsable content.") - continue - - if part_type in ("text", "refusal"): - texts.append(content) - elif part_type == "image_url": - mm_parser.parse_image(content) - has_image = True - elif part_type == "audio_url": - mm_parser.parse_audio(content) - else: - raise NotImplementedError(f"Unknown part type: {part_type}") + parse_res = _parse_chat_message_content_part( + part, mm_parser, wrap_dicts=keep_multimodal_content) + if parse_res: + content.append(parse_res) - text_prompt = "\n".join(texts) if keep_multimodal_content: - text_prompt = "\n".join(texts) - role_content = [{'type': 'text', 'text': text_prompt}] - - if has_image: - role_content = [{'type': 'image'}] + role_content + # Parsing wraps images and texts as interleaved dictionaries return [ConversationMessage(role=role, - content=role_content)] # type: ignore - else: - mm_placeholder_counts = mm_parser.mm_placeholder_counts() - if mm_placeholder_counts: - text_prompt = _get_full_multimodal_text_prompt( - mm_placeholder_counts, text_prompt) - return [ConversationMessage(role=role, content=text_prompt)] + content=content)] # type: ignore + texts = cast(List[str], content) + text_prompt = "\n".join(texts) + mm_placeholder_counts = mm_parser.mm_placeholder_counts() + if mm_placeholder_counts: + text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, + text_prompt) + return [ConversationMessage(role=role, content=text_prompt)] + + +def _parse_chat_message_content_part( + part: ChatCompletionContentPartParam, + mm_parser: BaseMultiModalContentParser, + wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]: + """Parses a single part of a conversation. If wrap_dicts is True, + structured dictionary pieces for texts and images will be + wrapped in dictionaries, i.e., {"type": "text", "text", ...} and + {"type": "image"}, respectively. Otherwise multimodal data will be + handled by mm_parser, and texts will be returned as strings to be joined + with multimodal placeholders. + """ + if isinstance(part, str): # Handle plain text parts + text = _TextParser(part) + return text + + # Handle structured dictionary parts + part_type, content = _parse_chat_message_content_mm_part(part) + + # if part_type is text/refusal/image_url/audio_url but + # content is empty, log a warning and skip + if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content: + logger.warning( + "Skipping multimodal part (type: '%s')" + "with empty / unparsable content.", part_type) + return None + + if part_type in ("text", "refusal"): + return {'type': 'text', 'text': content} if wrap_dicts else content + + if part_type == "image_url": + mm_parser.parse_image(content) + return {'type': 'image'} if wrap_dicts else None + + if part_type == "audio_url": + mm_parser.parse_audio(content) + return {'type': 'audio'} if wrap_dicts else None + + raise NotImplementedError(f"Unknown part type: {part_type}") # No need to validate using Pydantic again