diff --git a/changes/3201.feature.md b/changes/3201.feature.md new file mode 100644 index 00000000000..8333b460be5 --- /dev/null +++ b/changes/3201.feature.md @@ -0,0 +1 @@ +Change the type of `status_history` from a mapping of status and timestamps to a list of log entries containing status and timestamps, to preserve timestamps when revisiting session/kernel statuses (e.g., after session restarts). \ No newline at end of file diff --git a/src/ai/backend/client/cli/session/lifecycle.py b/src/ai/backend/client/cli/session/lifecycle.py index ae2f3e16879..6abb5e4c6ac 100644 --- a/src/ai/backend/client/cli/session/lifecycle.py +++ b/src/ai/backend/client/cli/session/lifecycle.py @@ -16,7 +16,6 @@ import inquirer import treelib from async_timeout import timeout -from dateutil.parser import isoparse from dateutil.tz import tzutc from faker import Faker from humanize import naturalsize @@ -25,6 +24,9 @@ from ai.backend.cli.main import main from ai.backend.cli.params import CommaSeparatedListType, OptionalType from ai.backend.cli.types import ExitCode, Undefined, undefined +from ai.backend.client.cli.extensions import pass_ctx_obj +from ai.backend.client.cli.types import CLIContext +from ai.backend.client.utils import get_latest_timestamp_for_status from ai.backend.common.arch import DEFAULT_IMAGE_ARCH from ai.backend.common.types import ClusterMode @@ -873,7 +875,7 @@ def logs(session_id: str, kernel: str | None) -> None: Shows the full console log of a compute session. \b - SESSID: Session ID or its alias given when creating the session. + SESSION_ID_OR_NAME: Session ID or its alias given when creating the session. """ _kernel_id = uuid.UUID(kernel) if kernel is not None else None with Session() as session: @@ -889,48 +891,63 @@ def logs(session_id: str, kernel: str | None) -> None: sys.exit(ExitCode.FAILURE) -@session.command("status-history") -@click.argument("session_id", metavar="SESSID") -def status_history(session_id: str) -> None: +@session.command() +@pass_ctx_obj +@click.argument("session_id_or_name", metavar="SESSION_ID_OR_NAME") +def status_history(ctx: CLIContext, session_id_or_name: str) -> None: """ Shows the status transition history of the compute session. \b SESSID: Session ID or its alias given when creating the session. """ - with Session() as session: - print_wait("Retrieving status history...") - kernel = session.ComputeSession(session_id) - try: - status_history = kernel.get_status_history().get("result") - print_info(f"status_history: {status_history}") - if (preparing := status_history.get("preparing")) is None: - result = { - "result": { - "seconds": 0, - "microseconds": 0, - }, - } - elif (terminated := status_history.get("terminated")) is None: - alloc_time_until_now: timedelta = datetime.now(tzutc()) - isoparse(preparing) - result = { - "result": { - "seconds": alloc_time_until_now.seconds, - "microseconds": alloc_time_until_now.microseconds, - }, - } - else: - alloc_time: timedelta = isoparse(terminated) - isoparse(preparing) - result = { - "result": { - "seconds": alloc_time.seconds, - "microseconds": alloc_time.microseconds, - }, - } - print_done(f"Actual Resource Allocation Time: {result}") - except Exception as e: - print_error(e) - sys.exit(ExitCode.FAILURE) + + async def cmd_main() -> None: + async with AsyncSession() as session: + print_wait("Retrieving status history...") + + kernel = session.ComputeSession(session_id_or_name) + try: + resp = await kernel.get_status_history() + status_history = resp["result"] + + prev_time = None + + for status_record in status_history: + timestamp = datetime.fromisoformat(status_record["timestamp"]) + + if prev_time: + time_diff = timestamp - prev_time + status_record["time_elapsed"] = str(time_diff) + + prev_time = timestamp + + ctx.output.print_list( + status_history, + [FieldSpec("status"), FieldSpec("timestamp"), FieldSpec("time_elapsed")], + ) + + if ( + preparing := get_latest_timestamp_for_status(status_history, "PREPARING") + ) is None: + elapsed = timedelta() + elif ( + terminated := get_latest_timestamp_for_status(status_history, "TERMINATED") + ) is None: + elapsed = datetime.now(tzutc()) - preparing + else: + elapsed = terminated - preparing + + print_done(f"Actual Resource Allocation Time: {elapsed.total_seconds()}") + except Exception as e: + print_error(e) + sys.exit(ExitCode.FAILURE) + + try: + asyncio.run(cmd_main()) + except Exception as e: + print_error(e) + sys.exit(ExitCode.FAILURE) @session.command() diff --git a/src/ai/backend/client/output/fields.py b/src/ai/backend/client/output/fields.py index ab45df52e25..b37448d6ef5 100644 --- a/src/ai/backend/client/output/fields.py +++ b/src/ai/backend/client/output/fields.py @@ -186,6 +186,8 @@ FieldSpec("created_user_id"), FieldSpec("status"), FieldSpec("status_info"), + FieldSpec("status_history"), + FieldSpec("status_history_log"), FieldSpec("status_data", formatter=nested_dict_formatter), FieldSpec("status_changed", "Last Updated"), FieldSpec("created_at"), diff --git a/src/ai/backend/client/utils.py b/src/ai/backend/client/utils.py index fac5bc1c752..944c0a4203e 100644 --- a/src/ai/backend/client/utils.py +++ b/src/ai/backend/client/utils.py @@ -1,6 +1,10 @@ import io import os import textwrap +from datetime import datetime +from typing import Optional + +from dateutil.parser import parse as dtparse def dedent(text: str) -> str: @@ -53,3 +57,13 @@ def readinto1(self, *args, **kwargs): count = super().readinto1(*args, **kwargs) self.tqdm.set_postfix(file=self._filename, refresh=False) self.tqdm.update(count) + + +def get_latest_timestamp_for_status( + status_history: list[dict[str, str]], + status: str, +) -> Optional[datetime]: + for item in reversed(status_history): + if item["status"] == status: + return dtparse(item["timestamp"]) + return None diff --git a/src/ai/backend/manager/api/resource.py b/src/ai/backend/manager/api/resource.py index 4489d28454d..61a80b13781 100644 --- a/src/ai/backend/manager/api/resource.py +++ b/src/ai/backend/manager/api/resource.py @@ -466,7 +466,7 @@ async def _pipe_builder(r: Redis) -> RedisPipeline: "status": row["status"].name, "status_info": row["status_info"], "status_changed": str(row["status_changed"]), - "status_history": row["status_history"] or {}, + "status_history": row["status_history"], "cluster_mode": row["cluster_mode"], } if group_id not in objs_per_group: diff --git a/src/ai/backend/manager/api/schema.graphql b/src/ai/backend/manager/api/schema.graphql index cfdf9a29eae..7e6ed21ccfa 100644 --- a/src/ai/backend/manager/api/schema.graphql +++ b/src/ai/backend/manager/api/schema.graphql @@ -306,6 +306,9 @@ type ComputeContainer implements Item { registry: String status: String status_changed: DateTime + + """Added in 24.12.0.""" + status_history: JSONString status_info: String status_data: JSONString created_at: DateTime @@ -925,7 +928,10 @@ type ComputeSession implements Item { status_changed: DateTime status_info: String status_data: JSONString - status_history: JSONString + status_history: JSONString @deprecated(reason: "Deprecated since 25.1.0; use `status_history_log`") + + """Added in 25.1.0.""" + status_history_log: JSONString created_at: DateTime terminated_at: DateTime starts_at: DateTime @@ -1204,6 +1210,8 @@ type ComputeSessionNode implements Node { status: String status_info: String status_data: JSONString + + """Added in 25.1.0.""" status_history: JSONString created_at: DateTime terminated_at: DateTime diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index d119ce39e89..eacd6567eb4 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -2335,6 +2335,35 @@ async def get_container_logs( return web.json_response(resp, status=200) +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + t.Key("owner_access_key", default=None): t.Null | t.String, + }) +) +async def get_status_history(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app["_root.context"] + session_name: str = request.match_info["session_name"] + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + log.info( + "GET_STATUS_HISTORY (ak:{}/{}, s:{})", requester_access_key, owner_access_key, session_name + ) + resp: dict[str, Mapping] = {"result": {}} + + async with root_ctx.db.begin_readonly_session() as db_sess: + compute_session = await SessionRow.get_session( + db_sess, + session_name, + owner_access_key, + allow_stale=True, + kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS, + ) + resp["result"] = compute_session.status_history + + return web.json_response(resp, status=200) + + @server_status_required(READ_ALLOWED) @auth_required @check_api_params( @@ -2474,6 +2503,7 @@ def create_app( app.router.add_route("GET", "/{session_name}/direct-access-info", get_direct_access_info) ) cors.add(app.router.add_route("GET", "/{session_name}/logs", get_container_logs)) + cors.add(app.router.add_route("GET", "/{session_name}/status-history", get_status_history)) cors.add(app.router.add_route("POST", "/{session_name}/rename", rename_session)) cors.add(app.router.add_route("POST", "/{session_name}/interrupt", interrupt)) cors.add(app.router.add_route("POST", "/{session_name}/complete", complete)) diff --git a/src/ai/backend/manager/models/alembic/versions/8c8e90aebacd_replace_status_history_to_list.py b/src/ai/backend/manager/models/alembic/versions/8c8e90aebacd_replace_status_history_to_list.py new file mode 100644 index 00000000000..942582889f5 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/8c8e90aebacd_replace_status_history_to_list.py @@ -0,0 +1,109 @@ +"""Replace sessions, kernels's status_history's type map with list +Revision ID: 8c8e90aebacd +Revises: 0bb88d5a46bf +Create Date: 2024-12-05 11:19:23.075014 +""" + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8c8e90aebacd" +down_revision = "0bb88d5a46bf" +branch_labels = None +depends_on = None + + +def upgrade(): + op.execute( + """ + WITH data AS ( + SELECT id, + (jsonb_each(status_history)).key AS status, + (jsonb_each(status_history)).value AS timestamp + FROM kernels + WHERE jsonb_typeof(status_history) = 'object' + ) + UPDATE kernels + SET status_history = ( + SELECT jsonb_agg( + jsonb_build_object('status', status, 'timestamp', timestamp) + ORDER BY timestamp + ) + FROM data + WHERE data.id = kernels.id + AND jsonb_typeof(kernels.status_history) = 'object' + ); + """ + ) + op.execute("UPDATE kernels SET status_history = '[]'::jsonb WHERE status_history IS NULL;") + op.alter_column("kernels", "status_history", nullable=False, default=[]) + + op.execute( + """ + WITH data AS ( + SELECT id, + (jsonb_each(status_history)).key AS status, + (jsonb_each(status_history)).value AS timestamp + FROM sessions + WHERE jsonb_typeof(status_history) = 'object' + ) + UPDATE sessions + SET status_history = ( + SELECT jsonb_agg( + jsonb_build_object('status', status, 'timestamp', timestamp) + ORDER BY timestamp + ) + FROM data + WHERE data.id = sessions.id + AND jsonb_typeof(sessions.status_history) = 'object' + ); + """ + ) + op.execute("UPDATE sessions SET status_history = '[]'::jsonb WHERE status_history IS NULL;") + op.alter_column("sessions", "status_history", nullable=False, default=[]) + + +def downgrade(): + op.execute( + """ + WITH data AS ( + SELECT id, + jsonb_object_agg( + elem->>'status', elem->>'timestamp' + ) AS new_status_history + FROM kernels, + jsonb_array_elements(status_history) AS elem + WHERE jsonb_typeof(status_history) = 'array' + GROUP BY id + ) + UPDATE kernels + SET status_history = data.new_status_history + FROM data + WHERE data.id = kernels.id + AND jsonb_typeof(kernels.status_history) = 'array'; + """ + ) + op.alter_column("kernels", "status_history", nullable=True, default=None) + op.execute("UPDATE kernels SET status_history = NULL WHERE status_history = '[]'::jsonb;") + + op.execute( + """ + WITH data AS ( + SELECT id, + jsonb_object_agg( + elem->>'status', elem->>'timestamp' + ) AS new_status_history + FROM sessions, + jsonb_array_elements(status_history) AS elem + WHERE jsonb_typeof(status_history) = 'array' + GROUP BY id + ) + UPDATE sessions + SET status_history = data.new_status_history + FROM data + WHERE data.id = sessions.id + AND jsonb_typeof(sessions.status_history) = 'array'; + """ + ) + op.alter_column("sessions", "status_history", nullable=True, default=None) + op.execute("UPDATE sessions SET status_history = NULL WHERE status_history = '[]'::jsonb;") diff --git a/src/ai/backend/manager/models/gql_models/kernel.py b/src/ai/backend/manager/models/gql_models/kernel.py index cbb52db6858..8fb9ade2363 100644 --- a/src/ai/backend/manager/models/gql_models/kernel.py +++ b/src/ai/backend/manager/models/gql_models/kernel.py @@ -5,6 +5,7 @@ TYPE_CHECKING, Any, Self, + cast, ) import graphene @@ -14,14 +15,15 @@ from ai.backend.common import msgpack, redis_helper from ai.backend.common.types import AgentId, KernelId, SessionId -from ai.backend.manager.models.base import ( + +from ..base import ( batch_multiresult_in_scalar_stream, batch_multiresult_in_session, ) - from ..gql_relay import AsyncNode, Connection from ..kernel import KernelRow, KernelStatus from ..user import UserRole +from ..utils import get_latest_timestamp_for_status from .image import ImageNode if TYPE_CHECKING: @@ -113,7 +115,12 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self: hide_agents = False else: hide_agents = ctx.local_config["manager"]["hide-agents"] - status_history = row.status_history or {} + + timestamp = get_latest_timestamp_for_status( + cast(list[dict[str, str]], row.status_history), KernelStatus.SCHEDULED + ) + scheduled_at = str(timestamp) if timestamp is not None else None + return KernelNode( id=row.id, # auto-converted to Relay global ID row_id=row.id, @@ -129,7 +136,7 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self: created_at=row.created_at, terminated_at=row.terminated_at, starts_at=row.starts_at, - scheduled_at=status_history.get(KernelStatus.SCHEDULED.name), + scheduled_at=scheduled_at, occupied_slots=row.occupied_slots.to_json(), agent_id=row.agent if not hide_agents else None, agent_addr=row.agent_addr if not hide_agents else None, diff --git a/src/ai/backend/manager/models/gql_models/session.py b/src/ai/backend/manager/models/gql_models/session.py index 14e8cb9134a..788ee1ced2b 100644 --- a/src/ai/backend/manager/models/gql_models/session.py +++ b/src/ai/backend/manager/models/gql_models/session.py @@ -56,7 +56,7 @@ get_permission_ctx, ) from ..user import UserRole -from ..utils import execute_with_txn_retry +from ..utils import execute_with_txn_retry, get_latest_timestamp_for_status from .kernel import KernelConnection, KernelNode if TYPE_CHECKING: @@ -174,7 +174,7 @@ class Meta: # status_changed = GQLDateTime() # FIXME: generated attribute status_info = graphene.String() status_data = graphene.JSONString() - status_history = graphene.JSONString() + status_history = graphene.JSONString(description="Added in 25.1.0.") created_at = GQLDateTime() terminated_at = GQLDateTime() starts_at = GQLDateTime() @@ -225,7 +225,11 @@ def from_row( permissions: Optional[Iterable[ComputeSessionPermission]] = None, ) -> Self: status_history = row.status_history or {} - raw_scheduled_at = status_history.get(SessionStatus.SCHEDULED.name) + timestamp = get_latest_timestamp_for_status( + cast(list[dict[str, str]], status_history), SessionStatus.SCHEDULED + ) + scheduled_at = str(timestamp) if timestamp is not None else None + result = cls( # identity id=row.id, # auto-converted to Relay global ID @@ -251,9 +255,7 @@ def from_row( created_at=row.created_at, starts_at=row.starts_at, terminated_at=row.terminated_at, - scheduled_at=datetime.fromisoformat(raw_scheduled_at) - if raw_scheduled_at is not None - else None, + scheduled_at=scheduled_at, startup_command=row.startup_command, result=row.result.name, # resources diff --git a/src/ai/backend/manager/models/kernel.py b/src/ai/backend/manager/models/kernel.py index 623f12c70a3..48731ebcc48 100644 --- a/src/ai/backend/manager/models/kernel.py +++ b/src/ai/backend/manager/models/kernel.py @@ -47,6 +47,7 @@ VFolderMount, ) from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.models.minilang import JSONArrayFieldItem from ..api.exceptions import ( BackendError, @@ -79,11 +80,16 @@ from .gql_models.image import ImageNode from .group import groups from .image import ImageRow -from .minilang import JSONFieldItem from .minilang.ordering import ColumnMapType, QueryOrderParser from .minilang.queryfilter import FieldSpecType, QueryFilterParser, enum_field_getter from .user import users -from .utils import ExtendedAsyncSAEngine, JSONCoalesceExpr, execute_with_retry, sql_json_merge +from .utils import ( + ExtendedAsyncSAEngine, + JSONCoalesceExpr, + execute_with_retry, + get_latest_timestamp_for_status, + sql_append_dict_to_list, +) if TYPE_CHECKING: from .gql import GraphQueryContext @@ -524,7 +530,14 @@ class KernelRow(Base): # // used to prevent duplication of SessionTerminatedEvent # } # } - status_history = sa.Column("status_history", pgsql.JSONB(), nullable=True, default=sa.null()) + status_history = sa.Column("status_history", pgsql.JSONB(), nullable=False, default=[]) + # status_history records all status changes + # e.g) + # [ + # {"status: "PENDING", "timestamp": "2022-10-22T10:22:30"}, + # {"status: "SCHEDULED", "timestamp": "2022-10-22T11:40:30"}, + # {"status: "PREPARING", "timestamp": "2022-10-25T10:22:30"} + # ] callback_url = sa.Column("callback_url", URLColumn, nullable=True, default=sa.null()) startup_command = sa.Column("startup_command", sa.Text, nullable=True) result = sa.Column( @@ -676,10 +689,8 @@ def set_status( self.terminated_at = now self.status_changed = now self.status = status - self.status_history = { - **self.status_history, - status.name: now.isoformat(), - } + self.status_history += {"status": status.name, "timestamp": now.isoformat()} + if status_info is not None: self.status_info = status_info if status_data is not None: @@ -706,12 +717,9 @@ async def set_kernel_status( data = { "status": status, "status_changed": now, - "status_history": sql_json_merge( - kernels.c.status_history, - (), - { - status.name: now.isoformat(), # ["PULLING", "CREATING"] - }, + "status_history": sql_append_dict_to_list( + KernelRow.status_history, + {"status": status.name, "timestamp": now.isoformat()}, # ["PULLING", "CREATING"] ), } if status_data is not None: @@ -757,12 +765,9 @@ async def _update() -> bool: if update_data is None: update_values = { "status": new_status, - "status_history": sql_json_merge( + "status_history": sql_append_dict_to_list( KernelRow.status_history, - (), - { - new_status.name: now.isoformat(), - }, + {"status": new_status.name, "timestamp": now.isoformat()}, ), } else: @@ -876,6 +881,7 @@ class Meta: # status status = graphene.String() status_changed = GQLDateTime() + status_history = graphene.JSONString(description="Added in 24.12.0.") status_info = graphene.String() status_data = graphene.JSONString() created_at = GQLDateTime() @@ -904,7 +910,9 @@ def parse_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Mapping[str, Any]: hide_agents = False else: hide_agents = ctx.local_config["manager"]["hide-agents"] - status_history = row.status_history or {} + status_history = cast(list[dict[str, str]], row.status_history) + scheduled_at = get_latest_timestamp_for_status(status_history, KernelStatus.SCHEDULED) + return { # identity "id": row.id, @@ -925,12 +933,13 @@ def parse_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Mapping[str, Any]: # status "status": row.status.name, "status_changed": row.status_changed, + "status_history": row.status_history, "status_info": row.status_info, "status_data": row.status_data, "created_at": row.created_at, "terminated_at": row.terminated_at, "starts_at": row.starts_at, - "scheduled_at": status_history.get(KernelStatus.SCHEDULED.name), + "scheduled_at": scheduled_at, "occupied_slots": row.occupied_slots.to_json(), # resources "agent": row.agent if not hide_agents else None, @@ -984,7 +993,14 @@ async def resolve_abusing_report( "created_at": ("created_at", dtparse), "status_changed": ("status_changed", dtparse), "terminated_at": ("terminated_at", dtparse), - "scheduled_at": (JSONFieldItem("status_history", KernelStatus.SCHEDULED.name), dtparse), + "scheduled_at": ( + JSONArrayFieldItem( + column_name="status_history", + conditions={"status": KernelStatus.SCHEDULED.name}, + key_name="timestamp", + ), + dtparse, + ), } _queryorder_colmap: ColumnMapType = { @@ -1001,7 +1017,7 @@ async def resolve_abusing_report( "status_changed": ("status_changed", None), "created_at": ("created_at", None), "terminated_at": ("terminated_at", None), - "scheduled_at": (JSONFieldItem("status_history", KernelStatus.SCHEDULED.name), None), + "scheduled_at": ("scheduled_at", None), } @classmethod diff --git a/src/ai/backend/manager/models/minilang/__init__.py b/src/ai/backend/manager/models/minilang/__init__.py index 74f3c0ea996..ff046868c43 100644 --- a/src/ai/backend/manager/models/minilang/__init__.py +++ b/src/ai/backend/manager/models/minilang/__init__.py @@ -13,6 +13,12 @@ class JSONFieldItem(NamedTuple): key_name: str +class JSONArrayFieldItem(NamedTuple): + column_name: str + conditions: dict[str, str] + key_name: str + + TEnum = TypeVar("TEnum", bound=Enum) @@ -22,10 +28,12 @@ class EnumFieldItem(NamedTuple, Generic[TEnum]): FieldSpecItem = tuple[ - str | ArrayFieldItem | JSONFieldItem | EnumFieldItem, Callable[[str], Any] | None + str | ArrayFieldItem | JSONFieldItem | EnumFieldItem | JSONArrayFieldItem, + Callable[[str], Any] | None, ] OrderSpecItem = tuple[ - str | ArrayFieldItem | JSONFieldItem | EnumFieldItem, Callable[[sa.Column], Any] | None + str | ArrayFieldItem | JSONFieldItem | EnumFieldItem | JSONArrayFieldItem, + Callable[[sa.Column], Any] | None, ] diff --git a/src/ai/backend/manager/models/minilang/ordering.py b/src/ai/backend/manager/models/minilang/ordering.py index 0643392628f..63c54fefa43 100644 --- a/src/ai/backend/manager/models/minilang/ordering.py +++ b/src/ai/backend/manager/models/minilang/ordering.py @@ -5,7 +5,7 @@ from lark import Lark, LarkError, Transformer from lark.lexer import Token -from . import JSONFieldItem, OrderSpecItem, get_col_from_table +from . import JSONArrayFieldItem, JSONFieldItem, OrderSpecItem, get_col_from_table __all__ = ( "ColumnMapType", @@ -56,6 +56,10 @@ def _get_col(self, col_name: str) -> sa.Column: case JSONFieldItem(_col, _key): _column = get_col_from_table(self._sa_table, _col) matched_col = _column.op("->>")(_key) + case JSONArrayFieldItem(_col_name, _conditions, _key_name): + # TODO: Implement this. + pass + # ... case _: raise ValueError("Invalid type of field name", col_name) col = func(matched_col) if func is not None else matched_col diff --git a/src/ai/backend/manager/models/minilang/queryfilter.py b/src/ai/backend/manager/models/minilang/queryfilter.py index 4973ffd54ce..761b63f8d0d 100644 --- a/src/ai/backend/manager/models/minilang/queryfilter.py +++ b/src/ai/backend/manager/models/minilang/queryfilter.py @@ -4,8 +4,16 @@ import sqlalchemy as sa from lark import Lark, LarkError, Transformer, Tree from lark.lexer import Token +from sqlalchemy.dialects.postgresql import JSONB -from . import ArrayFieldItem, EnumFieldItem, FieldSpecItem, JSONFieldItem, get_col_from_table +from . import ( + ArrayFieldItem, + EnumFieldItem, + FieldSpecItem, + JSONArrayFieldItem, + JSONFieldItem, + get_col_from_table, +) __all__ = ( "FieldSpecType", @@ -172,6 +180,34 @@ def build_expr(op: str, col, val): # to retrieve the value used in the expression. col = get_col_from_table(self._sa_table, col_name).op("->>")(obj_key) expr = build_expr(op, col, val) + case JSONArrayFieldItem(col_name, conditions, key_name): + col = get_col_from_table(self._sa_table, col_name) + json_array = sa.func.jsonb_array_elements(col.cast(JSONB)).alias("item") + + condition_list = [] + for key, expected_value in conditions.items(): + condition_list.append( + sa.column("item").op("->>")(key) == expected_value + ) + + element_timestamp = ( + sa.column("item") + .op("->>")(key_name) + .cast(sa.types.TIMESTAMP(timezone=True)) + ) + + combined_conditions = sa.and_(*condition_list) + + subq = ( + sa.select([sa.literal(1)]) + .select_from(json_array) + .where( + sa.and_(combined_conditions, build_expr(op, element_timestamp, val)) + ) + ) + + expr = sa.exists(subq) + case EnumFieldItem(col_name, enum_cls): col = get_col_from_table(self._sa_table, col_name) # allow both key and value of enum to be specified on variable `val` diff --git a/src/ai/backend/manager/models/resource_usage.py b/src/ai/backend/manager/models/resource_usage.py index d8997867d7a..2f3c18cdb3f 100644 --- a/src/ai/backend/manager/models/resource_usage.py +++ b/src/ai/backend/manager/models/resource_usage.py @@ -1,8 +1,9 @@ from __future__ import annotations +import json from datetime import datetime from enum import Enum -from typing import Any, Mapping, Optional, Sequence +from typing import Any, Mapping, Optional, Sequence, cast from uuid import UUID import attrs @@ -14,6 +15,7 @@ from sqlalchemy.orm import joinedload, load_only from ai.backend.common import redis_helper +from ai.backend.common.json import ExtendedJSONEncoder from ai.backend.common.types import RedisConnectionInfo from ai.backend.common.utils import nmget @@ -21,7 +23,7 @@ from .kernel import LIVE_STATUS, RESOURCE_USAGE_KERNEL_STATUSES, KernelRow, KernelStatus from .session import SessionRow from .user import UserRow -from .utils import ExtendedAsyncSAEngine +from .utils import ExtendedAsyncSAEngine, get_latest_timestamp_for_status __all__: Sequence[str] = ( "ResourceGroupUnit", @@ -509,14 +511,20 @@ async def _pipe_builder(r: Redis) -> RedisPipeline: continue stat_map[kern_id] = msgpack.unpackb(raw_stat) - return [ - BaseResourceUsageGroup( + result = [] + for kern in kernels: + timestamp = get_latest_timestamp_for_status( + cast(list[dict[str, str]], kern.status_history), KernelStatus.SCHEDULED + ) + scheduled_at = str(timestamp) if timestamp is not None else None + + resource_usage_group = BaseResourceUsageGroup( kernel_row=kern, project_row=kern.session.group, session_row=kern.session, created_at=kern.created_at, terminated_at=kern.terminated_at, - scheduled_at=kern.status_history.get(KernelStatus.SCHEDULED.name), + scheduled_at=scheduled_at, used_time=kern.used_time, used_days=kern.get_used_days(local_tz), last_stat=stat_map[kern.id], @@ -534,14 +542,15 @@ async def _pipe_builder(r: Redis) -> RedisPipeline: images={kern.image}, agents={kern.agent}, status=kern.status.name, - status_history=kern.status_history, + status_history=json.dumps(kern.status_history, cls=ExtendedJSONEncoder), cluster_mode=kern.cluster_mode, status_info=kern.status_info, group_unit=ResourceGroupUnit.KERNEL, total_usage=parse_resource_usage(kern, stat_map[kern.id]), ) - for kern in kernels - ] + result.append(resource_usage_group) + + return result SESSION_RESOURCE_SELECT_COLS = ( diff --git a/src/ai/backend/manager/models/session.py b/src/ai/backend/manager/models/session.py index 041e5e4b5e8..e0784d23343 100644 --- a/src/ai/backend/manager/models/session.py +++ b/src/ai/backend/manager/models/session.py @@ -89,7 +89,7 @@ from .group import GroupRow from .image import ImageRow from .kernel import ComputeContainer, KernelRow, KernelStatus -from .minilang import ArrayFieldItem, JSONFieldItem +from .minilang import ArrayFieldItem, JSONArrayFieldItem from .minilang.ordering import ColumnMapType, QueryOrderParser from .minilang.queryfilter import FieldSpecType, QueryFilterParser, enum_field_getter from .network import NetworkRow, NetworkType @@ -113,7 +113,9 @@ agg_to_array, execute_with_retry, execute_with_txn_retry, - sql_json_merge, + get_latest_timestamp_for_status, + get_legacy_status_history, + sql_append_dict_to_list, ) if TYPE_CHECKING: @@ -786,7 +788,14 @@ class SessionRow(Base): # // used to prevent duplication of SessionTerminatedEvent # } # } - status_history = sa.Column("status_history", pgsql.JSONB(), nullable=True, default=sa.null()) + status_history = sa.Column("status_history", pgsql.JSONB(), nullable=False, default=[]) + # status_history records all status changes + # e.g) + # [ + # {"status: "PENDING", "timestamp": "2022-10-22T10:22:30"}, + # {"status: "SCHEDULED", "timestamp": "2022-10-22T11:40:30"}, + # {"status: "PREPARING", "timestamp": "2022-10-25T10:22:30"} + # ] callback_url = sa.Column("callback_url", URLColumn, nullable=True, default=sa.null()) startup_command = sa.Column("startup_command", sa.Text, nullable=True) @@ -842,12 +851,7 @@ def main_kernel(self) -> KernelRow: @property def status_changed(self) -> Optional[datetime]: - if self.status_history is None: - return None - try: - return datetime.fromisoformat(self.status_history[self.status.name]) - except KeyError: - return None + return get_latest_timestamp_for_status(self.status_history, self.status) @property def resource_opts(self) -> dict[str, Any]: @@ -961,10 +965,8 @@ def set_status( if status in (SessionStatus.CANCELLED, SessionStatus.TERMINATED): self.terminated_at = now self.status = status - self.status_history = { - **self.status_history, - status.name: now.isoformat(), - } + self.status_history += {"status": status.name, "timestamp": now.isoformat()} + if status_data is not None: self.status_data = status_data @@ -992,11 +994,11 @@ async def set_session_status( now = status_changed_at data = { "status": status, - "status_history": sql_json_merge( + "status_history": sql_append_dict_to_list( SessionRow.status_history, - (), { - status.name: datetime.now(tzutc()).isoformat(), + "status": status.name, + "timestamp": datetime.now(tzutc()).isoformat(), }, ), } @@ -1588,7 +1590,11 @@ class Meta: status_changed = GQLDateTime() status_info = graphene.String() status_data = graphene.JSONString() - status_history = graphene.JSONString() + status_history = graphene.JSONString( + deprecation_reason="Deprecated since 25.1.0; use `status_history_log`" + ) + status_history_log = graphene.JSONString(description="Added in 25.1.0.") + created_at = GQLDateTime() terminated_at = GQLDateTime() starts_at = GQLDateTime() @@ -1629,8 +1635,8 @@ def parse_row(cls, ctx: GraphQueryContext, row: Row) -> Mapping[str, Any]: full_name = getattr(row, "full_name") group_name = getattr(row, "group_name") row = row.SessionRow - status_history = row.status_history or {} - raw_scheduled_at = status_history.get(SessionStatus.SCHEDULED.name) + scheduled_at = get_latest_timestamp_for_status(row.status_history, SessionStatus.SCHEDULED) + return { # identity "id": row.id, @@ -1663,13 +1669,12 @@ def parse_row(cls, ctx: GraphQueryContext, row: Row) -> Mapping[str, Any]: "status_changed": row.status_changed, "status_info": row.status_info, "status_data": row.status_data, - "status_history": status_history, + "status_history": get_legacy_status_history(row.status_history), + "status_history_log": row.status_history, "created_at": row.created_at, "terminated_at": row.terminated_at, "starts_at": row.starts_at, - "scheduled_at": ( - datetime.fromisoformat(raw_scheduled_at) if raw_scheduled_at is not None else None - ), + "scheduled_at": scheduled_at, "startup_command": row.startup_command, "result": row.result.name, # resources @@ -1774,7 +1779,11 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> Mapping[str, "terminated_at": ("sessions_terminated_at", dtparse), "starts_at": ("sessions_starts_at", dtparse), "scheduled_at": ( - JSONFieldItem("sessions_status_history", SessionStatus.SCHEDULED.name), + JSONArrayFieldItem( + column_name="sessions_status_history", + conditions={"status": SessionStatus.SCHEDULED.name}, + key_name="timestamp", + ), dtparse, ), "startup_command": ("sessions_startup_command", None), @@ -1805,10 +1814,7 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> Mapping[str, "created_at": ("sessions_created_at", None), "terminated_at": ("sessions_terminated_at", None), "starts_at": ("sessions_starts_at", None), - "scheduled_at": ( - JSONFieldItem("sessions_status_history", SessionStatus.SCHEDULED.name), - None, - ), + "scheduled_at": ("sessions_scheduled_at", None), } @classmethod diff --git a/src/ai/backend/manager/models/utils.py b/src/ai/backend/manager/models/utils.py index 2731b9fce4d..3ed3792f45e 100644 --- a/src/ai/backend/manager/models/utils.py +++ b/src/ai/backend/manager/models/utils.py @@ -6,6 +6,7 @@ import logging from contextlib import AbstractAsyncContextManager as AbstractAsyncCtxMgr from contextlib import asynccontextmanager as actxmgr +from datetime import datetime from typing import ( TYPE_CHECKING, Any, @@ -24,6 +25,7 @@ from urllib.parse import quote_plus as urlquote import sqlalchemy as sa +from dateutil.parser import parse as dtparse from sqlalchemy.dialects import postgresql as psql from sqlalchemy.engine import create_engine as _create_engine from sqlalchemy.exc import DBAPIError @@ -476,6 +478,16 @@ def sql_json_merge( return expr +def sql_append_dict_to_list(col, arg: dict): + """ + Generate an SQLAlchemy column update expression that appends a dictionary to + the existing JSONB array. + """ + new_item_str = json.dumps(arg).replace("'", '"') + expr = col.op("||")(sa.text(f"'[{new_item_str}]'::jsonb")) + return expr + + def sql_json_increment( col, key: Tuple[str, ...], @@ -550,3 +562,34 @@ async def vacuum_db( vacuum_sql = "VACUUM FULL" if vacuum_full else "VACUUM" log.info(f"Perfoming {vacuum_sql} operation...") await conn.exec_driver_sql(vacuum_sql) + + +def get_latest_timestamp_for_status( + status_history_log: list[dict[str, str]], + status: str, +) -> Optional[datetime]: + """ + Get the last occurrence time of the given status from the status history records. + If the status is not found, return None. + """ + + for item in reversed(status_history_log): + if item["status"] == status: + return dtparse(item["timestamp"]) + return None + + +def get_legacy_status_history(status_history_log: list[dict[str, str]]) -> dict[str, str]: + """ + Get the last occurrence time of each status from the status history records. + This function is used to retrieve legacy status_history from status_history_log. + """ + statuses = set(item["status"] for item in status_history_log) + result_dict = {} + + for status in statuses: + latest_time = get_latest_timestamp_for_status(status_history_log, status) + if latest_time: + result_dict[status] = latest_time.isoformat() + + return result_dict diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index a3489c42476..5df1ac0140f 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -198,6 +198,7 @@ is_db_retry_error, reenter_txn, reenter_txn_session, + sql_append_dict_to_list, sql_json_merge, ) from .models.vfolder import VFolderOperationStatus, update_vfolder_status @@ -1077,9 +1078,12 @@ async def enqueue_session( "id": session_id, "priority": priority, "status": SessionStatus.PENDING, - "status_history": { - SessionStatus.PENDING.name: datetime.now(tzutc()).isoformat(), - }, + "status_history": [ + { + "status": SessionStatus.PENDING.name, + "timestamp": datetime.now(tzutc()).isoformat(), + } + ], "creation_id": session_creation_id, "name": session_name, "session_type": session_type, @@ -1102,9 +1106,12 @@ async def enqueue_session( kernel_shared_data = { "status": KernelStatus.PENDING, - "status_history": { - KernelStatus.PENDING.name: datetime.now(tzutc()).isoformat(), - }, + "status_history": [ + { + "status": KernelStatus.PENDING.name, + "timestamp": datetime.now(tzutc()).isoformat(), + } + ], "session_creation_id": session_creation_id, "session_id": session_id, "session_name": session_name, @@ -1912,14 +1919,12 @@ async def _update_failure() -> None: status_info=f"other-error ({ex!r})", status_changed=now, terminated_at=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( KernelRow.status_history, - (), { - KernelStatus.ERROR.name: ( - now.isoformat() - ), # ["PULLING", "CREATING"] - }, + "status": KernelStatus.ERROR.name, + "timestamp": now.isoformat(), + }, # ["PULLING", "CREATING"] ), status_data=err_info, ) @@ -2345,22 +2350,19 @@ async def _destroy(db_session: AsyncSession) -> SessionRow: kern.status = kernel_target_status kern.terminated_at = current_time kern.status_info = destroy_reason - kern.status_history = sql_json_merge( + kern.status_history = sql_append_dict_to_list( KernelRow.status_history, - (), { - kernel_target_status.name: current_time.isoformat(), + "status": kernel_target_status.name, + "timestamp": current_time.isoformat(), }, ) session_row.status = target_status session_row.terminated_at = current_time session_row.status_info = destroy_reason - session_row.status_history = sql_json_merge( + session_row.status_history = sql_append_dict_to_list( SessionRow.status_history, - (), - { - target_status.name: current_time.isoformat(), - }, + {"status": target_status.name, "timestamp": current_time.isoformat()}, ) return session_row @@ -2554,11 +2556,11 @@ async def _update() -> None: "status_info": reason, "status_changed": now, "terminated_at": now, - "status_history": sql_json_merge( + "status_history": sql_append_dict_to_list( KernelRow.status_history, - (), { - KernelStatus.TERMINATED.name: now.isoformat(), + "status": KernelStatus.TERMINATED.name, + "timestamp": now.isoformat(), }, ), } @@ -2586,11 +2588,11 @@ async def _update() -> None: "kernel": {"exit_code": None}, "session": {"status": "terminating"}, }, - "status_history": sql_json_merge( + "status_history": sql_append_dict_to_list( KernelRow.status_history, - (), { - KernelStatus.TERMINATING.name: now.isoformat(), + "status": KernelStatus.TERMINATING.name, + "timestamp": now.isoformat(), }, ), } @@ -2736,11 +2738,11 @@ async def _restarting_session() -> None: sa.update(SessionRow) .values( status=SessionStatus.RESTARTING, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( SessionRow.status_history, - (), { - SessionStatus.RESTARTING.name: datetime.now(tzutc()).isoformat(), + "status": SessionStatus.RESTARTING.name, + "timestamp": datetime.now(tzutc()).isoformat(), }, ), ) @@ -2781,12 +2783,9 @@ async def _restart_kernel(kernel: KernelRow) -> None: "stdin_port": kernel_info["stdin_port"], "stdout_port": kernel_info["stdout_port"], "service_ports": kernel_info.get("service_ports", []), - "status_history": sql_json_merge( + "status_history": sql_append_dict_to_list( KernelRow.status_history, - (), - { - KernelStatus.RUNNING.name: now.isoformat(), - }, + {"status": KernelStatus.RUNNING.name, "timestamp": now.isoformat()}, ), } await KernelRow.update_kernel( diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index 396abb0dd9e..76e12844f4c 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -100,6 +100,7 @@ execute_with_retry, execute_with_txn_retry, retry_txn, + sql_append_dict_to_list, sql_json_increment, sql_json_merge, ) @@ -900,11 +901,11 @@ async def _finalize_scheduled() -> None: status_info="scheduled", status_data={}, status_changed=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( KernelRow.status_history, - (), { - KernelStatus.SCHEDULED.name: now.isoformat(), + "status": KernelStatus.SCHEDULED.name, + "timestamp": now.isoformat(), }, ), ) @@ -922,12 +923,9 @@ async def _finalize_scheduled() -> None: status=SessionStatus.SCHEDULED, status_info="scheduled", status_data={}, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( SessionRow.status_history, - (), - { - SessionStatus.SCHEDULED.name: now.isoformat(), - }, + {"status": SessionStatus.SCHEDULED.name, "timestamp": now.isoformat()}, ), ) .where(SessionRow.id == sess_ctx.id) @@ -1135,11 +1133,11 @@ async def _finalize_scheduled() -> None: status_info="scheduled", status_data={}, status_changed=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( KernelRow.status_history, - (), { - KernelStatus.SCHEDULED.name: now.isoformat(), + "status": KernelStatus.SCHEDULED.name, + "timestamp": now.isoformat(), }, ), ) @@ -1158,12 +1156,9 @@ async def _finalize_scheduled() -> None: status_info="scheduled", status_data={}, # status_changed=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( SessionRow.status_history, - (), - { - SessionStatus.SCHEDULED.name: now.isoformat(), - }, + {"status": SessionStatus.SCHEDULED.name, "timestamp": now.isoformat()}, ), ) .where(SessionRow.id == sess_ctx.id) @@ -1609,11 +1604,11 @@ async def _mark_session_cancelled() -> None: status_info="failed-to-start", status_data=status_data, terminated_at=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( KernelRow.status_history, - (), { - KernelStatus.CANCELLED.name: now.isoformat(), + "status": KernelStatus.CANCELLED.name, + "timestamp": now.isoformat(), }, ), ) @@ -1628,11 +1623,11 @@ async def _mark_session_cancelled() -> None: status_info="failed-to-start", status_data=status_data, terminated_at=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( SessionRow.status_history, - (), { - SessionStatus.CANCELLED.name: now.isoformat(), + "status": SessionStatus.CANCELLED.name, + "timestamp": now.isoformat(), }, ), ) @@ -1778,12 +1773,9 @@ async def _apply_cancellation( status=KernelStatus.CANCELLED, status_info=reason, terminated_at=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( KernelRow.status_history, - (), - { - KernelStatus.CANCELLED.name: now.isoformat(), - }, + {"status": KernelStatus.CANCELLED.name, "timestamp": now.isoformat()}, ), ) .where(KernelRow.session_id.in_(session_ids)) @@ -1795,12 +1787,9 @@ async def _apply_cancellation( status=SessionStatus.CANCELLED, status_info=reason, terminated_at=now, - status_history=sql_json_merge( + status_history=sql_append_dict_to_list( SessionRow.status_history, - (), - { - SessionStatus.CANCELLED.name: now.isoformat(), - }, + {"status": SessionStatus.CANCELLED.name, "timestamp": now.isoformat()}, ), ) .where(SessionRow.id.in_(session_ids)) diff --git a/src/ai/backend/manager/server.py b/src/ai/backend/manager/server.py index c695a4e1317..2a735688df8 100644 --- a/src/ai/backend/manager/server.py +++ b/src/ai/backend/manager/server.py @@ -582,7 +582,6 @@ async def hanging_session_scanner_ctx(root_ctx: RootContext) -> AsyncIterator[No import sqlalchemy as sa from dateutil.relativedelta import relativedelta - from dateutil.tz import tzutc from sqlalchemy.orm import load_only, noload from .config import session_hang_tolerance_iv @@ -600,13 +599,20 @@ async def _fetch_hanging_sessions( sa.select(SessionRow) .where(SessionRow.status == status) .where( - ( - datetime.now(tz=tzutc()) - - SessionRow.status_history[status.name].astext.cast( - sa.types.DateTime(timezone=True) + # TODO: Can we replace the following query with SQLAlchemy ORM for better readability? + sa.text( + """ + EXISTS ( + SELECT 1 + FROM jsonb_array_elements(status_history) AS session_history + WHERE + session_history->>'status' = :status_name AND + ( + now() - CAST(session_history->>'timestamp' AS TIMESTAMP WITH TIME ZONE) + ) > :threshold ) - ) - > threshold + """ + ).bindparams(status_name=status.name, threshold=threshold) ) .options( noload("*"), diff --git a/tests/manager/models/test_utils.py b/tests/manager/models/test_utils.py index 7323b929624..c50ba1ee486 100644 --- a/tests/manager/models/test_utils.py +++ b/tests/manager/models/test_utils.py @@ -1,243 +1,8 @@ -import uuid -from datetime import datetime -from typing import Any, Dict, Optional, Union - import pytest -import sqlalchemy import sqlalchemy as sa -from dateutil.tz import tzutc -from sqlalchemy.engine import Row from ai.backend.manager.models import KernelRow, SessionRow, kernels -from ai.backend.manager.models.utils import agg_to_array, agg_to_str, sql_json_merge - - -async def _select_kernel_row( - conn: sqlalchemy.ext.asyncio.engine.AsyncConnection, - session_id: Union[str, uuid.UUID], -) -> Row: - query = kernels.select().select_from(kernels).where(kernels.c.session_id == session_id) - kernel, *_ = await conn.execute(query) - return kernel - - -@pytest.mark.asyncio -async def test_sql_json_merge__default(session_info) -> None: - session_id, conn = session_info - expected: Optional[Dict[str, Any]] = None - kernel = await _select_kernel_row(conn, session_id) - assert kernel is not None - assert kernel.status_history == expected - - -@pytest.mark.asyncio -async def test_sql_json_merge__deeper_object(session_info) -> None: - session_id, conn = session_info - timestamp = datetime.now(tzutc()).isoformat() - expected = { - "kernel": { - "session": { - "PENDING": timestamp, - "PREPARING": timestamp, - }, - }, - } - query = ( - kernels.update() - .values({ - "status_history": sql_json_merge( - kernels.c.status_history, - ("kernel", "session"), - { - "PENDING": timestamp, - "PREPARING": timestamp, - }, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - kernel = await _select_kernel_row(conn, session_id) - assert kernel is not None - assert kernel.status_history == expected - - -@pytest.mark.asyncio -async def test_sql_json_merge__append_values(session_info) -> None: - session_id, conn = session_info - timestamp = datetime.now(tzutc()).isoformat() - expected = { - "kernel": { - "session": { - "PENDING": timestamp, - "PREPARING": timestamp, - "TERMINATED": timestamp, - "TERMINATING": timestamp, - }, - }, - } - query = ( - kernels.update() - .values({ - "status_history": sql_json_merge( - kernels.c.status_history, - ("kernel", "session"), - { - "PENDING": timestamp, - "PREPARING": timestamp, - }, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - query = ( - kernels.update() - .values({ - "status_history": sql_json_merge( - kernels.c.status_history, - ("kernel", "session"), - { - "TERMINATING": timestamp, - "TERMINATED": timestamp, - }, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - kernel = await _select_kernel_row(conn, session_id) - assert kernel is not None - assert kernel.status_history == expected - - -@pytest.mark.asyncio -async def test_sql_json_merge__kernel_status_history(session_info) -> None: - session_id, conn = session_info - timestamp = datetime.now(tzutc()).isoformat() - expected = { - "PENDING": timestamp, - "PREPARING": timestamp, - "TERMINATING": timestamp, - "TERMINATED": timestamp, - } - query = ( - kernels.update() - .values({ - # "status_history": sqlalchemy.func.coalesce(sqlalchemy.text("'{}'::jsonb")).concat( - # sqlalchemy.func.cast( - # {"PENDING": timestamp, "PREPARING": timestamp}, - # sqlalchemy.dialects.postgresql.JSONB, - # ), - # ), - "status_history": sql_json_merge( - kernels.c.status_history, - (), - { - "PENDING": timestamp, - "PREPARING": timestamp, - }, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - query = ( - kernels.update() - .values({ - "status_history": sql_json_merge( - kernels.c.status_history, - (), - { - "TERMINATING": timestamp, - "TERMINATED": timestamp, - }, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - kernel = await _select_kernel_row(conn, session_id) - assert kernel is not None - assert kernel.status_history == expected - - -@pytest.mark.asyncio -async def test_sql_json_merge__mixed_formats(session_info) -> None: - session_id, conn = session_info - timestamp = datetime.now(tzutc()).isoformat() - expected = { - "PENDING": timestamp, - "kernel": { - "PREPARING": timestamp, - }, - } - query = ( - kernels.update() - .values({ - "status_history": sql_json_merge( - kernels.c.status_history, - (), - { - "PENDING": timestamp, - }, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - kernel = await _select_kernel_row(conn, session_id) - query = ( - kernels.update() - .values({ - "status_history": sql_json_merge( - kernels.c.status_history, - ("kernel",), - { - "PREPARING": timestamp, - }, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - kernel = await _select_kernel_row(conn, session_id) - assert kernel is not None - assert kernel.status_history == expected - - -@pytest.mark.asyncio -async def test_sql_json_merge__json_serializable_types(session_info) -> None: - session_id, conn = session_info - expected = { - "boolean": True, - "integer": 10101010, - "float": 1010.1010, - "string": "10101010", - # "bytes": b"10101010", - "list": [ - 10101010, - "10101010", - ], - "dict": { - "10101010": 10101010, - }, - } - query = ( - kernels.update() - .values({ - "status_history": sql_json_merge( - kernels.c.status_history, - (), - expected, - ), - }) - .where(kernels.c.session_id == session_id) - ) - await conn.execute(query) - kernel = await _select_kernel_row(conn, session_id) - assert kernel is not None - assert kernel.status_history == expected +from ai.backend.manager.models.utils import agg_to_array, agg_to_str @pytest.mark.asyncio