Skip to content

Commit

Permalink
Fix mllama content parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-jw-brooks committed Oct 22, 2024
1 parent 1c6fd08 commit e433606
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 39 deletions.
33 changes: 22 additions & 11 deletions tests/entrypoints/test_chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig
from vllm.entrypoints.llm import apply_hf_chat_template
from vllm.entrypoints.chat_utils import (parse_chat_messages,
parse_chat_messages_futures)
from vllm.entrypoints.llm import apply_hf_chat_template
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import encode_image_base64
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
Expand Down Expand Up @@ -515,27 +515,38 @@ def test_mllama_interleaved_images(
}]
}]


def test_mllama_parse_matches_hf(
mllama_model_config,
mllama_tokenizer,
image_url,
):
"""Checks end to end correctness of hf allignment for mllama parsing."""
"""Checks end to end correctness of hf alignment for mllama parsing."""

def get_conversation(is_hf: bool):
img_part = {"type": "image_url", "image_url": {"url": image_url}}
if is_hf:
img_part = {'type': 'image'}
return [
return [{
'role':
'user',
'content': [
{
'role': 'user', 'content': [
{'type': 'text', 'text': 'The content of the first image is:'},
img_part,
{'type': 'text', 'text': 'The content of the second image is:'},
img_part,
{'type': 'text', 'text': 'What animal is in the first image?'},
]
}
'type': 'text',
'text': 'The content of the first image is:'
},
img_part,
{
'type': 'text',
'text': 'The content of the second image is:'
},
img_part,
{
'type': 'text',
'text': 'What animal is in the first image?'
},
]
}]

tokenizer = mllama_tokenizer.tokenizer
# Build and parse a conversation with {"type": "image"} using the tokenizer
Expand Down
84 changes: 56 additions & 28 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}

# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
Expand Down Expand Up @@ -482,44 +483,71 @@ def _parse_chat_message_content_parts(
parts: Iterable[ChatCompletionContentPartParam],
mm_tracker: BaseMultiModalItemTracker,
) -> List[ConversationMessage]:
texts: List[str] = []
content = []

mm_parser = mm_tracker.create_parser()
keep_multimodal_content = \
mm_tracker._model_config.hf_config.model_type in \
MODEL_KEEP_MULTI_MODAL_CONTENT

for part in parts:
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
texts.append(text)
else: # Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)

# if part_type is text/refusal/image_url/audio_url but
# content is empty, logg a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning("Skipping multimodal part "
"with empty / unparsable content.")
continue

if part_type in ("text", "refusal"):
texts.append(content)
elif part_type == "image_url":
mm_parser.parse_image(content)
# has_image = True
elif part_type == "audio_url":
mm_parser.parse_audio(content)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")

parse_res = _parse_chat_message_content_part(
part, mm_parser, wrap_dicts=keep_multimodal_content)
if parse_res:
content.append(parse_res)

if keep_multimodal_content:
# Parsing wraps images and texts as interleaved dictionaries
return [ConversationMessage(role=role,
content=content)] # type: ignore
texts = cast(List[str], content)
text_prompt = "\n".join(texts)
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(
mm_placeholder_counts,
text_prompt,
)
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
part: ChatCompletionContentPartParam,
mm_parser: BaseMultiModalContentParser,
wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
{"type": "image"}, respectively. Otherwise multimodal data will be
handled by mm_parser, and texts will be returned as strings to be joined
with multimodal placeholders.
"""
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
return text
else: # Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)

# if part_type is text/refusal/image_url/audio_url but
# content is empty, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning("Skipping multimodal part "
"with empty / unparsable content.")
return None

if part_type in ("text", "refusal"):
if wrap_dicts:
return {'type': 'text', 'text': content}
return content
elif part_type == "image_url":
mm_parser.parse_image(content)
if wrap_dicts:
return {'type': 'image'}
elif part_type == "audio_url":
mm_parser.parse_audio(content)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
return None


# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)
Expand Down

0 comments on commit e433606

Please sign in to comment.