Skip to content

Commit

Permalink
replace asyncio.gather() with aiotools.TaskGroup()
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Sep 2, 2024
1 parent b59f950 commit 6ad9a17
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 41 deletions.
1 change: 1 addition & 0 deletions changes/2311.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Revamp update mechanism of session & kernel status by reconciliation.
2 changes: 1 addition & 1 deletion src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ async def get_kernel_to_update_status(
db_session: SASession,
kernel_id: KernelId,
) -> KernelRow:
_stmt = sa.select(KernelRow).where(KernelRow.id == kernel_id)
_stmt = sa.select(KernelRow).where(KernelRow.id == kernel_id).with_for_update()
kernel_row = cast(KernelRow | None, await db_session.scalar(_stmt))
if kernel_row is None:
raise KernelNotFound(f"Kernel not found (id:{kernel_id})")
Expand Down
7 changes: 6 additions & 1 deletion src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,12 @@ async def get_session_to_determine_status(
.where(SessionRow.id == session_id)
.options(
selectinload(SessionRow.kernels).options(
load_only(KernelRow.status, KernelRow.cluster_role, KernelRow.status_info)
load_only(
KernelRow.status,
KernelRow.cluster_role,
KernelRow.status_info,
KernelRow.occupied_slots,
)
),
)
)
Expand Down
96 changes: 67 additions & 29 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import uuid
import zlib
from collections import defaultdict
from collections.abc import Mapping
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
from datetime import datetime
from decimal import Decimal
from io import BytesIO
Expand All @@ -23,9 +23,7 @@
Dict,
List,
Literal,
MutableMapping,
Optional,
Sequence,
Tuple,
TypeAlias,
Union,
Expand Down Expand Up @@ -3099,7 +3097,7 @@ async def _set_status(db_session: AsyncSession) -> None:
)

await execute_with_txn_retry(_set_status, self.db.begin_session, db_conn)
await self.set_status_updatable_session(session_id)
await self.set_status_updatable_session([session_id])

