Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save complete trajectory in presence of history truncation #6751

Merged
merged 11 commits into from
Feb 21, 2025
38 changes: 21 additions & 17 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
NullObservation,
Observation,
)
from openhands.events.serialization.event import truncate_content
from openhands.events.serialization.event import event_to_trajectory, truncate_content
from openhands.llm.llm import LLM

# note: RESUME is only available on web GUI
Expand Down Expand Up @@ -687,22 +687,7 @@ async def _step(self) -> None:
or 'prompt is too long' in error_str
or isinstance(e, ContextWindowExceededError)
):
# When context window is exceeded, keep roughly half of agent interactions
self.state.history = self._apply_conversation_window(
self.state.history
)

# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id

# Add an error event to trigger another step by the agent
self.event_stream.add_event(
AgentCondensationObservation(
content='Trimming prompt to meet context window limitations'
),
EventSource.AGENT,
)
self._handle_long_context_error()
return
raise e

Expand Down Expand Up @@ -831,6 +816,9 @@ def set_initial_state(
# Always load from the event stream to avoid losing history
self._init_history()

def get_trajectory(self) -> list[dict]:
return [event_to_trajectory(event) for event in self.state.history]

def _init_history(self) -> None:
"""Initializes the agent's history from the event stream.

Expand Down Expand Up @@ -956,6 +944,22 @@ def _init_history(self) -> None:
# make sure history is in sync
self.state.start_id = start_id

def _handle_long_context_error(self) -> None:
# When context window is exceeded, keep roughly half of agent interactions
self.state.history = self._apply_conversation_window(self.state.history)

# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id

# Add an error event to trigger another step by the agent
self.event_stream.add_event(
AgentCondensationObservation(
content='Trimming prompt to meet context window limitations'
),
EventSource.AGENT,
)

def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
"""Cuts history roughly in half when context window is exceeded, preserving action-observation pairs
and ensuring the first user message is always included.
Expand Down
5 changes: 3 additions & 2 deletions openhands/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.serialization import event_from_dict
from openhands.events.serialization.event import event_to_trajectory
from openhands.runtime.base import Runtime


Expand Down Expand Up @@ -194,6 +193,8 @@ def on_event(event: Event):
# NOTE: the saved state does not include delegates events
end_state.save_to_session(event_stream.sid, event_stream.file_store)

await controller.close()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are entirely correct

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be messing up with the event loop... causing tests failure. Will investigate later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't figure this out, so I cheated. Setting agent state to STOPPED seems unnecessary for headless mode anyways, so I think it's okay...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thank you. Sorry I didn't get to focus on this, but I'll give it a try with the next refactoring on history: FWIW my thought is we may need to close the stream which closes subscribers, but more importantly, maybe we shouldn't do the reassembly of history back again in the first place, after all. 😅

But for this PR it is as it is, and it's fine IMO to work around somewhat to get it done.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!


state = controller.get_state()

# save trajectories if applicable
Expand All @@ -204,7 +205,7 @@ def on_event(event: Event):
else:
file_path = config.save_trajectory_path
os.makedirs(os.path.dirname(file_path), exist_ok=True)
histories = [event_to_trajectory(event) for event in state.history]
histories = controller.get_trajectory()
with open(file_path, 'w') as f:
json.dump(histories, f)

Expand Down
4 changes: 1 addition & 3 deletions openhands/server/routes/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ async def get_trajectory(request: Request):
events.
"""
try:
async_stream = AsyncEventStreamWrapper(
request.state.conversation.event_stream, filter_hidden=True
)
async_stream = AsyncEventStreamWrapper(request.state.conversation.event_stream)
trajectory = []
async for event in async_stream:
trajectory.append(event_to_trajectory(event))
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/test_truncation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -72,6 +73,53 @@ def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent):
if isinstance(event, CmdOutputObservation):
assert any(e._id == event._cause for e in truncated[: i + 1])

def test_truncation_does_not_impact_trajectory(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)

# Create a sequence of events with IDs
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1

pairs = 10
history_len = 1 + 2 * pairs
events = [first_msg]
for i in range(pairs):
cmd = CmdRunAction(command=f'cmd{i}')
cmd._id = i + 2
obs = CmdOutputObservation(
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
)
obs._cause = cmd._id
events.extend([cmd, obs])

# patch events to history for testing purpose
controller.state.history = events

# Update mock event stream
mock_event_stream.get_events.return_value = controller.state.history

assert len(controller.state.history) == history_len

# Force apply truncation
controller._handle_long_context_error()

# Check that the history has been truncated before closing the controller
assert len(controller.state.history) == 13 < history_len

# Check that after properly closing the controller, history is recovered
asyncio.run(controller.close())
assert len(controller.event_stream.get_events()) == history_len
assert len(controller.state.history) == history_len
assert len(controller.get_trajectory()) == history_len

def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
Expand Down