Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix function calling with the latest litellm version #29

Merged
merged 2 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions docs/prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,6 @@ Banks supports the following ones, specific for prompt engineering.
show_signature_annotations: false
heading_level: 3

::: banks.extensions.docs.generate
options:
show_root_full_path: false
show_symbol_type_heading: false
show_signature_annotations: false
heading_level: 3

### `canary_word`

Insert into the prompt a canary word that can be checked later with `Prompt.canary_leaked()`
Expand Down
30 changes: 18 additions & 12 deletions src/banks/extensions/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,16 @@ def _do_completion(self, model_name, caller):
Helper callback.
"""
messages, tools = self._body_to_messages(caller())
messages_as_dict = [m.model_dump() for m in messages]
message_dicts = [m.model_dump() for m in messages]
tool_dicts = [t.model_dump() for t in tools] or None

response = cast(ModelResponse, completion(model=model_name, messages=messages_as_dict, tools=tools or None))
response = cast(ModelResponse, completion(model=model_name, messages=message_dicts, tools=tool_dicts))
choices = cast(list[Choices], response.choices)
tool_calls = choices[0].message.tool_calls
if not tool_calls:
return choices[0].message.content

messages.append(choices[0].message) # type:ignore
message_dicts.append(choices[0].message.model_dump())
for tool_call in tool_calls:
if not tool_call.function.name:
msg = "Malformed response: function name is empty"
Expand All @@ -107,14 +108,13 @@ def _do_completion(self, model_name, caller):

function_args = json.loads(tool_call.function.arguments)
function_response = func(**function_args)
messages.append(
message_dicts.append(
ChatMessage(
tool_call_id=tool_call.id, role="tool", name=tool_call.function.name, content=function_response
)
).model_dump()
)

messages_as_dict = [m.model_dump() for m in messages]
response = cast(ModelResponse, completion(model=model_name, messages=messages_as_dict))
response = cast(ModelResponse, completion(model=model_name, messages=message_dicts, tools=tool_dicts))
choices = cast(list[Choices], response.choices)
return choices[0].message.content

Expand All @@ -123,45 +123,51 @@ async def _do_completion_async(self, model_name, caller):
Helper callback.
"""
messages, tools = self._body_to_messages(caller())
message_dicts = [m.model_dump() for m in messages]
tool_dicts = [t.model_dump() for t in tools] or None

response = cast(ModelResponse, await acompletion(model=model_name, messages=messages, tools=tools))
response = cast(ModelResponse, await acompletion(model=model_name, messages=message_dicts, tools=tool_dicts))
choices = cast(list[Choices], response.choices)
tool_calls = choices[0].message.tool_calls or []
if not tool_calls:
return choices[0].message.content

messages.append(choices[0].message) # type:ignore
message_dicts.append(choices[0].message.model_dump())
for tool_call in tool_calls:
if not tool_call.function.name:
msg = "Function name is empty"
raise LLMError(msg)

func = self._get_tool_callable(tools, tool_call)

messages.append(
message_dicts.append(
ChatMessage(
tool_call_id=tool_call.id,
role="tool",
name=tool_call.function.name,
content=func(**json.loads(tool_call.function.arguments)),
)
).model_dump()
)

response = cast(ModelResponse, await acompletion(model=model_name, messages=messages))
response = cast(ModelResponse, await acompletion(model=model_name, messages=message_dicts, tools=tool_dicts))
choices = cast(list[Choices], response.choices)
return choices[0].message.content

def _body_to_messages(self, body: str) -> tuple[list[ChatMessage], list[Tool]]:
"""Converts each line in the body of a block into a chat message."""
body = body.strip()
messages = []
tools = []
for line in body.split("\n"):
try:
# Try to parse a chat message
messages.append(ChatMessage.model_validate_json(line))
except ValidationError: # pylint: disable=R0801
try:
# If not a chat message, try to parse a tool
tools.append(Tool.model_validate_json(line))
except ValidationError:
# Give up
pass