async def mark_kernel_pulling(
self,
Expand All @@ -3121,7 +3119,7 @@ async def _transit_status(db_session: AsyncSession) -> bool:

transited = await execute_with_txn_retry(_transit_status, self.db.begin_session, db_conn)
if transited:
await self.set_status_updatable_session(session_id)
await self.set_status_updatable_session([session_id])

async def mark_kernel_running(
self,
Expand Down Expand Up @@ -3157,21 +3155,9 @@ async def _get_and_transit(db_session: AsyncSession) -> bool:

transited = await execute_with_txn_retry(_get_and_transit, self.db.begin_session, db_conn)

async def _update_session(db_session: AsyncSession) -> None:
_stmt = sa.select(SessionRow).where(SessionRow.id == session_id).with_for_update()
session_row = cast(SessionRow | None, await db_session.scalar(_stmt))
if session_row is None:
return
session_occupying_slots = ResourceSlot.from_json({**session_row.occupying_slots})
session_occupying_slots.sync_keys(actual_allocs)
for key, val in session_occupying_slots.items():
session_occupying_slots[key] = str(Decimal(val) + Decimal(actual_allocs[key]))
session_row.occupying_slots = session_occupying_slots

if transited:
await execute_with_txn_retry(_update_session, self.db.begin_session, db_conn)
self._kernel_actual_allocated_resources[kernel_id] = actual_allocs
await self.set_status_updatable_session(session_id)
await self.set_status_updatable_session([session_id])

async def mark_kernel_terminated(
self,
Expand Down Expand Up @@ -3242,9 +3228,9 @@ async def _recalc(db_session: AsyncSession) -> None:
# it may take a while to fetch stats from Redis.

# await self.sync_kernel_stats([kernel_id])
await self.set_status_updatable_session(session_id)
await self.set_status_updatable_session([session_id])

async def transit_session_status(
async def _transit_session_status(
self,
db_conn: SAConnection,
session_id: SessionId,
Expand All @@ -3257,11 +3243,27 @@ async def _get_and_transit(
) -> tuple[SessionRow, bool]:
session_row = await SessionRow.get_session_to_determine_status(db_session, session_id)
transited = session_row.determine_and_set_status(status_changed_at=now)

def _calculate_session_occupied_slots(session_row: SessionRow):
session_occupying_slots = ResourceSlot.from_json({**session_row.occupying_slots})
for row in session_row.kernels:
kernel_row = cast(KernelRow, row)
kernel_allocs = kernel_row.occupied_slots
session_occupying_slots.sync_keys(kernel_allocs)
for key, val in session_occupying_slots.items():
session_occupying_slots[key] = str(
Decimal(val) + Decimal(kernel_allocs[key])
)
session_row.occupying_slots = session_occupying_slots

match session_row.status:
case SessionStatus.PREPARING | SessionStatus.RUNNING:
_calculate_session_occupied_slots(session_row)
return session_row, transited

return await execute_with_txn_retry(_get_and_transit, self.db.begin_session, db_conn)

async def post_status_transition(
async def _post_status_transition(
self,
session_row: SessionRow,
) -> None:
Expand Down Expand Up @@ -3293,14 +3295,42 @@ async def post_status_transition(
case _:
pass

async def set_status_updatable_session(self, session_id: SessionId) -> None:
async def transit_session_status(
self,
session_ids: Iterable[SessionId],
status_changed_at: datetime | None = None,
) -> list[SessionRow]:
now = status_changed_at or datetime.now(tzutc())
transited_sessions: list[SessionRow] = []
async with self.db.connect() as db_conn:
for sid in session_ids:
row, is_transited = await self._transit_session_status(db_conn, sid, now)
if is_transited:
transited_sessions.append(row)
for row in transited_sessions:
await self._post_status_transition(row)
return transited_sessions

async def set_status_updatable_session(self, session_ids: Iterable[SessionId]) -> None:
sadd_session_ids_script = textwrap.dedent("""
local key = KEYS[1]
local values = ARGV
return redis.call('SADD', key, unpack(values))
""")
try:
await redis_helper.execute(
await redis_helper.execute_script(
self.redis_stat,
lambda r: r.sadd("session_status_update", msgpack.packb(session_id)),
"session_status_update",
sadd_session_ids_script,
["session_status_update"],
[sid.bytes for sid in session_ids],
)
except redis.exceptions.ResponseError:
log.warning("Failed to update session status to redis, skip.")
except (
redis.exceptions.RedisError,
redis.exceptions.RedisClusterException,
redis.exceptions.ChildDeadlockedError,
) as e:
log.warning(f"Failed to update session status to redis, skip. (e:{repr(e)})")

async def get_status_updatable_sessions(self) -> list[SessionId]:
pop_all_session_id_script = textwrap.dedent("""
Expand All @@ -3316,14 +3346,22 @@ async def get_status_updatable_sessions(self) -> list[SessionId]:
["session_status_update"],
[],
)
except redis.exceptions.ResponseError:
log.warning("Failed to fetch session data from redis, skip.")
except (
redis.exceptions.RedisError,
redis.exceptions.RedisClusterException,
redis.exceptions.ChildDeadlockedError,
) as e:
log.warning(f"Failed to fetch session status data from redis, skip. (e:{repr(e)})")
return []
raw_result = cast(list[bytes], raw_result)

result: list[SessionId] = []
for raw_session_id in raw_result:
result.append(SessionId(msgpack.unpackb(raw_session_id)))
try:
result.append(SessionId(uuid.UUID(bytes=raw_session_id)))
except (ValueError, SyntaxError):
log.warning(f"Cannot parse session id, skip. (id:{raw_session_id})")
continue
return result

async def _get_user_email(
Expand Down
11 changes: 1 addition & 10 deletions src/ai/backend/manager/scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,16 +1600,7 @@ async def update_session_status(
) -> None:
log.debug("update_session_status(): triggered")
candidates = await self.registry.get_status_updatable_sessions()

async def _transit(session_id: SessionId):
async with self.db.connect() as db_conn:
row, is_transited = await self.registry.transit_session_status(db_conn, session_id)
if is_transited:
await self.registry.post_status_transition(row)

async with aiotools.TaskGroup() as tg:
for session_id in candidates:
tg.create_task(_transit(session_id))
await self.registry.transit_session_status(candidates)

async def start_session(
self,
Expand Down

0 comments on commit 6ad9a17

Please sign in to comment.