Skip to content

Commit

Permalink
Update chat message class for multi-modal (#15969)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich authored Sep 20, 2024
1 parent fcc89e8 commit 1d49e15
Show file tree
Hide file tree
Showing 13 changed files with 134 additions and 15 deletions.
2 changes: 2 additions & 0 deletions llama-index-core/llama_index/core/agent/legacy/react/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def _extract_reasoning_step(
if output.message.content is None:
raise ValueError("Got empty message.")
message_content = output.message.content

current_reasoning = []
try:
reasoning_step = self._output_parser.parse(message_content, is_streaming)
Expand Down Expand Up @@ -268,6 +269,7 @@ def _infer_stream_chunk_is_final(self, chunk: ChatResponse) -> bool:
bool: Boolean on whether the chunk is the start of the final response
"""
latest_content = chunk.message.content

if latest_content:
if not latest_content.startswith(
"Thought"
Expand Down
1 change: 1 addition & 0 deletions llama-index-core/llama_index/core/agent/react/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def _extract_reasoning_step(
if output.message.content is None:
raise ValueError("Got empty message.")
message_content = output.message.content

current_reasoning = []
try:
reasoning_step = self._output_parser.parse(message_content, is_streaming)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def _extract_reasoning_step(
if output.message.content is None:
raise ValueError("Got empty message.")
message_content = output.message.content

current_reasoning = []
try:
reasoning_step = self._output_parser.parse(message_content, is_streaming)
Expand Down
22 changes: 22 additions & 0 deletions llama-index-core/llama_index/core/base/llms/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import abstractmethod
from typing import (
Any,
List,
Sequence,
)

Expand All @@ -13,6 +14,7 @@
CompletionResponseAsyncGen,
CompletionResponseGen,
LLMMetadata,
TextBlock,
)
from llama_index.core.base.query_pipeline.query import (
ChainableMixin,
Expand Down Expand Up @@ -46,6 +48,26 @@ def metadata(self) -> LLMMetadata:
LLMMetadata: LLM metadata containing various information about the LLM.
"""

def convert_chat_messages(self, messages: Sequence[ChatMessage]) -> List[Any]:
"""Convert chat messages to an LLM specific message format."""
converted_messages = []
for message in messages:
if isinstance(message.content, str):
converted_messages.append(message)
elif isinstance(message.content, List):
content_string = ""
for block in message.content:
if isinstance(block, TextBlock):
content_string += block.text
else:
raise ValueError("LLM only supports text inputs")
message.content = content_string
converted_messages.append(message)
else:
raise ValueError(f"Invalid message content: {message.content!s}")

return converted_messages

@abstractmethod
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
"""Chat endpoint for LLM.
Expand Down
49 changes: 48 additions & 1 deletion llama-index-core/llama_index/core/base/llms/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
import base64
import requests
from enum import Enum
from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union, List, Any
from io import BytesIO
from typing import (
Any,
AsyncGenerator,
Dict,
Generator,
Literal,
Optional,
Union,
List,
Any,
)

from llama_index.core.bridge.pydantic import BaseModel, Field, ConfigDict
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
from llama_index.core.schema import ImageType

try:
from pydantic import BaseModel as V2BaseModel
Expand All @@ -26,6 +40,39 @@ class MessageRole(str, Enum):


# ===== Generic Model Input - Chat =====
class ContentBlockTypes(str, Enum):
TEXT = "text"
IMAGE = "image"


class TextBlock(BaseModel):
type: Literal[ContentBlockTypes.TEXT] = ContentBlockTypes.TEXT

text: str


class ImageBlock(BaseModel):
type: Literal[ContentBlockTypes.IMAGE] = ContentBlockTypes.IMAGE

image: Optional[str] = None
image_path: Optional[str] = None
image_url: Optional[str] = None
image_mimetype: Optional[str] = None

def resolve_image(self) -> ImageType:
"""Resolve an image such that PIL can read it."""
if self.image is not None:
return BytesIO(base64.b64decode(self.image))
elif self.image_path is not None:
return self.image_path
elif self.image_url is not None:
# load image from URL
response = requests.get(self.image_url)
return BytesIO(response.content)
else:
raise ValueError("No image found in the chat message!")


class ChatMessage(BaseModel):
"""Chat message."""

Expand Down
32 changes: 28 additions & 4 deletions llama-index-core/llama_index/core/chat_engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,13 @@ def chat(
if hasattr(self._memory, "tokenizer_fn"):
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in self._prefix_messages])
" ".join(
[
(m.content or "")
for m in self._prefix_messages
if isinstance(m.content, str)
]
)
)
)
else:
Expand All @@ -107,7 +113,13 @@ def stream_chat(
if hasattr(self._memory, "tokenizer_fn"):
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in self._prefix_messages])
" ".join(
[
(m.content or "")
for m in self._prefix_messages
if isinstance(m.content, str)
]
)
)
)
else:
Expand Down Expand Up @@ -138,7 +150,13 @@ async def achat(
if hasattr(self._memory, "tokenizer_fn"):
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in self._prefix_messages])
" ".join(
[
(m.content or "")
for m in self._prefix_messages
if isinstance(m.content, str)
]
)
)
)
else:
Expand All @@ -165,7 +183,13 @@ async def astream_chat(
if hasattr(self._memory, "tokenizer_fn"):
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in self._prefix_messages])
" ".join(
[
(m.content or "")
for m in self._prefix_messages
if isinstance(m.content, str)
]
)
)
)
else:
Expand Down
1 change: 1 addition & 0 deletions llama-index-core/llama_index/core/llms/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ async def apredict_and_call(
allow_parallel_tool_calls=allow_parallel_tool_calls,
**kwargs,
)

tool_calls = self.get_tool_calls_from_response(
response, error_on_no_tool_call=error_on_no_tool_call
)
Expand Down
18 changes: 14 additions & 4 deletions llama-index-core/llama_index/core/llms/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,11 +755,16 @@ def predict_and_call(
handle_reasoning_failure_fn=kwargs.get("handle_reasoning_failure_fn", None),
)

if isinstance(user_msg, ChatMessage):
if isinstance(user_msg, ChatMessage) and isinstance(user_msg.content, str):
user_msg = user_msg.content
elif isinstance(user_msg, str):
pass
elif not user_msg and chat_history is not None and len(chat_history) > 0:
elif (
not user_msg
and chat_history is not None
and len(chat_history) > 0
and isinstance(chat_history[-1].content, str)
):
user_msg = chat_history[-1].content
else:
raise ValueError("No user message provided or found in chat history.")
Expand Down Expand Up @@ -813,11 +818,16 @@ async def apredict_and_call(
handle_reasoning_failure_fn=kwargs.get("handle_reasoning_failure_fn", None),
)

if isinstance(user_msg, ChatMessage):
if isinstance(user_msg, ChatMessage) and isinstance(user_msg.content, str):
user_msg = user_msg.content
elif isinstance(user_msg, str):
pass
elif not user_msg and chat_history is not None and len(chat_history) > 0:
elif (
not user_msg
and chat_history is not None
and len(chat_history) > 0
and isinstance(chat_history[-1].content, str)
):
user_msg = chat_history[-1].content
else:
raise ValueError("No user message provided or found in chat history.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ def _get_prompt_to_summarize(
# TODO: This probably works better when question/answers are considered together.
prompt = '"Transcript so far: '
for msg in chat_history_to_be_summarized:
if not isinstance(msg.content, str):
continue

prompt += msg.role + ": "
if msg.content:
prompt += msg.content + "\n\n"
Expand Down
4 changes: 3 additions & 1 deletion llama-index-core/llama_index/core/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ def format_messages(

messages: List[ChatMessage] = []
for message_template in self.message_templates:
template_vars = get_template_vars(message_template.content or "")
message_content = message_template.content or ""

template_vars = get_template_vars(message_content)
relevant_kwargs = {
k: v for k, v in mapped_all_kwargs.items() if k in template_vars
}
Expand Down
6 changes: 4 additions & 2 deletions llama-index-core/llama_index/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ def format_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]:
# or the last message
if messages:
if messages[0].role == MessageRole.SYSTEM:
messages[0].content = self.format(messages[0].content or "")
message_content = messages[0].content or ""
messages[0].content = self.format(message_content)
else:
messages[-1].content = self.format(messages[-1].content or "")
message_content = messages[-1].content or ""
messages[-1].content = self.format(message_content)

return messages

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def estimate_tokens_in_messages(self, messages: List[ChatMessage]) -> int:
if message.role:
tokens += self.get_string_tokens(message.role)

if message.content:
if isinstance(message.content, str):
tokens += self.get_string_tokens(message.content)

additional_kwargs = {**message.additional_kwargs}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,9 @@ def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
reflection, reflection_msg = self._reflect(chat_history=messages)
is_done = reflection.is_done

critique_msg = ChatMessage(role=MessageRole.USER, content=reflection_msg)
critique_msg = ChatMessage(
role=MessageRole.USER, content=reflection_msg.content
)
task.extra_state["new_memory"].put(critique_msg)

# correction phase
Expand Down Expand Up @@ -386,7 +388,9 @@ async def arun_step(
reflection, reflection_msg = await self._areflect(chat_history=messages)
is_done = reflection.is_done

critique_msg = ChatMessage(role=MessageRole.USER, content=reflection_msg)
critique_msg = ChatMessage(
role=MessageRole.USER, content=reflection_msg.content
)
task.extra_state["new_memory"].put(critique_msg)

# correction phase
Expand Down

0 comments on commit 1d49e15

Please sign in to comment.