From c7d5223a75dbaec9d2e23141df3e3da818749054 Mon Sep 17 00:00:00 2001 From: James Braza Date: Sat, 5 Oct 2024 17:46:54 -0700 Subject: [PATCH] Making `ToolSelector` require a tool call by default (#59) --- src/aviary/tools/utils.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/src/aviary/tools/utils.py b/src/aviary/tools/utils.py index f5504dbd..e5c0fb24 100644 --- a/src/aviary/tools/utils.py +++ b/src/aviary/tools/utils.py @@ -1,7 +1,7 @@ from collections.abc import Callable from enum import StrEnum from functools import partial -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, ClassVar, cast from pydantic import BaseModel, Field @@ -145,26 +145,49 @@ def __init__( self._model_name = model_name self._bound_acompletion = partial(cast(Callable, acompletion), model_name) + # SEE: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice + # > `required` means the model must call one or more tools. + TOOL_CHOICE_REQUIRED: ClassVar[str] = "required" + async def __call__( - self, messages: list[Message], tools: list[Tool] + self, + messages: list[Message], + tools: list[Tool], + tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED, ) -> ToolRequestMessage: """Run a completion that selects a tool in tools given the messages.""" + kwargs = {} + if tool_choice is not None: + kwargs["tool_choice"] = ( + { + "type": "function", + "function": {"name": tool_choice.info.name}, + } + if isinstance(tool_choice, Tool) + else tool_choice + ) model_response = await self._bound_acompletion( messages=MessagesAdapter.dump_python( messages, exclude_none=True, by_alias=True ), tools=ToolsAdapter.dump_python(tools, exclude_none=True, by_alias=True), + **kwargs, ) - if ( - len(model_response.choices) != 1 - or model_response.choices[0].finish_reason != "tool_calls" - ): + + if (num_choices := len(model_response.choices)) != 1: + raise MalformedMessageError( + f"Expected one choice in LiteLLM model response, got {num_choices}" + f" choices, full response was {model_response}." + ) + choice = model_response.choices[0] + if choice.finish_reason != "tool_calls": raise MalformedMessageError( - f"Unexpected shape of LiteLLM model response {model_response}." + "Expected finish reason 'tool_calls' in LiteLLM model response, got" + f" {choice.finish_reason!r}, full response was {model_response}." ) usage = model_response.usage return ToolRequestMessage( - **model_response.choices[0].message.model_dump(), + **choice.message.model_dump(), info={ "usage": (usage.prompt_tokens, usage.completion_tokens), "model": self._model_name,