Skip to content

Commit

Permalink
replace asyncio.gather() with aiotools.TaskGroup()
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Sep 12, 2024
1 parent 6cfb21e commit 5a17c2f
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 120 deletions.
1 change: 1 addition & 0 deletions changes/2311.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Revamp update mechanism of session & kernel status by reconciliation.
4 changes: 4 additions & 0 deletions src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,8 +651,12 @@ async def get_kernel_to_update_status(
cls,
db_session: SASession,
kernel_id: KernelId,
*,
for_update: bool = True,
) -> KernelRow:
_stmt = sa.select(KernelRow).where(KernelRow.id == kernel_id)
if for_update:
_stmt = _stmt.with_for_update()
kernel_row = cast(KernelRow | None, await db_session.scalar(_stmt))
if kernel_row is None:
raise KernelNotFound(f"Kernel not found (id:{kernel_id})")
Expand Down
209 changes: 208 additions & 1 deletion src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import asyncio
import enum
import logging
import textwrap
from collections.abc import Iterable, Mapping, Sequence
from contextlib import asynccontextmanager as actxmgr
from dataclasses import dataclass, field
from datetime import datetime
from decimal import Decimal
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -21,6 +23,7 @@

import aiotools
import graphene
import redis.exceptions
import sqlalchemy as sa
from dateutil.parser import parse as dtparse
from dateutil.tz import tzutc
Expand All @@ -30,10 +33,20 @@
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import load_only, noload, relationship, selectinload

from ai.backend.common import redis_helper
from ai.backend.common.events import (
EventDispatcher,
EventProducer,
SessionStartedEvent,
SessionTerminatedEvent,
)
from ai.backend.common.plugin.hook import HookPluginContext
from ai.backend.common.types import (
AccessKey,
ClusterMode,
KernelId,
RedisConnectionInfo,
ResourceSlot,
SessionId,
SessionResult,
SessionTypes,
Expand Down Expand Up @@ -95,6 +108,7 @@
JSONCoalesceExpr,
agg_to_array,
execute_with_retry,
execute_with_txn_retry,
sql_json_merge,
)

