diff --git a/changes/2844.feature.md b/changes/2844.feature.md new file mode 100644 index 0000000000..cb5e815878 --- /dev/null +++ b/changes/2844.feature.md @@ -0,0 +1 @@ +Add dependee/dependent/graph ComputeSessionNode connection queries diff --git a/src/ai/backend/manager/api/schema.graphql b/src/ai/backend/manager/api/schema.graphql index 5ee859b60a..186711f290 100644 --- a/src/ai/backend/manager/api/schema.graphql +++ b/src/ai/backend/manager/api/schema.graphql @@ -927,6 +927,9 @@ type ComputeSessionNode implements Node { num_queries: BigInt inference_metrics: JSONString kernel_nodes(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): KernelConnection + dependents(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): ComputeSessionConnection + dependees(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): ComputeSessionConnection + graph(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): ComputeSessionConnection } """ @@ -986,11 +989,6 @@ type KernelNode implements Node { preopen_ports: [Int] } -""" -Added in 24.09.0. Global ID of GQL relay spec. Base64 encoded version of ":". UUID or string type values are also allowed. -""" -scalar GlobalIDField - """Added in 24.09.0.""" type ComputeSessionConnection { """Pagination data for this connection.""" @@ -1014,6 +1012,11 @@ type ComputeSessionEdge { cursor: String! } +""" +Added in 24.09.0. Global ID of GQL relay spec. Base64 encoded version of ":". UUID or string type values are also allowed. +""" +scalar GlobalIDField + type ComputeSessionList implements PaginatedList { items: [ComputeSession]! total_count: Int! diff --git a/src/ai/backend/manager/models/gql_models/session.py b/src/ai/backend/manager/models/gql_models/session.py index 028a65c235..cd637ed408 100644 --- a/src/ai/backend/manager/models/gql_models/session.py +++ b/src/ai/backend/manager/models/gql_models/session.py @@ -26,6 +26,7 @@ FilterExprArg, OrderExprArg, PaginatedConnectionField, + batch_multiresult_in_session, generate_sql_info_for_gql_connection, set_if_set, ) @@ -51,7 +52,7 @@ get_permission_ctx, ) from ..user import UserRole -from .kernel import KernelConnection +from .kernel import KernelConnection, KernelNode if TYPE_CHECKING: from ..gql import GraphQueryContext @@ -197,6 +198,18 @@ class Meta: kernel_nodes = PaginatedConnectionField( KernelConnection, ) + dependents = PaginatedConnectionField( + "ai.backend.manager.models.gql_models.session.ComputeSessionConnection", + description="Added in 24.09.0.", + ) + dependees = PaginatedConnectionField( + "ai.backend.manager.models.gql_models.session.ComputeSessionConnection", + description="Added in 24.09.0.", + ) + graph = PaginatedConnectionField( + "ai.backend.manager.models.gql_models.session.ComputeSessionConnection", + description="Added in 24.09.0.", + ) @classmethod def from_row( @@ -260,7 +273,7 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> dict[str, Any async def resolve_kernel_nodes( self, info: graphene.ResolveInfo, - ) -> ConnectionResolverResult: + ) -> ConnectionResolverResult[KernelNode]: ctx: GraphQueryContext = info.context loader = ctx.dataloader_manager.get_loader(ctx, "KernelNode.by_session_id") kernels = await loader.load(self.row_id) @@ -272,6 +285,76 @@ async def resolve_kernel_nodes( total_count=len(kernels), ) + async def resolve_dependees( + self, + info: graphene.ResolveInfo, + ) -> ConnectionResolverResult[Self]: + ctx: GraphQueryContext = info.context + # Get my dependees (myself is the dependent) + loader = ctx.dataloader_manager.get_loader(ctx, "ComputeSessionNode.by_dependent_id") + sessions = await loader.load(self.row_id) + return ConnectionResolverResult( + sessions, + None, + None, + None, + total_count=len(sessions), + ) + + async def resolve_dependents( + self, + info: graphene.ResolveInfo, + ) -> ConnectionResolverResult[Self]: + ctx: GraphQueryContext = info.context + # Get my dependents (myself is the dependee) + loader = ctx.dataloader_manager.get_loader(ctx, "ComputeSessionNode.by_dependee_id") + sessions = await loader.load(self.row_id) + return ConnectionResolverResult( + sessions, + None, + None, + None, + total_count=len(sessions), + ) + + async def resolve_graph( + self, + info: graphene.ResolveInfo, + ) -> ConnectionResolverResult[Self]: + from ..session import SessionDependencyRow, SessionRow + + ctx: GraphQueryContext = info.context + + async with ctx.db.begin_readonly_session() as db_sess: + dependency_cte = ( + sa.select(SessionRow.id) + .filter(SessionRow.id == self.row_id) + .cte(name="dependency_cte", recursive=True) + ) + dependee = sa.select(SessionDependencyRow.depends_on).join( + dependency_cte, SessionDependencyRow.session_id == dependency_cte.c.id + ) + dependent = sa.select(SessionDependencyRow.session_id).join( + dependency_cte, SessionDependencyRow.depends_on == dependency_cte.c.id + ) + dependency_cte = dependency_cte.union_all(dependee).union_all(dependent) + # Get the session IDs in the graph + query = sa.select(dependency_cte.c.id) + session_ids = (await db_sess.execute(query)).scalars().all() + # Get the session rows in the graph + query = sa.select(SessionRow).where(SessionRow.id.in_(session_ids)) + session_rows = (await db_sess.execute(query)).scalars().all() + + # Convert into GraphQL node objects + sessions = [type(self).from_row(ctx, r) for r in session_rows] + return ConnectionResolverResult( + sessions, + None, + None, + None, + total_count=len(sessions), + ) + @classmethod async def batch_load_idle_checks( cls, ctx: GraphQueryContext, session_ids: Sequence[SessionId] @@ -279,6 +362,54 @@ async def batch_load_idle_checks( check_result = await ctx.idle_checker_host.get_batch_idle_check_report(session_ids) return [check_result[sid] for sid in session_ids] + @classmethod + async def batch_load_by_dependee_id( + cls, ctx: GraphQueryContext, session_ids: Sequence[SessionId] + ) -> Sequence[Sequence[Self]]: + from ..session import SessionDependencyRow, SessionRow + + async with ctx.db.begin_readonly_session() as db_sess: + j = sa.join( + SessionRow, SessionDependencyRow, SessionRow.id == SessionDependencyRow.depends_on + ) + query = ( + sa.select(SessionRow) + .select_from(j) + .where(SessionDependencyRow.session_id.in_(session_ids)) + ) + return await batch_multiresult_in_session( + ctx, + db_sess, + query, + cls, + session_ids, + lambda row: row.id, + ) + + @classmethod + async def batch_load_by_dependent_id( + cls, ctx: GraphQueryContext, session_ids: Sequence[SessionId] + ) -> Sequence[Sequence[Self]]: + from ..session import SessionDependencyRow, SessionRow + + async with ctx.db.begin_readonly_session() as db_sess: + j = sa.join( + SessionRow, SessionDependencyRow, SessionRow.id == SessionDependencyRow.session_id + ) + query = ( + sa.select(SessionRow) + .select_from(j) + .where(SessionDependencyRow.depends_on.in_(session_ids)) + ) + return await batch_multiresult_in_session( + ctx, + db_sess, + query, + cls, + session_ids, + lambda row: row.id, + ) + @classmethod async def get_accessible_node( cls, @@ -325,7 +456,7 @@ async def get_accessible_connection( first: int | None = None, before: str | None = None, last: int | None = None, - ) -> ConnectionResolverResult[ComputeSessionNode]: + ) -> ConnectionResolverResult[Self]: graph_ctx: GraphQueryContext = info.context _filter_arg = ( FilterExprArg(filter_expr, QueryFilterParser(_queryfilter_fieldspec)) @@ -373,7 +504,7 @@ async def get_accessible_connection( async with graph_ctx.db.begin_readonly_session(db_conn) as db_session: session_rows = (await db_session.scalars(query)).all() total_cnt = await db_session.scalar(cnt_query) - result: list[ComputeSessionNode] = [ + result: list[Self] = [ cls.from_row( graph_ctx, row, @@ -412,7 +543,7 @@ async def mutate_and_get_payload( root: Any, info: graphene.ResolveInfo, **input, - ) -> ModifyComputeSession: + ) -> Self: graph_ctx: GraphQueryContext = info.context _, raw_session_id = cast(ResolvedGlobalID, input["id"]) session_id = SessionId(uuid.UUID(raw_session_id)) @@ -434,7 +565,7 @@ async def mutate_and_get_payload( ) result = await db_sess.execute(query) session_row = result.fetchone() - return ModifyComputeSession( + return cls( ComputeSessionNode.from_row(graph_ctx, session_row), input.get("client_mutation_id"), )