Skip to content

Commit

Permalink
check authority
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Aug 5, 2024
1 parent 9c5b19e commit 671e3e3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 31 deletions.
1 change: 1 addition & 0 deletions changes/2312.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add session status check & update API.
47 changes: 17 additions & 30 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@
session_templates,
vfolders,
)
from ..models.utils import execute_with_txn_retry
from ..types import UserScope
from ..utils import query_userinfo as _query_userinfo
from .auth import auth_required
Expand Down Expand Up @@ -973,52 +972,40 @@ async def sync_agent_registry(request: web.Request, params: Any) -> web.StreamRe

class TransitSessionStatusRequestModel(BaseModel):
id: uuid.UUID = Field(
validation_alias=AliasChoices("id", "session_id", "SessionId"),
description="ID of the session to transit status.",
validation_alias=AliasChoices("id", "session_id", "sessionId", "SessionId"),
description="ID of the session to check and transit status.",
)


class SessionStatusResponseModel(BaseResponseModel):
session_status: str | None
session_status: str


@auth_required
@server_status_required(ALL_ALLOWED)
@pydantic_params_api_handler(TransitSessionStatusRequestModel)
async def transit_status(
async def check_and_transit_status(
request: web.Request, params: TransitSessionStatusRequestModel
) -> SessionStatusResponseModel:
root_ctx: RootContext = request.app["_root.context"]
session_id = SessionId(params.id)
user_role = cast(UserRole, request["user"]["role"])
user_id = cast(uuid.UUID, request["user"]["uuid"])
requester_access_key, owner_access_key = await get_access_key_scopes(request)
log.info("TRANSIT_STATUS (ak:{}/{}, s:{})", requester_access_key, owner_access_key, session_id)
if requester_access_key != owner_access_key and request["user"]["role"] not in (

async with root_ctx.db.begin_readonly_session() as db_session:
session_row = await SessionRow.get_session_to_determine_status(db_session, session_id)
if session_row.user_uuid != user_id and user_role not in (
UserRole.ADMIN,
UserRole.SUPERADMIN,
):
raise InsufficientPrivilege("You are not allowed to transit others's sessions status")

async with root_ctx.db.connect() as db_conn:
async with root_ctx.db.begin_readonly_session(db_conn) as db_session:
session_row = await db_session.scalar(
sa.select(SessionRow).where(SessionRow.id == session_id)
)
session_row = cast(SessionRow | None, session_row)
if session_row is None:
raise SessionNotFound(f"Session (id={session_id}) does not exist.")
if requester_access_key != session_row.access_key:
raise InvalidAPIParameters(
f"Access key does not own the session (ak:{requester_access_key}, s:{session_id})"
)

async def _transit(db_session: SASession) -> SessionStatus | None:
return await root_ctx.registry.transit_session_status(db_session, session_id)

new_status = await execute_with_txn_retry(_transit, root_ctx.db.begin_session, db_conn)

return SessionStatusResponseModel(
session_status=new_status.name if new_status is not None else new_status
)
raise InvalidAPIParameters(
f"You are not allowed to transit others's sessions status (s:{session_id})"
)
now = datetime.now(tzutc())
row = await root_ctx.registry.transit_session_status(session_id, now)
return SessionStatusResponseModel(session_status=row.status.name)


@server_status_required(ALL_ALLOWED)
Expand Down Expand Up @@ -2365,7 +2352,7 @@ def create_app(
cors.add(app.router.add_route("POST", "/_/create-cluster", create_cluster))
cors.add(app.router.add_route("GET", "/_/match", match_sessions))
cors.add(app.router.add_route("POST", "/_/sync-agent-registry", sync_agent_registry))
cors.add(app.router.add_route("POST", "/transit-status", transit_status))
cors.add(app.router.add_route("PATCH", "/_/transit-status", check_and_transit_status))
session_resource = cors.add(app.router.add_resource(r"/{session_name}"))
cors.add(session_resource.add_route("GET", get_info))
cors.add(session_resource.add_route("PATCH", restart))
Expand Down
3 changes: 2 additions & 1 deletion src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3303,12 +3303,13 @@ async def transit_session_status(
self,
session_id: SessionId,
status_changed_at: datetime | None = None,
) -> None:
) -> SessionRow:
now = status_changed_at or datetime.now(tzutc())
async with self.db.connect() as db_conn:
row, is_transited = await self._transit_session_status(db_conn, session_id, now)
if is_transited:
await self._post_status_transition(row)
return row

async def set_status_updatable_session(self, session_id: SessionId) -> None:
try:
Expand Down

0 comments on commit 671e3e3

Please sign in to comment.