diff --git a/src/agents/lifecycle.py b/src/agents/lifecycle.py index 8643248b1..3e5903af8 100644 --- a/src/agents/lifecycle.py +++ b/src/agents/lifecycle.py @@ -1,6 +1,7 @@ -from typing import Any, Generic +from typing import Any, Generic, Optional from .agent import Agent +from .items import ModelResponse, TResponseInputItem from .run_context import RunContextWrapper, TContext from .tool import Tool @@ -10,6 +11,25 @@ class RunHooks(Generic[TContext]): override the methods you need. """ + # Two new hook methods added to the RunHooks class to handle LLM start and end events. + # These methods allow you to perform actions just before and after the LLM call for an agent. + # This is useful for logging, monitoring, or modifying the context before and after the LLM call + async def on_llm_start( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + system_prompt: Optional[str], + input_items: list[TResponseInputItem], + ) -> None: + """Called just before invoking the LLM for this agent.""" + pass + + async def on_llm_end( + self, context: RunContextWrapper[TContext], agent: Agent[TContext], response: ModelResponse + ) -> None: + """Called immediately after the LLM call returns for this agent.""" + pass + async def on_agent_start( self, context: RunContextWrapper[TContext], agent: Agent[TContext] ) -> None: @@ -103,3 +123,22 @@ async def on_tool_end( ) -> None: """Called after a tool is invoked.""" pass + + async def on_llm_start( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + system_prompt: Optional[str], + input_items: list[TResponseInputItem], + ) -> None: + """Called immediately before the agent issues an LLM call.""" + pass + + async def on_llm_end( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + response: ModelResponse, + ) -> None: + """Called immediately after the agent receives the LLM response.""" + pass diff --git a/src/agents/run.py b/src/agents/run.py index e5f9378ec..406e1b549 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -812,7 +812,9 @@ async def _run_single_turn_streamed( input = ItemHelpers.input_to_new_input_list(streamed_result.input) input.extend([item.to_input_item() for item in streamed_result.new_items]) - + # Call hook just before the model is invoked, with the correct system_prompt. + if agent.hooks: + await agent.hooks.on_llm_start(context_wrapper, agent, system_prompt, input) # 1. Stream the output events async for event in model.stream_response( system_prompt, @@ -849,6 +851,10 @@ async def _run_single_turn_streamed( streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) + # Call hook just after the model response is finalized. + if agent.hooks: + await agent.hooks.on_llm_end(context_wrapper, agent, final_response) + # 2. At this point, the streaming is complete for this turn of the agent loop. if not final_response: raise ModelBehaviorError("Model did not produce a final response!") @@ -1067,6 +1073,9 @@ async def _get_new_response( model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) + # If the agent has hooks, we need to call them before and after the LLM call + if agent.hooks: + await agent.hooks.on_llm_start(context_wrapper, agent, system_prompt, input) new_response = await model.get_response( system_instructions=system_prompt, @@ -1081,6 +1090,9 @@ async def _get_new_response( previous_response_id=previous_response_id, prompt=prompt_config, ) + # If the agent has hooks, we need to call them after the LLM call + if agent.hooks: + await agent.hooks.on_llm_end(context_wrapper, agent, new_response) context_wrapper.usage.add(new_response.usage) diff --git a/tests/test_agent_llm_hooks.py b/tests/test_agent_llm_hooks.py new file mode 100644 index 000000000..2eb2cfb03 --- /dev/null +++ b/tests/test_agent_llm_hooks.py @@ -0,0 +1,130 @@ +from collections import defaultdict +from typing import Any, Optional + +import pytest + +from agents.agent import Agent +from agents.items import ItemHelpers, ModelResponse, TResponseInputItem +from agents.lifecycle import AgentHooks +from agents.run import Runner +from agents.run_context import RunContextWrapper, TContext +from agents.tool import Tool + +from .fake_model import FakeModel +from .test_responses import ( + get_function_tool, + get_text_message, +) + + +class AgentHooksForTests(AgentHooks): + def __init__(self): + self.events: dict[str, int] = defaultdict(int) + + def reset(self): + self.events.clear() + + async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None: + self.events["on_start"] += 1 + + async def on_end( + self, context: RunContextWrapper[TContext], agent: Agent[TContext], output: Any + ) -> None: + self.events["on_end"] += 1 + + async def on_handoff( + self, context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext] + ) -> None: + self.events["on_handoff"] += 1 + + async def on_tool_start( + self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool + ) -> None: + self.events["on_tool_start"] += 1 + + async def on_tool_end( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + tool: Tool, + result: str, + ) -> None: + self.events["on_tool_end"] += 1 + + # NEW: LLM hooks + async def on_llm_start( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + system_prompt: Optional[str], + input_items: list[TResponseInputItem], + ) -> None: + self.events["on_llm_start"] += 1 + + async def on_llm_end( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + response: ModelResponse, + ) -> None: + self.events["on_llm_end"] += 1 + + +# Example test using the above hooks: +@pytest.mark.asyncio +async def test_async_agent_hooks_with_llm(): + hooks = AgentHooksForTests() + model = FakeModel() + agent = Agent( + name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks + ) + # Simulate a single LLM call producing an output: + model.set_next_output([get_text_message("hello")]) + await Runner.run(agent, input="hello") + # Expect one on_start, one on_llm_start, one on_llm_end, and one on_end + assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1} + + +# test_sync_agent_hook_with_llm() +def test_sync_agent_hook_with_llm(): + hooks = AgentHooksForTests() + model = FakeModel() + agent = Agent( + name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks + ) + # Simulate a single LLM call producing an output: + model.set_next_output([get_text_message("hello")]) + Runner.run_sync(agent, input="hello") + # Expect one on_start, one on_llm_start, one on_llm_end, and one on_end + assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1} + + +# test_streamed_agent_hooks_with_llm(): +@pytest.mark.asyncio +async def test_streamed_agent_hooks_with_llm(): + hooks = AgentHooksForTests() + model = FakeModel() + agent = Agent( + name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks + ) + # Simulate a single LLM call producing an output: + model.set_next_output([get_text_message("hello")]) + stream = Runner.run_streamed(agent, input="hello") + + async for event in stream.stream_events(): + if event.type == "raw_response_event": + continue + if event.type == "agent_updated_stream_event": + print(f"[EVENT] agent_updated → {event.new_agent.name}") + elif event.type == "run_item_stream_event": + item = event.item + if item.type == "tool_call_item": + print("[EVENT] tool_call_item") + elif item.type == "tool_call_output_item": + print(f"[EVENT] tool_call_output_item → {item.output}") + elif item.type == "message_output_item": + text = ItemHelpers.text_message_output(item) + print(f"[EVENT] message_output_item → {text}") + + # Expect one on_start, one on_llm_start, one on_llm_end, and one on_end + assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}