Skip to content

Commit

Permalink
Refactor GoogleProvider to use generate_content method and improve me…
Browse files Browse the repository at this point in the history
…ssage handling in chat completions
  • Loading branch information
vargacypher committed Dec 17, 2024
1 parent 9da9ba5 commit da0a998
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 33 deletions.
64 changes: 47 additions & 17 deletions aisuite/providers/google_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ChatCompletionResponse,
ChatCompletionMessageToolCall,
Function,
Message,
)
from typing import Any

Expand Down Expand Up @@ -68,12 +69,7 @@ def chat_completions_create(self, model, messages, **kwargs):
transformed_messages = self.transform_roles(messages)

# Convert the messages to the format expected Google
final_message_history = self.convert_openai_to_vertex_ai(
transformed_messages[:-1]
)

# Get the last message from the transformed messages
last_message = transformed_messages[-1]["content"]
final_message_history = self.convert_openai_to_vertex_ai(transformed_messages)

model_kwargs = {
"model_name": model,
Expand All @@ -86,8 +82,7 @@ def chat_completions_create(self, model, messages, **kwargs):
model = GenerativeModel(**model_kwargs)

# Start a chat with the GenerativeModel and send the last message
chat = model.start_chat(history=final_message_history)
response = chat.send_message(last_message)
response = model.generate_content(final_message_history)

# Convert the response to the format expected by the OpenAI API
return self.normalize_response(response)
Expand All @@ -96,25 +91,60 @@ def convert_openai_to_vertex_ai(self, messages):
"""Convert OpenAI messages to Google AI messages."""
from vertexai.generative_models import Content, Part

function_calls = {}
history = []

for message in messages:
tool_calls = message.get("tool_calls")
if tool_calls:
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])
function_calls[function_name] = Part.from_dict(
{
"function_call": {
"name": function_name,
"args": function_args,
}
}
)
continue

if message["role"] == "tool":
history.append(
Content(role="model", parts=[function_calls.get(message["name"])])
)
parts = [
Part.from_function_response(
name=message["name"],
response=json.loads(message["content"]),
)
]
else:
parts = [Part.from_text(message["content"])]

role = message["role"]
content = message["content"]
parts = [Part.from_text(content)]
history.append(Content(role=role, parts=parts))

return history

def transform_roles(self, messages):
"""Transform the roles in the messages based on the provided transformations."""
openai_roles_to_google_roles = {
"system": "user",
"user": "user",
"assistant": "model",
"tool": "tool",
}
transformed_messages = []

for message in messages:
if isinstance(message, Message):
message = message.model_dump()
if role := openai_roles_to_google_roles.get(message["role"], None):
message["role"] = role
return messages
transformed_messages.append(message)
return transformed_messages

def normalize_response(self, response: GenerationResponse):
"""Normalize the response from Google AI to match OpenAI's response format."""
Expand All @@ -138,14 +168,14 @@ def normalize_response(self, response: GenerationResponse):
for fc in function_calls
]
openai_response.choices[0].finish_reason = "tool_calls"
openai_response.choices[0].message.content = ""
openai_response.choices[0].message.role = "assistant"
openai_response.choices[0].message.tool_calls = tool_calls
else:
# Handle regular text response
openai_response.choices[0].message.content = (
candidate.content.parts[0].text if candidate.content.parts else ""
)

try:
openai_response.choices[0].message.content = candidate.content.parts[0].text
except AttributeError:
openai_response.choices[0].message.content = ""

return openai_response

def _extract_tool_calls(self, response: GenerationResponse) -> list[dict]:
Expand Down
2 changes: 1 addition & 1 deletion examples/simple_tool_calling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"if response.choices[0].message.tool_calls:\n",
" tool_results, result_as_message = tool_manager.execute_tool(response.choices[0].message.tool_calls)\n",
" messages.append(response.choices[0].message) # Model's function call message\n",
" messages.append(result_as_message[0])\n",
" messages.extend(result_as_message)\n",
"\n",
" final_response = client.chat.completions.create(\n",
" model=model, messages=messages, tools=tool_manager.tools())\n",
Expand Down
24 changes: 9 additions & 15 deletions tests/providers/test_google_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@ def test_vertex_interface():
interface = GoogleProvider()
mock_response = MagicMock()
mock_response.candidates = [MagicMock()]
mock_response.candidates[0].content.parts = [MagicMock()]
mock_response.candidates[0].content.parts[0].text = response_text_content

with patch(
"aisuite.providers.google_provider.GenerativeModel"
) as mock_generative_model:
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_chat = MagicMock()
mock_model.start_chat.return_value = mock_chat
mock_chat.send_message.return_value = mock_response
mock_model.generate_content.return_value = mock_response

response = interface.chat_completions_create(
messages=message_history,
Expand All @@ -56,14 +55,10 @@ def test_vertex_interface():
assert kwargs["model_name"] == selected_model
assert "generation_config" in kwargs

# Assert that start_chat was called with correct history.
mock_model.start_chat.assert_called_once()
_chat_args, chat_kwargs = mock_model.start_chat.call_args
assert "history" in chat_kwargs
assert isinstance(chat_kwargs["history"], list)

# Assert that send_message was called with the last message.
mock_chat.send_message.assert_called_once_with(user_greeting)
# Assert that generate_content was called with correct history.
mock_model.generate_content.assert_called_once()
_chat_args, chat_kwargs = mock_model.generate_content.call_args
assert isinstance(_chat_args[0], list)

# Assert that the response is in the correct format.
assert response.choices[0].message.content == response_text_content
Expand Down Expand Up @@ -128,6 +123,7 @@ def test_tool_calls():
interface = GoogleProvider()
mock_response = MagicMock()
mock_response.candidates = [MagicMock()]
mock_response.candidates[0].content.parts = [MagicMock()]
mock_response.candidates[0].content.parts[0].text = response_text_content

# Create a mock function call with the necessary attributes
Expand All @@ -142,9 +138,7 @@ def test_tool_calls():
) as mock_generative_model:
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_chat = MagicMock()
mock_model.start_chat.return_value = mock_chat
mock_chat.send_message.return_value = mock_response
mock_model.generate_content.return_value = mock_response

response = interface.chat_completions_create(
messages=message_history,
Expand All @@ -154,7 +148,7 @@ def test_tool_calls():
)

# Assert that the response is in the correct format.
assert response.choices[0].message.content == ""
assert response.choices[0].message.content == response_text_content
assert response.choices[0].finish_reason == "tool_calls"
assert len(response.choices[0].message.tool_calls) == 1
assert response.choices[0].message.tool_calls[0].function.name == "example_tool"
Expand Down

0 comments on commit da0a998

Please sign in to comment.