Skip to content

Commit

Permalink
replace asyncio.gather() with aiotools.TaskGroup()
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Aug 25, 2024
1 parent f333144 commit 24f9ec9
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 33 deletions.
1 change: 1 addition & 0 deletions changes/2311.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Revamp update mechanism of session & kernel status by reconciliation.
7 changes: 6 additions & 1 deletion src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,12 @@ async def get_session_to_determine_status(
.where(SessionRow.id == session_id)
.options(
selectinload(SessionRow.kernels).options(
load_only(KernelRow.status, KernelRow.cluster_role, KernelRow.status_info)
load_only(
KernelRow.status,
KernelRow.cluster_role,
KernelRow.status_info,
KernelRow.occupied_slots,
)
),
)
)
Expand Down
73 changes: 48 additions & 25 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import uuid
import zlib
from collections import defaultdict
from collections.abc import Mapping
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
from datetime import datetime
from decimal import Decimal
from io import BytesIO
Expand All @@ -23,9 +23,7 @@
Dict,
List,
Literal,
MutableMapping,
Optional,
Sequence,
Tuple,
TypeAlias,
Union,
Expand Down Expand Up @@ -3099,7 +3097,7 @@ async def _set_status(db_session: AsyncSession) -> None:
)

await execute_with_txn_retry(_set_status, self.db.begin_session, db_conn)
await self.set_status_updatable_session(session_id)
await self.set_status_updatable_session([session_id])

async def mark_kernel_pulling(
self,
Expand All @@ -3121,7 +3119,7 @@ async def _transit_status(db_session: AsyncSession) -> bool:

transited = await execute_with_txn_retry(_transit_status, self.db.begin_session, db_conn)
if transited:
await self.set_status_updatable_session(session_id)
await self.set_status_updatable_session([session_id])

async def mark_kernel_running(
self,
Expand Down Expand Up @@ -3157,21 +3155,9 @@ async def _get_and_transit(db_session: AsyncSession) -> bool:

transited = await execute_with_txn_retry(_get_and_transit, self.db.begin_session, db_conn)

async def _update_session(db_session: AsyncSession) -> None:
_stmt = sa.select(SessionRow).where(SessionRow.id == session_id).with_for_update()
session_row = cast(SessionRow | None, await db_session.scalar(_stmt))
if session_row is None:
return
session_occupying_slots = ResourceSlot.from_json({**session_row.occupying_slots})
session_occupying_slots.sync_keys(actual_allocs)
for key, val in session_occupying_slots.items():
session_occupying_slots[key] = str(Decimal(val) + Decimal(actual_allocs[key]))
session_row.occupying_slots = session_occupying_slots

if transited:
await execute_with_txn_retry(_update_session, self.db.begin_session, db_conn)
self._kernel_actual_allocated_resources[kernel_id] = actual_allocs
await self.set_status_updatable_session(session_id)
await self.set_status_updatable_session([session_id])

async def mark_kernel_terminated(
self,
Expand Down Expand Up @@ -3242,9 +3228,9 @@ async def _recalc(db_session: AsyncSession) -> None:
# it may take a while to fetch stats from Redis.

# await self.sync_kernel_stats([kernel_id])
await self.set_status_updatable_session(session_id)
await self.set_status_updatable_session([session_id])

async def transit_session_status(
async def _transit_session_status(
self,
db_conn: SAConnection,
session_id: SessionId,
Expand All @@ -3257,11 +3243,29 @@ async def _get_and_transit(
) -> tuple[SessionRow, bool]:
session_row = await SessionRow.get_session_to_determine_status(db_session, session_id)
transited = session_row.determine_and_set_status(status_changed_at=now)

def _calculate_session_occupied_slots(session_row: SessionRow):
session_occupying_slots = ResourceSlot.from_json({**session_row.occupying_slots})
for row in session_row.kernels:
kernel_row = cast(KernelRow, row)
kernel_allocs = kernel_row.occupied_slots
session_occupying_slots.sync_keys(kernel_allocs)
for key, val in session_occupying_slots.items():
session_occupying_slots[key] = str(
Decimal(val) + Decimal(kernel_allocs[key])
)
session_row.occupying_slots = session_occupying_slots

match session_row.status:
case SessionStatus.PREPARING:
_calculate_session_occupied_slots(session_row)
case SessionStatus.RUNNING if transited:
_calculate_session_occupied_slots(session_row)
return session_row, transited

return await execute_with_txn_retry(_get_and_transit, self.db.begin_session, db_conn)

async def post_status_transition(
async def _post_status_transition(
self,
session_row: SessionRow,
) -> None:
Expand Down Expand Up @@ -3293,11 +3297,30 @@ async def post_status_transition(
case _:
pass

async def set_status_updatable_session(self, session_id: SessionId) -> None:
async def transit_session_status(
self,
session_id: SessionId,
status_changed_at: datetime | None = None,
) -> None:
now = status_changed_at or datetime.now(tzutc())
async with self.db.connect() as db_conn:
row, is_transited = await self._transit_session_status(db_conn, session_id, now)
if is_transited:
await self._post_status_transition(row)

async def set_status_updatable_session(self, session_ids: Iterable[SessionId]) -> None:
sadd_session_ids_script = textwrap.dedent("""
local key = KEYS[1]
local values = ARGV
return redis.call('SADD', key, unpack(values))
""")
try:
await redis_helper.execute(
await redis_helper.execute_script(
self.redis_stat,
lambda r: r.sadd("session_status_update", msgpack.packb(session_id)),
"session_status_update",
sadd_session_ids_script,
["session_status_update"],
[sid.bytes for sid in session_ids],
)
except redis.exceptions.ResponseError:
log.warning("Failed to update session status to redis, skip.")
Expand All @@ -3323,7 +3346,7 @@ async def get_status_updatable_sessions(self) -> list[SessionId]:

result: list[SessionId] = []
for raw_session_id in raw_result:
result.append(SessionId(msgpack.unpackb(raw_session_id)))
result.append(SessionId(uuid.UUID(bytes=raw_session_id)))
return result

async def _get_user_email(
Expand Down
8 changes: 1 addition & 7 deletions src/ai/backend/manager/scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,15 +1601,9 @@ async def update_session_status(
log.debug("update_session_status(): triggered")
candidates = await self.registry.get_status_updatable_sessions()

async def _transit(session_id: SessionId):
async with self.db.connect() as db_conn:
row, is_transited = await self.registry.transit_session_status(db_conn, session_id)
if is_transited:
await self.registry.post_status_transition(row)

async with aiotools.TaskGroup() as tg:
for session_id in candidates:
tg.create_task(_transit(session_id))
tg.create_task(self.registry.transit_session_status(session_id))

async def start_session(
self,
Expand Down

0 comments on commit 24f9ec9

Please sign in to comment.