Skip to content

Commit da3a057

Browse files
committed
feat: support streaming nested tool events in Agent.as_tool
- Add `stream_inner_events` flag to allow sub-agent tool call visibility - Use `Runner.run_streamed` for streaming inner agents - Emit nested tool_called/tool_output events in parent stream - Add comprehensive test coverage for inner streaming behavior
1 parent db85a6d commit da3a057

File tree

6 files changed

+380
-14
lines changed

6 files changed

+380
-14
lines changed

src/agents/agent.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
if TYPE_CHECKING:
2828
from .lifecycle import AgentHooks
2929
from .mcp import MCPServer
30-
from .result import RunResult
30+
from .result import RunResult, RunResultStreaming
3131

3232

3333
@dataclass
@@ -199,7 +199,9 @@ def as_tool(
199199
self,
200200
tool_name: str | None,
201201
tool_description: str | None,
202+
*,
202203
custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
204+
stream_inner_events: bool = False,
203205
) -> Tool:
204206
"""Transform this agent into a tool, callable by other agents.
205207
@@ -224,17 +226,36 @@ def as_tool(
224226
async def run_agent(context: RunContextWrapper, input: str) -> str:
225227
from .run import Runner
226228

227-
output = await Runner.run(
228-
starting_agent=self,
229-
input=input,
230-
context=context.context,
231-
)
229+
output_run: RunResult | RunResultStreaming
230+
if stream_inner_events:
231+
from .stream_events import RunItemStreamEvent
232+
233+
sub_run = Runner.run_streamed(
234+
self,
235+
input=input,
236+
context=context.context,
237+
)
238+
parent_queue = getattr(context, "_event_queue", None)
239+
async for ev in sub_run.stream_events():
240+
if parent_queue is not None and isinstance(ev, RunItemStreamEvent):
241+
if ev.name in ("tool_called", "tool_output"):
242+
parent_queue.put_nowait(ev)
243+
output_run = sub_run
244+
else:
245+
output_run = await Runner.run(
246+
starting_agent=self,
247+
input=input,
248+
context=context.context,
249+
)
250+
232251
if custom_output_extractor:
233-
return await custom_output_extractor(output)
252+
return await custom_output_extractor(cast(Any, output_run))
234253

235-
return ItemHelpers.text_message_outputs(output.new_items)
254+
return ItemHelpers.text_message_outputs(output_run.new_items)
236255

237-
return run_agent
256+
tool = run_agent
257+
tool.stream_inner_events = stream_inner_events
258+
return tool
238259

239260
async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None:
240261
"""Get the system prompt for the agent."""
@@ -256,9 +277,7 @@ async def get_prompt(
256277
"""Get the prompt for the agent."""
257278
return await PromptUtil.to_model_input(self.prompt, run_context, self)
258279

259-
async def get_mcp_tools(
260-
self, run_context: RunContextWrapper[TContext]
261-
) -> list[Tool]:
280+
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
262281
"""Fetches the available tools from the MCP servers."""
263282
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
264283
return await MCPUtil.get_all_function_tools(

src/agents/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ def run_streamed(
551551
trace=new_trace,
552552
context_wrapper=context_wrapper,
553553
)
554+
context_wrapper._event_queue = streamed_result._event_queue
554555

555556
# Kick off the actual agent loop in the background and return the streamed result object.
556557
streamed_result._run_impl_task = asyncio.create_task(

src/agents/run_context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from dataclasses import dataclass, field
23
from typing import Any, Generic
34

@@ -24,3 +25,5 @@ class RunContextWrapper(Generic[TContext]):
2425
"""The usage of the agent run so far. For streamed responses, the usage will be stale until the
2526
last chunk of the stream is processed.
2627
"""
28+
29+
_event_queue: asyncio.Queue[Any] | None = field(default=None, init=False, repr=False)

src/agents/tool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from .util._types import MaybeAwaitable
3131

3232
if TYPE_CHECKING:
33-
3433
from .agent import Agent
3534

3635
ToolParams = ParamSpec("ToolParams")
@@ -93,6 +92,9 @@ class FunctionTool:
9392
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
9493
based on your context/state."""
9594

95+
stream_inner_events: bool = False
96+
"""Whether to stream inner events when used as an agent tool."""
97+
9698

9799
@dataclass
98100
class FileSearchTool:

src/agents/tool_context.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from dataclasses import dataclass, field, fields
23
from typing import Any
34

@@ -15,6 +16,8 @@ class ToolContext(RunContextWrapper[TContext]):
1516
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
1617
"""The ID of the tool call."""
1718

19+
_event_queue: asyncio.Queue[Any] | None = field(default=None, init=False, repr=False)
20+
1821
@classmethod
1922
def from_agent_context(
2023
cls, context: RunContextWrapper[TContext], tool_call_id: str
@@ -26,4 +29,7 @@ def from_agent_context(
2629
base_values: dict[str, Any] = {
2730
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
2831
}
29-
return cls(tool_call_id=tool_call_id, **base_values)
32+
obj = cls(tool_call_id=tool_call_id, **base_values)
33+
if hasattr(context, "_event_queue"):
34+
obj._event_queue = context._event_queue
35+
return obj

0 commit comments

Comments
 (0)