Skip to content

Commit

Permalink
Set MessageRole appropriately for Vertex
Browse files Browse the repository at this point in the history
  • Loading branch information
tbeamish-benchsci committed Feb 29, 2024
1 parent 563e4a6 commit e425f90
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
1 change: 1 addition & 0 deletions llama-index-core/llama_index/core/base/llms/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class MessageRole(str, Enum):
FUNCTION = "function"
TOOL = "tool"
CHATBOT = "chatbot"
MODEL = "model"


# ===== Generic Model Input - Chat =====
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import google.ai.generativelanguage as glm
import google.generativeai as genai
import PIL
from llama_index.core.base.llms.types import MessageRole

from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
CompletionResponse,
MessageRole,
)

ROLES_TO_GEMINI = {
Expand Down Expand Up @@ -105,7 +106,7 @@ def merge_neighboring_same_role_messages(

# Create a new ChatMessage or similar object with merged content
merged_message = ChatMessage(
role=current_message.role,
role=ROLES_TO_GEMINI[current_message.role],
content="\n".join([str(msg_content) for msg_content in merged_content]),
additional_kwargs=current_message.additional_kwargs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,17 @@

import google.api_core
import vertexai
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from vertexai.language_models import (
ChatMessage as VertexChatMessage,
)
from vertexai.language_models import (
InputOutputTextPair,
)
from vertexai.language_models import ChatMessage as VertexChatMessage
from vertexai.language_models import InputOutputTextPair

from llama_index.core.base.llms.types import ChatMessage, MessageRole

CHAT_MODELS = ["chat-bison", "chat-bison-32k", "chat-bison@001"]
TEXT_MODELS = ["text-bison", "text-bison-32k", "text-bison@001"]
Expand Down Expand Up @@ -171,7 +168,7 @@ def _parse_chat_history(history: Any, is_gemini: bool) -> Any:
if is_gemini:
raise ValueError("Gemini model don't support system messages")
context = message.content
elif message.role == MessageRole.ASSISTANT or message.role == MessageRole.USER:
elif message.role == MessageRole.MODEL or message.role == MessageRole.USER:
if is_gemini:
from llama_index.llms.vertex.gemini_utils import (
convert_chat_message_to_gemini_content,
Expand All @@ -185,12 +182,12 @@ def _parse_chat_history(history: Any, is_gemini: bool) -> Any:
else:
vertex_message = VertexChatMessage(
content=message.content,
author="bot" if message.role == MessageRole.ASSISTANT else "user",
author=("bot" if message.role == MessageRole.MODEL else "user"),
)
vertex_messages.append(vertex_message)
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
f"Unexpected message with role {message.role} at the position {i}."
)
if len(vertex_messages) % 2 != 0:
raise ValueError("total no of messages should be even")
Expand Down

0 comments on commit e425f90

Please sign in to comment.