From 1d49e15f4b91f6e4b931d8ae42f69dc678ce8ee4 Mon Sep 17 00:00:00 2001 From: Logan Date: Fri, 20 Sep 2024 13:32:09 -0600 Subject: [PATCH] Update chat message class for multi-modal (#15969) --- .../core/agent/legacy/react/base.py | 2 + .../llama_index/core/agent/react/step.py | 1 + .../core/agent/react_multimodal/step.py | 1 + .../llama_index/core/base/llms/base.py | 22 +++++++++ .../llama_index/core/base/llms/types.py | 49 ++++++++++++++++++- .../llama_index/core/chat_engine/simple.py | 32 ++++++++++-- .../llama_index/core/llms/function_calling.py | 1 + llama-index-core/llama_index/core/llms/llm.py | 18 +++++-- .../core/memory/chat_summary_memory_buffer.py | 3 ++ .../llama_index/core/prompts/base.py | 4 +- llama-index-core/llama_index/core/types.py | 6 ++- .../core/utilities/token_counting.py | 2 +- .../reflective/self_reflection.py | 8 ++- 13 files changed, 134 insertions(+), 15 deletions(-) diff --git a/llama-index-core/llama_index/core/agent/legacy/react/base.py b/llama-index-core/llama_index/core/agent/legacy/react/base.py index 9d70172bcd2dd..04ef72d55e14c 100644 --- a/llama-index-core/llama_index/core/agent/legacy/react/base.py +++ b/llama-index-core/llama_index/core/agent/legacy/react/base.py @@ -151,6 +151,7 @@ def _extract_reasoning_step( if output.message.content is None: raise ValueError("Got empty message.") message_content = output.message.content + current_reasoning = [] try: reasoning_step = self._output_parser.parse(message_content, is_streaming) @@ -268,6 +269,7 @@ def _infer_stream_chunk_is_final(self, chunk: ChatResponse) -> bool: bool: Boolean on whether the chunk is the start of the final response """ latest_content = chunk.message.content + if latest_content: if not latest_content.startswith( "Thought" diff --git a/llama-index-core/llama_index/core/agent/react/step.py b/llama-index-core/llama_index/core/agent/react/step.py index bf4e6e1f0ff10..c43b729029d30 100644 --- a/llama-index-core/llama_index/core/agent/react/step.py +++ b/llama-index-core/llama_index/core/agent/react/step.py @@ -237,6 +237,7 @@ def _extract_reasoning_step( if output.message.content is None: raise ValueError("Got empty message.") message_content = output.message.content + current_reasoning = [] try: reasoning_step = self._output_parser.parse(message_content, is_streaming) diff --git a/llama-index-core/llama_index/core/agent/react_multimodal/step.py b/llama-index-core/llama_index/core/agent/react_multimodal/step.py index c42104cbfdf44..ff087a3ef0159 100644 --- a/llama-index-core/llama_index/core/agent/react_multimodal/step.py +++ b/llama-index-core/llama_index/core/agent/react_multimodal/step.py @@ -244,6 +244,7 @@ def _extract_reasoning_step( if output.message.content is None: raise ValueError("Got empty message.") message_content = output.message.content + current_reasoning = [] try: reasoning_step = self._output_parser.parse(message_content, is_streaming) diff --git a/llama-index-core/llama_index/core/base/llms/base.py b/llama-index-core/llama_index/core/base/llms/base.py index cca8a8d78065e..860b31586b461 100644 --- a/llama-index-core/llama_index/core/base/llms/base.py +++ b/llama-index-core/llama_index/core/base/llms/base.py @@ -1,6 +1,7 @@ from abc import abstractmethod from typing import ( Any, + List, Sequence, ) @@ -13,6 +14,7 @@ CompletionResponseAsyncGen, CompletionResponseGen, LLMMetadata, + TextBlock, ) from llama_index.core.base.query_pipeline.query import ( ChainableMixin, @@ -46,6 +48,26 @@ def metadata(self) -> LLMMetadata: LLMMetadata: LLM metadata containing various information about the LLM. """ + def convert_chat_messages(self, messages: Sequence[ChatMessage]) -> List[Any]: + """Convert chat messages to an LLM specific message format.""" + converted_messages = [] + for message in messages: + if isinstance(message.content, str): + converted_messages.append(message) + elif isinstance(message.content, List): + content_string = "" + for block in message.content: + if isinstance(block, TextBlock): + content_string += block.text + else: + raise ValueError("LLM only supports text inputs") + message.content = content_string + converted_messages.append(message) + else: + raise ValueError(f"Invalid message content: {message.content!s}") + + return converted_messages + @abstractmethod def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: """Chat endpoint for LLM. diff --git a/llama-index-core/llama_index/core/base/llms/types.py b/llama-index-core/llama_index/core/base/llms/types.py index 8c3248e3769f3..971db8e743841 100644 --- a/llama-index-core/llama_index/core/base/llms/types.py +++ b/llama-index-core/llama_index/core/base/llms/types.py @@ -1,8 +1,22 @@ +import base64 +import requests from enum import Enum -from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union, List, Any +from io import BytesIO +from typing import ( + Any, + AsyncGenerator, + Dict, + Generator, + Literal, + Optional, + Union, + List, + Any, +) from llama_index.core.bridge.pydantic import BaseModel, Field, ConfigDict from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS +from llama_index.core.schema import ImageType try: from pydantic import BaseModel as V2BaseModel @@ -26,6 +40,39 @@ class MessageRole(str, Enum): # ===== Generic Model Input - Chat ===== +class ContentBlockTypes(str, Enum): + TEXT = "text" + IMAGE = "image" + + +class TextBlock(BaseModel): + type: Literal[ContentBlockTypes.TEXT] = ContentBlockTypes.TEXT + + text: str + + +class ImageBlock(BaseModel): + type: Literal[ContentBlockTypes.IMAGE] = ContentBlockTypes.IMAGE + + image: Optional[str] = None + image_path: Optional[str] = None + image_url: Optional[str] = None + image_mimetype: Optional[str] = None + + def resolve_image(self) -> ImageType: + """Resolve an image such that PIL can read it.""" + if self.image is not None: + return BytesIO(base64.b64decode(self.image)) + elif self.image_path is not None: + return self.image_path + elif self.image_url is not None: + # load image from URL + response = requests.get(self.image_url) + return BytesIO(response.content) + else: + raise ValueError("No image found in the chat message!") + + class ChatMessage(BaseModel): """Chat message.""" diff --git a/llama-index-core/llama_index/core/chat_engine/simple.py b/llama-index-core/llama_index/core/chat_engine/simple.py index 6a35de5570eca..c96b000b32888 100644 --- a/llama-index-core/llama_index/core/chat_engine/simple.py +++ b/llama-index-core/llama_index/core/chat_engine/simple.py @@ -80,7 +80,13 @@ def chat( if hasattr(self._memory, "tokenizer_fn"): initial_token_count = len( self._memory.tokenizer_fn( - " ".join([(m.content or "") for m in self._prefix_messages]) + " ".join( + [ + (m.content or "") + for m in self._prefix_messages + if isinstance(m.content, str) + ] + ) ) ) else: @@ -107,7 +113,13 @@ def stream_chat( if hasattr(self._memory, "tokenizer_fn"): initial_token_count = len( self._memory.tokenizer_fn( - " ".join([(m.content or "") for m in self._prefix_messages]) + " ".join( + [ + (m.content or "") + for m in self._prefix_messages + if isinstance(m.content, str) + ] + ) ) ) else: @@ -138,7 +150,13 @@ async def achat( if hasattr(self._memory, "tokenizer_fn"): initial_token_count = len( self._memory.tokenizer_fn( - " ".join([(m.content or "") for m in self._prefix_messages]) + " ".join( + [ + (m.content or "") + for m in self._prefix_messages + if isinstance(m.content, str) + ] + ) ) ) else: @@ -165,7 +183,13 @@ async def astream_chat( if hasattr(self._memory, "tokenizer_fn"): initial_token_count = len( self._memory.tokenizer_fn( - " ".join([(m.content or "") for m in self._prefix_messages]) + " ".join( + [ + (m.content or "") + for m in self._prefix_messages + if isinstance(m.content, str) + ] + ) ) ) else: diff --git a/llama-index-core/llama_index/core/llms/function_calling.py b/llama-index-core/llama_index/core/llms/function_calling.py index c4809ef554877..005d676c891ca 100644 --- a/llama-index-core/llama_index/core/llms/function_calling.py +++ b/llama-index-core/llama_index/core/llms/function_calling.py @@ -244,6 +244,7 @@ async def apredict_and_call( allow_parallel_tool_calls=allow_parallel_tool_calls, **kwargs, ) + tool_calls = self.get_tool_calls_from_response( response, error_on_no_tool_call=error_on_no_tool_call ) diff --git a/llama-index-core/llama_index/core/llms/llm.py b/llama-index-core/llama_index/core/llms/llm.py index 813c726e5a0fe..8978f6c162cef 100644 --- a/llama-index-core/llama_index/core/llms/llm.py +++ b/llama-index-core/llama_index/core/llms/llm.py @@ -755,11 +755,16 @@ def predict_and_call( handle_reasoning_failure_fn=kwargs.get("handle_reasoning_failure_fn", None), ) - if isinstance(user_msg, ChatMessage): + if isinstance(user_msg, ChatMessage) and isinstance(user_msg.content, str): user_msg = user_msg.content elif isinstance(user_msg, str): pass - elif not user_msg and chat_history is not None and len(chat_history) > 0: + elif ( + not user_msg + and chat_history is not None + and len(chat_history) > 0 + and isinstance(chat_history[-1].content, str) + ): user_msg = chat_history[-1].content else: raise ValueError("No user message provided or found in chat history.") @@ -813,11 +818,16 @@ async def apredict_and_call( handle_reasoning_failure_fn=kwargs.get("handle_reasoning_failure_fn", None), ) - if isinstance(user_msg, ChatMessage): + if isinstance(user_msg, ChatMessage) and isinstance(user_msg.content, str): user_msg = user_msg.content elif isinstance(user_msg, str): pass - elif not user_msg and chat_history is not None and len(chat_history) > 0: + elif ( + not user_msg + and chat_history is not None + and len(chat_history) > 0 + and isinstance(chat_history[-1].content, str) + ): user_msg = chat_history[-1].content else: raise ValueError("No user message provided or found in chat history.") diff --git a/llama-index-core/llama_index/core/memory/chat_summary_memory_buffer.py b/llama-index-core/llama_index/core/memory/chat_summary_memory_buffer.py index 8770e8f1f491b..45f25781d3c4f 100644 --- a/llama-index-core/llama_index/core/memory/chat_summary_memory_buffer.py +++ b/llama-index-core/llama_index/core/memory/chat_summary_memory_buffer.py @@ -286,6 +286,9 @@ def _get_prompt_to_summarize( # TODO: This probably works better when question/answers are considered together. prompt = '"Transcript so far: ' for msg in chat_history_to_be_summarized: + if not isinstance(msg.content, str): + continue + prompt += msg.role + ": " if msg.content: prompt += msg.content + "\n\n" diff --git a/llama-index-core/llama_index/core/prompts/base.py b/llama-index-core/llama_index/core/prompts/base.py index 80982d9f80f60..bf62357a69a76 100644 --- a/llama-index-core/llama_index/core/prompts/base.py +++ b/llama-index-core/llama_index/core/prompts/base.py @@ -304,7 +304,9 @@ def format_messages( messages: List[ChatMessage] = [] for message_template in self.message_templates: - template_vars = get_template_vars(message_template.content or "") + message_content = message_template.content or "" + + template_vars = get_template_vars(message_content) relevant_kwargs = { k: v for k, v in mapped_all_kwargs.items() if k in template_vars } diff --git a/llama-index-core/llama_index/core/types.py b/llama-index-core/llama_index/core/types.py index 63cd2365197da..4c7a0d84efc53 100644 --- a/llama-index-core/llama_index/core/types.py +++ b/llama-index-core/llama_index/core/types.py @@ -53,9 +53,11 @@ def format_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]: # or the last message if messages: if messages[0].role == MessageRole.SYSTEM: - messages[0].content = self.format(messages[0].content or "") + message_content = messages[0].content or "" + messages[0].content = self.format(message_content) else: - messages[-1].content = self.format(messages[-1].content or "") + message_content = messages[-1].content or "" + messages[-1].content = self.format(message_content) return messages diff --git a/llama-index-core/llama_index/core/utilities/token_counting.py b/llama-index-core/llama_index/core/utilities/token_counting.py index a06d0c14c3879..b6ddc9892f200 100644 --- a/llama-index-core/llama_index/core/utilities/token_counting.py +++ b/llama-index-core/llama_index/core/utilities/token_counting.py @@ -43,7 +43,7 @@ def estimate_tokens_in_messages(self, messages: List[ChatMessage]) -> int: if message.role: tokens += self.get_string_tokens(message.role) - if message.content: + if isinstance(message.content, str): tokens += self.get_string_tokens(message.content) additional_kwargs = {**message.additional_kwargs} diff --git a/llama-index-integrations/agent/llama-index-agent-introspective/llama_index/agent/introspective/reflective/self_reflection.py b/llama-index-integrations/agent/llama-index-agent-introspective/llama_index/agent/introspective/reflective/self_reflection.py index 9df7a8b6f48fc..d4740f13f1063 100644 --- a/llama-index-integrations/agent/llama-index-agent-introspective/llama_index/agent/introspective/reflective/self_reflection.py +++ b/llama-index-integrations/agent/llama-index-agent-introspective/llama_index/agent/introspective/reflective/self_reflection.py @@ -269,7 +269,9 @@ def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput: reflection, reflection_msg = self._reflect(chat_history=messages) is_done = reflection.is_done - critique_msg = ChatMessage(role=MessageRole.USER, content=reflection_msg) + critique_msg = ChatMessage( + role=MessageRole.USER, content=reflection_msg.content + ) task.extra_state["new_memory"].put(critique_msg) # correction phase @@ -386,7 +388,9 @@ async def arun_step( reflection, reflection_msg = await self._areflect(chat_history=messages) is_done = reflection.is_done - critique_msg = ChatMessage(role=MessageRole.USER, content=reflection_msg) + critique_msg = ChatMessage( + role=MessageRole.USER, content=reflection_msg.content + ) task.extra_state["new_memory"].put(critique_msg) # correction phase