diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index e97efbed6a23..cf66e89383f1 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -51,7 +51,7 @@ NullObservation, Observation, ) -from openhands.events.serialization.event import truncate_content +from openhands.events.serialization.event import event_to_trajectory, truncate_content from openhands.llm.llm import LLM # note: RESUME is only available on web GUI @@ -149,12 +149,13 @@ def __init__( # replay-related self._replay_manager = ReplayManager(replay_events) - async def close(self) -> None: + async def close(self, set_stop_state=True) -> None: """Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream. Note that it's fairly important that this closes properly, otherwise the state is incomplete. """ - await self.set_agent_state_to(AgentState.STOPPED) + if set_stop_state: + await self.set_agent_state_to(AgentState.STOPPED) # we made history, now is the time to rewrite it! # the final state.history will be used by external scripts like evals, tests, etc. @@ -701,22 +702,7 @@ async def _step(self) -> None: or isinstance(e, ContextWindowExceededError) ): if self.agent.config.enable_history_truncation: - # When context window is exceeded, keep roughly half of agent interactions - self.state.history = self._apply_conversation_window( - self.state.history - ) - - # Save the ID of the first event in our truncated history for future reloading - if self.state.history: - self.state.start_id = self.state.history[0].id - - # Add an error event to trigger another step by the agent - self.event_stream.add_event( - AgentCondensationObservation( - content='Trimming prompt to meet context window limitations' - ), - EventSource.AGENT, - ) + self._handle_long_context_error() return else: raise LLMContextWindowExceedError() @@ -848,6 +834,11 @@ def set_initial_state( # Always load from the event stream to avoid losing history self._init_history() + def get_trajectory(self) -> list[dict]: + # state history could be partially hidden/truncated before controller is closed + assert self._closed + return [event_to_trajectory(event) for event in self.state.history] + def _init_history(self) -> None: """Initializes the agent's history from the event stream. @@ -973,6 +964,22 @@ def _init_history(self) -> None: # make sure history is in sync self.state.start_id = start_id + def _handle_long_context_error(self) -> None: + # When context window is exceeded, keep roughly half of agent interactions + self.state.history = self._apply_conversation_window(self.state.history) + + # Save the ID of the first event in our truncated history for future reloading + if self.state.history: + self.state.start_id = self.state.history[0].id + + # Add an error event to trigger another step by the agent + self.event_stream.add_event( + AgentCondensationObservation( + content='Trimming prompt to meet context window limitations' + ), + EventSource.AGENT, + ) + def _apply_conversation_window(self, events: list[Event]) -> list[Event]: """Cuts history roughly in half when context window is exceeded, preserving action-observation pairs and ensuring the first user message is always included. diff --git a/openhands/core/main.py b/openhands/core/main.py index 12e0c4e7876c..a5192b9c5ff2 100644 --- a/openhands/core/main.py +++ b/openhands/core/main.py @@ -27,7 +27,6 @@ from openhands.events.event import Event from openhands.events.observation import AgentStateChangedObservation from openhands.events.serialization import event_from_dict -from openhands.events.serialization.event import event_to_trajectory from openhands.io import read_input, read_task from openhands.runtime.base import Runtime @@ -167,6 +166,8 @@ def on_event(event: Event): # NOTE: the saved state does not include delegates events end_state.save_to_session(event_stream.sid, event_stream.file_store) + await controller.close(set_stop_state=False) + state = controller.get_state() # save trajectories if applicable @@ -177,7 +178,7 @@ def on_event(event: Event): else: file_path = config.save_trajectory_path os.makedirs(os.path.dirname(file_path), exist_ok=True) - histories = [event_to_trajectory(event) for event in state.history] + histories = controller.get_trajectory() with open(file_path, 'w') as f: json.dump(histories, f) diff --git a/tests/unit/test_truncation.py b/tests/unit/test_truncation.py index 08e7d8f7be71..c8218194ceda 100644 --- a/tests/unit/test_truncation.py +++ b/tests/unit/test_truncation.py @@ -1,3 +1,4 @@ +import asyncio from unittest.mock import MagicMock import pytest @@ -72,6 +73,53 @@ def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent): if isinstance(event, CmdOutputObservation): assert any(e._id == event._cause for e in truncated[: i + 1]) + def test_truncation_does_not_impact_trajectory(self, mock_event_stream, mock_agent): + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test_truncation', + confirmation_mode=False, + headless_mode=True, + ) + + # Create a sequence of events with IDs + first_msg = MessageAction(content='Hello, start task', wait_for_response=False) + first_msg._source = EventSource.USER + first_msg._id = 1 + + pairs = 10 + history_len = 1 + 2 * pairs + events = [first_msg] + for i in range(pairs): + cmd = CmdRunAction(command=f'cmd{i}') + cmd._id = i + 2 + obs = CmdOutputObservation( + command=f'cmd{i}', content=f'output{i}', command_id=cmd._id + ) + obs._cause = cmd._id + events.extend([cmd, obs]) + + # patch events to history for testing purpose + controller.state.history = events + + # Update mock event stream + mock_event_stream.get_events.return_value = controller.state.history + + assert len(controller.state.history) == history_len + + # Force apply truncation + controller._handle_long_context_error() + + # Check that the history has been truncated before closing the controller + assert len(controller.state.history) == 13 < history_len + + # Check that after properly closing the controller, history is recovered + asyncio.run(controller.close()) + assert len(controller.event_stream.get_events()) == history_len + assert len(controller.state.history) == history_len + assert len(controller.get_trajectory()) == history_len + def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent): controller = AgentController( agent=mock_agent,