Skip to content

Commit

Permalink
fix: fix function calling with the latest litellm version
Browse files Browse the repository at this point in the history
  • Loading branch information
masci committed Nov 30, 2024
1 parent b84f584 commit 0e808cd
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 19 deletions.
30 changes: 18 additions & 12 deletions src/banks/extensions/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,16 @@ def _do_completion(self, model_name, caller):
Helper callback.
"""
messages, tools = self._body_to_messages(caller())
messages_as_dict = [m.model_dump() for m in messages]
message_dicts = [m.model_dump() for m in messages]
tool_dicts = [t.model_dump() for t in tools] or None

response = cast(ModelResponse, completion(model=model_name, messages=messages_as_dict, tools=tools or None))
response = cast(ModelResponse, completion(model=model_name, messages=message_dicts, tools=tool_dicts))
choices = cast(list[Choices], response.choices)
tool_calls = choices[0].message.tool_calls
if not tool_calls:
return choices[0].message.content

messages.append(choices[0].message) # type:ignore
message_dicts.append(choices[0].message.model_dump())
for tool_call in tool_calls:
if not tool_call.function.name:
msg = "Malformed response: function name is empty"
Expand All @@ -107,14 +108,13 @@ def _do_completion(self, model_name, caller):

function_args = json.loads(tool_call.function.arguments)
function_response = func(**function_args)
messages.append(
message_dicts.append(
ChatMessage(
tool_call_id=tool_call.id, role="tool", name=tool_call.function.name, content=function_response
)
).model_dump()
)

messages_as_dict = [m.model_dump() for m in messages]
response = cast(ModelResponse, completion(model=model_name, messages=messages_as_dict))
response = cast(ModelResponse, completion(model=model_name, messages=message_dicts, tools=tool_dicts))
choices = cast(list[Choices], response.choices)
return choices[0].message.content

Expand All @@ -123,45 +123,51 @@ async def _do_completion_async(self, model_name, caller):
Helper callback.
"""
messages, tools = self._body_to_messages(caller())
message_dicts = [m.model_dump() for m in messages]
tool_dicts = [t.model_dump() for t in tools] or None

response = cast(ModelResponse, await acompletion(model=model_name, messages=messages, tools=tools))
response = cast(ModelResponse, await acompletion(model=model_name, messages=message_dicts, tools=tool_dicts))
choices = cast(list[Choices], response.choices)
tool_calls = choices[0].message.tool_calls or []
if not tool_calls:
return choices[0].message.content

messages.append(choices[0].message) # type:ignore
message_dicts.append(choices[0].message.model_dump())
for tool_call in tool_calls:
if not tool_call.function.name:
msg = "Function name is empty"
raise LLMError(msg)

func = self._get_tool_callable(tools, tool_call)

messages.append(
message_dicts.append(
ChatMessage(
tool_call_id=tool_call.id,
role="tool",
name=tool_call.function.name,
content=func(**json.loads(tool_call.function.arguments)),
)
).model_dump()
)

response = cast(ModelResponse, await acompletion(model=model_name, messages=messages))
response = cast(ModelResponse, await acompletion(model=model_name, messages=message_dicts, tools=tool_dicts))
choices = cast(list[Choices], response.choices)
return choices[0].message.content

def _body_to_messages(self, body: str) -> tuple[list[ChatMessage], list[Tool]]:
"""Converts each line in the body of a block into a chat message."""
body = body.strip()
messages = []
tools = []
for line in body.split("\n"):
try:
# Try to parse a chat message
messages.append(ChatMessage.model_validate_json(line))
except ValidationError: # pylint: disable=R0801
try:
# If not a chat message, try to parse a tool
tools.append(Tool.model_validate_json(line))
except ValidationError:
# Give up
pass

if not messages:
Expand Down
54 changes: 54 additions & 0 deletions tests/e2e/test_function_calling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import platform

import pytest

from banks import Prompt

from .conftest import anthropic_api_key_set, openai_api_key_set


def get_laptop_info():
"""Get information about the user laptop.
For example, it returns the operating system and version, along with hardware and network specs."""
return str(platform.uname())


@pytest.mark.e2e
@openai_api_key_set
def test_function_call_openai():
p = Prompt("""
{% set response %}
{% completion model="gpt-3.5-turbo-0125" %}
{% chat role="user" %}{{ query }}{% endchat %}
{{ get_laptop_info | tool }}
{% endcompletion %}
{% endset %}
{# the variable 'response' contains the result #}
{{ response }}
""")

res = p.text({"query": "Can you guess the name of my laptop?", "get_laptop_info": get_laptop_info})
assert res


@pytest.mark.e2e
@anthropic_api_key_set
def test_function_call_anthropic():
p = Prompt("""
{% set response %}
{% completion model="claude-3-5-sonnet-20240620" %}
{% chat role="user" %}{{ query }}{% endchat %}
{{ get_laptop_info | tool }}
{% endcompletion %}
{% endset %}
{# the variable 'response' contains the result #}
{{ response }}
""")

res = p.text({"query": "Can you guess the name of my laptop? Use tools.", "get_laptop_info": get_laptop_info})
assert res
21 changes: 14 additions & 7 deletions tests/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ async def test__do_completion_async_no_tools(ext, mocked_choices_no_tools):
mocked_completion.return_value.choices = mocked_choices_no_tools
await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}')
mocked_completion.assert_called_with(
model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[]
model="test-model",
messages=[{"role": "user", "content": "hello", "tool_call_id": None, "name": None}],
tools=None,
)


Expand All @@ -143,24 +145,27 @@ def test__do_completion_with_tools(ext, mocked_choices_with_tools):
calls = mocked_completion.call_args_list
assert len(calls) == 2 # complete query, complete with tool results
assert len(calls[0].kwargs["tools"]) == 2
assert "tools" not in calls[1].kwargs
for m in calls[1].kwargs["messages"]:
if type(m) is ChatMessage:
assert m.role == "tool"
assert m.name == "get_current_weather"


@pytest.mark.asyncio
async def test__do_completion_async_with_tools(ext, mocked_choices_with_tools):
async def test__do_completion_async_with_tools(ext, mocked_choices_with_tools, tools):
ext._get_tool_callable = mock.MagicMock(return_value=lambda location, unit: f"I got {location} with {unit}")
ext._body_to_messages = mock.MagicMock(return_value=(["message1", "message2"], ["tool1", "tool2"]))
ext._body_to_messages = mock.MagicMock(
return_value=(
[ChatMessage(role="user", content="message1"), ChatMessage(role="user", content="message2")],
tools,
)
)
with mock.patch("banks.extensions.completion.acompletion") as mocked_completion:
mocked_completion.return_value.choices = mocked_choices_with_tools
await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}')
calls = mocked_completion.call_args_list
assert len(calls) == 2 # complete query, complete with tool results
assert calls[0].kwargs["tools"] == ["tool1", "tool2"]
assert "tools" not in calls[1].kwargs
assert calls[0].kwargs["tools"] == [t.model_dump() for t in tools]
for m in calls[1].kwargs["messages"]:
if type(m) is ChatMessage:
assert m.role == "tool"
Expand Down Expand Up @@ -190,7 +195,9 @@ async def test__do_completion_async_no_prompt_no_tools(ext, mocked_choices_no_to
mocked_completion.return_value.choices = mocked_choices_no_tools
await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}')
mocked_completion.assert_called_with(
model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[]
model="test-model",
messages=[{"role": "user", "content": "hello", "tool_call_id": None, "name": None}],
tools=None,
)


Expand Down

0 comments on commit 0e808cd

Please sign in to comment.