diff --git a/aisuite/utils/tool_manager.py b/aisuite/utils/tool_manager.py index 235f67f..3da948e 100644 --- a/aisuite/utils/tool_manager.py +++ b/aisuite/utils/tool_manager.py @@ -5,7 +5,7 @@ from docstring_parser import parse -class ToolManager: +class Tools: def __init__(self, tools: list[Callable] = None): self._tools = {} if tools: @@ -18,7 +18,7 @@ def _add_tool(self, func: Callable, param_model: Optional[Type[BaseModel]] = Non if param_model: tool_spec = self._convert_to_tool_spec(func, param_model) else: - tool_spec, param_model = self._infer_from_signature(func) + tool_spec, param_model = self.__infer_from_signature(func) self._tools[func.__name__] = { "function": func, @@ -26,12 +26,14 @@ def _add_tool(self, func: Callable, param_model: Optional[Type[BaseModel]] = Non "spec": tool_spec, } + # Return tools in the specified format (default OpenAI). def tools(self, format="openai") -> list: """Return tools in the specified format (default OpenAI).""" if format == "openai": - return self._convert_to_openai_format() + return self.__convert_to_openai_format() return [tool["spec"] for tool in self._tools.values()] + # Convert the function and its Pydantic model to a unified tool specification. def _convert_to_tool_spec( self, func: Callable, param_model: Type[BaseModel] ) -> Dict[str, Any]: @@ -83,7 +85,7 @@ def _convert_to_tool_spec( }, } - def _extract_param_descriptions(self, func: Callable) -> dict[str, str]: + def __extract_param_descriptions(self, func: Callable) -> dict[str, str]: """Extract parameter descriptions from function docstring. Args: @@ -101,7 +103,7 @@ def _extract_param_descriptions(self, func: Callable) -> dict[str, str]: return param_descriptions - def _infer_from_signature( + def __infer_from_signature( self, func: Callable ) -> tuple[Dict[str, Any], Type[BaseModel]]: """Infer parameters(required and optional) and requirements directly from the function signature.""" @@ -110,7 +112,7 @@ def _infer_from_signature( required_fields = [] # Get function's docstring and parse parameter descriptions - param_descriptions = self._extract_param_descriptions(func) + param_descriptions = self.__extract_param_descriptions(func) docstring = inspect.getdoc(func) or "" for param_name, param in signature.parameters.items(): @@ -144,13 +146,82 @@ def _infer_from_signature( return tool_spec, param_model - def _convert_to_openai_format(self) -> list: + def __convert_to_openai_format(self) -> list: """Convert tools to OpenAI's format.""" return [ {"type": "function", "function": tool["spec"]} for tool in self._tools.values() ] + def results_to_messages(self, results: list, message: any) -> list: + """Converts results to messages.""" + # if message is empty return empty list + if not message or len(results) == 0: + return [] + + messages = [] + # Iterate over results and match with tool calls from the message + for result in results: + # Find matching tool call from message.tool_calls + for tool_call in message.tool_calls: + if tool_call.id == result["tool_call_id"]: + messages.append( + { + "role": "tool", + "name": result["name"], + "content": json.dumps(result["content"]), + "tool_call_id": tool_call.id, + } + ) + break + + return messages + + def execute(self, tool_calls) -> list: + """Executes registered tools based on the tool calls from the model. + + Args: + tool_calls: List of tool calls from the model + + Returns: + List of results from executing each tool call + """ + results = [] + + # Handle single tool call or list of tool calls + if not isinstance(tool_calls, list): + tool_calls = [tool_calls] + + for tool_call in tool_calls: + # Handle both dictionary and object-style tool calls + if isinstance(tool_call, dict): + tool_name = tool_call["function"]["name"] + arguments = tool_call["function"]["arguments"] + else: + tool_name = tool_call.function.name + arguments = tool_call.function.arguments + + # Ensure arguments is a dict + if isinstance(arguments, str): + arguments = json.loads(arguments) + + if tool_name not in self._tools: + raise ValueError(f"Tool '{tool_name}' not registered.") + + tool = self._tools[tool_name] + tool_func = tool["function"] + param_model = tool["param_model"] + + # Validate and parse the arguments with Pydantic if a model exists + try: + validated_args = param_model(**arguments) + result = tool_func(**validated_args.model_dump()) + results.append(result) + except ValidationError as e: + raise ValueError(f"Error in tool '{tool_name}' parameters: {e}") + + return results + def execute_tool(self, tool_calls) -> tuple[list, list]: """Executes registered tools based on the tool calls from the model. diff --git a/tests/providers/test_google_provider.py b/tests/providers/test_google_provider.py index a589b44..f34c5f0 100644 --- a/tests/providers/test_google_provider.py +++ b/tests/providers/test_google_provider.py @@ -2,6 +2,7 @@ from unittest.mock import patch, MagicMock from aisuite.providers.google_provider import GoogleProvider from vertexai.generative_models import Content, Part +import json @pytest.fixture(autouse=True) @@ -25,54 +26,99 @@ def test_missing_env_vars(): def test_vertex_interface(): """High-level test that the interface is initialized and chat completions are requested successfully.""" - user_greeting = "Hello!" - message_history = [{"role": "user", "content": user_greeting}] - selected_model = "our-favorite-model" - response_text_content = "mocked-text-response-from-model" - - interface = GoogleProvider() - mock_response = MagicMock() - mock_response.candidates = [MagicMock()] - mock_response.candidates[0].content.parts[0].text = response_text_content - - with patch( - "aisuite.providers.google_provider.GenerativeModel" - ) as mock_generative_model: - mock_model = MagicMock() - mock_generative_model.return_value = mock_model - mock_chat = MagicMock() - mock_model.start_chat.return_value = mock_chat - mock_chat.send_message.return_value = mock_response - - response = interface.chat_completions_create( - messages=message_history, - model=selected_model, - temperature=0.7, - ) - - # Assert that GenerativeModel was called with correct arguments. - mock_generative_model.assert_called_once() - args, kwargs = mock_generative_model.call_args - assert args[0] == selected_model - assert "generation_config" in kwargs - - # Assert that start_chat was called with correct history. - mock_model.start_chat.assert_called_once() - _chat_args, chat_kwargs = mock_model.start_chat.call_args - assert "history" in chat_kwargs - assert isinstance(chat_kwargs["history"], list) - - # Assert that send_message was called with the last message. - mock_chat.send_message.assert_called_once_with(user_greeting) - - # Assert that the response is in the correct format. - assert response.choices[0].message.content == response_text_content + # Test case 1: Regular text response + def test_text_response(): + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + response_text_content = "mocked-text-response-from-model" + + interface = GoogleProvider() + mock_response = MagicMock() + mock_response.candidates = [MagicMock()] + mock_response.candidates[0].content.parts = [MagicMock()] + mock_response.candidates[0].content.parts[0].text = response_text_content + # Ensure function_call attribute doesn't exist + del mock_response.candidates[0].content.parts[0].function_call + + with patch( + "aisuite.providers.google_provider.GenerativeModel" + ) as mock_generative_model: + mock_model = MagicMock() + mock_generative_model.return_value = mock_model + mock_chat = MagicMock() + mock_model.start_chat.return_value = mock_chat + mock_chat.send_message.return_value = mock_response + + response = interface.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=0.7, + ) + + # Assert the response is in the correct format + assert response.choices[0].message.content == response_text_content + assert response.choices[0].finish_reason == "stop" + + # Test case 2: Function call response + def test_function_call(): + user_greeting = "What's the weather?" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + + interface = GoogleProvider() + mock_response = MagicMock() + mock_response.candidates = [MagicMock()] + mock_response.candidates[0].content.parts = [MagicMock()] + + # Mock the function call response + function_call_mock = MagicMock() + function_call_mock.name = "get_weather" + function_call_mock.args = {"location": "San Francisco"} + mock_response.candidates[0].content.parts[0].function_call = function_call_mock + mock_response.candidates[0].content.parts[0].text = None + + with patch( + "aisuite.providers.google_provider.GenerativeModel" + ) as mock_generative_model: + mock_model = MagicMock() + mock_generative_model.return_value = mock_model + mock_chat = MagicMock() + mock_model.start_chat.return_value = mock_chat + mock_chat.send_message.return_value = mock_response + + response = interface.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=0.7, + ) + + # Assert the response contains the function call + assert response.choices[0].message.content is None + assert response.choices[0].message.tool_calls[0].type == "function" + assert ( + response.choices[0].message.tool_calls[0].function.name == "get_weather" + ) + assert json.loads( + response.choices[0].message.tool_calls[0].function.arguments + ) == {"location": "San Francisco"} + assert response.choices[0].finish_reason == "tool_calls" + + # Run both test cases + test_text_response() + test_function_call() def test_convert_openai_to_vertex_ai(): + """Test the message conversion from OpenAI format to Vertex AI format.""" interface = GoogleProvider() - message_history = [{"role": "user", "content": "Hello!"}] - result = interface.convert_openai_to_vertex_ai(message_history) + message = {"role": "user", "content": "Hello!"} + + # Use the transformer to convert the message + result = interface.transformer.convert_request([message]) + + # Verify the conversion result + assert len(result) == 1 assert isinstance(result[0], Content) assert result[0].role == "user" assert len(result[0].parts) == 1 @@ -80,21 +126,25 @@ def test_convert_openai_to_vertex_ai(): assert result[0].parts[0].text == "Hello!" -def test_transform_roles(): +def test_role_conversions(): + """Test that different message roles are converted correctly.""" interface = GoogleProvider() messages = [ - {"role": "system", "content": "Google: system message = 1st user message."}, - {"role": "user", "content": "User message 1."}, - {"role": "assistant", "content": "Assistant message 1."}, + {"role": "system", "content": "System message"}, + {"role": "user", "content": "User message"}, + {"role": "assistant", "content": "Assistant message"}, ] - expected_output = [ - {"role": "user", "content": "Google: system message = 1st user message."}, - {"role": "user", "content": "User message 1."}, - {"role": "model", "content": "Assistant message 1."}, - ] + result = interface.transformer.convert_request(messages) + + # System and user messages should both be converted to "user" role in Vertex AI + assert len(result) == 3 + assert result[0].role == "user" # system converted to user + assert result[0].parts[0].text == "System message" - result = interface.transform_roles(messages) + assert result[1].role == "user" + assert result[1].parts[0].text == "User message" - assert result == expected_output + assert result[2].role == "model" # assistant converted to model + assert result[2].parts[0].text == "Assistant message" diff --git a/tests/utils/test_tool_manager.py b/tests/utils/test_tool_manager.py new file mode 100644 index 0000000..629b17b --- /dev/null +++ b/tests/utils/test_tool_manager.py @@ -0,0 +1,200 @@ +import unittest +from pydantic import BaseModel +from typing import Dict +from aisuite.utils.tool_manager import Tools # Import your ToolManager class +from enum import Enum + + +# Define a sample tool function and Pydantic model for testing +class TemperatureUnit(str, Enum): + CELSIUS = "Celsius" + FAHRENHEIT = "Fahrenheit" + + +class TemperatureParamsV2(BaseModel): + location: str + unit: TemperatureUnit = TemperatureUnit.CELSIUS + + +class TemperatureParams(BaseModel): + location: str + unit: str = "Celsius" + + +def get_current_temperature(location: str, unit: str = "Celsius") -> Dict[str, str]: + """Gets the current temperature for a specific location and unit.""" + return {"location": location, "unit": unit, "temperature": "72"} + + +def missing_annotation_tool(location, unit="Celsius"): + """Tool function without type annotations.""" + return {"location": location, "unit": unit, "temperature": "72"} + + +def get_current_temperature_v2( + location: str, unit: TemperatureUnit = TemperatureUnit.CELSIUS +) -> Dict[str, str]: + """Gets the current temperature for a specific location and unit (with enum support).""" + return {"location": location, "unit": unit, "temperature": "72"} + + +class TestToolManager(unittest.TestCase): + def setUp(self): + self.tool_manager = Tools() + + def test_add_tool_with_pydantic_model(self): + """Test adding a tool with an explicit Pydantic model.""" + self.tool_manager._add_tool(get_current_temperature, TemperatureParams) + + expected_tool_spec = [ + { + "type": "function", + "function": { + "name": "get_current_temperature", + "description": "Gets the current temperature for a specific location and unit.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "", + }, + "unit": { + "type": "string", + "description": "", + "default": "Celsius", + }, + }, + "required": ["location"], + }, + }, + } + ] + + tools = self.tool_manager.tools() + self.assertIn( + "get_current_temperature", [tool["function"]["name"] for tool in tools] + ) + assert ( + tools == expected_tool_spec + ), f"Expected {expected_tool_spec}, but got {tools}" + + def test_add_tool_with_signature_inference(self): + """Test adding a tool and inferring parameters from the function signature.""" + self.tool_manager._add_tool(get_current_temperature) + # Expected output from tool_manager.tools() when called with OpenAI format + expected_tool_spec = [ + { + "type": "function", + "function": { + "name": "get_current_temperature", + "description": "Gets the current temperature for a specific location and unit.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "", # No description provided in function signature + }, + "unit": { + "type": "string", + "description": "", + "default": "Celsius", + }, + }, + "required": ["location"], + }, + }, + } + ] + tools = self.tool_manager.tools() + print(tools) + self.assertIn( + "get_current_temperature", [tool["function"]["name"] for tool in tools] + ) + assert ( + tools == expected_tool_spec + ), f"Expected {expected_tool_spec}, but got {tools}" + + def test_add_tool_missing_annotation_raises_exception(self): + """Test that adding a tool with missing type annotations raises a TypeError.""" + with self.assertRaises(TypeError): + self.tool_manager._add_tool(missing_annotation_tool) + + def test_execute_tool_valid_parameters(self): + """Test executing a registered tool with valid parameters.""" + self.tool_manager._add_tool(get_current_temperature, TemperatureParams) + tool_call = { + "id": "call_1", + "function": { + "name": "get_current_temperature", + "arguments": {"location": "San Francisco", "unit": "Celsius"}, + }, + } + result, result_message = self.tool_manager.execute_tool(tool_call) + + # Assuming result is returned as a list with a single dictionary + result_dict = result[0] if isinstance(result, list) else result + + # Check that the result matches expected output + self.assertEqual(result_dict["location"], "San Francisco") + self.assertEqual(result_dict["unit"], "Celsius") + self.assertEqual(result_dict["temperature"], "72") + + def test_execute_tool_invalid_parameters(self): + """Test that executing a tool with invalid parameters raises a ValueError.""" + self.tool_manager._add_tool(get_current_temperature, TemperatureParams) + tool_call = { + "id": "call_1", + "function": { + "name": "get_current_temperature", + "arguments": {"location": 123}, # Invalid type for location + }, + } + + with self.assertRaises(ValueError) as context: + self.tool_manager.execute_tool(tool_call) + + # Verify the error message contains information about the validation error + self.assertIn( + "Error in tool 'get_current_temperature' parameters", str(context.exception) + ) + + def test_add_tool_with_enum(self): + """Test adding a tool with an enum parameter.""" + self.tool_manager._add_tool(get_current_temperature_v2, TemperatureParamsV2) + + expected_tool_spec = [ + { + "type": "function", + "function": { + "name": "get_current_temperature_v2", + "description": "Gets the current temperature for a specific location and unit (with enum support).", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "", + }, + "unit": { + "type": "string", + "enum": ["Celsius", "Fahrenheit"], + "description": "", + "default": "Celsius", + }, + }, + "required": ["location"], + }, + }, + } + ] + + tools = self.tool_manager.tools() + assert ( + tools == expected_tool_spec + ), f"Expected {expected_tool_spec}, but got {tools}" + + +if __name__ == "__main__": + unittest.main()