Skip to content

Commit

Permalink
Merge messages from same role in Vertex if using Gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
tbeamish-benchsci committed Feb 29, 2024
1 parent 348cad7 commit 563e4a6
Showing 1 changed file with 18 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback
from llama_index.core.llms.llm import LLM
from llama_index.core.types import BaseOutputParser, PydanticProgramMode
from llama_index.llms.gemini.utils import merge_neighboring_same_role_messages
from llama_index.llms.vertex.gemini_utils import create_gemini_client, is_gemini_model
from llama_index.llms.vertex.utils import (
CHAT_MODELS,
Expand Down Expand Up @@ -130,7 +131,9 @@ def metadata(self) -> LLMMetadata:
return LLMMetadata(
is_chat_model=self._is_chat_model,
model_name=self.model,
system_role=MessageRole.USER, # Vertex does not support the default: MessageRole.SYSTEM
system_role=(
MessageRole.USER if self._is_gemini else MessageRole.SYSTEM
), # Gemini does not support the default: MessageRole.SYSTEM
)

@property
Expand All @@ -152,8 +155,13 @@ def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:

@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
question = _parse_message(messages[-1], self._is_gemini)
chat_history = _parse_chat_history(messages[:-1], self._is_gemini)
merged_messages = (
merge_neighboring_same_role_messages(messages)
if self._is_gemini
else messages
)
question = _parse_message(merged_messages[-1], self._is_gemini)
chat_history = _parse_chat_history(merged_messages[:-1], self._is_gemini)
chat_params = {**chat_history}

kwargs = kwargs if kwargs else {}
Expand Down Expand Up @@ -209,8 +217,13 @@ def complete(
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
question = _parse_message(messages[-1], self._is_gemini)
chat_history = _parse_chat_history(messages[:-1], self._is_gemini)
merged_messages = (
merge_neighboring_same_role_messages(messages)
if self._is_gemini
else messages
)
question = _parse_message(merged_messages[-1], self._is_gemini)
chat_history = _parse_chat_history(merged_messages[:-1], self._is_gemini)
chat_params = {**chat_history}
kwargs = kwargs if kwargs else {}
params = {**self._model_kwargs, **kwargs}
Expand Down

0 comments on commit 563e4a6

Please sign in to comment.