Skip to content

Commit

Permalink
feat: GQL resolvers for querying agent's total resource allocation (#…
Browse files Browse the repository at this point in the history
…2254)

Co-authored-by: Sanghun Lee <[email protected]>
  • Loading branch information
jopemachine and fregataa authored Jul 3, 2024
1 parent df57a09 commit 9c6f51d
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 3 deletions.
1 change: 1 addition & 0 deletions changes/2254.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `scaling_group.agent_count_by_status` and `scaling_group.agent_total_resource_slots_by_status` GQL fields to query the count and the resource allocation of agents that belong to a scaling group.
16 changes: 16 additions & 0 deletions src/ai/backend/manager/api/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,22 @@ type ScalingGroup {
scheduler: String
scheduler_opts: JSONString
use_host_network: Boolean

"""Added in 24.03.7."""
agent_count_by_status(
"""
Possible states of an agent. Should be one of ['ALIVE', 'LOST', 'RESTARTING', 'TERMINATED']. Default is 'ALIVE'.
"""
status: String = "ALIVE"
): Int

"""Added in 24.03.7."""
agent_total_resource_slots_by_status(
"""
Possible states of an agent. Should be one of ['ALIVE', 'LOST', 'RESTARTING', 'TERMINATED']. Default is 'ALIVE'.
"""
status: String = "ALIVE"
): JSONString
}

type StorageVolume implements Item {
Expand Down
3 changes: 2 additions & 1 deletion src/ai/backend/manager/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from .keypair import keypairs
from .minilang.ordering import OrderSpecItem, QueryOrderParser
from .minilang.queryfilter import FieldSpecItem, QueryFilterParser, enum_field_getter
from .scaling_group import query_allowed_sgroups
from .user import UserRole, users

if TYPE_CHECKING:
Expand Down Expand Up @@ -432,6 +431,8 @@ async def _append_sgroup_from_clause(
domain_name: str | None,
scaling_group: str | None = None,
) -> sa.sql.Select:
from .scaling_group import query_allowed_sgroups

if scaling_group is not None:
query = query.where(agents.c.scaling_group == scaling_group)
else:
Expand Down
69 changes: 67 additions & 2 deletions src/ai/backend/manager/models/scaling_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@
from sqlalchemy.dialects import postgresql as pgsql
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from sqlalchemy.orm import relationship
from sqlalchemy.orm import load_only, relationship
from sqlalchemy.sql.expression import true

from ai.backend.common import validators as tx
from ai.backend.common.types import AgentSelectionStrategy, JSONSerializableMixin, SessionTypes
from ai.backend.common.types import (
AgentSelectionStrategy,
JSONSerializableMixin,
ResourceSlot,
SessionTypes,
)

from .agent import AgentStatus
from .base import (
Base,
IDColumn,
Expand Down Expand Up @@ -315,6 +321,65 @@ class ScalingGroup(graphene.ObjectType):
scheduler_opts = graphene.JSONString()
use_host_network = graphene.Boolean()

# Dynamic fields.
agent_count_by_status = graphene.Field(
graphene.Int,
description="Added in 24.03.7.",
status=graphene.String(
default_value=AgentStatus.ALIVE.name,
description=f"Possible states of an agent. Should be one of {[s.name for s in AgentStatus]}. Default is 'ALIVE'.",
),
)

agent_total_resource_slots_by_status = graphene.Field(
graphene.JSONString,
description="Added in 24.03.7.",
status=graphene.String(
default_value=AgentStatus.ALIVE.name,
description=f"Possible states of an agent. Should be one of {[s.name for s in AgentStatus]}. Default is 'ALIVE'.",
),
)

async def resolve_agent_count_by_status(
self, info: graphene.ResolveInfo, status: str = AgentStatus.ALIVE.name
) -> int:
from .agent import Agent

return await Agent.load_count(
info.context,
raw_status=status,
scaling_group=self.name,
)

async def resolve_agent_total_resource_slots_by_status(
self, info: graphene.ResolveInfo, status: str = AgentStatus.ALIVE.name
) -> Mapping[str, Any]:
from .agent import AgentRow, AgentStatus

graph_ctx = info.context
async with graph_ctx.db.begin_readonly_session() as db_session:
query_stmt = (
sa.select(AgentRow)
.where(
(AgentRow.scaling_group == self.name) & (AgentRow.status == AgentStatus[status])
)
.options(load_only(AgentRow.occupied_slots, AgentRow.available_slots))
)
result = (await db_session.scalars(query_stmt)).all()
agent_rows = cast(list[AgentRow], result)

total_occupied_slots = ResourceSlot()
total_available_slots = ResourceSlot()

for agent_row in agent_rows:
total_occupied_slots += agent_row.occupied_slots
total_available_slots += agent_row.available_slots

return {
"occupied_slots": total_occupied_slots.to_json(),
"available_slots": total_available_slots.to_json(),
}

@classmethod
def from_row(
cls,
Expand Down

0 comments on commit 9c6f51d

Please sign in to comment.