diff --git a/src/banks/extensions/completion.py b/src/banks/extensions/completion.py index c39d333..9ed72d5 100644 --- a/src/banks/extensions/completion.py +++ b/src/banks/extensions/completion.py @@ -89,15 +89,16 @@ def _do_completion(self, model_name, caller): Helper callback. """ messages, tools = self._body_to_messages(caller()) - messages_as_dict = [m.model_dump() for m in messages] + message_dicts = [m.model_dump() for m in messages] + tool_dicts = [t.model_dump() for t in tools] or None - response = cast(ModelResponse, completion(model=model_name, messages=messages_as_dict, tools=tools or None)) + response = cast(ModelResponse, completion(model=model_name, messages=message_dicts, tools=tool_dicts)) choices = cast(list[Choices], response.choices) tool_calls = choices[0].message.tool_calls if not tool_calls: return choices[0].message.content - messages.append(choices[0].message) # type:ignore + message_dicts.append(choices[0].message.model_dump()) for tool_call in tool_calls: if not tool_call.function.name: msg = "Malformed response: function name is empty" @@ -107,14 +108,13 @@ def _do_completion(self, model_name, caller): function_args = json.loads(tool_call.function.arguments) function_response = func(**function_args) - messages.append( + message_dicts.append( ChatMessage( tool_call_id=tool_call.id, role="tool", name=tool_call.function.name, content=function_response - ) + ).model_dump() ) - messages_as_dict = [m.model_dump() for m in messages] - response = cast(ModelResponse, completion(model=model_name, messages=messages_as_dict)) + response = cast(ModelResponse, completion(model=model_name, messages=message_dicts, tools=tool_dicts)) choices = cast(list[Choices], response.choices) return choices[0].message.content @@ -123,14 +123,16 @@ async def _do_completion_async(self, model_name, caller): Helper callback. """ messages, tools = self._body_to_messages(caller()) + message_dicts = [m.model_dump() for m in messages] + tool_dicts = [t.model_dump() for t in tools] or None - response = cast(ModelResponse, await acompletion(model=model_name, messages=messages, tools=tools)) + response = cast(ModelResponse, await acompletion(model=model_name, messages=message_dicts, tools=tool_dicts)) choices = cast(list[Choices], response.choices) tool_calls = choices[0].message.tool_calls or [] if not tool_calls: return choices[0].message.content - messages.append(choices[0].message) # type:ignore + message_dicts.append(choices[0].message.model_dump()) for tool_call in tool_calls: if not tool_call.function.name: msg = "Function name is empty" @@ -138,30 +140,34 @@ async def _do_completion_async(self, model_name, caller): func = self._get_tool_callable(tools, tool_call) - messages.append( + message_dicts.append( ChatMessage( tool_call_id=tool_call.id, role="tool", name=tool_call.function.name, content=func(**json.loads(tool_call.function.arguments)), - ) + ).model_dump() ) - response = cast(ModelResponse, await acompletion(model=model_name, messages=messages)) + response = cast(ModelResponse, await acompletion(model=model_name, messages=message_dicts, tools=tool_dicts)) choices = cast(list[Choices], response.choices) return choices[0].message.content def _body_to_messages(self, body: str) -> tuple[list[ChatMessage], list[Tool]]: + """Converts each line in the body of a block into a chat message.""" body = body.strip() messages = [] tools = [] for line in body.split("\n"): try: + # Try to parse a chat message messages.append(ChatMessage.model_validate_json(line)) except ValidationError: # pylint: disable=R0801 try: + # If not a chat message, try to parse a tool tools.append(Tool.model_validate_json(line)) except ValidationError: + # Give up pass if not messages: diff --git a/tests/e2e/test_function_calling.py b/tests/e2e/test_function_calling.py new file mode 100644 index 0000000..3384113 --- /dev/null +++ b/tests/e2e/test_function_calling.py @@ -0,0 +1,54 @@ +import platform + +import pytest + +from banks import Prompt + +from .conftest import anthropic_api_key_set, openai_api_key_set + + +def get_laptop_info(): + """Get information about the user laptop. + + For example, it returns the operating system and version, along with hardware and network specs.""" + return str(platform.uname()) + + +@pytest.mark.e2e +@openai_api_key_set +def test_function_call_openai(): + p = Prompt(""" + {% set response %} + {% completion model="gpt-3.5-turbo-0125" %} + {% chat role="user" %}{{ query }}{% endchat %} + {{ get_laptop_info | tool }} + {% endcompletion %} + {% endset %} + + {# the variable 'response' contains the result #} + + {{ response }} + """) + + res = p.text({"query": "Can you guess the name of my laptop?", "get_laptop_info": get_laptop_info}) + assert res + + +@pytest.mark.e2e +@anthropic_api_key_set +def test_function_call_anthropic(): + p = Prompt(""" + {% set response %} + {% completion model="claude-3-5-sonnet-20240620" %} + {% chat role="user" %}{{ query }}{% endchat %} + {{ get_laptop_info | tool }} + {% endcompletion %} + {% endset %} + + {# the variable 'response' contains the result #} + + {{ response }} + """) + + res = p.text({"query": "Can you guess the name of my laptop? Use tools.", "get_laptop_info": get_laptop_info}) + assert res diff --git a/tests/test_completion.py b/tests/test_completion.py index ea5b885..08f4239 100644 --- a/tests/test_completion.py +++ b/tests/test_completion.py @@ -125,7 +125,9 @@ async def test__do_completion_async_no_tools(ext, mocked_choices_no_tools): mocked_completion.return_value.choices = mocked_choices_no_tools await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}') mocked_completion.assert_called_with( - model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[] + model="test-model", + messages=[{"role": "user", "content": "hello", "tool_call_id": None, "name": None}], + tools=None, ) @@ -143,7 +145,6 @@ def test__do_completion_with_tools(ext, mocked_choices_with_tools): calls = mocked_completion.call_args_list assert len(calls) == 2 # complete query, complete with tool results assert len(calls[0].kwargs["tools"]) == 2 - assert "tools" not in calls[1].kwargs for m in calls[1].kwargs["messages"]: if type(m) is ChatMessage: assert m.role == "tool" @@ -151,16 +152,20 @@ def test__do_completion_with_tools(ext, mocked_choices_with_tools): @pytest.mark.asyncio -async def test__do_completion_async_with_tools(ext, mocked_choices_with_tools): +async def test__do_completion_async_with_tools(ext, mocked_choices_with_tools, tools): ext._get_tool_callable = mock.MagicMock(return_value=lambda location, unit: f"I got {location} with {unit}") - ext._body_to_messages = mock.MagicMock(return_value=(["message1", "message2"], ["tool1", "tool2"])) + ext._body_to_messages = mock.MagicMock( + return_value=( + [ChatMessage(role="user", content="message1"), ChatMessage(role="user", content="message2")], + tools, + ) + ) with mock.patch("banks.extensions.completion.acompletion") as mocked_completion: mocked_completion.return_value.choices = mocked_choices_with_tools await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}') calls = mocked_completion.call_args_list assert len(calls) == 2 # complete query, complete with tool results - assert calls[0].kwargs["tools"] == ["tool1", "tool2"] - assert "tools" not in calls[1].kwargs + assert calls[0].kwargs["tools"] == [t.model_dump() for t in tools] for m in calls[1].kwargs["messages"]: if type(m) is ChatMessage: assert m.role == "tool" @@ -190,7 +195,9 @@ async def test__do_completion_async_no_prompt_no_tools(ext, mocked_choices_no_to mocked_completion.return_value.choices = mocked_choices_no_tools await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}') mocked_completion.assert_called_with( - model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[] + model="test-model", + messages=[{"role": "user", "content": "hello", "tool_call_id": None, "name": None}], + tools=None, )