Skip to content

Commit

Permalink
Fix tests broken due to pydantic checks.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitprasad15 committed Jan 22, 2025
1 parent d9b7485 commit 5b547b8
Show file tree
Hide file tree
Showing 3 changed files with 383 additions and 62 deletions.
85 changes: 78 additions & 7 deletions aisuite/utils/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from docstring_parser import parse


class ToolManager:
class Tools:
def __init__(self, tools: list[Callable] = None):
self._tools = {}
if tools:
Expand All @@ -18,20 +18,22 @@ def _add_tool(self, func: Callable, param_model: Optional[Type[BaseModel]] = Non
if param_model:
tool_spec = self._convert_to_tool_spec(func, param_model)
else:
tool_spec, param_model = self._infer_from_signature(func)
tool_spec, param_model = self.__infer_from_signature(func)

self._tools[func.__name__] = {
"function": func,
"param_model": param_model,
"spec": tool_spec,
}

# Return tools in the specified format (default OpenAI).
def tools(self, format="openai") -> list:
"""Return tools in the specified format (default OpenAI)."""
if format == "openai":
return self._convert_to_openai_format()
return self.__convert_to_openai_format()
return [tool["spec"] for tool in self._tools.values()]

# Convert the function and its Pydantic model to a unified tool specification.
def _convert_to_tool_spec(
self, func: Callable, param_model: Type[BaseModel]
) -> Dict[str, Any]:
Expand Down Expand Up @@ -83,7 +85,7 @@ def _convert_to_tool_spec(
},
}

def _extract_param_descriptions(self, func: Callable) -> dict[str, str]:
def __extract_param_descriptions(self, func: Callable) -> dict[str, str]:
"""Extract parameter descriptions from function docstring.
Args:
Expand All @@ -101,7 +103,7 @@ def _extract_param_descriptions(self, func: Callable) -> dict[str, str]:

return param_descriptions

def _infer_from_signature(
def __infer_from_signature(
self, func: Callable
) -> tuple[Dict[str, Any], Type[BaseModel]]:
"""Infer parameters(required and optional) and requirements directly from the function signature."""
Expand All @@ -110,7 +112,7 @@ def _infer_from_signature(
required_fields = []

# Get function's docstring and parse parameter descriptions
param_descriptions = self._extract_param_descriptions(func)
param_descriptions = self.__extract_param_descriptions(func)
docstring = inspect.getdoc(func) or ""

for param_name, param in signature.parameters.items():
Expand Down Expand Up @@ -144,13 +146,82 @@ def _infer_from_signature(

return tool_spec, param_model

def _convert_to_openai_format(self) -> list:
def __convert_to_openai_format(self) -> list:
"""Convert tools to OpenAI's format."""
return [
{"type": "function", "function": tool["spec"]}
for tool in self._tools.values()
]

def results_to_messages(self, results: list, message: any) -> list:
"""Converts results to messages."""
# if message is empty return empty list
if not message or len(results) == 0:
return []

messages = []
# Iterate over results and match with tool calls from the message
for result in results:
# Find matching tool call from message.tool_calls
for tool_call in message.tool_calls:
if tool_call.id == result["tool_call_id"]:
messages.append(
{
"role": "tool",
"name": result["name"],
"content": json.dumps(result["content"]),
"tool_call_id": tool_call.id,
}
)
break

return messages

def execute(self, tool_calls) -> list:
"""Executes registered tools based on the tool calls from the model.
Args:
tool_calls: List of tool calls from the model
Returns:
List of results from executing each tool call
"""
results = []

# Handle single tool call or list of tool calls
if not isinstance(tool_calls, list):
tool_calls = [tool_calls]

for tool_call in tool_calls:
# Handle both dictionary and object-style tool calls
if isinstance(tool_call, dict):
tool_name = tool_call["function"]["name"]
arguments = tool_call["function"]["arguments"]
else:
tool_name = tool_call.function.name
arguments = tool_call.function.arguments

# Ensure arguments is a dict
if isinstance(arguments, str):
arguments = json.loads(arguments)

if tool_name not in self._tools:
raise ValueError(f"Tool '{tool_name}' not registered.")

tool = self._tools[tool_name]
tool_func = tool["function"]
param_model = tool["param_model"]

# Validate and parse the arguments with Pydantic if a model exists
try:
validated_args = param_model(**arguments)
result = tool_func(**validated_args.model_dump())
results.append(result)
except ValidationError as e:
raise ValueError(f"Error in tool '{tool_name}' parameters: {e}")

return results

