From 25523831113641c28e32e29866c01dbb10ab7b15 Mon Sep 17 00:00:00 2001 From: "J.J. Allaire" Date: Mon, 23 Dec 2024 06:30:37 -0500 Subject: [PATCH] add tests for optional tool args --- tests/tools/test_tool_types.py | 39 ++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/tools/test_tool_types.py b/tests/tools/test_tool_types.py index b38c3ae27..712db0a38 100644 --- a/tests/tools/test_tool_types.py +++ b/tests/tools/test_tool_types.py @@ -107,6 +107,25 @@ async def execute(extracted: list[Word]): return execute +@tool +def sound_check(): + async def execute(sound: str, extra: str = "stuff"): + """ + Accepts the extracted nouns and adjectives from a sentence + + Args: + sound: The sound to check + extra: Optional extra stuff + + + Returns: + The sound that was passed to check. + """ + return f"{sound} ({extra})" + + return execute + + def check_point(model: str, tool: Tool, function_name: str) -> None: task = Task( dataset=MemoryDataset( @@ -173,11 +192,31 @@ def check_list_of_objects(model: str) -> None: verify_tool_call(log, "quick:") +def check_optional_args(model: str) -> None: + task = Task( + dataset=MemoryDataset( + [ + Sample( + input="Please call the sound_check tool with single argument 'boo' and report its output." + ) + ] + ), + solver=[ + use_tools([sound_check()], tool_choice=ToolFunction("sound_check")), + generate(), + ], + ) + + log = eval(task, model=model)[0] + verify_tool_call(log, "stuff") + + def check_tool_types(model: str): check_typed_dict(model) check_dataclass(model) check_list_of_numbers(model) check_list_of_objects(model) + check_optional_args(model) @skip_if_no_openai