From 9c6f51d223ed33bfb83b754792093c161693d8e8 Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Wed, 3 Jul 2024 13:02:41 +0900 Subject: [PATCH] feat: GQL resolvers for querying agent's total resource allocation (#2254) Co-authored-by: Sanghun Lee --- changes/2254.feature.md | 1 + src/ai/backend/manager/api/schema.graphql | 16 +++++ src/ai/backend/manager/models/agent.py | 3 +- .../backend/manager/models/scaling_group.py | 69 ++++++++++++++++++- 4 files changed, 86 insertions(+), 3 deletions(-) create mode 100644 changes/2254.feature.md diff --git a/changes/2254.feature.md b/changes/2254.feature.md new file mode 100644 index 0000000000..86fa74e28c --- /dev/null +++ b/changes/2254.feature.md @@ -0,0 +1 @@ +Add `scaling_group.agent_count_by_status` and `scaling_group.agent_total_resource_slots_by_status` GQL fields to query the count and the resource allocation of agents that belong to a scaling group. diff --git a/src/ai/backend/manager/api/schema.graphql b/src/ai/backend/manager/api/schema.graphql index f02301fec3..d09f07d037 100644 --- a/src/ai/backend/manager/api/schema.graphql +++ b/src/ai/backend/manager/api/schema.graphql @@ -693,6 +693,22 @@ type ScalingGroup { scheduler: String scheduler_opts: JSONString use_host_network: Boolean + + """Added in 24.03.7.""" + agent_count_by_status( + """ + Possible states of an agent. Should be one of ['ALIVE', 'LOST', 'RESTARTING', 'TERMINATED']. Default is 'ALIVE'. + """ + status: String = "ALIVE" + ): Int + + """Added in 24.03.7.""" + agent_total_resource_slots_by_status( + """ + Possible states of an agent. Should be one of ['ALIVE', 'LOST', 'RESTARTING', 'TERMINATED']. Default is 'ALIVE'. + """ + status: String = "ALIVE" + ): JSONString } type StorageVolume implements Item { diff --git a/src/ai/backend/manager/models/agent.py b/src/ai/backend/manager/models/agent.py index cd64a3cf78..88cff60670 100644 --- a/src/ai/backend/manager/models/agent.py +++ b/src/ai/backend/manager/models/agent.py @@ -36,7 +36,6 @@ from .keypair import keypairs from .minilang.ordering import OrderSpecItem, QueryOrderParser from .minilang.queryfilter import FieldSpecItem, QueryFilterParser, enum_field_getter -from .scaling_group import query_allowed_sgroups from .user import UserRole, users if TYPE_CHECKING: @@ -432,6 +431,8 @@ async def _append_sgroup_from_clause( domain_name: str | None, scaling_group: str | None = None, ) -> sa.sql.Select: + from .scaling_group import query_allowed_sgroups + if scaling_group is not None: query = query.where(agents.c.scaling_group == scaling_group) else: diff --git a/src/ai/backend/manager/models/scaling_group.py b/src/ai/backend/manager/models/scaling_group.py index 522c8870c5..11b35ad4f0 100644 --- a/src/ai/backend/manager/models/scaling_group.py +++ b/src/ai/backend/manager/models/scaling_group.py @@ -22,12 +22,18 @@ from sqlalchemy.dialects import postgresql as pgsql from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection -from sqlalchemy.orm import relationship +from sqlalchemy.orm import load_only, relationship from sqlalchemy.sql.expression import true from ai.backend.common import validators as tx -from ai.backend.common.types import AgentSelectionStrategy, JSONSerializableMixin, SessionTypes +from ai.backend.common.types import ( + AgentSelectionStrategy, + JSONSerializableMixin, + ResourceSlot, + SessionTypes, +) +from .agent import AgentStatus from .base import ( Base, IDColumn, @@ -315,6 +321,65 @@ class ScalingGroup(graphene.ObjectType): scheduler_opts = graphene.JSONString() use_host_network = graphene.Boolean() + # Dynamic fields. + agent_count_by_status = graphene.Field( + graphene.Int, + description="Added in 24.03.7.", + status=graphene.String( + default_value=AgentStatus.ALIVE.name, + description=f"Possible states of an agent. Should be one of {[s.name for s in AgentStatus]}. Default is 'ALIVE'.", + ), + ) + + agent_total_resource_slots_by_status = graphene.Field( + graphene.JSONString, + description="Added in 24.03.7.", + status=graphene.String( + default_value=AgentStatus.ALIVE.name, + description=f"Possible states of an agent. Should be one of {[s.name for s in AgentStatus]}. Default is 'ALIVE'.", + ), + ) + + async def resolve_agent_count_by_status( + self, info: graphene.ResolveInfo, status: str = AgentStatus.ALIVE.name + ) -> int: + from .agent import Agent + + return await Agent.load_count( + info.context, + raw_status=status, + scaling_group=self.name, + ) + + async def resolve_agent_total_resource_slots_by_status( + self, info: graphene.ResolveInfo, status: str = AgentStatus.ALIVE.name + ) -> Mapping[str, Any]: + from .agent import AgentRow, AgentStatus + + graph_ctx = info.context + async with graph_ctx.db.begin_readonly_session() as db_session: + query_stmt = ( + sa.select(AgentRow) + .where( + (AgentRow.scaling_group == self.name) & (AgentRow.status == AgentStatus[status]) + ) + .options(load_only(AgentRow.occupied_slots, AgentRow.available_slots)) + ) + result = (await db_session.scalars(query_stmt)).all() + agent_rows = cast(list[AgentRow], result) + + total_occupied_slots = ResourceSlot() + total_available_slots = ResourceSlot() + + for agent_row in agent_rows: + total_occupied_slots += agent_row.occupied_slots + total_available_slots += agent_row.available_slots + + return { + "occupied_slots": total_occupied_slots.to_json(), + "available_slots": total_available_slots.to_json(), + } + @classmethod def from_row( cls,