diff --git a/changes/2844.feature.md b/changes/2844.feature.md new file mode 100644 index 0000000000..cb5e815878 --- /dev/null +++ b/changes/2844.feature.md @@ -0,0 +1 @@ +Add dependee/dependent/graph ComputeSessionNode connection queries diff --git a/src/ai/backend/manager/api/schema.graphql b/src/ai/backend/manager/api/schema.graphql index dffeac9a02..2ca0c6274e 100644 --- a/src/ai/backend/manager/api/schema.graphql +++ b/src/ai/backend/manager/api/schema.graphql @@ -923,6 +923,9 @@ type ComputeSessionNode implements Node { num_queries: BigInt inference_metrics: JSONString kernel_nodes(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): KernelConnection + dependents(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): ComputeSessionConnection + dependees(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): ComputeSessionConnection + graph(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): ComputeSessionConnection } """ @@ -982,11 +985,6 @@ type KernelNode implements Node { preopen_ports: [Int] } -""" -Added in 24.09.0. Global ID of GQL relay spec. Base64 encoded version of ":". UUID or string type values are also allowed. -""" -scalar GlobalIDField - """Added in 24.09.0.""" type ComputeSessionConnection { """Pagination data for this connection.""" @@ -1010,6 +1008,11 @@ type ComputeSessionEdge { cursor: String! } +""" +Added in 24.09.0. Global ID of GQL relay spec. Base64 encoded version of ":". UUID or string type values are also allowed. +""" +scalar GlobalIDField + type ComputeSessionList implements PaginatedList { items: [ComputeSession]! total_count: Int! diff --git a/src/ai/backend/manager/models/gql_models/session.py b/src/ai/backend/manager/models/gql_models/session.py index 9064864902..2f40cdb7bf 100644 --- a/src/ai/backend/manager/models/gql_models/session.py +++ b/src/ai/backend/manager/models/gql_models/session.py @@ -26,6 +26,7 @@ FilterExprArg, OrderExprArg, PaginatedConnectionField, + batch_multiresult_in_session, generate_sql_info_for_gql_connection, set_if_set, ) @@ -44,7 +45,7 @@ from ..rbac.permission_defs import ComputeSessionPermission from ..session import SessionRow, SessionStatus, SessionTypes, get_permission_ctx from ..user import UserRole -from .kernel import KernelConnection +from .kernel import KernelConnection, KernelNode if TYPE_CHECKING: from ..gql import GraphQueryContext @@ -188,6 +189,15 @@ class Meta: kernel_nodes = PaginatedConnectionField( KernelConnection, ) + dependents = PaginatedConnectionField( + "ai.backend.manager.models.gql_models.session.ComputeSessionConnection", + ) + dependees = PaginatedConnectionField( + "ai.backend.manager.models.gql_models.session.ComputeSessionConnection", + ) + graph = PaginatedConnectionField( + "ai.backend.manager.models.gql_models.session.ComputeSessionConnection", + ) @classmethod def from_row( @@ -251,7 +261,7 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> dict[str, Any async def resolve_kernel_nodes( self, info: graphene.ResolveInfo, - ) -> ConnectionResolverResult: + ) -> ConnectionResolverResult[KernelNode]: ctx: GraphQueryContext = info.context loader = ctx.dataloader_manager.get_loader(ctx, "KernelNode.by_session_id") kernels = await loader.load(self.row_id) @@ -263,6 +273,74 @@ async def resolve_kernel_nodes( total_count=len(kernels), ) + async def resolve_dependees( + self, + info: graphene.ResolveInfo, + ) -> ConnectionResolverResult[Self]: + ctx: GraphQueryContext = info.context + loader = ctx.dataloader_manager.get_loader(ctx, "ComputeSessionNode.by_dependee_id") + sessions = await loader.load(self.row_id) + return ConnectionResolverResult( + sessions, + None, + None, + None, + total_count=len(sessions), + ) + + async def resolve_dependents( + self, + info: graphene.ResolveInfo, + ) -> ConnectionResolverResult[Self]: + ctx: GraphQueryContext = info.context + loader = ctx.dataloader_manager.get_loader(ctx, "ComputeSessionNode.by_dependent_id") + sessions = await loader.load(self.row_id) + return ConnectionResolverResult( + sessions, + None, + None, + None, + total_count=len(sessions), + ) + + async def resolve_graph( + self, + info: graphene.ResolveInfo, + ) -> ConnectionResolverResult[Self]: + from ..session import SessionDependencyRow, SessionRow + + ctx: GraphQueryContext = info.context + + async with ctx.db.begin_readonly_session() as db_sess: + dependency_cte = ( + sa.select(SessionRow.id) + .filter(SessionRow.id == self.row_id) + .cte(name="dependency_cte", recursive=True) + ) + dependee = sa.select(SessionDependencyRow.depends_on).join( + dependency_cte, SessionDependencyRow.session_id == dependency_cte.c.id + ) + dependent = sa.select(SessionDependencyRow.session_id).join( + dependency_cte, SessionDependencyRow.depends_on == dependency_cte.c.id + ) + dependency_cte = dependency_cte.union_all(dependee).union_all(dependent) + # Get the session IDs in the graph + query = sa.select(dependency_cte.c.id) + session_ids = (await db_sess.execute(query)).scalars().all() + # Get the session rows in the graph + query = sa.select(SessionRow).where(SessionRow.id.in_(session_ids)) + session_rows = (await db_sess.execute(query)).scalars().all() + + # Convert into GraphQL node objects + sessions = [type(self).from_row(ctx, r) for r in session_rows] + return ConnectionResolverResult( + sessions, + None, + None, + None, + total_count=len(sessions), + ) + @classmethod async def batch_load_idle_checks( cls, ctx: GraphQueryContext, session_ids: Sequence[SessionId] @@ -270,6 +348,54 @@ async def batch_load_idle_checks( check_result = await ctx.idle_checker_host.get_batch_idle_check_report(session_ids) return [check_result[sid] for sid in session_ids] + @classmethod + async def batch_load_by_dependee_id( + cls, ctx: GraphQueryContext, session_ids: Sequence[SessionId] + ) -> Sequence[Sequence[Self]]: + from ..session import SessionDependencyRow, SessionRow + + async with ctx.db.begin_readonly_session() as db_sess: + j = sa.join( + SessionRow, SessionDependencyRow, SessionRow.id == SessionDependencyRow.depends_on + ) + query = ( + sa.select(SessionRow) + .select_from(j) + .where(SessionDependencyRow.session_id.in_(session_ids)) + ) + return await batch_multiresult_in_session( + ctx, + db_sess, + query, + cls, + session_ids, + lambda row: row.id, + ) + + @classmethod + async def batch_load_by_dependent_id( + cls, ctx: GraphQueryContext, session_ids: Sequence[SessionId] + ) -> Sequence[Sequence[Self]]: + from ..session import SessionDependencyRow, SessionRow + + async with ctx.db.begin_readonly_session() as db_sess: + j = sa.join( + SessionRow, SessionDependencyRow, SessionRow.id == SessionDependencyRow.session_id + ) + query = ( + sa.select(SessionRow) + .select_from(j) + .where(SessionDependencyRow.depends_on.in_(session_ids)) + ) + return await batch_multiresult_in_session( + ctx, + db_sess, + query, + cls, + session_ids, + lambda row: row.id, + ) + @classmethod async def get_accessible_node( cls, @@ -316,7 +442,7 @@ async def get_accessible_connection( first: int | None = None, before: str | None = None, last: int | None = None, - ) -> ConnectionResolverResult[ComputeSessionNode]: + ) -> ConnectionResolverResult[Self]: graph_ctx: GraphQueryContext = info.context _filter_arg = ( FilterExprArg(filter_expr, QueryFilterParser(_queryfilter_fieldspec)) @@ -364,7 +490,7 @@ async def get_accessible_connection( async with graph_ctx.db.begin_readonly_session(db_conn) as db_session: session_rows = (await db_session.scalars(query)).all() total_cnt = await db_session.scalar(cnt_query) - result: list[ComputeSessionNode] = [ + result: list[Self] = [ cls.from_row( graph_ctx, row, @@ -400,7 +526,7 @@ async def mutate_and_get_payload( root: Any, info: graphene.ResolveInfo, **input, - ) -> ModifyComputeSession: + ) -> Self: graph_ctx: GraphQueryContext = info.context _, raw_session_id = cast(ResolvedGlobalID, input["id"]) session_id = SessionId(uuid.UUID(raw_session_id)) @@ -416,7 +542,7 @@ async def mutate_and_get_payload( ) result = await db_sess.execute(query) session_row = result.fetchone() - return ModifyComputeSession( + return cls( ComputeSessionNode.from_row(graph_ctx, session_row), input.get("client_mutation_id"), )