Skip to content

Commit

Permalink
refactor: Move get_first_timestamp_for_status() from common.utils to …
Browse files Browse the repository at this point in the history
…manager.models.utils

- This eliminates the necessity to use stringified enum arguments in the
  `ai.backend.manager.models` codes.

- In the client SDK, add a copy of it using relaxed str-only arguments.
  Since this is a fairly simple logic, I think it is not worth to
  introduce additional complexity to share and reuse the code between
  the client SDK and the manager.
  (Note that originally `ai.backend.common` was not the dependency of
  the client SDK...)

- I think it would be better to introduce a JSON-fied TypedDict of each status
  history records.

- Also fix several merge errors.
  • Loading branch information
achimnol committed Jul 15, 2024
1 parent 553430c commit 0157e60
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/ai/backend/client/cli/session/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@
from ai.backend.client.cli.types import CLIContext
from ai.backend.common.arch import DEFAULT_IMAGE_ARCH
from ai.backend.common.types import ClusterMode
from ai.backend.common.utils import get_first_timestamp_for_status

from ...compat import asyncio_run
from ...exceptions import BackendAPIError
from ...func.session import ComputeSession
from ...output.fields import session_fields
from ...output.types import FieldSpec
from ...session import AsyncSession, Session
from ...utils import get_first_timestamp_for_status
from .. import events
from ..pretty import (
ProgressViewer,
Expand Down
14 changes: 14 additions & 0 deletions src/ai/backend/client/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

import io
import os
from datetime import datetime

from dateutil.parser import parse as dtparse
from tqdm import tqdm


Expand Down Expand Up @@ -48,3 +52,13 @@ def readinto1(self, *args, **kwargs):
count = super().readinto1(*args, **kwargs)
self.tqdm.set_postfix(file=self._filename, refresh=False)
self.tqdm.update(count)


def get_first_timestamp_for_status(
status_history: list[dict[str, str]],
status: str,
) -> datetime | None:
for rec in status_history:
if rec["status"] == status:
return dtparse(rec["timestamp"])
return None
16 changes: 1 addition & 15 deletions src/ai/backend/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
import uuid
from collections import OrderedDict
from datetime import datetime, timedelta
from datetime import timedelta
from itertools import chain
from pathlib import Path
from typing import (
Expand All @@ -25,7 +25,6 @@

import aiofiles
from async_timeout import timeout
from dateutil.parser import parse as dtparse

if TYPE_CHECKING:
from decimal import Decimal
Expand Down Expand Up @@ -405,16 +404,3 @@ async def umount(
fstab = Fstab(fp)
await fstab.remove_by_mountpoint(str(mountpoint))
return True


def get_first_timestamp_for_status(
status_history_records: list[dict[str, str]], status: str
) -> datetime | None:
"""
Get the first occurrence time of the given status from the status history records.
"""

for status_history in status_history_records:
if status_history["status"] == status:
return dtparse(status_history["timestamp"])
return None
7 changes: 4 additions & 3 deletions src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
SessionTypes,
VFolderMount,
)
from ai.backend.common.utils import get_first_timestamp_for_status

from ..api.exceptions import (
BackendError,
Expand Down Expand Up @@ -83,7 +82,9 @@
ExtendedAsyncSAEngine,
JSONCoalesceExpr,
execute_with_retry,
get_first_timestamp_for_status,
sql_append_dict_to_list,
sql_json_merge,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -927,8 +928,8 @@ def parse_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Mapping[str, Any]:
hide_agents = False
else:
hide_agents = ctx.local_config["manager"]["hide-agents"]
status_history = row.status_history
scheduled_at = get_first_timestamp_for_status(status_history, KernelStatus.SCHEDULED.name)
status_history = cast(list[dict[str, str]], row.status_history)
scheduled_at = get_first_timestamp_for_status(status_history, KernelStatus.SCHEDULED)

return {
# identity
Expand Down
14 changes: 9 additions & 5 deletions src/ai/backend/manager/models/resource_usage.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import json
from datetime import datetime
from enum import Enum
from typing import Any, Mapping, Optional, Sequence
from typing import Any, Mapping, Optional, Sequence, cast
from uuid import UUID

import attrs
Expand All @@ -14,14 +15,15 @@
from sqlalchemy.orm import joinedload, load_only

from ai.backend.common import redis_helper
from ai.backend.common.json import ExtendedJSONEncoder
from ai.backend.common.types import RedisConnectionInfo
from ai.backend.common.utils import get_first_timestamp_for_status, nmget
from ai.backend.common.utils import nmget

from .group import GroupRow
from .kernel import LIVE_STATUS, RESOURCE_USAGE_KERNEL_STATUSES, KernelRow, KernelStatus
from .session import SessionRow
from .user import UserRow
from .utils import ExtendedAsyncSAEngine
from .utils import ExtendedAsyncSAEngine, get_first_timestamp_for_status

__all__: Sequence[str] = (
"ResourceGroupUnit",
Expand Down Expand Up @@ -517,7 +519,9 @@ async def _pipe_builder(r: Redis) -> RedisPipeline:
created_at=kern.created_at,
terminated_at=kern.terminated_at,
scheduled_at=str(
get_first_timestamp_for_status(kern.status_history, KernelStatus.SCHEDULED.name)
get_first_timestamp_for_status(
cast(list[dict[str, str]], kern.status_history), KernelStatus.SCHEDULED
)
),
used_time=kern.used_time,
used_days=kern.get_used_days(local_tz),
Expand All @@ -536,7 +540,7 @@ async def _pipe_builder(r: Redis) -> RedisPipeline:
images={kern.image},
agents={kern.agent},
status=kern.status.name,
status_history=kern.status_history,
status_history=json.dumps(kern.status_history, cls=ExtendedJSONEncoder),
cluster_mode=kern.cluster_mode,
status_info=kern.status_info,
group_unit=ResourceGroupUnit.KERNEL,
Expand Down
7 changes: 3 additions & 4 deletions src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
SessionTypes,
VFolderMount,
)
from ai.backend.common.utils import get_first_timestamp_for_status

from ..api.exceptions import (
AgentError,
Expand Down Expand Up @@ -80,7 +79,9 @@
JSONCoalesceExpr,
agg_to_array,
execute_with_retry,
get_first_timestamp_for_status,
sql_append_dict_to_list,
sql_json_merge,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1324,9 +1325,7 @@ def parse_row(cls, ctx: GraphQueryContext, row: Row) -> Mapping[str, Any]:
full_name = getattr(row, "full_name")
group_name = getattr(row, "group_name")
row = row.SessionRow
scheduled_at = get_first_timestamp_for_status(
row.status_history, SessionStatus.SCHEDULED.name
)
scheduled_at = get_first_timestamp_for_status(row.status_history, SessionStatus.SCHEDULED)

return {
# identity
Expand Down
20 changes: 20 additions & 0 deletions src/ai/backend/manager/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
from contextlib import AbstractAsyncContextManager as AbstractAsyncCtxMgr
from contextlib import asynccontextmanager as actxmgr
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -23,6 +24,7 @@
from urllib.parse import quote_plus as urlquote

import sqlalchemy as sa
from dateutil.parser import parse as dtparse
from sqlalchemy.dialects import postgresql as psql
from sqlalchemy.engine import create_engine as _create_engine
from sqlalchemy.exc import DBAPIError
Expand All @@ -44,6 +46,10 @@

if TYPE_CHECKING:
from ..config import LocalConfig
from . import (
KernelStatus,
SessionStatus,
)

from ..defs import LockID
from ..types import Sentinel
Expand Down Expand Up @@ -536,3 +542,17 @@ async def vacuum_db(
vacuum_sql = "VACUUM FULL" if vacuum_full else "VACUUM"
log.info(f"Perfoming {vacuum_sql} operation...")
await conn.exec_driver_sql(vacuum_sql)


def get_first_timestamp_for_status(
status_history_records: list[dict[str, str]],
status: KernelStatus | SessionStatus,
) -> datetime | None:
"""
Get the first occurrence time of the given status from the status history records.
"""

for status_history in status_history_records:
if status_history["status"] == status.name:
return dtparse(status_history["timestamp"])
return None
2 changes: 2 additions & 0 deletions src/ai/backend/manager/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Optional
from uuid import UUID

Expand Down

0 comments on commit 0157e60

Please sign in to comment.