Skip to content

Commit

Permalink
chore: Impl status-update interface for Session/Kernel ORM objects (#…
Browse files Browse the repository at this point in the history
…2310)

## 1. DB lifetime Convention
As #2088 has been merged, we can control the lifetime of DB connections and DB transactions granularly. I suggest how we should manage DB connections and transactions.
### Manager Entrypoints
There are two main entrypoints of the code paths in the Manager.
- API handlers (webapps)
- Event dispatch handlers (consumers, subscribers)
These entrypoints invoke inner layers like AgentRegistry and the database model functions.
### Connection
Generally, it is better to reuse a single connection to perform multiple transactions.
- Ideally, we should keep only one database connection for each code path starting from an entrypoint.
- However, connection reuse should be implemented via connection pools, instead of the application logic, particularly when it is not feasible to predict how long an operation will take in the application (like Backend.AI).
### Mutations
All database mutations should be expressed as explicit transactions.
- If there are business logic branches between database updates, those codes and queries should go inside a single transaction.
- If the code inside a transaction involves a long-running interaction with other systems (e.g., agent RPC, storage-proxy operations), it should be split out from the transaction block. There may be multiple ways to handle such situation:
  - Introduce a state machine to mark the transitional states and the target states. It is also better to acquire/release transient database connections to make state transitions, instead of holding a single connection for an arbitrarily long time.
  - Defer the long-running operation as a background task and return the bgtask ID as the API response for further client-side tracking.

## 2. About ORM
SQLAlchemy ORM objects can be updated in a session which the ORM object is fetched from.
```python
# Example
with Session() as session:
    kernel_row = session.scalar(select(Kernel).where(Kernel.id == "kernel-id"))
    kernel_row.name = "new-name"
    session.commit()
# UPDATE SQL queried, the name of kernel record has been updated to "new-name" in DB
# and the `kernel_row` object has also been updated.
assert kernel_row.name == "new-name"  # No assertion error
```
- SQLAlchemy Session should be opened before fetching any ORM object.
- Not recommend passing any SQLAlchemy Engine, Connection or Session object to model layer APIs unless it is an ORM-object fetch function. Fetch functions need any of SQLAlchemy Engine, Connection or Session object but update functions don't need them. We can simply assign values to ORM attributes like the code above.

## 3. Set-state & Transit-state
- **Set-state API** is called by external requests, such as client requests. Set-state API changes the current state to a state specified by the API parameter regardless of any condition or any rule defined in the state machine (We can set conditions to allow or block certain states **BEFORE calling** the Set-state API)
- **Transit-state API** is called by external requests or the state machine itself. This API does not have a "state" parameter. When this API is called, the state machine determines its state by own conditions or rules and moves to the determined state.

**Checklist:** (if applicable)

- [x] Milestone metadata specifying the target backport version
- [x] Documentation
  - Contents in the `docs` directory
  - docstrings in public interfaces and type annotations

<!-- readthedocs-preview sorna start -->
----
📚 Documentation preview 📚: https://sorna--2310.org.readthedocs.build/en/2310/

<!-- readthedocs-preview sorna end -->

<!-- readthedocs-preview sorna-ko start -->
----
📚 Documentation preview 📚: https://sorna-ko--2310.org.readthedocs.build/ko/2310/

<!-- readthedocs-preview sorna-ko end -->
  • Loading branch information
fregataa committed Jul 11, 2024
1 parent 845b1fe commit 55d44af
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 5 deletions.
62 changes: 60 additions & 2 deletions src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
import enum
import logging
import uuid
from collections.abc import Mapping
from contextlib import asynccontextmanager as actxmgr
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Mapping,
Optional,
Sequence,
Type,
TypedDict,
TypeVar,
cast,
)

import graphene
Expand Down Expand Up @@ -51,6 +52,7 @@
KernelCreationFailed,
KernelDestructionFailed,
KernelExecutionFailed,
KernelNotFound,
KernelRestartFailed,
SessionNotFound,
)
Expand All @@ -77,7 +79,7 @@
from .minilang.ordering import ColumnMapType, QueryOrderParser
from .minilang.queryfilter import FieldSpecType, QueryFilterParser, enum_field_getter
from .user import users
from .utils import ExtendedAsyncSAEngine, execute_with_retry, sql_json_merge
from .utils import ExtendedAsyncSAEngine, JSONCoalesceExpr, execute_with_retry, sql_json_merge

if TYPE_CHECKING:
from .gql import GraphQueryContext
Expand Down Expand Up @@ -643,6 +645,62 @@ async def _query():

return await execute_with_retry(_query)

@classmethod
async def get_kernel_to_update_status(
cls,
db_session: SASession,
kernel_id: KernelId,
) -> KernelRow:
_stmt = sa.select(KernelRow).where(KernelRow.id == kernel_id)
kernel_row = cast(KernelRow | None, await db_session.scalar(_stmt))
if kernel_row is None:
raise KernelNotFound(f"Kernel not found (id:{kernel_id})")
return kernel_row

