Skip to content

Commit

Permalink
Save complete trajectory in presence of history truncation
Browse files Browse the repository at this point in the history
  • Loading branch information
li-boxuan committed Feb 17, 2025
1 parent 30e39e8 commit 3d10df4
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 22 deletions.
39 changes: 22 additions & 17 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
NullObservation,
Observation,
)
from openhands.events.serialization.event import truncate_content
from openhands.events.serialization.event import event_to_trajectory, truncate_content
from openhands.llm.llm import LLM

# note: RESUME is only available on web GUI
Expand Down Expand Up @@ -687,22 +687,7 @@ async def _step(self) -> None:
or 'prompt is too long' in error_str
or isinstance(e, ContextWindowExceededError)
):
# When context window is exceeded, keep roughly half of agent interactions
self.state.history = self._apply_conversation_window(
self.state.history
)

# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id

# Add an error event to trigger another step by the agent
self.event_stream.add_event(
AgentCondensationObservation(
content='Trimming prompt to meet context window limitations'
),
EventSource.AGENT,
)
self._handle_long_context_error()
return
raise e

Expand Down Expand Up @@ -831,6 +816,10 @@ def set_initial_state(
# Always load from the event stream to avoid losing history
self._init_history()

def get_trajectory(self) -> list[dict]:
events = self.event_stream.get_events()
return [event_to_trajectory(event) for event in events]

def _init_history(self) -> None:
"""Initializes the agent's history from the event stream.
Expand Down Expand Up @@ -956,6 +945,22 @@ def _init_history(self) -> None:
# make sure history is in sync
self.state.start_id = start_id

def _handle_long_context_error(self) -> None:
# When context window is exceeded, keep roughly half of agent interactions
self.state.history = self._apply_conversation_window(self.state.history)

# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id

# Add an error event to trigger another step by the agent
self.event_stream.add_event(
AgentCondensationObservation(
content='Trimming prompt to meet context window limitations'
),
EventSource.AGENT,
)

def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
"""Cuts history roughly in half when context window is exceeded, preserving action-observation pairs
and ensuring the first user message is always included.
Expand Down
3 changes: 1 addition & 2 deletions openhands/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.serialization import event_from_dict
from openhands.events.serialization.event import event_to_trajectory
from openhands.runtime.base import Runtime


Expand Down Expand Up @@ -204,7 +203,7 @@ def on_event(event: Event):
else:
file_path = config.save_trajectory_path
os.makedirs(os.path.dirname(file_path), exist_ok=True)
histories = [event_to_trajectory(event) for event in state.history]
histories = controller.get_trajectory()
with open(file_path, 'w') as f:
json.dump(histories, f)

Expand Down
4 changes: 1 addition & 3 deletions openhands/server/routes/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ async def get_trajectory(request: Request):
events.
"""
try:
async_stream = AsyncEventStreamWrapper(
request.state.conversation.event_stream, filter_hidden=True
)
async_stream = AsyncEventStreamWrapper(request.state.conversation.event_stream)
trajectory = []
async for event in async_stream:
trajectory.append(event_to_trajectory(event))
Expand Down
44 changes: 44 additions & 0 deletions tests/unit/test_truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,50 @@ def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent):
if isinstance(event, CmdOutputObservation):
assert any(e._id == event._cause for e in truncated[: i + 1])

def test_truncation_does_not_impact_trajectory(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)

# Create a sequence of events with IDs
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1

pairs = 10
history_len = 1 + 2 * pairs
events = [first_msg]
for i in range(pairs):
cmd = CmdRunAction(command=f'cmd{i}')
cmd._id = i + 2
obs = CmdOutputObservation(
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
)
obs._cause = cmd._id
events.extend([cmd, obs])

# patch events to history for testing purpose
controller.state.history = events

# Update mock event stream
mock_event_stream.get_events.return_value = controller.state.history

assert len(controller.state.history) == history_len

# Force apply truncation
controller._handle_long_context_error()

# Check that the history has been truncated
assert len(controller.state.history) == 13 < history_len

assert len(controller.event_stream.get_events()) == history_len
assert len(controller.get_trajectory()) == history_len

def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
Expand Down

0 comments on commit 3d10df4

Please sign in to comment.