Expand Down Expand Up @@ -804,7 +818,12 @@ async def get_session_to_determine_status(
.where(SessionRow.id == session_id)
.options(
selectinload(SessionRow.kernels).options(
load_only(KernelRow.status, KernelRow.cluster_role, KernelRow.status_info)
load_only(
KernelRow.status,
KernelRow.cluster_role,
KernelRow.status_info,
KernelRow.occupied_slots,
)
),
)
)
Expand Down Expand Up @@ -1162,6 +1181,194 @@ async def get_sgroup_managed_sessions(
return result.scalars().all()


class SessionLifecycleManager:
status_set_key = "session_status_update"

def __init__(
self,
db: ExtendedAsyncSAEngine,
redis_obj: RedisConnectionInfo,
event_dispatcher: EventDispatcher,
event_producer: EventProducer,
hook_plugin_ctx: HookPluginContext,
) -> None:
self.db = db
self.redis_obj = redis_obj
self.event_dispatcher = event_dispatcher
self.event_producer = event_producer
self.hook_plugin_ctx = hook_plugin_ctx

def _encode(sid: SessionId) -> bytes:
return sid.bytes

def _decode(raw_sid: bytes) -> SessionId:
return SessionId(UUID(bytes=raw_sid))

self._encoder = _encode
self._decoder = _decode

async def _transit_session_status(
self,
db_conn: SAConnection,
session_id: SessionId,
status_changed_at: datetime | None = None,
) -> tuple[SessionRow, bool]:
now = status_changed_at or datetime.now(tzutc())

async def _get_and_transit(
db_session: SASession,
) -> tuple[SessionRow, bool]:
session_row = await SessionRow.get_session_to_determine_status(db_session, session_id)
transited = session_row.determine_and_set_status(status_changed_at=now)

def _calculate_session_occupied_slots(session_row: SessionRow):
session_occupying_slots = ResourceSlot.from_json({**session_row.occupying_slots})
for row in session_row.kernels:
kernel_row = cast(KernelRow, row)
kernel_allocs = kernel_row.occupied_slots
session_occupying_slots.sync_keys(kernel_allocs)
for key, val in session_occupying_slots.items():
session_occupying_slots[key] = str(
Decimal(val) + Decimal(kernel_allocs[key])
)
session_row.occupying_slots = session_occupying_slots

match session_row.status:
case SessionStatus.PREPARING | SessionStatus.RUNNING:
_calculate_session_occupied_slots(session_row)
return session_row, transited

return await execute_with_txn_retry(_get_and_transit, self.db.begin_session, db_conn)

async def _post_status_transition(
self,
session_row: SessionRow,
) -> None:
match session_row.status:
case SessionStatus.RUNNING:
log.debug(
"Producing SessionStartedEvent({}, {})",
session_row.id,
session_row.creation_id,
)
await self.event_producer.produce_event(
SessionStartedEvent(session_row.id, session_row.creation_id),
)
await self.hook_plugin_ctx.notify(
"POST_START_SESSION",
(
session_row.id,
session_row.name,
session_row.access_key,
),
)
await self.event_producer.produce_event(
SessionStartedEvent(session_row.id, session_row.creation_id),
)
case SessionStatus.TERMINATED:
await self.event_producer.produce_event(
SessionTerminatedEvent(session_row.id, session_row.main_kernel.status_info),
)
case _:
pass

async def transit_session_status(
self,
session_ids: Iterable[SessionId],
status_changed_at: datetime | None = None,
) -> list[SessionRow]:
if not session_ids:
return []
now = status_changed_at or datetime.now(tzutc())
transited_sessions: list[SessionRow] = []
async with self.db.connect() as db_conn:
for sid in session_ids:
row, is_transited = await self._transit_session_status(db_conn, sid, now)
if is_transited:
transited_sessions.append(row)
for row in transited_sessions:
await self._post_status_transition(row)
return transited_sessions

async def register_status_updatable_session(self, session_ids: Iterable[SessionId]) -> None:
sadd_session_ids_script = textwrap.dedent("""
local key = KEYS[1]
local values = ARGV
return redis.call('SADD', key, unpack(values))
""")
try:
await redis_helper.execute_script(
self.redis_obj,
"session_status_update",
sadd_session_ids_script,
[self.status_set_key],
[self._encoder(sid) for sid in session_ids],
)
except (
redis.exceptions.RedisError,
redis.exceptions.RedisClusterException,
redis.exceptions.ChildDeadlockedError,
) as e:
log.warning(f"Failed to update session status to redis, skip. (e:{repr(e)})")

async def get_status_updatable_sessions(self) -> list[SessionId]:
pop_all_session_id_script = textwrap.dedent("""
local key = KEYS[1]
local count = redis.call('SCARD', key)
return redis.call('SPOP', key, count)
""")
try:
raw_result = await redis_helper.execute_script(
self.redis_obj,
"pop_all_session_id_to_update_status",
pop_all_session_id_script,
[self.status_set_key],
[],
)
except (
redis.exceptions.RedisError,
redis.exceptions.RedisClusterException,
redis.exceptions.ChildDeadlockedError,
) as e:
log.warning(f"Failed to fetch session status data from redis, skip. (e:{repr(e)})")
return []
raw_result = cast(list[bytes], raw_result)
result: list[SessionId] = []
for raw_session_id in raw_result:
try:
result.append(self._decoder(raw_session_id))
except (ValueError, SyntaxError):
log.warning(f"Cannot parse session id, skip. (id:{raw_session_id})")
continue
return result

async def deregister_status_updatable_session(
self,
session_ids: Iterable[SessionId],
) -> int:
srem_session_ids_script = textwrap.dedent("""
local key = KEYS[1]
local values = ARGV
return redis.call('SREM', key, unpack(values))
""")
try:
ret = await redis_helper.execute_script(
self.redis_obj,
"session_status_update",
srem_session_ids_script,
[self.status_set_key],
[self._encoder(sid) for sid in session_ids],
)
except (
redis.exceptions.RedisError,
redis.exceptions.RedisClusterException,
redis.exceptions.ChildDeadlockedError,
) as e:
log.warning(f"Failed to remove session status data from redis, skip. (e:{repr(e)})")
return 0
return ret


class SessionDependencyRow(Base):
__tablename__ = "session_dependencies"
session_id = sa.Column(
Expand Down
Loading

0 comments on commit 5a17c2f

Please sign in to comment.