Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Enable Online Multi-image Support for MLlama #9393

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions tests/entrypoints/test_chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (parse_chat_messages,
parse_chat_messages_futures)
from vllm.entrypoints.llm import apply_hf_chat_template
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import encode_image_base64
from vllm.transformers_utils.tokenizer_group import TokenizerGroup

PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"


@pytest.fixture(scope="module")
Expand All @@ -39,6 +41,30 @@ def phi3v_tokenizer():
)


@pytest.fixture(scope="module")
def mllama_model_config():
return ModelConfig(MLLAMA_MODEL_ID,
task="generate",
tokenizer=MLLAMA_MODEL_ID,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16",
seed=0,
limit_mm_per_prompt={
"image": 2,
})


@pytest.fixture(scope="module")
def mllama_tokenizer():
return TokenizerGroup(
MLLAMA_MODEL_ID,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)


@pytest.fixture(scope="module")
def image_url():
image = ImageAsset('cherry_blossom')
Expand Down Expand Up @@ -414,3 +440,153 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
}]
_assert_mm_data_is_image_input(mm_data, 2)


### Mllama currently wraps images / texts as interleaved dictionaries
def test_mllama_single_image(
mllama_model_config,
mllama_tokenizer,
image_url,
):
"""Ensures that a single image is parsed correctly mllama."""
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [{
'type': 'text',
'text': 'The content of this image is:'
}, {
"image_url": image_url
}]
}], mllama_model_config, mllama_tokenizer)
_assert_mm_data_is_image_input(mm_data, 1)
assert conversation == [{
'role':
'user',
'content': [{
'type': 'text',
'text': 'The content of this image is:'
}, {
'type': 'image'
}]
}]


def test_mllama_interleaved_images(
mllama_model_config,
mllama_tokenizer,
image_url,
):
"""Ensures that multiple image are parsed as interleaved dicts."""
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [
{
'type': 'text',
'text': 'The content of the first image is:'
},
{
"image_url": image_url
},
{
'type': 'text',
'text': 'The content of the second image is:'
},
{
"image_url": image_url
},
]
}], mllama_model_config, mllama_tokenizer)
_assert_mm_data_is_image_input(mm_data, 2)
assert conversation == [{
'role':
'user',
'content': [{
'type': 'text',
'text': 'The content of the first image is:'
}, {
'type': 'image'
}, {
'type': 'text',
'text': 'The content of the second image is:'
}, {
'type': 'image'
}]
}]


@pytest.mark.parametrize("model", [MLLAMA_MODEL_ID])
def test_multimodal_image_parsing_matches_hf(model, image_url):
"""Checks end to end hf alignment for multimodal [image] parsing."""

def get_conversation(is_hf: bool):
img_part = {"type": "image_url", "image_url": {"url": image_url}}
if is_hf:
img_part = {'type': 'image'}
return [{
'role':
'user',
'content': [
{
'type': 'text',
'text': 'The content of the first image is:'
},
img_part,
{
'type': 'text',
'text': 'The content of the second image is:'
},
img_part,
{
'type': 'text',
'text': 'What animal is in the first image?'
},
]
}]

# Build a config for the model
model_config = ModelConfig(model,
task="generate",
tokenizer=MLLAMA_MODEL_ID,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16",
seed=0,
limit_mm_per_prompt={
"image": 2,
})

# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup(
MLLAMA_MODEL_ID,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
tokenizer = tokenizer_group.tokenizer

# Build and parse a conversation with {"type": "image"} using the tokenizer
hf_conversation = get_conversation(is_hf=True)
hf_result = tokenizer.apply_chat_template(
hf_conversation,
tokenize=False,
add_generation_prompt=True,
)

# Now parse with vLLMs chat utils & apply the template
vllm_conversation = get_conversation(is_hf=False)
conversation, _ = parse_chat_messages(
vllm_conversation,
model_config,
tokenizer_group,
)

vllm_result = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=None,
add_generation_prompt=True,
)

assert hf_result == vllm_result
90 changes: 53 additions & 37 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,53 +483,69 @@ def _parse_chat_message_content_parts(
parts: Iterable[ChatCompletionContentPartParam],
mm_tracker: BaseMultiModalItemTracker,
) -> List[ConversationMessage]:
texts: List[str] = []
content: List[Union[str, Dict[str, str]]] = []

mm_parser = mm_tracker.create_parser()
keep_multimodal_content = \
mm_tracker._model_config.hf_config.model_type in \
MODEL_KEEP_MULTI_MODAL_CONTENT

has_image = False
for part in parts:
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
texts.append(text)
else: # Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)

# if part_type is text/refusal/image_url/audio_url but
# content is empty, logg a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning("Skipping multimodal part "
"with empty / unparsable content.")
continue

if part_type in ("text", "refusal"):
texts.append(content)
elif part_type == "image_url":
mm_parser.parse_image(content)
has_image = True
elif part_type == "audio_url":
mm_parser.parse_audio(content)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
parse_res = _parse_chat_message_content_part(
part, mm_parser, wrap_dicts=keep_multimodal_content)
if parse_res:
content.append(parse_res)

text_prompt = "\n".join(texts)
if keep_multimodal_content:
text_prompt = "\n".join(texts)
role_content = [{'type': 'text', 'text': text_prompt}]

if has_image:
role_content = [{'type': 'image'}] + role_content
# Parsing wraps images and texts as interleaved dictionaries
return [ConversationMessage(role=role,
content=role_content)] # type: ignore
else:
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(
mm_placeholder_counts, text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]
content=content)] # type: ignore
texts = cast(List[str], content)
text_prompt = "\n".join(texts)
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
part: ChatCompletionContentPartParam,
mm_parser: BaseMultiModalContentParser,
wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
{"type": "image"}, respectively. Otherwise multimodal data will be
handled by mm_parser, and texts will be returned as strings to be joined
with multimodal placeholders.
"""
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
return text

# Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)

# if part_type is text/refusal/image_url/audio_url but
# content is empty, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning("Skipping multimodal part "
"with empty / unparsable content.")
alex-jw-brooks marked this conversation as resolved.
Show resolved Hide resolved
return None

if part_type in ("text", "refusal"):
return {'type': 'text', 'text': content} if wrap_dicts else content

if part_type == "image_url":
mm_parser.parse_image(content)
return {'type': 'image'} if wrap_dicts else None

if part_type == "audio_url":
mm_parser.parse_audio(content)
return {'type': 'audio'} if wrap_dicts else None

raise NotImplementedError(f"Unknown part type: {part_type}")


# No need to validate using Pydantic again
Expand Down