diff --git a/openhands/core/config/sandbox_config.py b/openhands/core/config/sandbox_config.py index 0ea40f29faab..3a0b705dd02d 100644 --- a/openhands/core/config/sandbox_config.py +++ b/openhands/core/config/sandbox_config.py @@ -41,7 +41,7 @@ class SandboxConfig: remote_runtime_api_url: str = 'http://localhost:8000' local_runtime_url: str = 'http://localhost' - keep_runtime_alive: bool = False + keep_runtime_alive: bool = True rm_all_containers: bool = False api_key: str | None = None base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22' # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime @@ -60,7 +60,7 @@ class SandboxConfig: runtime_startup_env_vars: dict[str, str] = field(default_factory=dict) browsergym_eval_env: str | None = None platform: str | None = None - close_delay: int = 15 + close_delay: int = 900 remote_runtime_resource_factor: int = 1 enable_gpu: bool = False docker_runtime_kwargs: str | None = None diff --git a/openhands/runtime/builder/remote.py b/openhands/runtime/builder/remote.py index a728460a374e..b2e869eca3bf 100644 --- a/openhands/runtime/builder/remote.py +++ b/openhands/runtime/builder/remote.py @@ -9,7 +9,6 @@ from openhands.core.logger import openhands_logger as logger from openhands.runtime.builder import RuntimeBuilder from openhands.runtime.utils.request import send_request -from openhands.utils.http_session import HttpSession from openhands.utils.shutdown_listener import ( should_continue, sleep_if_should_continue, @@ -19,10 +18,12 @@ class RemoteRuntimeBuilder(RuntimeBuilder): """This class interacts with the remote Runtime API for building and managing container images.""" - def __init__(self, api_url: str, api_key: str, session: HttpSession | None = None): + def __init__( + self, api_url: str, api_key: str, session: requests.Session | None = None + ): self.api_url = api_url self.api_key = api_key - self.session = session or HttpSession() + self.session = session or requests.Session() self.session.headers.update({'X-API-Key': self.api_key}) def build( diff --git a/openhands/runtime/impl/action_execution/action_execution_client.py b/openhands/runtime/impl/action_execution/action_execution_client.py index 4965fc1752af..24fb8250b30e 100644 --- a/openhands/runtime/impl/action_execution/action_execution_client.py +++ b/openhands/runtime/impl/action_execution/action_execution_client.py @@ -35,7 +35,6 @@ from openhands.runtime.base import Runtime from openhands.runtime.plugins import PluginRequirement from openhands.runtime.utils.request import send_request -from openhands.utils.http_session import HttpSession class ActionExecutionClient(Runtime): @@ -56,7 +55,7 @@ def __init__( attach_to_existing: bool = False, headless_mode: bool = True, ): - self.session = HttpSession() + self.session = requests.Session() self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time self._runtime_initialized: bool = False self._vscode_token: str | None = None # initial dummy value diff --git a/openhands/runtime/utils/request.py b/openhands/runtime/utils/request.py index 0117e019a6a8..e05a083e7b0d 100644 --- a/openhands/runtime/utils/request.py +++ b/openhands/runtime/utils/request.py @@ -4,7 +4,6 @@ import requests from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential -from openhands.utils.http_session import HttpSession from openhands.utils.tenacity_stop import stop_if_should_exit @@ -35,7 +34,7 @@ def is_retryable_error(exception): wait=wait_exponential(multiplier=1, min=4, max=60), ) def send_request( - session: HttpSession, + session: requests.Session, method: str, url: str, timeout: int = 10, @@ -49,11 +48,11 @@ def send_request( _json = response.json() except (requests.exceptions.JSONDecodeError, json.decoder.JSONDecodeError): _json = None + finally: + response.close() raise RequestHTTPError( e, response=e.response, detail=_json.get('detail') if _json is not None else None, ) from e - finally: - response.close() return response diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index f622c5bad9cf..5cfc0ba82d52 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -130,7 +130,7 @@ async def search_conversations( for conversation in conversation_metadata_result_set.results if hasattr(conversation, 'created_at') ) - running_conversations = await session_manager.get_running_agent_loops( + running_conversations = await session_manager.get_agent_loop_running( get_user_id(request), set(conversation_ids) ) result = ConversationInfoResultSet( diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 285acccbfbe4..70bf6eeca6bb 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -1,5 +1,4 @@ import asyncio -import time from typing import Callable, Optional from openhands.controller import AgentController @@ -17,10 +16,10 @@ from openhands.runtime.base import Runtime from openhands.security import SecurityAnalyzer, options from openhands.storage.files import FileStore -from openhands.utils.async_utils import call_sync_from_async +from openhands.utils.async_utils import call_async_from_sync, call_sync_from_async from openhands.utils.shutdown_listener import should_continue -WAIT_TIME_BEFORE_CLOSE = 90 +WAIT_TIME_BEFORE_CLOSE = 300 WAIT_TIME_BEFORE_CLOSE_INTERVAL = 5 @@ -37,8 +36,7 @@ class AgentSession: controller: AgentController | None = None runtime: Runtime | None = None security_analyzer: SecurityAnalyzer | None = None - _starting: bool = False - _started_at: float = 0 + _initializing: bool = False _closed: bool = False loop: asyncio.AbstractEventLoop | None = None @@ -90,8 +88,7 @@ async def start( if self._closed: logger.warning('Session closed before starting') return - self._starting = True - self._started_at = time.time() + self._initializing = True self._create_security_analyzer(config.security.security_analyzer) await self._create_runtime( runtime_name=runtime_name, @@ -112,19 +109,24 @@ async def start( self.event_stream.add_event( ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT ) - self._starting = False + self._initializing = False - async def close(self): + def close(self): """Closes the Agent session""" if self._closed: return self._closed = True - while self._starting and should_continue(): + call_async_from_sync(self._close) + + async def _close(self): + seconds_waited = 0 + while self._initializing and should_continue(): logger.debug( f'Waiting for initialization to finish before closing session {self.sid}' ) await asyncio.sleep(WAIT_TIME_BEFORE_CLOSE_INTERVAL) - if time.time() <= self._started_at + WAIT_TIME_BEFORE_CLOSE: + seconds_waited += WAIT_TIME_BEFORE_CLOSE_INTERVAL + if seconds_waited > WAIT_TIME_BEFORE_CLOSE: logger.error( f'Waited too long for initialization to finish before closing session {self.sid}' ) @@ -309,12 +311,3 @@ def _maybe_restore_state(self) -> State | None: else: logger.debug('No events found, no state to restore') return restored_state - - def get_state(self) -> AgentState | None: - controller = self.controller - if controller: - return controller.state.agent_state - if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE: - # If 5 minutes have elapsed and we still don't have a controller, something has gone wrong - return AgentState.ERROR - return None diff --git a/openhands/server/session/manager.py b/openhands/server/session/manager.py index 3c4d929a72de..67358f61fbe8 100644 --- a/openhands/server/session/manager.py +++ b/openhands/server/session/manager.py @@ -2,7 +2,6 @@ import json import time from dataclasses import dataclass, field -from typing import Generic, Iterable, TypeVar from uuid import uuid4 import socketio @@ -10,28 +9,26 @@ from openhands.core.config import AppConfig from openhands.core.exceptions import AgentRuntimeUnavailableError from openhands.core.logger import openhands_logger as logger -from openhands.core.schema.agent import AgentState from openhands.events.stream import EventStream, session_exists from openhands.server.session.conversation import Conversation from openhands.server.session.session import ROOM_KEY, Session from openhands.server.settings import Settings from openhands.storage.files import FileStore -from openhands.utils.async_utils import wait_all +from openhands.utils.async_utils import call_sync_from_async from openhands.utils.shutdown_listener import should_continue _REDIS_POLL_TIMEOUT = 1.5 _CHECK_ALIVE_INTERVAL = 15 _CLEANUP_INTERVAL = 15 -MAX_RUNNING_CONVERSATIONS = 3 -T = TypeVar('T') +_CLEANUP_EXCEPTION_WAIT_TIME = 15 @dataclass -class _ClusterQuery(Generic[T]): - query_id: str - request_ids: set[str] | None - result: T +class _SessionIsRunningCheck: + request_id: str + request_sids: list[str] + running_sids: set[str] = field(default_factory=set) flag: asyncio.Event = field(default_factory=asyncio.Event) @@ -41,10 +38,10 @@ class SessionManager: config: AppConfig file_store: FileStore _local_agent_loops_by_sid: dict[str, Session] = field(default_factory=dict) - _local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict) + local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict) _last_alive_timestamps: dict[str, float] = field(default_factory=dict) _redis_listen_task: asyncio.Task | None = None - _running_sid_queries: dict[str, _ClusterQuery[set[str]]] = field( + _session_is_running_checks: dict[str, _SessionIsRunningCheck] = field( default_factory=dict ) _active_conversations: dict[str, tuple[Conversation, int]] = field( @@ -55,7 +52,7 @@ class SessionManager: ) _conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock) _cleanup_task: asyncio.Task | None = None - _connection_queries: dict[str, _ClusterQuery[dict[str, str]]] = field( + _has_remote_connections_flags: dict[str, asyncio.Event] = field( default_factory=dict ) @@ -63,7 +60,7 @@ async def __aenter__(self): redis_client = self._get_redis_client() if redis_client: self._redis_listen_task = asyncio.create_task(self._redis_subscribe()) - self._cleanup_task = asyncio.create_task(self._cleanup_stale()) + self._cleanup_task = asyncio.create_task(self._cleanup_detached_conversations()) return self async def __aexit__(self, exc_type, exc_value, traceback): @@ -85,7 +82,7 @@ async def _redis_subscribe(self): logger.debug('_redis_subscribe') redis_client = self._get_redis_client() pubsub = redis_client.pubsub() - await pubsub.subscribe('session_msg') + await pubsub.subscribe('oh_event') while should_continue(): try: message = await pubsub.get_message( @@ -111,71 +108,59 @@ async def _process_message(self, message: dict): session = self._local_agent_loops_by_sid.get(sid) if session: await session.dispatch(data['data']) - elif message_type == 'running_agent_loops_query': + elif message_type == 'is_session_running': # Another node in the cluster is asking if the current node is running the session given. - query_id = data['query_id'] - sids = self._get_running_agent_loops_locally( - data.get('user_id'), data.get('filter_to_sids') - ) + request_id = data['request_id'] + sids = [ + sid for sid in data['sids'] if sid in self._local_agent_loops_by_sid + ] if sids: await self._get_redis_client().publish( - 'session_msg', + 'oh_event', json.dumps( { - 'query_id': query_id, - 'sids': list(sids), - 'message_type': 'running_agent_loops_response', + 'request_id': request_id, + 'sids': sids, + 'message_type': 'session_is_running', } ), ) - elif message_type == 'running_agent_loops_response': - query_id = data['query_id'] + elif message_type == 'session_is_running': + request_id = data['request_id'] for sid in data['sids']: self._last_alive_timestamps[sid] = time.time() - running_query = self._running_sid_queries.get(query_id) - if running_query: - running_query.result.update(data['sids']) - if running_query.request_ids is not None and len( - running_query.request_ids - ) == len(running_query.result): - running_query.flag.set() - elif message_type == 'connections_query': + check = self._session_is_running_checks.get(request_id) + if check: + check.running_sids.update(data['sids']) + if len(check.request_sids) == len(check.running_sids): + check.flag.set() + elif message_type == 'has_remote_connections_query': # Another node in the cluster is asking if the current node is connected to a session - query_id = data['query_id'] - connections = self._get_connections_locally( - data.get('user_id'), data.get('filter_to_sids') - ) - if connections: + sid = data['sid'] + required = sid in self.local_connection_id_to_session_id.values() + if required: await self._get_redis_client().publish( - 'session_msg', + 'oh_event', json.dumps( - { - 'query_id': query_id, - 'connections': connections, - 'message_type': 'connections_response', - } + {'sid': sid, 'message_type': 'has_remote_connections_response'} ), ) - elif message_type == 'connections_response': - query_id = data['query_id'] - connection_query = self._connection_queries.get(query_id) - if connection_query: - connection_query.result.update(**data['connections']) - if connection_query.request_ids is not None and len( - connection_query.request_ids - ) == len(connection_query.result): - connection_query.flag.set() + elif message_type == 'has_remote_connections_response': + sid = data['sid'] + flag = self._has_remote_connections_flags.get(sid) + if flag: + flag.set() elif message_type == 'close_session': sid = data['sid'] if sid in self._local_agent_loops_by_sid: - await self._close_session(sid) + await self._on_close_session(sid) elif message_type == 'session_closing': # Session closing event - We only get this in the event of graceful shutdown, # which can't be guaranteed - nodes can simply vanish unexpectedly! sid = data['sid'] logger.debug(f'session_closing:{sid}') # Create a list of items to process to avoid modifying dict during iteration - items = list(self._local_connection_id_to_session_id.items()) + items = list(self.local_connection_id_to_session_id.items()) for connection_id, local_sid in items: if sid == local_sid: logger.warning( @@ -223,7 +208,7 @@ async def join_conversation( ): logger.info(f'join_conversation:{sid}:{connection_id}') await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid)) - self._local_connection_id_to_session_id[connection_id] = sid + self.local_connection_id_to_session_id[connection_id] = sid event_stream = await self._get_event_stream(sid) if not event_stream: return await self.maybe_start_agent_loop(sid, settings, user_id) @@ -241,7 +226,7 @@ async def detach_from_conversation(self, conversation: Conversation): self._active_conversations.pop(sid) self._detached_conversations[sid] = (conversation, time.time()) - async def _cleanup_stale(self): + async def _cleanup_detached_conversations(self): while should_continue(): if self._get_redis_client(): # Debug info for HA envs @@ -255,7 +240,7 @@ async def _cleanup_stale(self): f'Running agent loops: {len(self._local_agent_loops_by_sid)}' ) logger.info( - f'Local connections: {len(self._local_connection_id_to_session_id)}' + f'Local connections: {len(self.local_connection_id_to_session_id)}' ) try: async with self._conversations_lock: @@ -265,176 +250,97 @@ async def _cleanup_stale(self): await conversation.disconnect() self._detached_conversations.pop(sid, None) - close_threshold = time.time() - self.config.sandbox.close_delay - running_loops = list(self._local_agent_loops_by_sid.items()) - running_loops.sort(key=lambda item: item[1].last_active_ts) - sid_to_close: list[str] = [] - for sid, session in running_loops: - state = session.agent_session.get_state() - if session.last_active_ts < close_threshold and state not in [ - AgentState.RUNNING, - None, - ]: - sid_to_close.append(sid) - - connections = self._get_connections_locally( - filter_to_sids=set(sid_to_close) - ) - connected_sids = {sid for _, sid in connections.items()} - sid_to_close = [ - sid for sid in sid_to_close if sid not in connected_sids - ] - - if sid_to_close: - connections = await self._get_connections_remotely( - filter_to_sids=set(sid_to_close) - ) - connected_sids = {sid for _, sid in connections.items()} - sid_to_close = [ - sid for sid in sid_to_close if sid not in connected_sids - ] - - await wait_all(self._close_session(sid) for sid in sid_to_close) await asyncio.sleep(_CLEANUP_INTERVAL) except asyncio.CancelledError: async with self._conversations_lock: for conversation, _ in self._detached_conversations.values(): await conversation.disconnect() self._detached_conversations.clear() - await wait_all( - self._close_session(sid) for sid in self._local_agent_loops_by_sid - ) return except Exception as e: - logger.warning(f'error_cleaning_stale: {str(e)}') - await asyncio.sleep(_CLEANUP_INTERVAL) + logger.warning(f'error_cleaning_detached_conversations: {str(e)}') + await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME) + + async def get_agent_loop_running(self, user_id, sids: set[str]) -> set[str]: + running_sids = set(sid for sid in sids if sid in self._local_agent_loops_by_sid) + check_cluster_sids = [sid for sid in sids if sid not in running_sids] + running_cluster_sids = await self.get_agent_loop_running_in_cluster( + check_cluster_sids + ) + running_sids.union(running_cluster_sids) + return running_sids async def is_agent_loop_running(self, sid: str) -> bool: - sids = await self.get_running_agent_loops(filter_to_sids={sid}) - return bool(sids) - - async def get_running_agent_loops( - self, user_id: str | None = None, filter_to_sids: set[str] | None = None - ) -> set[str]: - """Get the running session ids. If a user is supplied, then the results are limited to session ids for that user. If a set of filter_to_sids is supplied, then results are limited to these ids of interest.""" - sids = self._get_running_agent_loops_locally(user_id, filter_to_sids) - remote_sids = await self._get_running_agent_loops_remotely( - user_id, filter_to_sids - ) - return sids.union(remote_sids) - - def _get_running_agent_loops_locally( - self, user_id: str | None = None, filter_to_sids: set[str] | None = None - ) -> set[str]: - items: Iterable[tuple[str, Session]] = self._local_agent_loops_by_sid.items() - if filter_to_sids is not None: - items = (item for item in items if item[0] in filter_to_sids) - if user_id: - items = (item for item in items if item[1].user_id == user_id) - sids = {sid for sid, _ in items} - return sids - - async def _get_running_agent_loops_remotely( - self, - user_id: str | None = None, - filter_to_sids: set[str] | None = None, - ) -> set[str]: + if await self.is_agent_loop_running_locally(sid): + return True + if await self.is_agent_loop_running_in_cluster(sid): + return True + return False + + async def is_agent_loop_running_locally(self, sid: str) -> bool: + return sid in self._local_agent_loops_by_sid + + async def is_agent_loop_running_in_cluster(self, sid: str) -> bool: + running_sids = await self.get_agent_loop_running_in_cluster([sid]) + return bool(running_sids) + + async def get_agent_loop_running_in_cluster(self, sids: list[str]) -> set[str]: """As the rest of the cluster if a session is running. Wait a for a short timeout for a reply""" redis_client = self._get_redis_client() if not redis_client: return set() flag = asyncio.Event() - query_id = str(uuid4()) - query = _ClusterQuery[set[str]]( - query_id=query_id, request_ids=filter_to_sids, result=set() - ) - self._running_sid_queries[query_id] = query + request_id = str(uuid4()) + check = _SessionIsRunningCheck(request_id=request_id, request_sids=sids) + self._session_is_running_checks[request_id] = check try: - logger.debug( - f'publish:_get_running_agent_loops_remotely_query:{user_id}:{filter_to_sids}' + logger.debug(f'publish:is_session_running:{sids}') + await redis_client.publish( + 'oh_event', + json.dumps( + { + 'request_id': request_id, + 'sids': sids, + 'message_type': 'is_session_running', + } + ), ) - data: dict = { - 'query_id': query_id, - 'message_type': 'running_agent_loops_query', - } - if user_id: - data['user_id'] = user_id - if filter_to_sids: - data['filter_to_sids'] = list(filter_to_sids) - await redis_client.publish('session_msg', json.dumps(data)) async with asyncio.timeout(_REDIS_POLL_TIMEOUT): await flag.wait() - return query.result + return check.running_sids except TimeoutError: # Nobody replied in time - return query.result + return check.running_sids finally: - self._running_sid_queries.pop(query_id, None) - - async def get_connections( - self, user_id: str | None = None, filter_to_sids: set[str] | None = None - ) -> dict[str, str]: - connection_ids = self._get_connections_locally(user_id, filter_to_sids) - remote_connection_ids = await self._get_connections_remotely( - user_id, filter_to_sids - ) - connection_ids.update(**remote_connection_ids) - return connection_ids - - def _get_connections_locally( - self, user_id: str | None = None, filter_to_sids: set[str] | None = None - ) -> dict[str, str]: - connections = dict(**self._local_connection_id_to_session_id) - if filter_to_sids is not None: - connections = { - connection_id: sid - for connection_id, sid in connections.items() - if sid in filter_to_sids - } - if user_id: - for connection_id, sid in list(connections.items()): - session = self._local_agent_loops_by_sid.get(sid) - if not session or session.user_id != user_id: - connections.pop(connection_id) - return connections - - async def _get_connections_remotely( - self, user_id: str | None = None, filter_to_sids: set[str] | None = None - ) -> dict[str, str]: - redis_client = self._get_redis_client() - if not redis_client: - return {} + self._session_is_running_checks.pop(request_id, None) + async def _has_remote_connections(self, sid: str) -> bool: + """As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply""" + # Create a flag for the callback flag = asyncio.Event() - query_id = str(uuid4()) - query = _ClusterQuery[dict[str, str]]( - query_id=query_id, request_ids=filter_to_sids, result={} - ) - self._connection_queries[query_id] = query + self._has_remote_connections_flags[sid] = flag try: - logger.debug( - f'publish:get_connections_remotely_query:{user_id}:{filter_to_sids}' + await self._get_redis_client().publish( + 'oh_event', + json.dumps( + { + 'sid': sid, + 'message_type': 'has_remote_connections_query', + } + ), ) - data: dict = { - 'query_id': query_id, - 'message_type': 'connections_query', - } - if user_id: - data['user_id'] = user_id - if filter_to_sids: - data['filter_to_sids'] = list(filter_to_sids) - await redis_client.publish('session_msg', json.dumps(data)) async with asyncio.timeout(_REDIS_POLL_TIMEOUT): await flag.wait() - return query.result + result = flag.is_set() + return result except TimeoutError: # Nobody replied in time - return query.result + return False finally: - self._connection_queries.pop(query_id, None) + self._has_remote_connections_flags.pop(sid, None) async def maybe_start_agent_loop( self, sid: str, settings: Settings, user_id: str | None @@ -443,18 +349,8 @@ async def maybe_start_agent_loop( session: Session | None = None if not await self.is_agent_loop_running(sid): logger.info(f'start_agent_loop:{sid}') - - response_ids = await self.get_running_agent_loops(user_id) - if len(response_ids) >= MAX_RUNNING_CONVERSATIONS: - logger.info('too_many_sessions_for:{user_id}') - await self.close_session(next(iter(response_ids))) - session = Session( - sid=sid, - file_store=self.file_store, - config=self.config, - sio=self.sio, - user_id=user_id, + sid=sid, file_store=self.file_store, config=self.config, sio=self.sio ) self._local_agent_loops_by_sid[sid] = session asyncio.create_task(session.initialize_agent(settings)) @@ -463,6 +359,7 @@ async def maybe_start_agent_loop( if not event_stream: logger.error(f'No event stream after starting agent loop: {sid}') raise RuntimeError(f'no_event_stream:{sid}') + asyncio.create_task(self._cleanup_session_later(sid)) return event_stream async def _get_event_stream(self, sid: str) -> EventStream | None: @@ -472,7 +369,7 @@ async def _get_event_stream(self, sid: str) -> EventStream | None: logger.info(f'found_local_agent_loop:{sid}') return session.agent_session.event_stream - if await self._get_running_agent_loops_remotely(filter_to_sids={sid}): + if await self.is_agent_loop_running_in_cluster(sid): logger.info(f'found_remote_agent_loop:{sid}') return EventStream(sid, self.file_store) @@ -480,7 +377,7 @@ async def _get_event_stream(self, sid: str) -> EventStream | None: async def send_to_event_stream(self, connection_id: str, data: dict): # If there is a local session running, send to that - sid = self._local_connection_id_to_session_id.get(connection_id) + sid = self.local_connection_id_to_session_id.get(connection_id) if not sid: raise RuntimeError(f'no_connected_session:{connection_id}') @@ -496,11 +393,11 @@ async def send_to_event_stream(self, connection_id: str, data: dict): next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL if ( next_alive_check > time.time() - or await self._get_running_agent_loops_remotely(filter_to_sids={sid}) + or await self.is_agent_loop_running_in_cluster(sid) ): # Send the event to the other pod await redis_client.publish( - 'session_msg', + 'oh_event', json.dumps( { 'sid': sid, @@ -514,37 +411,75 @@ async def send_to_event_stream(self, connection_id: str, data: dict): raise RuntimeError(f'no_connected_session:{connection_id}:{sid}') async def disconnect_from_session(self, connection_id: str): - sid = self._local_connection_id_to_session_id.pop(connection_id, None) + sid = self.local_connection_id_to_session_id.pop(connection_id, None) logger.info(f'disconnect_from_session:{connection_id}:{sid}') if not sid: # This can occur if the init action was never run. logger.warning(f'disconnect_from_uninitialized_session:{connection_id}') return + if should_continue(): + asyncio.create_task(self._cleanup_session_later(sid)) + else: + await self._on_close_session(sid) + + async def _cleanup_session_later(self, sid: str): + # Once there have been no connections to a session for a reasonable period, we close it + try: + await asyncio.sleep(self.config.sandbox.close_delay) + finally: + # If the sleep was cancelled, we still want to close these + await self._cleanup_session(sid) + + async def _cleanup_session(self, sid: str) -> bool: + # Get local connections + logger.info(f'_cleanup_session:{sid}') + has_local_connections = next( + (True for v in self.local_connection_id_to_session_id.values() if v == sid), + False, + ) + if has_local_connections: + return False + + # If no local connections, get connections through redis + redis_client = self._get_redis_client() + if redis_client and await self._has_remote_connections(sid): + return False + + # We alert the cluster in case they are interested + if redis_client: + await redis_client.publish( + 'oh_event', + json.dumps({'sid': sid, 'message_type': 'session_closing'}), + ) + + await self._on_close_session(sid) + return True + async def close_session(self, sid: str): session = self._local_agent_loops_by_sid.get(sid) if session: - await self._close_session(sid) + await self._on_close_session(sid) redis_client = self._get_redis_client() if redis_client: await redis_client.publish( - 'session_msg', + 'oh_event', json.dumps({'sid': sid, 'message_type': 'close_session'}), ) - async def _close_session(self, sid: str): + async def _on_close_session(self, sid: str): logger.info(f'_close_session:{sid}') # Clear up local variables connection_ids_to_remove = list( connection_id - for connection_id, conn_sid in self._local_connection_id_to_session_id.items() + for connection_id, conn_sid in self.local_connection_id_to_session_id.items() if sid == conn_sid ) logger.info(f'removing connections: {connection_ids_to_remove}') for connnnection_id in connection_ids_to_remove: - self._local_connection_id_to_session_id.pop(connnnection_id, None) + self.local_connection_id_to_session_id.pop(connnnection_id, None) session = self._local_agent_loops_by_sid.pop(sid, None) if not session: @@ -553,17 +488,12 @@ async def _close_session(self, sid: str): logger.info(f'closing_session:{session.sid}') # We alert the cluster in case they are interested - try: - redis_client = self._get_redis_client() - if redis_client: - await redis_client.publish( - 'session_msg', - json.dumps({'sid': session.sid, 'message_type': 'session_closing'}), - ) - except Exception: - logger.info( - 'error_publishing_close_session_event', exc_info=True, stack_info=True + redis_client = self._get_redis_client() + if redis_client: + await redis_client.publish( + 'oh_event', + json.dumps({'sid': session.sid, 'message_type': 'session_closing'}), ) - await session.close() + await call_sync_from_async(session.close) logger.info(f'closed_session:{session.sid}') diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index e77a77101b20..8318ab773129 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -62,17 +62,9 @@ def __init__( self.loop = asyncio.get_event_loop() self.user_id = user_id - async def close(self): - if self.sio: - await self.sio.emit( - 'oh_event', - event_to_dict( - AgentStateChangedObservation('', AgentState.STOPPED.value) - ), - to=ROOM_KEY.format(sid=self.sid), - ) + def close(self): self.is_alive = False - await self.agent_session.close() + self.agent_session.close() async def initialize_agent( self, diff --git a/openhands/utils/http_session.py b/openhands/utils/http_session.py deleted file mode 100644 index 4edc4e6546c3..000000000000 --- a/openhands/utils/http_session.py +++ /dev/null @@ -1,24 +0,0 @@ -from dataclasses import dataclass, field - -import requests - - -@dataclass -class HttpSession: - """ - request.Session is reusable after it has been closed. This behavior makes it - likely to leak file descriptors (Especially when combined with tenacity). - We wrap the session to make it unusable after being closed - """ - - session: requests.Session | None = field(default_factory=requests.Session) - - def __getattr__(self, name): - if self.session is None: - raise ValueError('session_was_closed') - return object.__getattribute__(self.session, name) - - def close(self): - if self.session is not None: - self.session.close() - self.session = None diff --git a/tests/unit/test_manager.py b/tests/unit/test_manager.py index cd2ddf6ba0a6..f0ac68ff8361 100644 --- a/tests/unit/test_manager.py +++ b/tests/unit/test_manager.py @@ -44,28 +44,28 @@ async def test_session_not_running_in_cluster(): async with SessionManager( sio, AppConfig(), InMemoryFileStore() ) as session_manager: - result = await session_manager._get_running_agent_loops_remotely( - filter_to_sids={'non-existant-session'} + result = await session_manager.is_agent_loop_running_in_cluster( + 'non-existant-session' ) - assert result == set() + assert result is False assert sio.manager.redis.publish.await_count == 1 sio.manager.redis.publish.assert_called_once_with( - 'session_msg', - '{"query_id": "' + 'oh_event', + '{"request_id": "' + str(id) - + '", "message_type": "running_agent_loops_query", "filter_to_sids": ["non-existant-session"]}', + + '", "sids": ["non-existant-session"], "message_type": "is_session_running"}', ) @pytest.mark.asyncio -async def test_get_running_agent_loops_remotely(): +async def test_session_is_running_in_cluster(): id = uuid4() sio = get_mock_sio( GetMessageMock( { - 'query_id': str(id), + 'request_id': str(id), 'sids': ['existing-session'], - 'message_type': 'running_agent_loops_response', + 'message_type': 'session_is_running', } ) ) @@ -76,16 +76,16 @@ async def test_get_running_agent_loops_remotely(): async with SessionManager( sio, AppConfig(), InMemoryFileStore() ) as session_manager: - result = await session_manager._get_running_agent_loops_remotely( - 1, {'existing-session'} + result = await session_manager.is_agent_loop_running_in_cluster( + 'existing-session' ) - assert result == {'existing-session'} + assert result is True assert sio.manager.redis.publish.await_count == 1 sio.manager.redis.publish.assert_called_once_with( - 'session_msg', - '{"query_id": "' + 'oh_event', + '{"request_id": "' + str(id) - + '", "message_type": "running_agent_loops_query", "user_id": 1, "filter_to_sids": ["existing-session"]}', + + '", "sids": ["existing-session"], "message_type": "is_session_running"}', ) @@ -96,8 +96,8 @@ async def test_init_new_local_session(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - get_running_agent_loops_mock = AsyncMock() - get_running_agent_loops_mock.return_value = set() + is_agent_loop_running_in_cluster_mock = AsyncMock() + is_agent_loop_running_in_cluster_mock.return_value = False with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1), @@ -106,8 +106,8 @@ async def test_init_new_local_session(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.get_running_agent_loops', - get_running_agent_loops_mock, + 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', + is_agent_loop_running_in_cluster_mock, ), ): async with SessionManager( @@ -130,8 +130,8 @@ async def test_join_local_session(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - get_running_agent_loops_mock = AsyncMock() - get_running_agent_loops_mock.return_value = set() + is_agent_loop_running_in_cluster_mock = AsyncMock() + is_agent_loop_running_in_cluster_mock.return_value = False with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -140,8 +140,8 @@ async def test_join_local_session(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.get_running_agent_loops', - get_running_agent_loops_mock, + 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', + is_agent_loop_running_in_cluster_mock, ), ): async with SessionManager( @@ -167,8 +167,8 @@ async def test_join_cluster_session(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - get_running_agent_loops_mock = AsyncMock() - get_running_agent_loops_mock.return_value = {'new-session-id'} + is_agent_loop_running_in_cluster_mock = AsyncMock() + is_agent_loop_running_in_cluster_mock.return_value = True with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -177,8 +177,8 @@ async def test_join_cluster_session(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager._get_running_agent_loops_remotely', - get_running_agent_loops_mock, + 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', + is_agent_loop_running_in_cluster_mock, ), ): async with SessionManager( @@ -198,8 +198,8 @@ async def test_add_to_local_event_stream(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - get_running_agent_loops_mock = AsyncMock() - get_running_agent_loops_mock.return_value = set() + is_agent_loop_running_in_cluster_mock = AsyncMock() + is_agent_loop_running_in_cluster_mock.return_value = False with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -208,8 +208,8 @@ async def test_add_to_local_event_stream(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager.get_running_agent_loops', - get_running_agent_loops_mock, + 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', + is_agent_loop_running_in_cluster_mock, ), ): async with SessionManager( @@ -234,8 +234,8 @@ async def test_add_to_cluster_event_stream(): mock_session = MagicMock() mock_session.return_value = session_instance sio = get_mock_sio() - get_running_agent_loops_mock = AsyncMock() - get_running_agent_loops_mock.return_value = {'new-session-id'} + is_agent_loop_running_in_cluster_mock = AsyncMock() + is_agent_loop_running_in_cluster_mock.return_value = True with ( patch('openhands.server.session.manager.Session', mock_session), patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01), @@ -244,8 +244,8 @@ async def test_add_to_cluster_event_stream(): AsyncMock(), ), patch( - 'openhands.server.session.manager.SessionManager._get_running_agent_loops_remotely', - get_running_agent_loops_mock, + 'openhands.server.session.manager.SessionManager.is_agent_loop_running_in_cluster', + is_agent_loop_running_in_cluster_mock, ), ): async with SessionManager( @@ -259,7 +259,7 @@ async def test_add_to_cluster_event_stream(): ) assert sio.manager.redis.publish.await_count == 1 sio.manager.redis.publish.assert_called_once_with( - 'session_msg', + 'oh_event', '{"sid": "new-session-id", "message_type": "event", "data": {"event_type": "some_event"}}', ) @@ -277,7 +277,7 @@ async def test_cleanup_session_connections(): async with SessionManager( sio, AppConfig(), InMemoryFileStore() ) as session_manager: - session_manager._local_connection_id_to_session_id.update( + session_manager.local_connection_id_to_session_id.update( { 'conn1': 'session1', 'conn2': 'session1', @@ -286,9 +286,9 @@ async def test_cleanup_session_connections(): } ) - await session_manager._close_session('session1') + await session_manager._on_close_session('session1') - remaining_connections = session_manager._local_connection_id_to_session_id + remaining_connections = session_manager.local_connection_id_to_session_id assert 'conn1' not in remaining_connections assert 'conn2' not in remaining_connections assert 'conn3' in remaining_connections