Skip to content

feat: #864 support streaming nested tool events in Agent.as_tool #1057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
if TYPE_CHECKING:
from .lifecycle import AgentHooks
from .mcp import MCPServer
from .result import RunResult
from .result import RunResult, RunResultStreaming


@dataclass
Expand Down Expand Up @@ -199,7 +199,9 @@ def as_tool(
self,
tool_name: str | None,
tool_description: str | None,
*,
custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
stream_inner_events: bool = False,
) -> Tool:
"""Transform this agent into a tool, callable by other agents.

Expand All @@ -224,17 +226,36 @@ def as_tool(
async def run_agent(context: RunContextWrapper, input: str) -> str:
from .run import Runner

output = await Runner.run(
starting_agent=self,
input=input,
context=context.context,
)
output_run: RunResult | RunResultStreaming
if stream_inner_events:
from .stream_events import RunItemStreamEvent

sub_run = Runner.run_streamed(
self,
input=input,
context=context.context,
)
parent_queue = getattr(context, "_event_queue", None)
async for ev in sub_run.stream_events():
if parent_queue is not None and isinstance(ev, RunItemStreamEvent):
if ev.name in ("tool_called", "tool_output"):
parent_queue.put_nowait(ev)
output_run = sub_run
else:
output_run = await Runner.run(
starting_agent=self,
input=input,
context=context.context,
)

if custom_output_extractor:
return await custom_output_extractor(output)
return await custom_output_extractor(cast(Any, output_run))

return ItemHelpers.text_message_outputs(output.new_items)
return ItemHelpers.text_message_outputs(output_run.new_items)

return run_agent
tool = run_agent
tool.stream_inner_events = stream_inner_events
return tool

async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None:
"""Get the system prompt for the agent."""
Expand Down
1 change: 1 addition & 0 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ def run_streamed(
trace=new_trace,
context_wrapper=context_wrapper,
)
context_wrapper._event_queue = streamed_result._event_queue

# Kick off the actual agent loop in the background and return the streamed result object.
streamed_result._run_impl_task = asyncio.create_task(
Expand Down
5 changes: 4 additions & 1 deletion src/agents/run_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from dataclasses import dataclass, field
from typing import Any, Generic
from typing import Any, Generic, Optional

from typing_extensions import TypeVar

Expand All @@ -24,3 +25,5 @@ class RunContextWrapper(Generic[TContext]):
"""The usage of the agent run so far. For streamed responses, the usage will be stale until the
last chunk of the stream is processed.
"""

_event_queue: Optional[asyncio.Queue[Any]] = field(default=None, init=False, repr=False)
3 changes: 3 additions & 0 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ class FunctionTool:
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
based on your context/state."""

stream_inner_events: bool = False
"""Whether to stream inner events when used as an agent tool."""


@dataclass
class FileSearchTool:
Expand Down
10 changes: 8 additions & 2 deletions src/agents/tool_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from dataclasses import dataclass, field, fields
from typing import Any
from typing import Any, Optional

from .run_context import RunContextWrapper, TContext

Expand All @@ -15,6 +16,8 @@ class ToolContext(RunContextWrapper[TContext]):
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
"""The ID of the tool call."""

_event_queue: Optional[asyncio.Queue[Any]] = field(default=None, init=False, repr=False)

@classmethod
def from_agent_context(
cls, context: RunContextWrapper[TContext], tool_call_id: str
Expand All @@ -26,4 +29,7 @@ def from_agent_context(
base_values: dict[str, Any] = {
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
}
return cls(tool_call_id=tool_call_id, **base_values)
obj = cls(tool_call_id=tool_call_id, **base_values)
if hasattr(context, "_event_queue"):
obj._event_queue = context._event_queue
return obj
Loading