Skip to content

Commit

Permalink
follow-up update
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Jun 18, 2024
1 parent 63c8574 commit 3753e07
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3096,10 +3096,8 @@ async def _transit_status(db_session: AsyncSession) -> bool:
await db_session.commit()
return is_pulling

update_successful = await execute_with_txn_retry(
_transit_status, self.db.begin_session, db_conn
)
if update_successful:
transited = await execute_with_txn_retry(_transit_status, self.db.begin_session, db_conn)
if transited:
await self.set_status_updatable_session(session_id)

async def mark_kernel_running(
Expand Down Expand Up @@ -3134,9 +3132,7 @@ async def _get_and_transit(db_session: AsyncSession) -> bool:
await db_session.commit()
return is_running

update_successful = await execute_with_txn_retry(
_get_and_transit, self.db.begin_session, db_conn
)
transited = await execute_with_txn_retry(_get_and_transit, self.db.begin_session, db_conn)

async def _update_session(db_session: AsyncSession) -> None:
_stmt = sa.select(SessionRow).where(SessionRow.id == session_id).with_for_update()
Expand All @@ -3149,7 +3145,7 @@ async def _update_session(db_session: AsyncSession) -> None:
session_occupying_slots[key] = str(Decimal(val) + Decimal(actual_allocs[key]))
session_row.occupying_slots = session_occupying_slots

if update_successful:
if transited:
await execute_with_txn_retry(_update_session, self.db.begin_session, db_conn)
self._kernel_actual_allocated_resources[kernel_id] = actual_allocs
await self.set_status_updatable_session(session_id)
Expand Down Expand Up @@ -3226,22 +3222,27 @@ async def _recalc(db_session: AsyncSession) -> None:
await self.set_status_updatable_session(session_id)

async def transit_session_status(
self, db_conn: SAConnection, session_id: SessionId, update_time: datetime | None = None
) -> SessionStatus | None:
now = update_time or datetime.now(tzutc())
self,
db_conn: SAConnection,
session_id: SessionId,
status_changed_at: datetime | None = None,
) -> None:
now = status_changed_at or datetime.now(tzutc())

async def _get_and_transit(
db_session: AsyncSession,
) -> tuple[SessionRow, SessionStatus | None]:
) -> tuple[SessionRow, bool]:
session_row = await SessionRow.get_session_to_determine_status(db_session, session_id)
new_status, _ = session_row.determine_and_set_status(status_changed_at=now)
return session_row, new_status
transited = session_row.determine_and_set_status(status_changed_at=now)
return session_row, transited

session_row, new_status = await execute_with_txn_retry(
session_row, transited = await execute_with_txn_retry(
_get_and_transit, self.db.begin_session, db_conn
)

match new_status:
if not transited:
return
match session_row.status:
case SessionStatus.RUNNING:
log.debug(
"Producing SessionStartedEvent({}, {})",
Expand All @@ -3268,7 +3269,6 @@ async def _get_and_transit(
)
case _:
pass
return new_status

async def set_status_updatable_session(self, session_id: SessionId) -> None:
await redis_helper.execute(
Expand Down

0 comments on commit 3753e07

Please sign in to comment.