Skip to content

Commit

Permalink
Merge branch 'main' into massi/fix-notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
masci authored Nov 25, 2024
2 parents 2175ed0 + 52e6a5e commit c2a8883
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ ignore = [
"ISC001",
# Magic numbers
"PLR2004",
# __all__ sorted
"RUF022",
]
unfixable = [
# Don't touch unused imports
Expand Down
6 changes: 4 additions & 2 deletions src/banks/extensions/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ def _do_completion(self, model_name, caller):
Helper callback.
"""
messages, tools = self._body_to_messages(caller())
messages_as_dict = [m.model_dump() for m in messages]

response = cast(ModelResponse, completion(model=model_name, messages=messages, tools=tools))
response = cast(ModelResponse, completion(model=model_name, messages=messages_as_dict, tools=tools or None))
choices = cast(list[Choices], response.choices)
tool_calls = choices[0].message.tool_calls
if not tool_calls:
Expand All @@ -112,7 +113,8 @@ def _do_completion(self, model_name, caller):
)
)

response = cast(ModelResponse, completion(model=model_name, messages=messages))
messages_as_dict = [m.model_dump() for m in messages]
response = cast(ModelResponse, completion(model=model_name, messages=messages_as_dict))
choices = cast(list[Choices], response.choices)
return choices[0].message.content

Expand Down
11 changes: 8 additions & 3 deletions tests/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test__do_completion_no_tools(ext, mocked_choices_no_tools):
mocked_completion.return_value.choices = mocked_choices_no_tools
ext._do_completion("test-model", lambda: '{"role":"user", "content":"hello"}')
mocked_completion.assert_called_with(
model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[]
model="test-model", messages=[ChatMessage(role="user", content="hello").model_dump()], tools=None
)


Expand All @@ -131,13 +131,18 @@ async def test__do_completion_async_no_tools(ext, mocked_choices_no_tools):

def test__do_completion_with_tools(ext, mocked_choices_with_tools):
ext._get_tool_callable = mock.MagicMock(return_value=lambda location, unit: f"I got {location} with {unit}")
ext._body_to_messages = mock.MagicMock(return_value=(["message1", "message2"], ["tool1", "tool2"]))
ext._body_to_messages = mock.MagicMock(
return_value=(
[ChatMessage(role="user", content="message1"), ChatMessage(role="user", content="message2")],
[mock.MagicMock(), mock.MagicMock()],
)
)
with mock.patch("banks.extensions.completion.completion") as mocked_completion:
mocked_completion.return_value.choices = mocked_choices_with_tools
ext._do_completion("test-model", lambda: '{"role":"user", "content":"hello"}')
calls = mocked_completion.call_args_list
assert len(calls) == 2 # complete query, complete with tool results
assert calls[0].kwargs["tools"] == ["tool1", "tool2"]
assert len(calls[0].kwargs["tools"]) == 2
assert "tools" not in calls[1].kwargs
for m in calls[1].kwargs["messages"]:
if type(m) is ChatMessage:
Expand Down

0 comments on commit c2a8883

Please sign in to comment.