From 9faa2c75c22f11d1673441c81c337baa69a0416c Mon Sep 17 00:00:00 2001 From: jjallaire-aisi Date: Thu, 12 Sep 2024 07:36:11 -0400 Subject: [PATCH] use_tools: accept a variadic list of `Tool` (#374) Co-authored-by: aisi-inspect <166920645+aisi-inspect@users.noreply.github.com> --- CHANGELOG.md | 1 + src/inspect_ai/solver/_use_tools.py | 20 ++++++++----- tests/tools/test_use_tools.py | 44 +++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 7 deletions(-) create mode 100644 tests/tools/test_use_tools.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 300249d48..8ea82ca23 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - [score()](https://inspect.ai-safety-institute.org.uk/solvers.html#sec-scoring-in-solvers) function for accessing scoring logic from within solvers. - `system_message()` now supports custom parameters and interpolation of `metadata` values from `Sample`. - `generate()` solver now accepts arbitrary generation config params. +- `use_tools()` now accepts a variadic list of `Tool` in addition to literal `list[Tool]`. - `bash()` and `python()` tools now have a `user` parameter for choosing an alternate user to run code as. - Input event for recording screen input in sample transcripts. - Record to sample function for CSV and JSON dataset readers can now return multiple samples. diff --git a/src/inspect_ai/solver/_use_tools.py b/src/inspect_ai/solver/_use_tools.py index cb88c1056..5b40d1061 100644 --- a/src/inspect_ai/solver/_use_tools.py +++ b/src/inspect_ai/solver/_use_tools.py @@ -6,15 +6,15 @@ @solver def use_tools( - tools: Tool | list[Tool] | None = None, tool_choice: ToolChoice | None = "auto" + *tools: Tool | list[Tool], tool_choice: ToolChoice | None = "auto" ) -> Solver: """ Inject tools into the task state to be used in generate(). Args: - tools (Tool | list[Tool] | None): One or more tools to make available - to the model. If `None` is passed, then no change to the currently - available set of `tools` is made. + *tools (Tool | list[Tool]): One or more tools or lists of tools + to make available to the model. If no tools are passed, then + no change to the currently available set of `tools` is made. tool_choice (ToolChoice | None): Directive indicating which tools the model should use. If `None` is passed, then no change to `tool_choice` is made. @@ -24,9 +24,15 @@ def use_tools( """ async def solve(state: TaskState, generate: Generate) -> TaskState: - # set tools if specified - if tools is not None: - state.tools = tools if isinstance(tools, list) else [tools] + # build up tools + tools_update: list[Tool] = [] + for tool in tools: + if isinstance(tool, list): + tools_update.extend(tool) + else: + tools_update.append(tool) + if len(tools_update) > 0: + state.tools = tools_update # set tool choice if specified if tool_choice is not None: diff --git a/tests/tools/test_use_tools.py b/tests/tools/test_use_tools.py new file mode 100644 index 000000000..90a2f2083 --- /dev/null +++ b/tests/tools/test_use_tools.py @@ -0,0 +1,44 @@ +from typing import Literal + +import pytest +from test_helpers.tools import addition, read_file +from test_helpers.utils import simple_task_state +from typing_extensions import Unpack + +from inspect_ai.model import CachePolicy, GenerateConfigArgs +from inspect_ai.solver import TaskState, use_tools + + +@pytest.mark.asyncio +async def test_use_tools(): + def generate( + state: TaskState, + tool_calls: Literal["loop", "single", "none"] = "loop", + cache: bool | CachePolicy = False, + **kwargs: Unpack[GenerateConfigArgs], + ) -> TaskState: + return state + + state = simple_task_state() + + addition_tool = addition() + read_file_tool = read_file() + + state = await (use_tools([addition_tool, read_file_tool]))(state, generate) + assert state.tools == [addition_tool, read_file_tool] + + state = await (use_tools([addition_tool, read_file_tool]))(state, generate) + assert state.tools == [addition_tool, read_file_tool] + + state = await (use_tools(addition_tool, read_file_tool))(state, generate) + assert state.tools == [addition_tool, read_file_tool] + + state = await (use_tools([addition_tool]))(state, generate) + assert state.tools == [addition_tool] + + state = await (use_tools(addition_tool))(state, generate) + assert state.tools == [addition_tool] + + state = await (use_tools(tool_choice="auto"))(state, generate) + assert state.tools == [addition_tool] + assert state.tool_choice == "auto"