Skip to content

Commit

Permalink
feat: Add dependee/dependent/graph ComputeSessionNode connection queries
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol committed Sep 21, 2024
1 parent 53ae837 commit cca8221
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 11 deletions.
1 change: 1 addition & 0 deletions changes/2844.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add dependee/dependent/graph ComputeSessionNode connection queries
13 changes: 8 additions & 5 deletions src/ai/backend/manager/api/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,9 @@ type ComputeSessionNode implements Node {
num_queries: BigInt
inference_metrics: JSONString
kernel_nodes(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): KernelConnection
dependents(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): ComputeSessionConnection
dependees(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): ComputeSessionConnection
graph(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): ComputeSessionConnection
}

"""
Expand Down Expand Up @@ -986,11 +989,6 @@ type KernelNode implements Node {
preopen_ports: [Int]
}

"""
Added in 24.09.0. Global ID of GQL relay spec. Base64 encoded version of "<node type name>:<node id>". UUID or string type values are also allowed.
"""
scalar GlobalIDField

"""Added in 24.09.0."""
type ComputeSessionConnection {
"""Pagination data for this connection."""
Expand All @@ -1014,6 +1012,11 @@ type ComputeSessionEdge {
cursor: String!
}

"""
Added in 24.09.0. Global ID of GQL relay spec. Base64 encoded version of "<node type name>:<node id>". UUID or string type values are also allowed.
"""
scalar GlobalIDField

type ComputeSessionList implements PaginatedList {
items: [ComputeSession]!
total_count: Int!
Expand Down
143 changes: 137 additions & 6 deletions src/ai/backend/manager/models/gql_models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
FilterExprArg,
OrderExprArg,
PaginatedConnectionField,
batch_multiresult_in_session,
generate_sql_info_for_gql_connection,
set_if_set,
)
Expand All @@ -51,7 +52,7 @@
get_permission_ctx,
)
from ..user import UserRole
from .kernel import KernelConnection
from .kernel import KernelConnection, KernelNode

if TYPE_CHECKING:
from ..gql import GraphQueryContext
Expand Down Expand Up @@ -197,6 +198,18 @@ class Meta:
kernel_nodes = PaginatedConnectionField(
KernelConnection,
)
dependents = PaginatedConnectionField(
"ai.backend.manager.models.gql_models.session.ComputeSessionConnection",
description="Added in 24.09.0.",
)
dependees = PaginatedConnectionField(
"ai.backend.manager.models.gql_models.session.ComputeSessionConnection",
description="Added in 24.09.0.",
)
graph = PaginatedConnectionField(
"ai.backend.manager.models.gql_models.session.ComputeSessionConnection",
description="Added in 24.09.0.",
)

@classmethod
def from_row(
Expand Down Expand Up @@ -260,7 +273,7 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> dict[str, Any
async def resolve_kernel_nodes(
self,
info: graphene.ResolveInfo,
) -> ConnectionResolverResult:
) -> ConnectionResolverResult[KernelNode]:
ctx: GraphQueryContext = info.context
loader = ctx.dataloader_manager.get_loader(ctx, "KernelNode.by_session_id")
kernels = await loader.load(self.row_id)
Expand All @@ -272,13 +285,131 @@ async def resolve_kernel_nodes(
total_count=len(kernels),
)

async def resolve_dependees(
self,
info: graphene.ResolveInfo,
) -> ConnectionResolverResult[Self]:
ctx: GraphQueryContext = info.context
# Get my dependees (myself is the dependent)
loader = ctx.dataloader_manager.get_loader(ctx, "ComputeSessionNode.by_dependent_id")
sessions = await loader.load(self.row_id)
return ConnectionResolverResult(
sessions,
None,
None,
None,
total_count=len(sessions),
)

async def resolve_dependents(
self,
info: graphene.ResolveInfo,
) -> ConnectionResolverResult[Self]:
ctx: GraphQueryContext = info.context
# Get my dependents (myself is the dependee)
loader = ctx.dataloader_manager.get_loader(ctx, "ComputeSessionNode.by_dependee_id")
sessions = await loader.load(self.row_id)
return ConnectionResolverResult(
sessions,
None,
None,
None,
total_count=len(sessions),
)

async def resolve_graph(
self,
info: graphene.ResolveInfo,
) -> ConnectionResolverResult[Self]:
from ..session import SessionDependencyRow, SessionRow

ctx: GraphQueryContext = info.context

async with ctx.db.begin_readonly_session() as db_sess:
dependency_cte = (
sa.select(SessionRow.id)
.filter(SessionRow.id == self.row_id)
.cte(name="dependency_cte", recursive=True)
)
dependee = sa.select(SessionDependencyRow.depends_on).join(
dependency_cte, SessionDependencyRow.session_id == dependency_cte.c.id
)
dependent = sa.select(SessionDependencyRow.session_id).join(
dependency_cte, SessionDependencyRow.depends_on == dependency_cte.c.id
)
dependency_cte = dependency_cte.union_all(dependee).union_all(dependent)
# Get the session IDs in the graph
query = sa.select(dependency_cte.c.id)
session_ids = (await db_sess.execute(query)).scalars().all()
# Get the session rows in the graph
query = sa.select(SessionRow).where(SessionRow.id.in_(session_ids))
session_rows = (await db_sess.execute(query)).scalars().all()

# Convert into GraphQL node objects
sessions = [type(self).from_row(ctx, r) for r in session_rows]
return ConnectionResolverResult(
sessions,
None,
None,
None,
total_count=len(sessions),
)

@classmethod
async def batch_load_idle_checks(
cls, ctx: GraphQueryContext, session_ids: Sequence[SessionId]
) -> list[dict[str, ReportInfo]]:
check_result = await ctx.idle_checker_host.get_batch_idle_check_report(session_ids)
return [check_result[sid] for sid in session_ids]

@classmethod
async def batch_load_by_dependee_id(
cls, ctx: GraphQueryContext, session_ids: Sequence[SessionId]
) -> Sequence[Sequence[Self]]:
from ..session import SessionDependencyRow, SessionRow

async with ctx.db.begin_readonly_session() as db_sess:
j = sa.join(
SessionRow, SessionDependencyRow, SessionRow.id == SessionDependencyRow.depends_on
)
query = (
sa.select(SessionRow)
.select_from(j)
.where(SessionDependencyRow.session_id.in_(session_ids))
)
return await batch_multiresult_in_session(
ctx,
db_sess,
query,
cls,
session_ids,
lambda row: row.id,
)

@classmethod
async def batch_load_by_dependent_id(
cls, ctx: GraphQueryContext, session_ids: Sequence[SessionId]
) -> Sequence[Sequence[Self]]:
from ..session import SessionDependencyRow, SessionRow

async with ctx.db.begin_readonly_session() as db_sess:
j = sa.join(
SessionRow, SessionDependencyRow, SessionRow.id == SessionDependencyRow.session_id
)
query = (
sa.select(SessionRow)
.select_from(j)
.where(SessionDependencyRow.depends_on.in_(session_ids))
)
return await batch_multiresult_in_session(
ctx,
db_sess,
query,
cls,
session_ids,
lambda row: row.id,
)

@classmethod
async def get_accessible_node(
cls,
Expand Down Expand Up @@ -325,7 +456,7 @@ async def get_accessible_connection(
first: int | None = None,
before: str | None = None,
last: int | None = None,
) -> ConnectionResolverResult[ComputeSessionNode]:
) -> ConnectionResolverResult[Self]:
graph_ctx: GraphQueryContext = info.context
_filter_arg = (
FilterExprArg(filter_expr, QueryFilterParser(_queryfilter_fieldspec))
Expand Down Expand Up @@ -373,7 +504,7 @@ async def get_accessible_connection(
async with graph_ctx.db.begin_readonly_session(db_conn) as db_session:
session_rows = (await db_session.scalars(query)).all()
total_cnt = await db_session.scalar(cnt_query)
result: list[ComputeSessionNode] = [
result: list[Self] = [
cls.from_row(
graph_ctx,
row,
Expand Down Expand Up @@ -412,7 +543,7 @@ async def mutate_and_get_payload(
root: Any,
info: graphene.ResolveInfo,
**input,
) -> ModifyComputeSession:
) -> Self:
graph_ctx: GraphQueryContext = info.context
_, raw_session_id = cast(ResolvedGlobalID, input["id"])
session_id = SessionId(uuid.UUID(raw_session_id))
Expand All @@ -434,7 +565,7 @@ async def mutate_and_get_payload(
)
result = await db_sess.execute(query)
session_row = result.fetchone()
return ModifyComputeSession(
return cls(
ComputeSessionNode.from_row(graph_ctx, session_row),
input.get("client_mutation_id"),
)

0 comments on commit cca8221

Please sign in to comment.