Skip to content

Commit

Permalink
Pass all placeholders for mllama
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-jw-brooks committed Oct 15, 2024
1 parent 5d264f4 commit 8ca401b
Showing 1 changed file with 7 additions and 20 deletions.
27 changes: 7 additions & 20 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,6 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}


def _parse_chat_message_content_parts(
Expand All @@ -395,11 +394,7 @@ def _parse_chat_message_content_parts(
texts: List[str] = []

mm_parser = mm_tracker.create_parser()
keep_multimodal_content = \
mm_tracker._model_config.hf_config.model_type in \
MODEL_KEEP_MULTI_MODAL_CONTENT

has_image = False
for part in parts:
part_type = part["type"]
if part_type == "text":
Expand All @@ -414,7 +409,6 @@ def _parse_chat_message_content_parts(
"will be ignored.")

mm_parser.parse_image(image_url["url"])
has_image = True
elif part_type == "audio_url":
audio_url = _AudioParser(part)["audio_url"]

Expand All @@ -426,20 +420,13 @@ def _parse_chat_message_content_parts(
raise NotImplementedError(f"Unknown part type: {part_type}")

text_prompt = "\n".join(texts)
if keep_multimodal_content:
text_prompt = "\n".join(texts)
role_content = [{'type': 'text', 'text': text_prompt}]

if has_image:
role_content = [{'type': 'image'}] + role_content
return [ConversationMessage(role=role,
content=role_content)] # type: ignore
else:
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(
mm_placeholder_counts, text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(
mm_placeholder_counts,
text_prompt,
)
return [ConversationMessage(role=role, content=text_prompt)]


# No need to validate using Pydantic again
Expand Down

0 comments on commit 8ca401b

Please sign in to comment.