Skip to content

Commit

Permalink
[Frontend] Enable Online Multi-image Support for MLlama (vllm-project…
Browse files Browse the repository at this point in the history
…#9393)

Signed-off-by: Alex-Brooks <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: NickLucche <[email protected]>
  • Loading branch information
2 people authored and NickLucche committed Oct 31, 2024
1 parent c404399 commit e620599
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 37 deletions.
176 changes: 176 additions & 0 deletions tests/entrypoints/test_chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (parse_chat_messages,
parse_chat_messages_futures)
from vllm.entrypoints.llm import apply_hf_chat_template
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import encode_image_base64
from vllm.transformers_utils.tokenizer_group import TokenizerGroup

PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"


@pytest.fixture(scope="module")
Expand All @@ -39,6 +41,30 @@ def phi3v_tokenizer():
)


@pytest.fixture(scope="module")
def mllama_model_config():
return ModelConfig(MLLAMA_MODEL_ID,
task="generate",
tokenizer=MLLAMA_MODEL_ID,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16",
seed=0,
limit_mm_per_prompt={
"image": 2,
})


@pytest.fixture(scope="module")
def mllama_tokenizer():
return TokenizerGroup(
MLLAMA_MODEL_ID,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)


@pytest.fixture(scope="module")
def image_url():
image = ImageAsset('cherry_blossom')
Expand Down Expand Up @@ -414,3 +440,153 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
}]
_assert_mm_data_is_image_input(mm_data, 2)


### Mllama currently wraps images / texts as interleaved dictionaries
def test_mllama_single_image(
mllama_model_config,
mllama_tokenizer,
image_url,
):
"""Ensures that a single image is parsed correctly mllama."""
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [{
'type': 'text',
'text': 'The content of this image is:'
}, {
"image_url": image_url
}]
}], mllama_model_config, mllama_tokenizer)
_assert_mm_data_is_image_input(mm_data, 1)
assert conversation == [{
'role':
'user',
'content': [{
'type': 'text',
'text': 'The content of this image is:'
}, {
'type': 'image'
}]
}]


def test_mllama_interleaved_images(
mllama_model_config,
mllama_tokenizer,
image_url,
):
"""Ensures that multiple image are parsed as interleaved dicts."""
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [
{
'type': 'text',
'text': 'The content of the first image is:'
},
{
"image_url": image_url
},
{
'type': 'text',
'text': 'The content of the second image is:'
},
{
"image_url": image_url
},
]
}], mllama_model_config, mllama_tokenizer)
_assert_mm_data_is_image_input(mm_data, 2)
assert conversation == [{
'role':
'user',
'content': [{
'type': 'text',
'text': 'The content of the first image is:'
}, {
'type': 'image'
}, {
'type': 'text',
'text': 'The content of the second image is:'
}, {
'type': 'image'
}]
}]


@pytest.mark.parametrize("model", [MLLAMA_MODEL_ID])
def test_multimodal_image_parsing_matches_hf(model, image_url):
"""Checks end to end hf alignment for multimodal [image] parsing."""

def get_conversation(is_hf: bool):
img_part = {"type": "image_url", "image_url": {"url": image_url}}
if is_hf:
img_part = {'type': 'image'}
return [{
'role':
'user',
'content': [
{
'type': 'text',
'text': 'The content of the first image is:'
},
img_part,
{
'type': 'text',
'text': 'The content of the second image is:'
},
img_part,
{
'type': 'text',
'text': 'What animal is in the first image?'
},
]
}]

# Build a config for the model
model_config = ModelConfig(model,
task="generate",
tokenizer=MLLAMA_MODEL_ID,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16",
seed=0,
limit_mm_per_prompt={
"image": 2,
})

# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup(
MLLAMA_MODEL_ID,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
tokenizer = tokenizer_group.tokenizer

# Build and parse a conversation with {"type": "image"} using the tokenizer
hf_conversation = get_conversation(is_hf=True)
hf_result = tokenizer.apply_chat_template(
hf_conversation,
tokenize=False,
add_generation_prompt=True,
)

# Now parse with vLLMs chat utils & apply the template
vllm_conversation = get_conversation(is_hf=False)
conversation, _ = parse_chat_messages(
vllm_conversation,
model_config,
tokenizer_group,
)

vllm_result = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=None,
add_generation_prompt=True,
)

assert hf_result == vllm_result
91 changes: 54 additions & 37 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,53 +483,70 @@ def _parse_chat_message_content_parts(
parts: Iterable[ChatCompletionContentPartParam],
mm_tracker: BaseMultiModalItemTracker,
) -> List[ConversationMessage]:
texts: List[str] = []
content: List[Union[str, Dict[str, str]]] = []

mm_parser = mm_tracker.create_parser()
keep_multimodal_content = \
mm_tracker._model_config.hf_config.model_type in \
MODEL_KEEP_MULTI_MODAL_CONTENT

has_image = False
for part in parts:
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
texts.append(text)
else: # Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)

# if part_type is text/refusal/image_url/audio_url but
# content is empty, logg a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning("Skipping multimodal part "
"with empty / unparsable content.")
continue

if part_type in ("text", "refusal"):
texts.append(content)
elif part_type == "image_url":
mm_parser.parse_image(content)
has_image = True
elif part_type == "audio_url":
mm_parser.parse_audio(content)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
parse_res = _parse_chat_message_content_part(
part, mm_parser, wrap_dicts=keep_multimodal_content)
if parse_res:
content.append(parse_res)

text_prompt = "\n".join(texts)
if keep_multimodal_content:
text_prompt = "\n".join(texts)
role_content = [{'type': 'text', 'text': text_prompt}]

if has_image:
role_content = [{'type': 'image'}] + role_content
# Parsing wraps images and texts as interleaved dictionaries
return [ConversationMessage(role=role,
content=role_content)] # type: ignore
else:
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(
mm_placeholder_counts, text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]
content=content)] # type: ignore
texts = cast(List[str], content)
text_prompt = "\n".join(texts)
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
part: ChatCompletionContentPartParam,
mm_parser: BaseMultiModalContentParser,
wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
{"type": "image"}, respectively. Otherwise multimodal data will be
handled by mm_parser, and texts will be returned as strings to be joined
with multimodal placeholders.
"""
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
return text

# Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)

# if part_type is text/refusal/image_url/audio_url but
# content is empty, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning(
"Skipping multimodal part (type: '%s')"
"with empty / unparsable content.", part_type)
return None

if part_type in ("text", "refusal"):
return {'type': 'text', 'text': content} if wrap_dicts else content

if part_type == "image_url":
mm_parser.parse_image(content)
return {'type': 'image'} if wrap_dicts else None

if part_type == "audio_url":
mm_parser.parse_audio(content)
return {'type': 'audio'} if wrap_dicts else None

raise NotImplementedError(f"Unknown part type: {part_type}")


# No need to validate using Pydantic again
Expand Down

0 comments on commit e620599

Please sign in to comment.