def execute_tool(self, tool_calls) -> tuple[list, list]:
"""Executes registered tools based on the tool calls from the model.
Expand Down
160 changes: 105 additions & 55 deletions tests/providers/test_google_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import patch, MagicMock
from aisuite.providers.google_provider import GoogleProvider
from vertexai.generative_models import Content, Part
import json


@pytest.fixture(autouse=True)
Expand All @@ -25,76 +26,125 @@ def test_missing_env_vars():
def test_vertex_interface():
"""High-level test that the interface is initialized and chat completions are requested successfully."""

user_greeting = "Hello!"
message_history = [{"role": "user", "content": user_greeting}]
selected_model = "our-favorite-model"
response_text_content = "mocked-text-response-from-model"

interface = GoogleProvider()
mock_response = MagicMock()
mock_response.candidates = [MagicMock()]
mock_response.candidates[0].content.parts[0].text = response_text_content

with patch(
"aisuite.providers.google_provider.GenerativeModel"
) as mock_generative_model:
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_chat = MagicMock()
mock_model.start_chat.return_value = mock_chat
mock_chat.send_message.return_value = mock_response

response = interface.chat_completions_create(
messages=message_history,
model=selected_model,
temperature=0.7,
)

# Assert that GenerativeModel was called with correct arguments.
mock_generative_model.assert_called_once()
args, kwargs = mock_generative_model.call_args
assert args[0] == selected_model
assert "generation_config" in kwargs

# Assert that start_chat was called with correct history.
mock_model.start_chat.assert_called_once()
_chat_args, chat_kwargs = mock_model.start_chat.call_args
assert "history" in chat_kwargs
assert isinstance(chat_kwargs["history"], list)

# Assert that send_message was called with the last message.
mock_chat.send_message.assert_called_once_with(user_greeting)

# Assert that the response is in the correct format.
assert response.choices[0].message.content == response_text_content
# Test case 1: Regular text response
def test_text_response():
user_greeting = "Hello!"
message_history = [{"role": "user", "content": user_greeting}]
selected_model = "our-favorite-model"
response_text_content = "mocked-text-response-from-model"

interface = GoogleProvider()
mock_response = MagicMock()
mock_response.candidates = [MagicMock()]
mock_response.candidates[0].content.parts = [MagicMock()]
mock_response.candidates[0].content.parts[0].text = response_text_content
# Ensure function_call attribute doesn't exist
del mock_response.candidates[0].content.parts[0].function_call

with patch(
"aisuite.providers.google_provider.GenerativeModel"
) as mock_generative_model:
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_chat = MagicMock()
mock_model.start_chat.return_value = mock_chat
mock_chat.send_message.return_value = mock_response

response = interface.chat_completions_create(
messages=message_history,
model=selected_model,
temperature=0.7,
)

# Assert the response is in the correct format
assert response.choices[0].message.content == response_text_content
assert response.choices[0].finish_reason == "stop"

# Test case 2: Function call response
def test_function_call():
user_greeting = "What's the weather?"
message_history = [{"role": "user", "content": user_greeting}]
selected_model = "our-favorite-model"

interface = GoogleProvider()
mock_response = MagicMock()
mock_response.candidates = [MagicMock()]
mock_response.candidates[0].content.parts = [MagicMock()]

# Mock the function call response
function_call_mock = MagicMock()
function_call_mock.name = "get_weather"
function_call_mock.args = {"location": "San Francisco"}
mock_response.candidates[0].content.parts[0].function_call = function_call_mock
mock_response.candidates[0].content.parts[0].text = None

with patch(
"aisuite.providers.google_provider.GenerativeModel"
) as mock_generative_model:
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_chat = MagicMock()
mock_model.start_chat.return_value = mock_chat
mock_chat.send_message.return_value = mock_response

response = interface.chat_completions_create(
messages=message_history,
model=selected_model,
temperature=0.7,
)

# Assert the response contains the function call
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls[0].type == "function"
assert (
response.choices[0].message.tool_calls[0].function.name == "get_weather"
)
assert json.loads(
response.choices[0].message.tool_calls[0].function.arguments
) == {"location": "San Francisco"}
assert response.choices[0].finish_reason == "tool_calls"

# Run both test cases
test_text_response()
test_function_call()


def test_convert_openai_to_vertex_ai():
"""Test the message conversion from OpenAI format to Vertex AI format."""
interface = GoogleProvider()
message_history = [{"role": "user", "content": "Hello!"}]
result = interface.convert_openai_to_vertex_ai(message_history)
message = {"role": "user", "content": "Hello!"}

# Use the transformer to convert the message
result = interface.transformer.convert_request([message])

# Verify the conversion result
assert len(result) == 1
assert isinstance(result[0], Content)
assert result[0].role == "user"
assert len(result[0].parts) == 1
assert isinstance(result[0].parts[0], Part)
assert result[0].parts[0].text == "Hello!"


def test_transform_roles():
def test_role_conversions():
"""Test that different message roles are converted correctly."""
interface = GoogleProvider()

messages = [
{"role": "system", "content": "Google: system message = 1st user message."},
{"role": "user", "content": "User message 1."},
{"role": "assistant", "content": "Assistant message 1."},
{"role": "system", "content": "System message"},
{"role": "user", "content": "User message"},
{"role": "assistant", "content": "Assistant message"},
]

expected_output = [
{"role": "user", "content": "Google: system message = 1st user message."},
{"role": "user", "content": "User message 1."},
{"role": "model", "content": "Assistant message 1."},
]
result = interface.transformer.convert_request(messages)

# System and user messages should both be converted to "user" role in Vertex AI
assert len(result) == 3
assert result[0].role == "user" # system converted to user
assert result[0].parts[0].text == "System message"

result = interface.transform_roles(messages)
assert result[1].role == "user"
assert result[1].parts[0].text == "User message"

assert result == expected_output
assert result[2].role == "model" # assistant converted to model
assert result[2].parts[0].text == "Assistant message"
Loading

0 comments on commit 5b547b8

Please sign in to comment.