From 3d2bc5887a41b4c6f842949243b0a79ed185360a Mon Sep 17 00:00:00 2001 From: Timothy Beamish Date: Thu, 29 Feb 2024 15:14:25 -0800 Subject: [PATCH] Set MessageRole appropriately for Vertex --- .../llama_index/core/base/llms/types.py | 1 + .../llama_index/llms/gemini/utils.py | 5 +++-- .../llama_index/llms/vertex/base.py | 9 +++++++-- .../llama_index/llms/vertex/utils.py | 17 +++++++---------- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/llama-index-core/llama_index/core/base/llms/types.py b/llama-index-core/llama_index/core/base/llms/types.py index 66ad45b08de361..fb19cb58892bd4 100644 --- a/llama-index-core/llama_index/core/base/llms/types.py +++ b/llama-index-core/llama_index/core/base/llms/types.py @@ -14,6 +14,7 @@ class MessageRole(str, Enum): FUNCTION = "function" TOOL = "tool" CHATBOT = "chatbot" + MODEL = "model" # ===== Generic Model Input - Chat ===== diff --git a/llama-index-integrations/llms/llama-index-llms-gemini/llama_index/llms/gemini/utils.py b/llama-index-integrations/llms/llama-index-llms-gemini/llama_index/llms/gemini/utils.py index 8bc01a53c076bf..44415e6a2bec66 100644 --- a/llama-index-integrations/llms/llama-index-llms-gemini/llama_index/llms/gemini/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-gemini/llama_index/llms/gemini/utils.py @@ -3,11 +3,12 @@ import google.ai.generativelanguage as glm import google.generativeai as genai import PIL -from llama_index.core.base.llms.types import MessageRole + from llama_index.core.base.llms.types import ( ChatMessage, ChatResponse, CompletionResponse, + MessageRole, ) ROLES_TO_GEMINI = { @@ -105,7 +106,7 @@ def merge_neighboring_same_role_messages( # Create a new ChatMessage or similar object with merged content merged_message = ChatMessage( - role=current_message.role, + role=ROLES_TO_GEMINI[current_message.role], content="\n".join([str(msg_content) for msg_content in merged_content]), additional_kwargs=current_message.additional_kwargs, ) diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py index e327f863e9987c..08cd7569f2f265 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py @@ -296,8 +296,13 @@ def gen() -> CompletionResponseGen: async def achat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponse: - question = _parse_message(messages[-1], self._is_gemini) - chat_history = _parse_chat_history(messages[:-1], self._is_gemini) + merged_messages = ( + merge_neighboring_same_role_messages(messages) + if self._is_gemini + else messages + ) + question = _parse_message(merged_messages[-1], self._is_gemini) + chat_history = _parse_chat_history(merged_messages[:-1], self._is_gemini) chat_params = {**chat_history} kwargs = kwargs if kwargs else {} params = {**self._model_kwargs, **kwargs} diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/utils.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/utils.py index c9f7bb2f9e97da..69c2ec481a3d2f 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/utils.py @@ -6,7 +6,6 @@ import google.api_core import vertexai -from llama_index.core.base.llms.types import ChatMessage, MessageRole from tenacity import ( before_sleep_log, retry, @@ -14,12 +13,10 @@ stop_after_attempt, wait_exponential, ) -from vertexai.language_models import ( - ChatMessage as VertexChatMessage, -) -from vertexai.language_models import ( - InputOutputTextPair, -) +from vertexai.language_models import ChatMessage as VertexChatMessage +from vertexai.language_models import InputOutputTextPair + +from llama_index.core.base.llms.types import ChatMessage, MessageRole CHAT_MODELS = ["chat-bison", "chat-bison-32k", "chat-bison@001"] TEXT_MODELS = ["text-bison", "text-bison-32k", "text-bison@001"] @@ -171,7 +168,7 @@ def _parse_chat_history(history: Any, is_gemini: bool) -> Any: if is_gemini: raise ValueError("Gemini model don't support system messages") context = message.content - elif message.role == MessageRole.ASSISTANT or message.role == MessageRole.USER: + elif message.role == MessageRole.MODEL or message.role == MessageRole.USER: if is_gemini: from llama_index.llms.vertex.gemini_utils import ( convert_chat_message_to_gemini_content, @@ -185,12 +182,12 @@ def _parse_chat_history(history: Any, is_gemini: bool) -> Any: else: vertex_message = VertexChatMessage( content=message.content, - author="bot" if message.role == MessageRole.ASSISTANT else "user", + author=("bot" if message.role == MessageRole.MODEL else "user"), ) vertex_messages.append(vertex_message) else: raise ValueError( - f"Unexpected message with type {type(message)} at the position {i}." + f"Unexpected message with role {message.role} at the position {i}." ) if len(vertex_messages) % 2 != 0: raise ValueError("total no of messages should be even")