diff --git a/changes/2312.feature.md b/changes/2312.feature.md new file mode 100644 index 00000000000..c25e29fadf9 --- /dev/null +++ b/changes/2312.feature.md @@ -0,0 +1 @@ +Add session status check & update API. diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 5a5e78a1b50..51849af72e9 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -101,7 +101,6 @@ session_templates, vfolders, ) -from ..models.utils import execute_with_txn_retry from ..types import UserScope from ..utils import query_userinfo as _query_userinfo from .auth import auth_required @@ -973,52 +972,40 @@ async def sync_agent_registry(request: web.Request, params: Any) -> web.StreamRe class TransitSessionStatusRequestModel(BaseModel): id: uuid.UUID = Field( - validation_alias=AliasChoices("id", "session_id", "SessionId"), - description="ID of the session to transit status.", + validation_alias=AliasChoices("id", "session_id", "sessionId", "SessionId"), + description="ID of the session to check and transit status.", ) class SessionStatusResponseModel(BaseResponseModel): - session_status: str | None + session_status: str @auth_required @server_status_required(ALL_ALLOWED) @pydantic_params_api_handler(TransitSessionStatusRequestModel) -async def transit_status( +async def check_and_transit_status( request: web.Request, params: TransitSessionStatusRequestModel ) -> SessionStatusResponseModel: root_ctx: RootContext = request.app["_root.context"] session_id = SessionId(params.id) + user_role = cast(UserRole, request["user"]["role"]) + user_id = cast(uuid.UUID, request["user"]["uuid"]) requester_access_key, owner_access_key = await get_access_key_scopes(request) log.info("TRANSIT_STATUS (ak:{}/{}, s:{})", requester_access_key, owner_access_key, session_id) - if requester_access_key != owner_access_key and request["user"]["role"] not in ( + + async with root_ctx.db.begin_readonly_session() as db_session: + session_row = await SessionRow.get_session_to_determine_status(db_session, session_id) + if session_row.user_uuid != user_id and user_role not in ( UserRole.ADMIN, UserRole.SUPERADMIN, ): - raise InsufficientPrivilege("You are not allowed to transit others's sessions status") - - async with root_ctx.db.connect() as db_conn: - async with root_ctx.db.begin_readonly_session(db_conn) as db_session: - session_row = await db_session.scalar( - sa.select(SessionRow).where(SessionRow.id == session_id) - ) - session_row = cast(SessionRow | None, session_row) - if session_row is None: - raise SessionNotFound(f"Session (id={session_id}) does not exist.") - if requester_access_key != session_row.access_key: - raise InvalidAPIParameters( - f"Access key does not own the session (ak:{requester_access_key}, s:{session_id})" - ) - - async def _transit(db_session: SASession) -> SessionStatus | None: - return await root_ctx.registry.transit_session_status(db_session, session_id) - - new_status = await execute_with_txn_retry(_transit, root_ctx.db.begin_session, db_conn) - - return SessionStatusResponseModel( - session_status=new_status.name if new_status is not None else new_status - ) + raise InvalidAPIParameters( + f"You are not allowed to transit others's sessions status (s:{session_id})" + ) + now = datetime.now(tzutc()) + row = await root_ctx.registry.transit_session_status(session_id, now) + return SessionStatusResponseModel(session_status=row.status.name) @server_status_required(ALL_ALLOWED) @@ -2365,7 +2352,7 @@ def create_app( cors.add(app.router.add_route("POST", "/_/create-cluster", create_cluster)) cors.add(app.router.add_route("GET", "/_/match", match_sessions)) cors.add(app.router.add_route("POST", "/_/sync-agent-registry", sync_agent_registry)) - cors.add(app.router.add_route("POST", "/transit-status", transit_status)) + cors.add(app.router.add_route("PATCH", "/_/transit-status", check_and_transit_status)) session_resource = cors.add(app.router.add_resource(r"/{session_name}")) cors.add(session_resource.add_route("GET", get_info)) cors.add(session_resource.add_route("PATCH", restart)) diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 99ec7224045..35324dad4dd 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -3303,12 +3303,13 @@ async def transit_session_status( self, session_id: SessionId, status_changed_at: datetime | None = None, - ) -> None: + ) -> SessionRow: now = status_changed_at or datetime.now(tzutc()) async with self.db.connect() as db_conn: row, is_transited = await self._transit_session_status(db_conn, session_id, now) if is_transited: await self._post_status_transition(row) + return row async def set_status_updatable_session(self, session_id: SessionId) -> None: try: