Skip to content

Added support for passing tool_call_id via the RunContextWrapper #766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,23 +539,26 @@ async def run_single_tool(
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
) -> Any:
with function_span(func_tool.name) as span_fn:
tool_context_wrapper = dataclasses.replace(
context_wrapper, tool_call_id=tool_call.call_id
)
if config.trace_include_sensitive_data:
span_fn.span_data.input = tool_call.arguments
try:
_, _, result = await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, func_tool),
hooks.on_tool_start(tool_context_wrapper, agent, func_tool),
(
agent.hooks.on_tool_start(context_wrapper, agent, func_tool)
agent.hooks.on_tool_start(tool_context_wrapper, agent, func_tool)
if agent.hooks
else _coro.noop_coroutine()
),
func_tool.on_invoke_tool(context_wrapper, tool_call.arguments),
func_tool.on_invoke_tool(tool_context_wrapper, tool_call.arguments),
)

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, func_tool, result),
hooks.on_tool_end(tool_context_wrapper, agent, func_tool, result),
(
agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result)
agent.hooks.on_tool_end(tool_context_wrapper, agent, func_tool, result)
if agent.hooks
else _coro.noop_coroutine()
),
Expand Down
3 changes: 3 additions & 0 deletions src/agents/run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ class RunContextWrapper(Generic[TContext]):
"""The usage of the agent run so far. For streamed responses, the usage will be stale until the
last chunk of the stream is processed.
"""

tool_call_id: str | None = None
"""The ID of the tool call for the current tool execution."""
6 changes: 4 additions & 2 deletions tests/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ def _foo() -> str:
)


def get_function_tool_call(name: str, arguments: str | None = None) -> ResponseOutputItem:
def get_function_tool_call(
name: str, arguments: str | None = None, call_id: str | None = None
) -> ResponseOutputItem:
return ResponseFunctionToolCall(
id="1",
call_id="2",
call_id=call_id or "2",
type="function_call",
name=name,
arguments=arguments or "",
Expand Down
38 changes: 38 additions & 0 deletions tests/test_run_step_execution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
from typing import Any

import pytest
Expand All @@ -26,6 +27,7 @@
RunImpl,
SingleStepResult,
)
from agents.tool import function_tool

from .test_responses import (
get_final_output_message,
Expand Down Expand Up @@ -158,6 +160,42 @@ async def test_multiple_tool_calls():
assert isinstance(result.next_step, NextStepRunAgain)


@pytest.mark.asyncio
async def test_multiple_tool_calls_with_tool_context():
async def _fake_tool(agent_context: RunContextWrapper[str], value: str) -> str:
return f"{value}-{agent_context.tool_call_id}"

tool = function_tool(_fake_tool, name_override="fake_tool", failure_error_function=None)

agent = Agent(
name="test",
tools=[tool],
)
response = ModelResponse(
output=[
get_function_tool_call("fake_tool", json.dumps({"value": "123"}), call_id="1"),
get_function_tool_call("fake_tool", json.dumps({"value": "456"}), call_id="2"),
],
usage=Usage(),
response_id=None,
)

result = await get_execute_result(agent, response)
assert result.original_input == "hello"

# 4 items: new message, 2 tool calls, 2 tool call outputs
assert len(result.generated_items) == 4
assert isinstance(result.next_step, NextStepRunAgain)

items = result.generated_items
assert_item_is_function_tool_call(items[0], "fake_tool", json.dumps({"value": "123"}))
assert_item_is_function_tool_call(items[1], "fake_tool", json.dumps({"value": "456"}))
assert_item_is_function_tool_call_output(items[2], "123-1")
assert_item_is_function_tool_call_output(items[3], "456-2")

assert isinstance(result.next_step, NextStepRunAgain)


@pytest.mark.asyncio
async def test_handoff_output_leads_to_handoff_next_step():
agent_1 = Agent(name="test_1")
Expand Down