diff --git a/src/agents/agent.py b/src/agents/agent.py index 6c87297f1..dff7815e1 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -222,16 +222,22 @@ def as_tool( description_override=tool_description or "", ) async def run_agent(context: RunContextWrapper, input: str) -> str: - from .run import Runner + from .run import Runner, get_current_run_config + + # Get the current run_config from context if available + run_config = None + current_run_config = get_current_run_config() + if current_run_config and current_run_config.pass_run_config_to_sub_agents: + run_config = current_run_config output = await Runner.run( starting_agent=self, input=input, context=context.context, + run_config=run_config, ) if custom_output_extractor: return await custom_output_extractor(output) - return ItemHelpers.text_message_outputs(output.new_items) return run_agent diff --git a/src/agents/run.py b/src/agents/run.py index e5f9378ec..7d3800593 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextvars import copy import inspect from dataclasses import dataclass, field @@ -55,6 +56,10 @@ from .usage import Usage from .util import _coro, _error_tracing +_current_run_config: contextvars.ContextVar[RunConfig | None] = contextvars.ContextVar( + "current_run_config", default=None +) + DEFAULT_MAX_TURNS = 10 DEFAULT_AGENT_RUNNER: AgentRunner = None # type: ignore @@ -79,6 +84,21 @@ def get_default_agent_runner() -> AgentRunner: return DEFAULT_AGENT_RUNNER +def get_current_run_config() -> RunConfig | None: + """Get the current run config from context.""" + return _current_run_config.get() + + +def set_current_run_config(run_config: RunConfig | None) -> contextvars.Token[RunConfig | None]: + """Set the current run config in context.""" + return _current_run_config.set(run_config) + + +def reset_current_run_config(token: contextvars.Token[RunConfig | None]) -> None: + """Reset the current run config in context.""" + _current_run_config.reset(token) + + @dataclass class RunConfig: """Configures settings for the entire agent run.""" @@ -137,6 +157,11 @@ class RunConfig: An optional dictionary of additional metadata to include with the trace. """ + pass_run_config_to_sub_agents: bool = False + """ + Whether to pass this run configuration to sub-agents when using as_tool(). + If True, sub-agents will inherit the parent's run configuration. + """ class RunOptions(TypedDict, Generic[TContext]): """Arguments for ``AgentRunner`` methods.""" @@ -332,77 +357,99 @@ async def run( tool_use_tracker = AgentToolUseTracker() - with TraceCtxManager( - workflow_name=run_config.workflow_name, - trace_id=run_config.trace_id, - group_id=run_config.group_id, - metadata=run_config.trace_metadata, - disabled=run_config.tracing_disabled, - ): - current_turn = 0 - original_input: str | list[TResponseInputItem] = copy.deepcopy(input) - generated_items: list[RunItem] = [] - model_responses: list[ModelResponse] = [] + # Set the run_config context variable if enabled + run_config_token = None + if run_config.pass_run_config_to_sub_agents: + run_config_token = set_current_run_config(run_config) - context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( - context=context, # type: ignore - ) + try: + with TraceCtxManager( + workflow_name=run_config.workflow_name, + trace_id=run_config.trace_id, + group_id=run_config.group_id, + metadata=run_config.trace_metadata, + disabled=run_config.tracing_disabled, + ): + current_turn = 0 + original_input: str | list[TResponseInputItem] = copy.deepcopy(input) + generated_items: list[RunItem] = [] + model_responses: list[ModelResponse] = [] + + context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( + context=context, # type: ignore + ) - input_guardrail_results: list[InputGuardrailResult] = [] - - current_span: Span[AgentSpanData] | None = None - current_agent = starting_agent - should_run_agent_start_hooks = True - - try: - while True: - all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper) - - # Start an agent span if we don't have one. This span is ended if the current - # agent changes, or if the agent loop ends. - if current_span is None: - handoff_names = [ - h.agent_name - for h in await AgentRunner._get_handoffs(current_agent, context_wrapper) - ] - if output_schema := AgentRunner._get_output_schema(current_agent): - output_type_name = output_schema.name() - else: - output_type_name = "str" + input_guardrail_results: list[InputGuardrailResult] = [] - current_span = agent_span( - name=current_agent.name, - handoffs=handoff_names, - output_type=output_type_name, - ) - current_span.start(mark_as_current=True) - current_span.span_data.tools = [t.name for t in all_tools] + current_span: Span[AgentSpanData] | None = None + current_agent = starting_agent + should_run_agent_start_hooks = True - current_turn += 1 - if current_turn > max_turns: - _error_tracing.attach_error_to_span( - current_span, - SpanError( - message="Max turns exceeded", - data={"max_turns": max_turns}, - ), - ) - raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded") + try: + while True: + all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper) + + # Start an agent span if we don't have one. This span is ended if the + # current agent changes, or if the agent loop ends. + if current_span is None: + handoff_names = [ + h.agent_name + for h in await AgentRunner._get_handoffs( + current_agent, context_wrapper + ) + ] + if output_schema := AgentRunner._get_output_schema(current_agent): + output_type_name = output_schema.name() + else: + output_type_name = "str" + + current_span = agent_span( + name=current_agent.name, + handoffs=handoff_names, + output_type=output_type_name, + ) + current_span.start(mark_as_current=True) + current_span.span_data.tools = [t.name for t in all_tools] + + current_turn += 1 + if current_turn > max_turns: + _error_tracing.attach_error_to_span( + current_span, + SpanError( + message="Max turns exceeded", + data={"max_turns": max_turns}, + ), + ) + raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded") - logger.debug( - f"Running agent {current_agent.name} (turn {current_turn})", - ) + logger.debug( + f"Running agent {current_agent.name} (turn {current_turn})", + ) - if current_turn == 1: - input_guardrail_results, turn_result = await asyncio.gather( - self._run_input_guardrails( - starting_agent, - starting_agent.input_guardrails - + (run_config.input_guardrails or []), - copy.deepcopy(input), - context_wrapper, - ), - self._run_single_turn( + if current_turn == 1: + input_guardrail_results, turn_result = await asyncio.gather( + self._run_input_guardrails( + starting_agent, + starting_agent.input_guardrails + + (run_config.input_guardrails or []), + copy.deepcopy(input), + context_wrapper, + ), + self._run_single_turn( + agent=current_agent, + all_tools=all_tools, + original_input=original_input, + generated_items=generated_items, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + should_run_agent_start_hooks=should_run_agent_start_hooks, + tool_use_tracker=tool_use_tracker, + previous_response_id=previous_response_id, + ), + ) + else: + turn_result = await self._run_single_turn( agent=current_agent, all_tools=all_tools, original_input=original_input, @@ -413,69 +460,61 @@ async def run( should_run_agent_start_hooks=should_run_agent_start_hooks, tool_use_tracker=tool_use_tracker, previous_response_id=previous_response_id, - ), - ) - else: - turn_result = await self._run_single_turn( - agent=current_agent, - all_tools=all_tools, - original_input=original_input, - generated_items=generated_items, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - should_run_agent_start_hooks=should_run_agent_start_hooks, - tool_use_tracker=tool_use_tracker, - previous_response_id=previous_response_id, - ) - should_run_agent_start_hooks = False + ) + should_run_agent_start_hooks = False - model_responses.append(turn_result.model_response) - original_input = turn_result.original_input - generated_items = turn_result.generated_items + model_responses.append(turn_result.model_response) + original_input = turn_result.original_input + generated_items = turn_result.generated_items - if isinstance(turn_result.next_step, NextStepFinalOutput): - output_guardrail_results = await self._run_output_guardrails( - current_agent.output_guardrails + (run_config.output_guardrails or []), - current_agent, - turn_result.next_step.output, - context_wrapper, - ) - return RunResult( - input=original_input, - new_items=generated_items, - raw_responses=model_responses, - final_output=turn_result.next_step.output, - _last_agent=current_agent, - input_guardrail_results=input_guardrail_results, - output_guardrail_results=output_guardrail_results, - context_wrapper=context_wrapper, - ) - elif isinstance(turn_result.next_step, NextStepHandoff): - current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) + if isinstance(turn_result.next_step, NextStepFinalOutput): + output_guardrail_results = await self._run_output_guardrails( + current_agent.output_guardrails + ( + run_config.output_guardrails or [] + ), + current_agent, + turn_result.next_step.output, + context_wrapper, + ) + return RunResult( + input=original_input, + new_items=generated_items, + raw_responses=model_responses, + final_output=turn_result.next_step.output, + _last_agent=current_agent, + input_guardrail_results=input_guardrail_results, + output_guardrail_results=output_guardrail_results, + context_wrapper=context_wrapper, + ) + elif isinstance(turn_result.next_step, NextStepHandoff): + current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) + current_span.finish(reset_current=True) + current_span = None + should_run_agent_start_hooks = True + elif isinstance(turn_result.next_step, NextStepRunAgain): + pass + else: + raise AgentsException( + f"Unknown next step type: {type(turn_result.next_step)}" + ) + except AgentsException as exc: + exc.run_data = RunErrorDetails( + input=original_input, + new_items=generated_items, + raw_responses=model_responses, + last_agent=current_agent, + context_wrapper=context_wrapper, + input_guardrail_results=input_guardrail_results, + output_guardrail_results=[], + ) + raise + finally: + if current_span: current_span.finish(reset_current=True) - current_span = None - should_run_agent_start_hooks = True - elif isinstance(turn_result.next_step, NextStepRunAgain): - pass - else: - raise AgentsException( - f"Unknown next step type: {type(turn_result.next_step)}" - ) - except AgentsException as exc: - exc.run_data = RunErrorDetails( - input=original_input, - new_items=generated_items, - raw_responses=model_responses, - last_agent=current_agent, - context_wrapper=context_wrapper, - input_guardrail_results=input_guardrail_results, - output_guardrail_results=[], - ) - raise - finally: - if current_span: - current_span.finish(reset_current=True) + finally: + # Always clean up the context variable + if run_config_token is not None: + reset_current_run_config(run_config_token) def run_sync( self, @@ -516,56 +555,67 @@ def run_streamed( if run_config is None: run_config = RunConfig() - # If there's already a trace, we don't create a new one. In addition, we can't end the - # trace here, because the actual work is done in `stream_events` and this method ends - # before that. - new_trace = ( - None - if get_current_trace() - else trace( - workflow_name=run_config.workflow_name, - trace_id=run_config.trace_id, - group_id=run_config.group_id, - metadata=run_config.trace_metadata, - disabled=run_config.tracing_disabled, - ) - ) + # Set the run_config context variable if enabled + run_config_token = None + if run_config.pass_run_config_to_sub_agents: + run_config_token = set_current_run_config(run_config) - output_schema = AgentRunner._get_output_schema(starting_agent) - context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( - context=context # type: ignore - ) + try: + # If there's already a trace, we don't create a new one. In addition, we can't end the + # trace here, because the actual work is done in `stream_events` and this method ends + # before that. + new_trace = ( + None + if get_current_trace() + else trace( + workflow_name=run_config.workflow_name, + trace_id=run_config.trace_id, + group_id=run_config.group_id, + metadata=run_config.trace_metadata, + disabled=run_config.tracing_disabled, + ) + ) - streamed_result = RunResultStreaming( - input=copy.deepcopy(input), - new_items=[], - current_agent=starting_agent, - raw_responses=[], - final_output=None, - is_complete=False, - current_turn=0, - max_turns=max_turns, - input_guardrail_results=[], - output_guardrail_results=[], - _current_agent_output_schema=output_schema, - trace=new_trace, - context_wrapper=context_wrapper, - ) + output_schema = AgentRunner._get_output_schema(starting_agent) + context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( + context=context # type: ignore + ) - # Kick off the actual agent loop in the background and return the streamed result object. - streamed_result._run_impl_task = asyncio.create_task( - self._start_streaming( - starting_input=input, - streamed_result=streamed_result, - starting_agent=starting_agent, + streamed_result = RunResultStreaming( + input=copy.deepcopy(input), + new_items=[], + current_agent=starting_agent, + raw_responses=[], + final_output=None, + is_complete=False, + current_turn=0, max_turns=max_turns, - hooks=hooks, + input_guardrail_results=[], + output_guardrail_results=[], + _current_agent_output_schema=output_schema, + trace=new_trace, context_wrapper=context_wrapper, - run_config=run_config, - previous_response_id=previous_response_id, ) - ) - return streamed_result + + # Kick off the actual agent loop in the background and return the + # streamed result object. + streamed_result._run_impl_task = asyncio.create_task( + self._start_streaming( + starting_input=input, + streamed_result=streamed_result, + starting_agent=starting_agent, + max_turns=max_turns, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + previous_response_id=previous_response_id, + ) + ) + return streamed_result + finally: + # Always reset the context variable + if run_config_token is not None: + reset_current_run_config(run_config_token) @classmethod async def _run_input_guardrails_with_queue( diff --git a/tests/test_run_config_inheritance.py b/tests/test_run_config_inheritance.py new file mode 100644 index 000000000..8886a37e4 --- /dev/null +++ b/tests/test_run_config_inheritance.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +from typing import cast + +import pytest + +from agents import Agent, RunConfig, Runner +from agents.run import get_current_run_config, reset_current_run_config, set_current_run_config +from agents.tool import function_tool + +from .fake_model import FakeModel +from .test_responses import get_function_tool_call, get_text_message + + +@pytest.mark.asyncio +async def test_run_config_inheritance_enabled(): + """Test that run_config is inherited when pass_run_config_to_sub_agents=True""" + inherited_configs = [] + + @function_tool + async def config_capture_tool() -> str: + """Tool that captures the current run config""" + current_config = get_current_run_config() + inherited_configs.append(current_config) + return "config_captured" + + sub_agent = Agent( + name="SubAgent", + instructions="You are a sub agent", + model=FakeModel(), + tools=[config_capture_tool], + ) + + sub_fake_model = cast(FakeModel, sub_agent.model) + sub_fake_model.add_multiple_turn_outputs( + [ + [get_function_tool_call("config_capture_tool", "{}")], + [get_text_message("sub_agent_response")], + ] + ) + + parent_agent = Agent( + name="ParentAgent", + instructions="You are a parent agent", + model=FakeModel(), + tools=[ + sub_agent.as_tool( + tool_name="sub_agent_tool", tool_description="Call the sub agent" + ) + ], + ) + + parent_fake_model = cast(FakeModel, parent_agent.model) + parent_fake_model.add_multiple_turn_outputs( + [ + [get_function_tool_call("sub_agent_tool", '{"input": "test"}')], + [get_text_message("parent_response")], + ] + ) + + run_config = RunConfig(pass_run_config_to_sub_agents=True) + + assert get_current_run_config() is None + + await Runner.run( + starting_agent=parent_agent, + input="Use the sub agent tool", + run_config=run_config, + ) + + assert get_current_run_config() is None + assert len(inherited_configs) == 1 + assert inherited_configs[0] is run_config + assert inherited_configs[0].pass_run_config_to_sub_agents is True + + +@pytest.mark.asyncio +async def test_run_config_inheritance_disabled(): + """Test that run_config is not inherited when pass_run_config_to_sub_agents=False""" + inherited_configs = [] + + @function_tool + async def config_capture_tool() -> str: + """Tool that captures the current run config""" + current_config = get_current_run_config() + inherited_configs.append(current_config) + return "config_captured" + + sub_agent = Agent( + name="SubAgent", + instructions="You are a sub agent", + model=FakeModel(), + tools=[config_capture_tool], + ) + + sub_fake_model = cast(FakeModel, sub_agent.model) + sub_fake_model.add_multiple_turn_outputs( + [ + [get_function_tool_call("config_capture_tool", "{}")], + [get_text_message("sub_agent_response")], + ] + ) + + parent_agent = Agent( + name="ParentAgent", + instructions="You are a parent agent", + model=FakeModel(), + tools=[ + sub_agent.as_tool( + tool_name="sub_agent_tool", tool_description="Call the sub agent" + ) + ], + ) + + parent_fake_model = cast(FakeModel, parent_agent.model) + parent_fake_model.add_multiple_turn_outputs( + [ + [get_function_tool_call("sub_agent_tool", '{"input": "test"}')], + [get_text_message("parent_response")], + ] + ) + + run_config = RunConfig() + + await Runner.run( + starting_agent=parent_agent, + input="Use the sub agent tool", + run_config=run_config, + ) + + assert get_current_run_config() is None + assert len(inherited_configs) == 1 + assert inherited_configs[0] is None + + +@pytest.mark.asyncio +async def test_context_variable_cleanup_on_error(): + """Test that context variable is cleaned up even when errors occur""" + failing_model = FakeModel() + failing_model.set_next_output(RuntimeError("Intentional test failure")) + + failing_agent = Agent( + name="FailingAgent", + instructions="Fail", + model=failing_model, + ) + + run_config = RunConfig(pass_run_config_to_sub_agents=True) + + assert get_current_run_config() is None + + with pytest.raises(RuntimeError, match="Intentional test failure"): + await Runner.run( + starting_agent=failing_agent, + input="This should fail", + run_config=run_config, + ) + + assert get_current_run_config() is None + + +@pytest.mark.asyncio +async def test_scope_methods_directly(): + """Test the Scope class methods directly for RunConfig management""" + run_config = RunConfig(pass_run_config_to_sub_agents=True) + + assert get_current_run_config() is None + + token = set_current_run_config(run_config) + assert get_current_run_config() is run_config + + reset_current_run_config(token) + assert get_current_run_config() is None + + token = set_current_run_config(None) + assert get_current_run_config() is None + reset_current_run_config(token)