def transit_status(
self,
status: KernelStatus,
status_info: str | None = None,
status_data: Mapping[str, Any] | JSONCoalesceExpr | None = None,
status_changed_at: datetime | None = None,
) -> bool:
"""
Check whether the transition from a current status to the given status is valid or not.
Set the status if it is valid and return True.
Else, return False.
"""
if status not in KERNEL_STATUS_TRANSITION_MAP[self.status]:
return False
self.set_status(status, status_info, status_data, status_changed_at)
return True

def set_status(
self,
status: KernelStatus,
status_info: str | None = None,
status_data: Mapping[str, Any] | JSONCoalesceExpr | None = None,
status_changed_at: datetime | None = None,
) -> None:
"""
Set the status of the kernel.
"""
now = status_changed_at or datetime.now(tzutc())
if status in (KernelStatus.CANCELLED, KernelStatus.TERMINATED):
self.terminated_at = now
self.status_changed = now
self.status = status
self.status_history = sql_json_merge(
KernelRow.status_history,
(),
{
status.name: now.isoformat(),
},
)
if status_info is not None:
self.status_info = status_info
if status_data is not None:
self.status_data = status_data

@classmethod
async def set_kernel_status(
cls,
Expand Down
79 changes: 78 additions & 1 deletion src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
List,
Optional,
Union,
cast,
)
from uuid import UUID

Expand Down Expand Up @@ -73,7 +74,13 @@
from .minilang.ordering import ColumnMapType, QueryOrderParser
from .minilang.queryfilter import FieldSpecType, QueryFilterParser, enum_field_getter
from .user import UserRow
from .utils import ExtendedAsyncSAEngine, agg_to_array, execute_with_retry, sql_json_merge
from .utils import (
ExtendedAsyncSAEngine,
JSONCoalesceExpr,
agg_to_array,
execute_with_retry,
sql_json_merge,
)

if TYPE_CHECKING:
from sqlalchemy.engine import Row
Expand Down Expand Up @@ -816,6 +823,76 @@ async def _check_and_update() -> SessionStatus | None:

return await execute_with_retry(_check_and_update)

@classmethod
async def get_session_to_determine_status(
cls, db_session: SASession, session_id: SessionId
) -> SessionRow:
stmt = (
sa.select(SessionRow)
.where(SessionRow.id == session_id)
.options(
selectinload(SessionRow.kernels).options(
load_only(KernelRow.status, KernelRow.cluster_role, KernelRow.status_info)
),
)
)
session_row = cast(SessionRow | None, await db_session.scalar(stmt))
if session_row is None:
raise SessionNotFound(f"Session not found (id:{session_id})")
return session_row

def determine_and_set_status(
self,
status_info: str | None = None,
status_data: Mapping[str, Any] | JSONCoalesceExpr | None = None,
status_changed_at: datetime | None = None,
) -> bool:
"""
Determine the current status of a session based on its sibling kernels.
If it is possible to transit from the current status to the determined status, set status.
Else, do nothing.
Return True if a transition happened, else return False.
"""

determined_status = determine_session_status(self.kernels)
if determined_status not in SESSION_STATUS_TRANSITION_MAP[self.status]:
return False

self.set_status(determined_status, status_info, status_data, status_changed_at)
return True

def set_status(
self,
status: SessionStatus,
status_info: str | None = None,
status_data: Mapping[str, Any] | JSONCoalesceExpr | None = None,
status_changed_at: datetime | None = None,
) -> None:
"""
Set the status of the session.
"""
now = status_changed_at or datetime.now(tzutc())
if status in (SessionStatus.CANCELLED, SessionStatus.TERMINATED):
self.terminated_at = now
self.status = status
self.status_history = sql_json_merge(
SessionRow.status_history,
(),
{
status.name: now.isoformat(),
},
)
if status_data is not None:
self.status_data = status_data

_status_info: str | None = None
if status_info is None:
_status_info = self.main_kernel.status_info
else:
_status_info = status_info
if _status_info is not None:
self.status_info = _status_info

@staticmethod
async def set_session_status(
db: ExtendedAsyncSAEngine,
Expand Down
8 changes: 6 additions & 2 deletions src/ai/backend/manager/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Mapping,
ParamSpec,
Tuple,
TypeAlias,
TypeVar,
overload,
)
Expand Down Expand Up @@ -412,13 +413,16 @@ async def execute_with_retry(txn_func: Callable[[], Awaitable[TQueryResult]]) ->
return result


JSONCoalesceExpr: TypeAlias = sa.sql.elements.BinaryExpression


def sql_json_merge(
col,
key: Tuple[str, ...],
obj: Mapping[str, Any],
*,
_depth: int = 0,
):
) -> JSONCoalesceExpr:
"""
Generate an SQLAlchemy column update expression that merges the given object with
the existing object at a specific (nested) key of the given JSONB column,
Expand Down Expand Up @@ -454,7 +458,7 @@ def sql_json_increment(
*,
parent_updates: Mapping[str, Any] = None,
_depth: int = 0,
):
) -> JSONCoalesceExpr:
"""
Generate an SQLAlchemy column update expression that increments the value at a specific
(nested) key of the given JSONB column,
Expand Down

0 comments on commit 55d44af

Please sign in to comment.