From da0a998684e43679bd706f3e8e378d8e7abe0102 Mon Sep 17 00:00:00 2001 From: Guilherme Cardoso de Vargas <77084039+vargacypher@users.noreply.github.com> Date: Tue, 17 Dec 2024 11:50:50 +0000 Subject: [PATCH] Refactor GoogleProvider to use generate_content method and improve message handling in chat completions --- aisuite/providers/google_provider.py | 64 ++++++++++++++++++------- examples/simple_tool_calling.ipynb | 2 +- tests/providers/test_google_provider.py | 24 ++++------ 3 files changed, 57 insertions(+), 33 deletions(-) diff --git a/aisuite/providers/google_provider.py b/aisuite/providers/google_provider.py index 2ca420f5..9070cc6d 100644 --- a/aisuite/providers/google_provider.py +++ b/aisuite/providers/google_provider.py @@ -18,6 +18,7 @@ ChatCompletionResponse, ChatCompletionMessageToolCall, Function, + Message, ) from typing import Any @@ -68,12 +69,7 @@ def chat_completions_create(self, model, messages, **kwargs): transformed_messages = self.transform_roles(messages) # Convert the messages to the format expected Google - final_message_history = self.convert_openai_to_vertex_ai( - transformed_messages[:-1] - ) - - # Get the last message from the transformed messages - last_message = transformed_messages[-1]["content"] + final_message_history = self.convert_openai_to_vertex_ai(transformed_messages) model_kwargs = { "model_name": model, @@ -86,8 +82,7 @@ def chat_completions_create(self, model, messages, **kwargs): model = GenerativeModel(**model_kwargs) # Start a chat with the GenerativeModel and send the last message - chat = model.start_chat(history=final_message_history) - response = chat.send_message(last_message) + response = model.generate_content(final_message_history) # Convert the response to the format expected by the OpenAI API return self.normalize_response(response) @@ -96,25 +91,60 @@ def convert_openai_to_vertex_ai(self, messages): """Convert OpenAI messages to Google AI messages.""" from vertexai.generative_models import Content, Part + function_calls = {} history = [] + for message in messages: + tool_calls = message.get("tool_calls") + if tool_calls: + for tool_call in tool_calls: + function_name = tool_call["function"]["name"] + function_args = json.loads(tool_call["function"]["arguments"]) + function_calls[function_name] = Part.from_dict( + { + "function_call": { + "name": function_name, + "args": function_args, + } + } + ) + continue + + if message["role"] == "tool": + history.append( + Content(role="model", parts=[function_calls.get(message["name"])]) + ) + parts = [ + Part.from_function_response( + name=message["name"], + response=json.loads(message["content"]), + ) + ] + else: + parts = [Part.from_text(message["content"])] + role = message["role"] - content = message["content"] - parts = [Part.from_text(content)] history.append(Content(role=role, parts=parts)) + return history def transform_roles(self, messages): """Transform the roles in the messages based on the provided transformations.""" openai_roles_to_google_roles = { "system": "user", + "user": "user", "assistant": "model", + "tool": "tool", } + transformed_messages = [] for message in messages: + if isinstance(message, Message): + message = message.model_dump() if role := openai_roles_to_google_roles.get(message["role"], None): message["role"] = role - return messages + transformed_messages.append(message) + return transformed_messages def normalize_response(self, response: GenerationResponse): """Normalize the response from Google AI to match OpenAI's response format.""" @@ -138,14 +168,14 @@ def normalize_response(self, response: GenerationResponse): for fc in function_calls ] openai_response.choices[0].finish_reason = "tool_calls" - openai_response.choices[0].message.content = "" openai_response.choices[0].message.role = "assistant" openai_response.choices[0].message.tool_calls = tool_calls - else: - # Handle regular text response - openai_response.choices[0].message.content = ( - candidate.content.parts[0].text if candidate.content.parts else "" - ) + + try: + openai_response.choices[0].message.content = candidate.content.parts[0].text + except AttributeError: + openai_response.choices[0].message.content = "" + return openai_response def _extract_tool_calls(self, response: GenerationResponse) -> list[dict]: diff --git a/examples/simple_tool_calling.ipynb b/examples/simple_tool_calling.ipynb index 579bab66..f9bbc162 100644 --- a/examples/simple_tool_calling.ipynb +++ b/examples/simple_tool_calling.ipynb @@ -78,7 +78,7 @@ "if response.choices[0].message.tool_calls:\n", " tool_results, result_as_message = tool_manager.execute_tool(response.choices[0].message.tool_calls)\n", " messages.append(response.choices[0].message) # Model's function call message\n", - " messages.append(result_as_message[0])\n", + " messages.extend(result_as_message)\n", "\n", " final_response = client.chat.completions.create(\n", " model=model, messages=messages, tools=tool_manager.tools())\n", diff --git a/tests/providers/test_google_provider.py b/tests/providers/test_google_provider.py index 3d4c5f35..b12b3a61 100644 --- a/tests/providers/test_google_provider.py +++ b/tests/providers/test_google_provider.py @@ -33,6 +33,7 @@ def test_vertex_interface(): interface = GoogleProvider() mock_response = MagicMock() mock_response.candidates = [MagicMock()] + mock_response.candidates[0].content.parts = [MagicMock()] mock_response.candidates[0].content.parts[0].text = response_text_content with patch( @@ -40,9 +41,7 @@ def test_vertex_interface(): ) as mock_generative_model: mock_model = MagicMock() mock_generative_model.return_value = mock_model - mock_chat = MagicMock() - mock_model.start_chat.return_value = mock_chat - mock_chat.send_message.return_value = mock_response + mock_model.generate_content.return_value = mock_response response = interface.chat_completions_create( messages=message_history, @@ -56,14 +55,10 @@ def test_vertex_interface(): assert kwargs["model_name"] == selected_model assert "generation_config" in kwargs - # Assert that start_chat was called with correct history. - mock_model.start_chat.assert_called_once() - _chat_args, chat_kwargs = mock_model.start_chat.call_args - assert "history" in chat_kwargs - assert isinstance(chat_kwargs["history"], list) - - # Assert that send_message was called with the last message. - mock_chat.send_message.assert_called_once_with(user_greeting) + # Assert that generate_content was called with correct history. + mock_model.generate_content.assert_called_once() + _chat_args, chat_kwargs = mock_model.generate_content.call_args + assert isinstance(_chat_args[0], list) # Assert that the response is in the correct format. assert response.choices[0].message.content == response_text_content @@ -128,6 +123,7 @@ def test_tool_calls(): interface = GoogleProvider() mock_response = MagicMock() mock_response.candidates = [MagicMock()] + mock_response.candidates[0].content.parts = [MagicMock()] mock_response.candidates[0].content.parts[0].text = response_text_content # Create a mock function call with the necessary attributes @@ -142,9 +138,7 @@ def test_tool_calls(): ) as mock_generative_model: mock_model = MagicMock() mock_generative_model.return_value = mock_model - mock_chat = MagicMock() - mock_model.start_chat.return_value = mock_chat - mock_chat.send_message.return_value = mock_response + mock_model.generate_content.return_value = mock_response response = interface.chat_completions_create( messages=message_history, @@ -154,7 +148,7 @@ def test_tool_calls(): ) # Assert that the response is in the correct format. - assert response.choices[0].message.content == "" + assert response.choices[0].message.content == response_text_content assert response.choices[0].finish_reason == "tool_calls" assert len(response.choices[0].message.tool_calls) == 1 assert response.choices[0].message.tool_calls[0].function.name == "example_tool"