From eec0eaf2ebc5f4a7f783934de87a2cb4e21a48f4 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 22 Oct 2024 02:16:17 -0400 Subject: [PATCH] Fix mullama content parsing --- vllm/entrypoints/chat_utils.py | 84 ++++++++++++++++++++++------------ 1 file changed, 56 insertions(+), 28 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 90e2a9000e0e6..0165f50a8354e 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -410,6 +410,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], _ImageParser = partial(cast, ChatCompletionContentPartImageParam) _AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) +MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'} # Define a mapping from part types to their corresponding parsing functions. MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = { @@ -482,44 +483,71 @@ def _parse_chat_message_content_parts( parts: Iterable[ChatCompletionContentPartParam], mm_tracker: BaseMultiModalItemTracker, ) -> List[ConversationMessage]: - texts: List[str] = [] + content = [] mm_parser = mm_tracker.create_parser() + keep_multimodal_content = \ + mm_tracker._model_config.hf_config.model_type in \ + MODEL_KEEP_MULTI_MODAL_CONTENT for part in parts: - if isinstance(part, str): # Handle plain text parts - text = _TextParser(part) - texts.append(text) - else: # Handle structured dictionary parts - part_type, content = _parse_chat_message_content_mm_part(part) - - # if part_type is text/refusal/image_url/audio_url but - # content is empty, logg a warning and skip - if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content: - logger.warning("Skipping multimodal part " - "with empty / unparsable content.") - continue - - if part_type in ("text", "refusal"): - texts.append(content) - elif part_type == "image_url": - mm_parser.parse_image(content) - # has_image = True - elif part_type == "audio_url": - mm_parser.parse_audio(content) - else: - raise NotImplementedError(f"Unknown part type: {part_type}") - + parse_res = _parse_chat_message_content_part( + part, mm_parser, wrap_dicts=keep_multimodal_content) + if parse_res: + content.append(parse_res) + + if keep_multimodal_content: + # Parsing wraps images and texts as interleaved dictionaries + return [ConversationMessage(role=role, + content=content)] # type: ignore + texts = cast(List[str], content) text_prompt = "\n".join(texts) mm_placeholder_counts = mm_parser.mm_placeholder_counts() if mm_placeholder_counts: - text_prompt = _get_full_multimodal_text_prompt( - mm_placeholder_counts, - text_prompt, - ) + text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, + text_prompt) return [ConversationMessage(role=role, content=text_prompt)] +def _parse_chat_message_content_part( + part: ChatCompletionContentPartParam, + mm_parser: BaseMultiModalContentParser, + wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]: + """Parses a single part of a conversation. If wrap_dicts is True, + structured dictionary pieces for texts and images will be + wrapped in dictionaries, i.e., {"type": "text", "text", ...} and + {"type": "image"}, respectively. Otherwise multimodal data will be + handled by mm_parser, and texts will be returned as strings to be joined + with multimodal placeholders. + """ + if isinstance(part, str): # Handle plain text parts + text = _TextParser(part) + return text + else: # Handle structured dictionary parts + part_type, content = _parse_chat_message_content_mm_part(part) + + # if part_type is text/refusal/image_url/audio_url but + # content is empty, log a warning and skip + if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content: + logger.warning("Skipping multimodal part " + "with empty / unparsable content.") + return None + + if part_type in ("text", "refusal"): + if wrap_dicts: + return {'type': 'text', 'text': content} + return content + elif part_type == "image_url": + mm_parser.parse_image(content) + if wrap_dicts: + return {'type': 'image'} + elif part_type == "audio_url": + mm_parser.parse_audio(content) + else: + raise NotImplementedError(f"Unknown part type: {part_type}") + return None + + # No need to validate using Pydantic again _AssistantParser = partial(cast, ChatCompletionAssistantMessageParam) _ToolParser = partial(cast, ChatCompletionToolMessageParam)