Skip to content

Commit

Permalink
Making ToolSelector require a tool call by default (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Oct 6, 2024
1 parent 7879ffe commit c7d5223
Showing 1 changed file with 31 additions and 8 deletions.
39 changes: 31 additions & 8 deletions src/aviary/tools/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Callable
from enum import StrEnum
from functools import partial
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, ClassVar, cast

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -145,26 +145,49 @@ def __init__(
self._model_name = model_name
self._bound_acompletion = partial(cast(Callable, acompletion), model_name)

# SEE: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# > `required` means the model must call one or more tools.
TOOL_CHOICE_REQUIRED: ClassVar[str] = "required"

async def __call__(
self, messages: list[Message], tools: list[Tool]
self,
messages: list[Message],
tools: list[Tool],
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
) -> ToolRequestMessage:
"""Run a completion that selects a tool in tools given the messages."""
kwargs = {}
if tool_choice is not None:
kwargs["tool_choice"] = (
{
"type": "function",
"function": {"name": tool_choice.info.name},
}
if isinstance(tool_choice, Tool)
else tool_choice
)
model_response = await self._bound_acompletion(
messages=MessagesAdapter.dump_python(
messages, exclude_none=True, by_alias=True
),
tools=ToolsAdapter.dump_python(tools, exclude_none=True, by_alias=True),
**kwargs,
)
if (
len(model_response.choices) != 1
or model_response.choices[0].finish_reason != "tool_calls"
):

if (num_choices := len(model_response.choices)) != 1:
raise MalformedMessageError(
f"Expected one choice in LiteLLM model response, got {num_choices}"
f" choices, full response was {model_response}."
)
choice = model_response.choices[0]
if choice.finish_reason != "tool_calls":
raise MalformedMessageError(
f"Unexpected shape of LiteLLM model response {model_response}."
"Expected finish reason 'tool_calls' in LiteLLM model response, got"
f" {choice.finish_reason!r}, full response was {model_response}."
)
usage = model_response.usage
return ToolRequestMessage(
**model_response.choices[0].message.model_dump(),
**choice.message.model_dump(),
info={
"usage": (usage.prompt_tokens, usage.completion_tokens),
"model": self._model_name,
Expand Down

0 comments on commit c7d5223

Please sign in to comment.