if not messages:
Expand Down
54 changes: 54 additions & 0 deletions tests/e2e/test_function_calling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import platform

import pytest

from banks import Prompt

from .conftest import anthropic_api_key_set, openai_api_key_set


def get_laptop_info():
"""Get information about the user laptop.

For example, it returns the operating system and version, along with hardware and network specs."""
return str(platform.uname())


@pytest.mark.e2e
@openai_api_key_set
def test_function_call_openai():
p = Prompt("""
{% set response %}
{% completion model="gpt-3.5-turbo-0125" %}
{% chat role="user" %}{{ query }}{% endchat %}
{{ get_laptop_info | tool }}
{% endcompletion %}
{% endset %}

{# the variable 'response' contains the result #}

{{ response }}
""")

res = p.text({"query": "Can you guess the name of my laptop?", "get_laptop_info": get_laptop_info})
assert res


@pytest.mark.e2e
@anthropic_api_key_set
def test_function_call_anthropic():
p = Prompt("""
{% set response %}
{% completion model="claude-3-5-sonnet-20240620" %}
{% chat role="user" %}{{ query }}{% endchat %}
{{ get_laptop_info | tool }}
{% endcompletion %}
{% endset %}

{# the variable 'response' contains the result #}

{{ response }}
""")

res = p.text({"query": "Can you guess the name of my laptop? Use tools.", "get_laptop_info": get_laptop_info})
assert res
21 changes: 14 additions & 7 deletions tests/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ async def test__do_completion_async_no_tools(ext, mocked_choices_no_tools):
mocked_completion.return_value.choices = mocked_choices_no_tools
await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}')
mocked_completion.assert_called_with(
model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[]
model="test-model",
messages=[{"role": "user", "content": "hello", "tool_call_id": None, "name": None}],
tools=None,
)


Expand All @@ -143,24 +145,27 @@ def test__do_completion_with_tools(ext, mocked_choices_with_tools):
calls = mocked_completion.call_args_list
assert len(calls) == 2 # complete query, complete with tool results
assert len(calls[0].kwargs["tools"]) == 2
assert "tools" not in calls[1].kwargs
for m in calls[1].kwargs["messages"]:
if type(m) is ChatMessage:
assert m.role == "tool"
assert m.name == "get_current_weather"


@pytest.mark.asyncio
async def test__do_completion_async_with_tools(ext, mocked_choices_with_tools):
async def test__do_completion_async_with_tools(ext, mocked_choices_with_tools, tools):
ext._get_tool_callable = mock.MagicMock(return_value=lambda location, unit: f"I got {location} with {unit}")
ext._body_to_messages = mock.MagicMock(return_value=(["message1", "message2"], ["tool1", "tool2"]))
ext._body_to_messages = mock.MagicMock(
return_value=(
[ChatMessage(role="user", content="message1"), ChatMessage(role="user", content="message2")],
tools,
)
)
with mock.patch("banks.extensions.completion.acompletion") as mocked_completion:
mocked_completion.return_value.choices = mocked_choices_with_tools
await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}')
calls = mocked_completion.call_args_list
assert len(calls) == 2 # complete query, complete with tool results
assert calls[0].kwargs["tools"] == ["tool1", "tool2"]
assert "tools" not in calls[1].kwargs
assert calls[0].kwargs["tools"] == [t.model_dump() for t in tools]
for m in calls[1].kwargs["messages"]:
if type(m) is ChatMessage:
assert m.role == "tool"
Expand Down Expand Up @@ -190,7 +195,9 @@ async def test__do_completion_async_no_prompt_no_tools(ext, mocked_choices_no_to
mocked_completion.return_value.choices = mocked_choices_no_tools
await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}')
mocked_completion.assert_called_with(
model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[]
model="test-model",
messages=[{"role": "user", "content": "hello", "tool_call_id": None, "name": None}],
tools=None,
)


Expand Down