Skip to content

Commit

Permalink
Fix mullama content parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-jw-brooks committed Oct 22, 2024
1 parent 1c6fd08 commit eec0eaf
Showing 1 changed file with 56 additions and 28 deletions.
84 changes: 56 additions & 28 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}

# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
Expand Down Expand Up @@ -482,44 +483,71 @@ def _parse_chat_message_content_parts(
parts: Iterable[ChatCompletionContentPartParam],
mm_tracker: BaseMultiModalItemTracker,
) -> List[ConversationMessage]:
texts: List[str] = []
content = []

mm_parser = mm_tracker.create_parser()
keep_multimodal_content = \
mm_tracker._model_config.hf_config.model_type in \
MODEL_KEEP_MULTI_MODAL_CONTENT

for part in parts:
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
texts.append(text)
else: # Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)

# if part_type is text/refusal/image_url/audio_url but
# content is empty, logg a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning("Skipping multimodal part "
"with empty / unparsable content.")
continue

if part_type in ("text", "refusal"):
texts.append(content)
elif part_type == "image_url":
mm_parser.parse_image(content)
# has_image = True
elif part_type == "audio_url":
mm_parser.parse_audio(content)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")

parse_res = _parse_chat_message_content_part(
part, mm_parser, wrap_dicts=keep_multimodal_content)
if parse_res:
content.append(parse_res)

if keep_multimodal_content:
# Parsing wraps images and texts as interleaved dictionaries
return [ConversationMessage(role=role,
content=content)] # type: ignore
texts = cast(List[str], content)
text_prompt = "\n".join(texts)
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(
mm_placeholder_counts,
text_prompt,
)
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content_part(
part: ChatCompletionContentPartParam,
mm_parser: BaseMultiModalContentParser,
wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
{"type": "image"}, respectively. Otherwise multimodal data will be
handled by mm_parser, and texts will be returned as strings to be joined
with multimodal placeholders.
"""
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
return text
else: # Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)

# if part_type is text/refusal/image_url/audio_url but
# content is empty, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning("Skipping multimodal part "
"with empty / unparsable content.")
return None

if part_type in ("text", "refusal"):
if wrap_dicts:
return {'type': 'text', 'text': content}
return content
elif part_type == "image_url":
mm_parser.parse_image(content)
if wrap_dicts:
return {'type': 'image'}
elif part_type == "audio_url":
mm_parser.parse_audio(content)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
return None


# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)
Expand Down

0 comments on commit eec0eaf

Please sign in to comment.