diff --git a/src/ai/backend/manager/models/session.py b/src/ai/backend/manager/models/session.py index 32f548dd257..e1b30920b40 100644 --- a/src/ai/backend/manager/models/session.py +++ b/src/ai/backend/manager/models/session.py @@ -733,62 +733,6 @@ async def get_session_id_by_kernel( async with db.begin_readonly_session() as db_session: return await db_session.scalar(query) - @classmethod - async def transit_session_status( - cls, - db: ExtendedAsyncSAEngine, - session_id: SessionId, - *, - status_info: str | None = None, - ) -> SessionStatus | None: - """ - Check status of session's sibling kernels and transit the status of session. - Return the new status of session. - """ - now = datetime.now(tzutc()) - - async def _check_and_update() -> SessionStatus | None: - async with db.begin_session() as db_session: - session_query = ( - sa.select(SessionRow) - .where(SessionRow.id == session_id) - .with_for_update() - .options( - noload("*"), - load_only(SessionRow.status), - selectinload(SessionRow.kernels).options( - noload("*"), load_only(KernelRow.status, KernelRow.cluster_role) - ), - ) - ) - session_row: SessionRow = (await db_session.scalars(session_query)).first() - determined_status = determine_session_status(session_row.kernels) - if determined_status not in SESSION_STATUS_TRANSITION_MAP[session_row.status]: - # TODO: log or raise error - return None - - update_values = { - "status": determined_status, - "status_history": sql_json_merge( - SessionRow.status_history, - (), - { - determined_status.name: now.isoformat(), - }, - ), - } - if determined_status in (SessionStatus.CANCELLED, SessionStatus.TERMINATED): - update_values["terminated_at"] = now - if status_info is not None: - update_values["status_info"] = status_info - update_query = ( - sa.update(SessionRow).where(SessionRow.id == session_id).values(**update_values) - ) - await db_session.execute(update_query) - return determined_status - - return await execute_with_retry(_check_and_update) - @classmethod async def get_session_to_determine_status( cls, db_session: SASession, session_id: SessionId diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 45f21e97356..0da09a06892 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -1553,98 +1553,6 @@ def convert_resource_spec_to_resource_slot( slots[slot_name] = str(sum(total_allocs)) return slots - async def finalize_running( - self, kernel_id: KernelId, session_id: SessionId, created_info: Mapping[str, Any] - ) -> None: - try: - agent_host = URL(created_info["agent_addr"]).host - kernel_host = created_info.get("kernel_host", agent_host) - service_ports = created_info.get("service_ports", []) - actual_allocs = self.convert_resource_spec_to_resource_slot( - created_info["resource_spec"]["allocations"] - ) - new_status = KernelStatus.RUNNING - update_data = { - "occupied_slots": actual_allocs, - "scaling_group": created_info["scaling_group"], - "container_id": created_info["container_id"], - "occupied_shares": {}, - "attached_devices": created_info.get("attached_devices", {}), - "kernel_host": kernel_host, - "repl_in_port": created_info["repl_in_port"], - "repl_out_port": created_info["repl_out_port"], - "stdin_port": created_info["stdin_port"], - "stdout_port": created_info["stdout_port"], - "service_ports": service_ports, - "status_history": sql_json_merge( - kernels.c.status_history, - (), - { - new_status.name: datetime.now(tzutc()).isoformat(), - }, - ), - } - self._kernel_actual_allocated_resources[kernel_id] = actual_allocs - - async def _update_session_occupying_slots(db_session: AsyncSession) -> None: - _stmt = sa.select(SessionRow).where(SessionRow.id == session_id) - session_row = cast(SessionRow | None, await db_session.scalar(_stmt)) - if session_row is None: - raise SessionNotFound(f"Failed to fetch session (id:{session_id})") - session_occupying_slots = ResourceSlot.from_json({**session_row.occupying_slots}) - session_occupying_slots.sync_keys(actual_allocs) - for key, val in session_occupying_slots.items(): - session_occupying_slots[key] = str(Decimal(val) + Decimal(actual_allocs[key])) - session_row.occupying_slots = session_occupying_slots - - async with self.db.connect() as db_conn: - await execute_with_txn_retry( - _update_session_occupying_slots, self.db.begin_session, db_conn - ) - kernel_did_update = await KernelRow.update_kernel( - self.db, kernel_id, new_status, update_data=update_data - ) - if not kernel_did_update: - return - new_session_status = await SessionRow.transit_session_status(self.db, session_id) - if new_session_status is None or new_session_status != SessionStatus.RUNNING: - return - query = ( - sa.select(SessionRow) - .where(SessionRow.id == session_id) - .options( - noload("*"), - load_only( - SessionRow.id, - SessionRow.name, - SessionRow.creation_id, - SessionRow.access_key, - ), - ) - ) - async with self.db.begin_readonly_session() as db_session: - updated_session = (await db_session.scalars(query)).first() - - log.debug( - "Producing SessionStartedEvent({}, {})", - updated_session.id, - updated_session.creation_id, - ) - await self.event_producer.produce_event( - SessionStartedEvent(updated_session.id, updated_session.creation_id), - ) - await self.hook_plugin_ctx.notify( - "POST_START_SESSION", - ( - updated_session.id, - updated_session.name, - updated_session.access_key, - ), - ) - except Exception: - log.exception("error while executing _finalize_running") - raise - async def _create_kernels_in_one_agent( self, agent_alloc_ctx: AgentAllocationContext,