From 8ca401baeee57585856f2007975ceae848220b63 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 15 Oct 2024 17:59:59 -0400 Subject: [PATCH] Pass all placeholders for mllama --- vllm/entrypoints/chat_utils.py | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 785dcbfa83119..e682225c3d37c 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -384,7 +384,6 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], _ImageParser = partial(cast, ChatCompletionContentPartImageParam) _AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) -MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'} def _parse_chat_message_content_parts( @@ -395,11 +394,7 @@ def _parse_chat_message_content_parts( texts: List[str] = [] mm_parser = mm_tracker.create_parser() - keep_multimodal_content = \ - mm_tracker._model_config.hf_config.model_type in \ - MODEL_KEEP_MULTI_MODAL_CONTENT - has_image = False for part in parts: part_type = part["type"] if part_type == "text": @@ -414,7 +409,6 @@ def _parse_chat_message_content_parts( "will be ignored.") mm_parser.parse_image(image_url["url"]) - has_image = True elif part_type == "audio_url": audio_url = _AudioParser(part)["audio_url"] @@ -426,20 +420,13 @@ def _parse_chat_message_content_parts( raise NotImplementedError(f"Unknown part type: {part_type}") text_prompt = "\n".join(texts) - if keep_multimodal_content: - text_prompt = "\n".join(texts) - role_content = [{'type': 'text', 'text': text_prompt}] - - if has_image: - role_content = [{'type': 'image'}] + role_content - return [ConversationMessage(role=role, - content=role_content)] # type: ignore - else: - mm_placeholder_counts = mm_parser.mm_placeholder_counts() - if mm_placeholder_counts: - text_prompt = _get_full_multimodal_text_prompt( - mm_placeholder_counts, text_prompt) - return [ConversationMessage(role=role, content=text_prompt)] + mm_placeholder_counts = mm_parser.mm_placeholder_counts() + if mm_placeholder_counts: + text_prompt = _get_full_multimodal_text_prompt( + mm_placeholder_counts, + text_prompt, + ) + return [ConversationMessage(role=role, content=text_prompt)] # No need to validate using Pydantic again