Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add session status checker API #2312

Merged
merged 2 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/2312.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add session status check & update API.
53 changes: 53 additions & 0 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
KernelId,
MountPermission,
MountTypes,
SessionId,
SessionTypes,
VFolderID,
)
Expand Down Expand Up @@ -969,6 +970,57 @@ async def sync_agent_registry(request: web.Request, params: Any) -> web.StreamRe
return web.json_response({}, status=200)


class TransitSessionStatusRequestModel(BaseModel):
ids: list[uuid.UUID] = Field(
validation_alias=AliasChoices("ids", "session_ids", "sessionIds", "SessionIds"),
description="ID array of sessions to check and transit status.",
)


class SessionStatusResponseModel(BaseResponseModel):
session_status_map: dict[SessionId, str]


@auth_required
@server_status_required(ALL_ALLOWED)
@pydantic_params_api_handler(TransitSessionStatusRequestModel)
async def check_and_transit_status(
request: web.Request, params: TransitSessionStatusRequestModel
) -> SessionStatusResponseModel:
root_ctx: RootContext = request.app["_root.context"]
session_ids = [SessionId(id) for id in params.ids]
user_role = cast(UserRole, request["user"]["role"])
user_id = cast(uuid.UUID, request["user"]["uuid"])
requester_access_key, owner_access_key = await get_access_key_scopes(request)
log.info("TRANSIT_STATUS (ak:{}/{}, s:{})", requester_access_key, owner_access_key, session_ids)

accessible_session_ids: list[SessionId] = []
async with root_ctx.db.begin_readonly_session() as db_session:
for sid in session_ids:
session_row = await SessionRow.get_session_to_determine_status(db_session, sid)
if session_row.user_uuid == user_id or user_role in (
UserRole.ADMIN,
UserRole.SUPERADMIN,
):
accessible_session_ids.append(sid)
else:
log.warning(
f"You are not allowed to transit others's sessions status, skip (s:{sid})"
)
if accessible_session_ids:
now = datetime.now(tzutc())
session_rows = await root_ctx.registry.session_lifecycle_mgr.transit_session_status(
accessible_session_ids, now
)
await root_ctx.registry.session_lifecycle_mgr.deregister_status_updatable_session([
row.id for row, is_transited in session_rows if is_transited
])
result = {row.id: row.status.name for row, _ in session_rows}
else:
result = {}
return SessionStatusResponseModel(session_status_map=result)


@server_status_required(ALL_ALLOWED)
@auth_required
@check_api_params(
Expand Down Expand Up @@ -2315,6 +2367,7 @@ def create_app(
cors.add(app.router.add_route("POST", "/_/create-cluster", create_cluster))
cors.add(app.router.add_route("GET", "/_/match", match_sessions))
cors.add(app.router.add_route("POST", "/_/sync-agent-registry", sync_agent_registry))
cors.add(app.router.add_route("POST", "/_/transit-status", check_and_transit_status))
session_resource = cors.add(app.router.add_resource(r"/{session_name}"))
cors.add(session_resource.add_route("GET", get_info))
cors.add(session_resource.add_route("PATCH", restart))
Expand Down
34 changes: 34 additions & 0 deletions src/ai/backend/manager/models/gql_models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,40 @@ def parse(
result.permissions = permissions
return result

@classmethod
async def parse_status_only(
cls,
info: graphene.ResolveInfo,
row: SessionRow,
) -> ComputeSessionNode:
status_history = row.status_history or {}
raw_scheduled_at = status_history.get(SessionStatus.SCHEDULED.name)
return cls(
# identity
id=row.id,
row_id=row.id,
name=row.name,
# status
status=row.status.name,
status_changed=row.status_changed,
status_info=row.status_info,
status_data=row.status_data,
status_history=status_history,
created_at=row.created_at,
starts_at=row.starts_at,
terminated_at=row.terminated_at,
scheduled_at=datetime.fromisoformat(raw_scheduled_at)
if raw_scheduled_at is not None
else None,
result=row.result.name,
# resources
agent_ids=row.agent_ids,
scaling_group=row.scaling_group_name,
vfolder_mounts=row.vfolder_mounts,
occupied_slots=row.occupying_slots.to_json(),
requested_slots=row.requested_slots.to_json(),
)

@classmethod
async def get_accessible_node(
cls,
Expand Down
36 changes: 21 additions & 15 deletions src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,13 +866,10 @@ def set_status(
if status in (SessionStatus.CANCELLED, SessionStatus.TERMINATED):
self.terminated_at = now
self.status = status
self.status_history = sql_json_merge(
SessionRow.status_history,
(),
{
status.name: now.isoformat(),
},
)
self.status_history = {
**self.status_history,
status.name: now.isoformat(),
}
if status_data is not None:
self.status_data = status_data

Expand Down Expand Up @@ -1234,8 +1231,11 @@ def _calculate_session_occupied_slots(session_row: SessionRow):
session_row.occupying_slots = session_occupying_slots

match session_row.status:
case SessionStatus.PREPARING | SessionStatus.RUNNING:
case SessionStatus.PREPARING:
_calculate_session_occupied_slots(session_row)
case SessionStatus.RUNNING if transited:
_calculate_session_occupied_slots(session_row)

return session_row, transited

return await execute_with_txn_retry(_get_and_transit, self.db.begin_session, db_conn)
Expand Down Expand Up @@ -1276,21 +1276,24 @@ async def transit_session_status(
self,
session_ids: Iterable[SessionId],
status_changed_at: datetime | None = None,
) -> list[SessionRow]:
) -> list[tuple[SessionRow, bool]]:
if not session_ids:
return []
now = status_changed_at or datetime.now(tzutc())
transited_sessions: list[SessionRow] = []
result: list[tuple[SessionRow, bool]] = []
async with self.db.connect() as db_conn:
for sid in session_ids:
row, is_transited = await self._transit_session_status(db_conn, sid, now)
if is_transited:
transited_sessions.append(row)
for row in transited_sessions:
await self._post_status_transition(row)
return transited_sessions
result.append((row, is_transited))
for row, is_transited in result:
if is_transited:
await self._post_status_transition(row)
return result

async def register_status_updatable_session(self, session_ids: Iterable[SessionId]) -> None:
if not session_ids:
return

sadd_session_ids_script = textwrap.dedent("""
local key = KEYS[1]
local values = ARGV
Expand Down Expand Up @@ -1346,6 +1349,9 @@ async def deregister_status_updatable_session(
self,
session_ids: Iterable[SessionId],
) -> int:
if not session_ids:
return 0

srem_session_ids_script = textwrap.dedent("""
local key = KEYS[1]
local values = ARGV
Expand Down
Loading