Skip to content

Commit

Permalink
fix: Add missing status_history GQL field
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Dec 8, 2024
1 parent bb0ce11 commit b6e059c
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 11 deletions.
18 changes: 13 additions & 5 deletions src/ai/backend/manager/models/gql_models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
TYPE_CHECKING,
Any,
Self,
cast,
)

import graphene
Expand All @@ -14,15 +15,17 @@

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.types import AgentId, KernelId, SessionId
from ai.backend.manager.models.base import (

from ..base import (
batch_multiresult_in_scalar_stream,
batch_multiresult_in_session,
)

from ..gql_relay import AsyncNode, Connection
from ..kernel import KernelRow, KernelStatus
from ..kernel import KernelRow
from ..user import UserRole
from ..utils import get_lastest_timestamp_for_status
from .image import ImageNode
from .session import SessionStatus

if TYPE_CHECKING:
from ..gql import GraphQueryContext
Expand Down Expand Up @@ -113,7 +116,12 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
hide_agents = False
else:
hide_agents = ctx.local_config["manager"]["hide-agents"]
status_history = row.status_history or {}

timestamp = get_lastest_timestamp_for_status(
cast(list[dict[str, str]], row.status_history), SessionStatus.SCHEDULED
)
scheduled_at = str(timestamp) if timestamp is not None else None

return KernelNode(
id=row.id, # auto-converted to Relay global ID
row_id=row.id,
Expand All @@ -129,7 +137,7 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
created_at=row.created_at,
terminated_at=row.terminated_at,
starts_at=row.starts_at,
scheduled_at=status_history.get(KernelStatus.SCHEDULED.name),
scheduled_at=scheduled_at,
occupied_slots=row.occupied_slots.to_json(),
agent_id=row.agent if not hide_agents else None,
agent_addr=row.agent_addr if not hide_agents else None,
Expand Down
14 changes: 8 additions & 6 deletions src/ai/backend/manager/models/gql_models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
get_permission_ctx,
)
from ..user import UserRole
from ..utils import execute_with_txn_retry
from ..utils import execute_with_txn_retry, get_lastest_timestamp_for_status
from .kernel import KernelConnection, KernelNode

if TYPE_CHECKING:
Expand Down Expand Up @@ -174,7 +174,7 @@ class Meta:
# status_changed = GQLDateTime() # FIXME: generated attribute
status_info = graphene.String()
status_data = graphene.JSONString()
status_history = graphene.JSONString()
status_history = graphene.JSONString(description="Added in 24.12.0.")
created_at = GQLDateTime()
terminated_at = GQLDateTime()
starts_at = GQLDateTime()
Expand Down Expand Up @@ -225,7 +225,11 @@ def from_row(
permissions: Optional[Iterable[ComputeSessionPermission]] = None,
) -> Self:
status_history = row.status_history or {}
raw_scheduled_at = status_history.get(SessionStatus.SCHEDULED.name)
timestamp = get_lastest_timestamp_for_status(
cast(list[dict[str, str]], status_history), SessionStatus.SCHEDULED
)
scheduled_at = str(timestamp) if timestamp is not None else None

result = cls(
# identity
id=row.id, # auto-converted to Relay global ID
Expand All @@ -251,9 +255,7 @@ def from_row(
created_at=row.created_at,
starts_at=row.starts_at,
terminated_at=row.terminated_at,
scheduled_at=datetime.fromisoformat(raw_scheduled_at)
if raw_scheduled_at is not None
else None,
scheduled_at=scheduled_at,
startup_command=row.startup_command,
result=row.result.name,
# resources
Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,7 @@ class Meta:
# status
status = graphene.String()
status_changed = GQLDateTime()
status_history = graphene.JSONString(description="Added in 24.12.0.")
status_info = graphene.String()
status_data = graphene.JSONString()
created_at = GQLDateTime()
Expand Down Expand Up @@ -931,6 +932,7 @@ def parse_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Mapping[str, Any]:
# status
"status": row.status.name,
"status_changed": row.status_changed,
"status_history": row.status_history,
"status_info": row.status_info,
"status_data": row.status_data,
"created_at": row.created_at,
Expand Down

0 comments on commit b6e059c

Please sign in to comment.