From 563e4a60d8c8afe2974321672adb739f110c1086 Mon Sep 17 00:00:00 2001 From: Timothy Beamish Date: Thu, 29 Feb 2024 09:56:40 -0800 Subject: [PATCH] Merge messages from same role in Vertex if using Gemini --- .../llama_index/llms/vertex/base.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py index 02971bd1833e9f..e327f863e9987c 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py @@ -16,6 +16,7 @@ from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback from llama_index.core.llms.llm import LLM from llama_index.core.types import BaseOutputParser, PydanticProgramMode +from llama_index.llms.gemini.utils import merge_neighboring_same_role_messages from llama_index.llms.vertex.gemini_utils import create_gemini_client, is_gemini_model from llama_index.llms.vertex.utils import ( CHAT_MODELS, @@ -130,7 +131,9 @@ def metadata(self) -> LLMMetadata: return LLMMetadata( is_chat_model=self._is_chat_model, model_name=self.model, - system_role=MessageRole.USER, # Vertex does not support the default: MessageRole.SYSTEM + system_role=( + MessageRole.USER if self._is_gemini else MessageRole.SYSTEM + ), # Gemini does not support the default: MessageRole.SYSTEM ) @property @@ -152,8 +155,13 @@ def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]: @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: - question = _parse_message(messages[-1], self._is_gemini) - chat_history = _parse_chat_history(messages[:-1], self._is_gemini) + merged_messages = ( + merge_neighboring_same_role_messages(messages) + if self._is_gemini + else messages + ) + question = _parse_message(merged_messages[-1], self._is_gemini) + chat_history = _parse_chat_history(merged_messages[:-1], self._is_gemini) chat_params = {**chat_history} kwargs = kwargs if kwargs else {} @@ -209,8 +217,13 @@ def complete( def stream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseGen: - question = _parse_message(messages[-1], self._is_gemini) - chat_history = _parse_chat_history(messages[:-1], self._is_gemini) + merged_messages = ( + merge_neighboring_same_role_messages(messages) + if self._is_gemini + else messages + ) + question = _parse_message(merged_messages[-1], self._is_gemini) + chat_history = _parse_chat_history(merged_messages[:-1], self._is_gemini) chat_params = {**chat_history} kwargs = kwargs if kwargs else {} params = {**self._model_kwargs, **kwargs}