Skip to content

feat(agents): Add on_llm_start and on_llm_end Lifecycle Hooks #987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion src/agents/lifecycle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Generic
from typing import Any, Generic, Optional

from .agent import Agent
from .items import ModelResponse, TResponseInputItem
from .run_context import RunContextWrapper, TContext
from .tool import Tool

Expand All @@ -10,6 +11,25 @@ class RunHooks(Generic[TContext]):
override the methods you need.
"""

# Two new hook methods added to the RunHooks class to handle LLM start and end events.
# These methods allow you to perform actions just before and after the LLM call for an agent.
# This is useful for logging, monitoring, or modifying the context before and after the LLM call
async def on_llm_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
system_prompt: Optional[str],
input_items: list[TResponseInputItem],
) -> None:
"""Called just before invoking the LLM for this agent."""
pass

async def on_llm_end(
self, context: RunContextWrapper[TContext], agent: Agent[TContext], response: ModelResponse
) -> None:
"""Called immediately after the LLM call returns for this agent."""
pass

async def on_agent_start(
self, context: RunContextWrapper[TContext], agent: Agent[TContext]
) -> None:
Expand Down Expand Up @@ -103,3 +123,22 @@ async def on_tool_end(
) -> None:
"""Called after a tool is invoked."""
pass

async def on_llm_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
system_prompt: Optional[str],
input_items: list[TResponseInputItem],
) -> None:
"""Called immediately before the agent issues an LLM call."""
pass

async def on_llm_end(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
response: ModelResponse,
) -> None:
"""Called immediately after the agent receives the LLM response."""
pass
14 changes: 13 additions & 1 deletion src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,9 @@ async def _run_single_turn_streamed(

input = ItemHelpers.input_to_new_input_list(streamed_result.input)
input.extend([item.to_input_item() for item in streamed_result.new_items])

# Call hook just before the model is invoked, with the correct system_prompt.
if agent.hooks:
await agent.hooks.on_llm_start(context_wrapper, agent, system_prompt, input)
# 1. Stream the output events
async for event in model.stream_response(
system_prompt,
Expand Down Expand Up @@ -849,6 +851,10 @@ async def _run_single_turn_streamed(

streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))

# Call hook just after the model response is finalized.
if agent.hooks:
await agent.hooks.on_llm_end(context_wrapper, agent, final_response)

# 2. At this point, the streaming is complete for this turn of the agent loop.
if not final_response:
raise ModelBehaviorError("Model did not produce a final response!")
Expand Down Expand Up @@ -1067,6 +1073,9 @@ async def _get_new_response(
model = cls._get_model(agent, run_config)
model_settings = agent.model_settings.resolve(run_config.model_settings)
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
# If the agent has hooks, we need to call them before and after the LLM call
if agent.hooks:
await agent.hooks.on_llm_start(context_wrapper, agent, system_prompt, input)

new_response = await model.get_response(
system_instructions=system_prompt,
Expand All @@ -1081,6 +1090,9 @@ async def _get_new_response(
previous_response_id=previous_response_id,
prompt=prompt_config,
)
# If the agent has hooks, we need to call them after the LLM call
if agent.hooks:
await agent.hooks.on_llm_end(context_wrapper, agent, new_response)

context_wrapper.usage.add(new_response.usage)

Expand Down
130 changes: 130 additions & 0 deletions tests/test_agent_llm_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from collections import defaultdict
from typing import Any, Optional

import pytest

from agents.agent import Agent
from agents.items import ItemHelpers, ModelResponse, TResponseInputItem
from agents.lifecycle import AgentHooks
from agents.run import Runner
from agents.run_context import RunContextWrapper, TContext
from agents.tool import Tool

from .fake_model import FakeModel
from .test_responses import (
get_function_tool,
get_text_message,
)


class AgentHooksForTests(AgentHooks):
def __init__(self):
self.events: dict[str, int] = defaultdict(int)

def reset(self):
self.events.clear()

async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None:
self.events["on_start"] += 1

async def on_end(
self, context: RunContextWrapper[TContext], agent: Agent[TContext], output: Any
) -> None:
self.events["on_end"] += 1

async def on_handoff(
self, context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext]
) -> None:
self.events["on_handoff"] += 1

async def on_tool_start(
self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool
) -> None:
self.events["on_tool_start"] += 1

async def on_tool_end(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
result: str,
) -> None:
self.events["on_tool_end"] += 1

# NEW: LLM hooks
async def on_llm_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
system_prompt: Optional[str],
input_items: list[TResponseInputItem],
) -> None:
self.events["on_llm_start"] += 1

async def on_llm_end(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
response: ModelResponse,
) -> None:
self.events["on_llm_end"] += 1


# Example test using the above hooks:
@pytest.mark.asyncio
async def test_async_agent_hooks_with_llm():
hooks = AgentHooksForTests()
model = FakeModel()
agent = Agent(
name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks
)
# Simulate a single LLM call producing an output:
model.set_next_output([get_text_message("hello")])
await Runner.run(agent, input="hello")
# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}


# test_sync_agent_hook_with_llm()
def test_sync_agent_hook_with_llm():
hooks = AgentHooksForTests()
model = FakeModel()
agent = Agent(
name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks
)
# Simulate a single LLM call producing an output:
model.set_next_output([get_text_message("hello")])
Runner.run_sync(agent, input="hello")
# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}


# test_streamed_agent_hooks_with_llm():
@pytest.mark.asyncio
async def test_streamed_agent_hooks_with_llm():
hooks = AgentHooksForTests()
model = FakeModel()
agent = Agent(
name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks
)
# Simulate a single LLM call producing an output:
model.set_next_output([get_text_message("hello")])
stream = Runner.run_streamed(agent, input="hello")

async for event in stream.stream_events():
if event.type == "raw_response_event":
continue
if event.type == "agent_updated_stream_event":
print(f"[EVENT] agent_updated → {event.new_agent.name}")
elif event.type == "run_item_stream_event":
item = event.item
if item.type == "tool_call_item":
print("[EVENT] tool_call_item")
elif item.type == "tool_call_output_item":
print(f"[EVENT] tool_call_output_item → {item.output}")
elif item.type == "message_output_item":
text = ItemHelpers.text_message_output(item)
print(f"[EVENT] message_output_item → {text}")

# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}