diff --git a/changes/2311.enhance.md b/changes/2311.enhance.md new file mode 100644 index 0000000000..e328a3256d --- /dev/null +++ b/changes/2311.enhance.md @@ -0,0 +1 @@ +Revamp update mechanism of session & kernel status by reconciliation. diff --git a/src/ai/backend/manager/models/session.py b/src/ai/backend/manager/models/session.py index dd0e1b123d..3a5d9287b5 100644 --- a/src/ai/backend/manager/models/session.py +++ b/src/ai/backend/manager/models/session.py @@ -794,7 +794,12 @@ async def get_session_to_determine_status( .where(SessionRow.id == session_id) .options( selectinload(SessionRow.kernels).options( - load_only(KernelRow.status, KernelRow.cluster_role, KernelRow.status_info) + load_only( + KernelRow.status, + KernelRow.cluster_role, + KernelRow.status_info, + KernelRow.occupied_slots, + ) ), ) ) diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index f5e1252aea..5c17384ea9 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -13,7 +13,7 @@ import uuid import zlib from collections import defaultdict -from collections.abc import Mapping +from collections.abc import Iterable, Mapping, MutableMapping, Sequence from datetime import datetime from decimal import Decimal from io import BytesIO @@ -23,9 +23,7 @@ Dict, List, Literal, - MutableMapping, Optional, - Sequence, Tuple, TypeAlias, Union, @@ -3099,7 +3097,7 @@ async def _set_status(db_session: AsyncSession) -> None: ) await execute_with_txn_retry(_set_status, self.db.begin_session, db_conn) - await self.set_status_updatable_session(session_id) + await self.set_status_updatable_session([session_id]) async def mark_kernel_pulling( self, @@ -3121,7 +3119,7 @@ async def _transit_status(db_session: AsyncSession) -> bool: transited = await execute_with_txn_retry(_transit_status, self.db.begin_session, db_conn) if transited: - await self.set_status_updatable_session(session_id) + await self.set_status_updatable_session([session_id]) async def mark_kernel_running( self, @@ -3157,21 +3155,9 @@ async def _get_and_transit(db_session: AsyncSession) -> bool: transited = await execute_with_txn_retry(_get_and_transit, self.db.begin_session, db_conn) - async def _update_session(db_session: AsyncSession) -> None: - _stmt = sa.select(SessionRow).where(SessionRow.id == session_id).with_for_update() - session_row = cast(SessionRow | None, await db_session.scalar(_stmt)) - if session_row is None: - return - session_occupying_slots = ResourceSlot.from_json({**session_row.occupying_slots}) - session_occupying_slots.sync_keys(actual_allocs) - for key, val in session_occupying_slots.items(): - session_occupying_slots[key] = str(Decimal(val) + Decimal(actual_allocs[key])) - session_row.occupying_slots = session_occupying_slots - if transited: - await execute_with_txn_retry(_update_session, self.db.begin_session, db_conn) self._kernel_actual_allocated_resources[kernel_id] = actual_allocs - await self.set_status_updatable_session(session_id) + await self.set_status_updatable_session([session_id]) async def mark_kernel_terminated( self, @@ -3242,9 +3228,9 @@ async def _recalc(db_session: AsyncSession) -> None: # it may take a while to fetch stats from Redis. # await self.sync_kernel_stats([kernel_id]) - await self.set_status_updatable_session(session_id) + await self.set_status_updatable_session([session_id]) - async def transit_session_status( + async def _transit_session_status( self, db_conn: SAConnection, session_id: SessionId, @@ -3257,11 +3243,29 @@ async def _get_and_transit( ) -> tuple[SessionRow, bool]: session_row = await SessionRow.get_session_to_determine_status(db_session, session_id) transited = session_row.determine_and_set_status(status_changed_at=now) + + def _calculate_session_occupied_slots(session_row: SessionRow): + session_occupying_slots = ResourceSlot.from_json({**session_row.occupying_slots}) + for row in session_row.kernels: + kernel_row = cast(KernelRow, row) + kernel_allocs = kernel_row.occupied_slots + session_occupying_slots.sync_keys(kernel_allocs) + for key, val in session_occupying_slots.items(): + session_occupying_slots[key] = str( + Decimal(val) + Decimal(kernel_allocs[key]) + ) + session_row.occupying_slots = session_occupying_slots + + match session_row.status: + case SessionStatus.PREPARING: + _calculate_session_occupied_slots(session_row) + case SessionStatus.RUNNING if transited: + _calculate_session_occupied_slots(session_row) return session_row, transited return await execute_with_txn_retry(_get_and_transit, self.db.begin_session, db_conn) - async def post_status_transition( + async def _post_status_transition( self, session_row: SessionRow, ) -> None: @@ -3293,11 +3297,30 @@ async def post_status_transition( case _: pass - async def set_status_updatable_session(self, session_id: SessionId) -> None: + async def transit_session_status( + self, + session_id: SessionId, + status_changed_at: datetime | None = None, + ) -> None: + now = status_changed_at or datetime.now(tzutc()) + async with self.db.connect() as db_conn: + row, is_transited = await self._transit_session_status(db_conn, session_id, now) + if is_transited: + await self._post_status_transition(row) + + async def set_status_updatable_session(self, session_ids: Iterable[SessionId]) -> None: + sadd_session_ids_script = textwrap.dedent(""" + local key = KEYS[1] + local values = ARGV + return redis.call('SADD', key, unpack(values)) + """) try: - await redis_helper.execute( + await redis_helper.execute_script( self.redis_stat, - lambda r: r.sadd("session_status_update", msgpack.packb(session_id)), + "session_status_update", + sadd_session_ids_script, + ["session_status_update"], + [sid.bytes for sid in session_ids], ) except redis.exceptions.ResponseError: log.warning("Failed to update session status to redis, skip.") @@ -3323,7 +3346,7 @@ async def get_status_updatable_sessions(self) -> list[SessionId]: result: list[SessionId] = [] for raw_session_id in raw_result: - result.append(SessionId(msgpack.unpackb(raw_session_id))) + result.append(SessionId(uuid.UUID(bytes=raw_session_id))) return result async def _get_user_email( diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index 3ef063e9f0..073156880f 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -1601,15 +1601,9 @@ async def update_session_status( log.debug("update_session_status(): triggered") candidates = await self.registry.get_status_updatable_sessions() - async def _transit(session_id: SessionId): - async with self.db.connect() as db_conn: - row, is_transited = await self.registry.transit_session_status(db_conn, session_id) - if is_transited: - await self.registry.post_status_transition(row) - async with aiotools.TaskGroup() as tg: for session_id in candidates: - tg.create_task(_transit(session_id)) + tg.create_task(self.registry.transit_session_status(session_id)) async def start_session( self,