From 3318e874d007a4e080b938a6385d8826baf9c01b Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Fri, 2 Aug 2024 12:12:49 +0000 Subject: [PATCH] chore: Add frament --- changes/2615.feature.md | 1 + .../manager/models/container_registry.py | 360 ++++++++++-------- src/ai/backend/manager/models/gql.py | 25 +- 3 files changed, 206 insertions(+), 180 deletions(-) create mode 100644 changes/2615.feature.md diff --git a/changes/2615.feature.md b/changes/2615.feature.md new file mode 100644 index 00000000000..b4ff29372fd --- /dev/null +++ b/changes/2615.feature.md @@ -0,0 +1 @@ +Implement ID-based client workflow to ContainerRegistry API. diff --git a/src/ai/backend/manager/models/container_registry.py b/src/ai/backend/manager/models/container_registry.py index 47b9fcc84e7..7c6b73a152e 100644 --- a/src/ai/backend/manager/models/container_registry.py +++ b/src/ai/backend/manager/models/container_registry.py @@ -3,12 +3,14 @@ import enum import logging import uuid -from typing import TYPE_CHECKING, Any, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, cast import graphene import graphql import sqlalchemy as sa from graphene.types import Scalar +from graphql import Undefined from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.exc import NoResultFound @@ -22,8 +24,6 @@ OrderExprArg, StrEnumType, generate_sql_info_for_gql_connection, - privileged_mutation, - set_if_set, ) from .gql_relay import AsyncNode, Connection, ConnectionResolverResult from .minilang.ordering import OrderSpecItem, QueryOrderParser @@ -51,25 +51,6 @@ class ContainerRegistryType(enum.StrEnum): LOCAL = "local" -class ContainerRegistryTypeField(Scalar): - """Added in 24.09.0.""" - - allowed_values = tuple(t.value for t in ContainerRegistryType) - - @staticmethod - def serialize(val: ContainerRegistryType) -> str: - return val.value - - @staticmethod - def parse_literal(node, _variables=None): - if isinstance(node, graphql.language.ast.StringValueNode): - return ContainerRegistryType(node.value) - - @staticmethod - def parse_value(value: str) -> ContainerRegistryType: - return ContainerRegistryType(value) - - class ContainerRegistryRow(Base): __tablename__ = "container_registries" id = IDColumn() @@ -122,47 +103,94 @@ async def list_by_registry_name( return rows +class ContainerRegistryTypeField(Scalar): + """Added in 24.09.0.""" + + allowed_values = tuple(t.value for t in ContainerRegistryType) + + @staticmethod + def serialize(val: ContainerRegistryType) -> str: + return val.value + + @staticmethod + def parse_literal(node, _variables=None): + if isinstance(node, graphql.language.ast.StringValueNode): + return ContainerRegistryType(node.value) + + @staticmethod + def parse_value(value: str) -> ContainerRegistryType: + return ContainerRegistryType(value) + + +# Legacy class CreateContainerRegistryInput(graphene.InputObjectType): - url = graphene.String(required=True, description="Added in 24.09.0.") - type = ContainerRegistryTypeField( - required=True, - description=f"Registry type. One of {ContainerRegistryTypeField.allowed_values}. Added in 24.09.0.", - ) - registry_name = graphene.String(required=True, description="Added in 24.09.0.") - is_global = graphene.Boolean(description="Added in 24.09.0.") - project = graphene.String(description="Added in 24.09.0.") - username = graphene.String(description="Added in 24.09.0.") - password = graphene.String(description="Added in 24.09.0.") - ssl_verify = graphene.Boolean(description="Added in 24.09.0.") + url = graphene.String(required=True) + type = graphene.String(required=True) + project = graphene.List(graphene.String) + username = graphene.String() + password = graphene.String() + ssl_verify = graphene.Boolean() +# Legacy class ModifyContainerRegistryInput(graphene.InputObjectType): - id = graphene.String( - required=True, - description="Object id. Can be either global id or object id. Added in 24.09.0.", - ) - url = graphene.String(description="Added in 24.09.0.") - type = ContainerRegistryTypeField( - description=f"Registry type. One of {ContainerRegistryTypeField.allowed_values}. Added in 24.09.0." - ) - registry_name = graphene.String(description="Added in 24.09.0.") + url = graphene.String() + type = graphene.String() + project = graphene.List(graphene.String) + username = graphene.String() + password = graphene.String() + ssl_verify = graphene.Boolean() + + +# Legacy +class ContainerRegistryConfig(graphene.ObjectType): + url = graphene.String(required=True) + type = graphene.String(required=True) + project = graphene.List(graphene.String) + username = graphene.String() + password = graphene.String() + ssl_verify = graphene.Boolean() is_global = graphene.Boolean(description="Added in 24.09.0.") - project = graphene.String(description="Added in 24.09.0.") - username = graphene.String(description="Added in 24.09.0.") - password = graphene.String(description="Added in 24.09.0.") - ssl_verify = graphene.Boolean(description="Added in 24.09.0.") -class DeleteContainerRegistryInput(graphene.InputObjectType): - """Added in 24.09.0.""" +# Legacy +class ContainerRegistry(graphene.ObjectType): + hostname = graphene.String() + config = graphene.Field(ContainerRegistryConfig) - id = graphene.String( - required=True, - description="Object id. Can be either global id or object id. Added in 24.09.0.", - ) + class Meta: + interfaces = (AsyncNode,) + @classmethod + async def load_by_hostname(cls, ctx: GraphQueryContext, hostname: str) -> ContainerRegistry: + async with ctx.db.begin_readonly_session() as session: + return cls.from_row( + ctx, + await ContainerRegistryRow.get_by_hostname( + session, + hostname, + ), + ) -class ContainerRegistryConfig(graphene.ObjectType): + @classmethod + async def load_all( + cls, + ctx: GraphQueryContext, + ) -> Sequence[ContainerRegistry]: + async with ctx.db.begin_readonly_session() as session: + rows = await session.execute(sa.select(ContainerRegistryRow)) + return [cls.from_row(ctx, row) for row in rows] + + +class ContainerRegistryNode(graphene.ObjectType): + class Meta: + interfaces = (AsyncNode,) + description = "Added in 24.09.0." + + row_id = graphene.UUID( + description="Added in 24.09.0. The undecoded UUID type id of DB container_registries row." + ) + name = graphene.String() url = graphene.String(required=True, description="Added in 24.09.0.") type = ContainerRegistryTypeField(required=True, description="Added in 24.09.0.") registry_name = graphene.String(required=True, description="Added in 24.09.0.") @@ -172,31 +200,22 @@ class ContainerRegistryConfig(graphene.ObjectType): password = graphene.String(description="Added in 24.09.0.") ssl_verify = graphene.Boolean(description="Added in 24.09.0.") - -class ContainerRegistry(graphene.ObjectType): - class Meta: - interfaces = (AsyncNode,) - - config = graphene.Field(ContainerRegistryConfig) - _queryfilter_fieldspec: dict[str, FieldSpecItem] = { - "id": ("id", None), + "row_id": ("id", None), "registry_name": ("registry_name", None), } - _queryorder_colmap: dict[str, OrderSpecItem] = { - "id": ("id", None), + "row_id": ("id", None), "registry_name": ("registry_name", None), } @classmethod - async def get_node(cls, info: graphene.ResolveInfo, id: str) -> ContainerRegistry: + async def get_node(cls, info: graphene.ResolveInfo, id: str) -> ContainerRegistryNode: graph_ctx: GraphQueryContext = info.context - _, reg_id = AsyncNode.resolve_global_id(info, id) select_stmt = sa.select(ContainerRegistryRow).where(ContainerRegistryRow.id == reg_id) async with graph_ctx.db.begin_readonly_session() as db_session: - reg_row = await db_session.scalar(select_stmt) + reg_row = cast(ContainerRegistryRow | None, await db_session.scalar(select_stmt)) if reg_row is None: raise ValueError(f"Container registry not found (id: {reg_id})") return cls.from_row(graph_ctx, reg_row) @@ -226,8 +245,8 @@ async def get_connection( ) ( query, + cnt_query, _, - conditions, cursor, pagination_order, page_size, @@ -243,99 +262,82 @@ async def get_connection( before=before, last=last, ) - cnt_query = sa.select(sa.func.count()).select_from(ContainerRegistryRow) - for cond in conditions: - cnt_query = cnt_query.where(cond) async with graph_ctx.db.begin_readonly_session() as db_session: - reg_rows = (await db_session.scalars(query)).all() - result = [cls.from_row(graph_ctx, row) for row in reg_rows] - + reg_rows = await db_session.scalars(query) total_cnt = await db_session.scalar(cnt_query) - return ConnectionResolverResult(result, cursor, pagination_order, page_size, total_cnt) + result = [cls.from_row(graph_ctx, cast(ContainerRegistryRow, row)) for row in reg_rows] + return ConnectionResolverResult(result, cursor, pagination_order, page_size, total_cnt) @classmethod - def from_row(cls, ctx: GraphQueryContext, row: ContainerRegistryRow) -> ContainerRegistry: + def from_row(cls, ctx: GraphQueryContext, row: ContainerRegistryRow) -> ContainerRegistryNode: return cls( - id=row.id, - config=ContainerRegistryConfig( - url=row.url, - type=row.type, - registry_name=row.registry_name, - project=row.project, - username=row.username, - password=PASSWORD_PLACEHOLDER if row.password is not None else None, - ssl_verify=row.ssl_verify, - is_global=row.is_global, - ), + row_id=row.id, + url=row.url, + type=row.type, + registry_name=row.registry_name, + project=row.project, + username=row.username, + password=PASSWORD_PLACEHOLDER if row.password is not None else None, + ssl_verify=row.ssl_verify, + is_global=row.is_global, ) - @classmethod - async def load(cls, ctx: GraphQueryContext, id: str | uuid.UUID) -> ContainerRegistry: - async with ctx.db.begin_readonly_session() as session: - return cls.from_row( - ctx, - await ContainerRegistryRow.get( - session, - id, - ), - ) - - @classmethod - async def load_all( - cls, - ctx: GraphQueryContext, - ) -> Sequence[ContainerRegistry]: - async with ctx.db.begin_readonly_session() as session: - rows = await session.execute(sa.select(ContainerRegistryRow)) - return [cls.from_row(ctx, row) for row in rows] - - @classmethod - async def list_by_registry_name( - cls, - ctx: GraphQueryContext, - registry_name: str, - ) -> Sequence[ContainerRegistry]: - async with ctx.db.begin_readonly_session() as session: - rows = await ContainerRegistryRow.list_by_registry_name(session, registry_name) - return [cls.from_row(ctx, row) for row in rows] - class ContainerRegistryConnection(Connection): """Added in 24.09.0.""" class Meta: - node = ContainerRegistry + node = ContainerRegistryNode class CreateContainerRegistry(graphene.Mutation): allowed_roles = (UserRole.SUPERADMIN,) - id = graphene.UUID(required=True) - container_registry = graphene.Field(ContainerRegistry) + container_registry = graphene.Field(ContainerRegistryNode) class Arguments: - props = CreateContainerRegistryInput(required=True, description="Added in 24.09.0.") + url = graphene.String(required=True, description="Added in 24.09.0.") + type = ContainerRegistryTypeField( + required=True, + description=f"Added in 24.09.0. Registry type. One of {ContainerRegistryTypeField.allowed_values}.", + ) + registry_name = graphene.String(required=True, description="Added in 24.09.0.") + is_global = graphene.Boolean(description="Added in 24.09.0.") + project = graphene.String(description="Added in 24.09.0.") + username = graphene.String(description="Added in 24.09.0.") + password = graphene.String(description="Added in 24.09.0.") + ssl_verify = graphene.Boolean(description="Added in 24.09.0.") @classmethod - @privileged_mutation( - UserRole.SUPERADMIN, - lambda id, **kwargs: (None, id), - ) async def mutate( - cls, root, info: graphene.ResolveInfo, props: CreateContainerRegistryInput + cls, + root, + info: graphene.ResolveInfo, + url: str, + type: ContainerRegistryType, + registry_name: str, + is_global: bool, + project: str, + username: str, + password: str, + ssl_verify: bool, ) -> CreateContainerRegistry: ctx: GraphQueryContext = info.context input_config: dict[str, Any] = { - "registry_name": props.registry_name, - "url": props.url, - "type": props.type, + "registry_name": registry_name, + "url": url, + "type": type, } - set_if_set(props, input_config, "project") - set_if_set(props, input_config, "username") - set_if_set(props, input_config, "password") - set_if_set(props, input_config, "ssl_verify") - set_if_set(props, input_config, "is_global") + def _set_if_set(name: str, val: Any) -> None: + if val is not Undefined: + input_config[name] = val + + _set_if_set("project", project) + _set_if_set("username", username) + _set_if_set("password", password) + _set_if_set("ssl_verify", ssl_verify) + _set_if_set("is_global", is_global) async with ctx.db.begin_session() as db_session: reg_row = ContainerRegistryRow(**input_config) @@ -344,44 +346,64 @@ async def mutate( await db_session.refresh(reg_row) return cls( - id=reg_row.id, - container_registry=ContainerRegistry.from_row(ctx, reg_row), + container_registry=ContainerRegistryNode.from_row(ctx, reg_row), ) class ModifyContainerRegistry(graphene.Mutation): allowed_roles = (UserRole.SUPERADMIN,) - container_registry = graphene.Field(ContainerRegistry) + container_registry = graphene.Field(ContainerRegistryNode) class Arguments: - props = ModifyContainerRegistryInput(required=True) + id = graphene.String( + required=True, + description="Object id. Can be either global id or object id. Added in 24.09.0.", + ) + url = graphene.String(description="Added in 24.09.0.") + type = ContainerRegistryTypeField( + description=f"Registry type. One of {ContainerRegistryTypeField.allowed_values}. Added in 24.09.0." + ) + registry_name = graphene.String(description="Added in 24.09.0.") + is_global = graphene.Boolean(description="Added in 24.09.0.") + project = graphene.String(description="Added in 24.09.0.") + username = graphene.String(description="Added in 24.09.0.") + password = graphene.String(description="Added in 24.09.0.") + ssl_verify = graphene.Boolean(description="Added in 24.09.0.") @classmethod - @privileged_mutation( - UserRole.SUPERADMIN, - lambda id, **kwargs: (None, id), - ) async def mutate( cls, root, info: graphene.ResolveInfo, - props: ModifyContainerRegistryInput, + id: str, + url: str, + type: ContainerRegistryType, + registry_name: str, + is_global: bool, + project: str, + username: str, + password: str, + ssl_verify: bool, ) -> ModifyContainerRegistry: ctx: GraphQueryContext = info.context input_config: dict[str, Any] = {} - set_if_set(props, input_config, "url") - set_if_set(props, input_config, "type") - set_if_set(props, input_config, "registry_name") - set_if_set(props, input_config, "is_global") - set_if_set(props, input_config, "username") - set_if_set(props, input_config, "password") - set_if_set(props, input_config, "project") - set_if_set(props, input_config, "ssl_verify") + def _set_if_set(name: str, val: Any) -> None: + if val is not Undefined: + input_config[name] = val - _, _id = AsyncNode.resolve_global_id(info, props.id) - reg_id = uuid.UUID(_id) if _id else uuid.UUID(props.id) + _set_if_set("url", url) + _set_if_set("type", type) + _set_if_set("registry_name", registry_name) + _set_if_set("username", username) + _set_if_set("password", password) + _set_if_set("project", project) + _set_if_set("ssl_verify", ssl_verify) + _set_if_set("is_global", is_global) + + _, _id = AsyncNode.resolve_global_id(info, id) + reg_id = uuid.UUID(_id) if _id else uuid.UUID(id) async with ctx.db.begin_session() as session: stmt = sa.select(ContainerRegistryRow).where(ContainerRegistryRow.id == reg_id) @@ -391,34 +413,42 @@ async def mutate( for field, val in input_config.items(): setattr(reg_row, field, val) - return cls(container_registry=ContainerRegistry.from_row(ctx, reg_row)) + return cls(container_registry=ContainerRegistryNode.from_row(ctx, reg_row)) class DeleteContainerRegistry(graphene.Mutation): allowed_roles = (UserRole.SUPERADMIN,) - container_registry = graphene.Field(ContainerRegistry) + container_registry = graphene.Field(ContainerRegistryNode) class Arguments: - props = DeleteContainerRegistryInput(required=True) + id = graphene.String( + required=True, + description="Object id. Can be either global id or object id. Added in 24.09.0.", + ) @classmethod - @privileged_mutation( - UserRole.SUPERADMIN, - lambda id, **kwargs: (None, id), - ) async def mutate( cls, root, info: graphene.ResolveInfo, - props: DeleteContainerRegistryInput, + id: str, ) -> DeleteContainerRegistry: ctx: GraphQueryContext = info.context - _, _id = AsyncNode.resolve_global_id(info, props.id) - reg_id = uuid.UUID(_id) if _id else uuid.UUID(props.id) - container_registry = await ContainerRegistry.load(ctx, reg_id) - async with ctx.db.begin_session() as session: - await session.execute( + _, _id = AsyncNode.resolve_global_id(info, id) + reg_id = uuid.UUID(_id) if _id else uuid.UUID(id) + async with ctx.db.begin_session() as db_session: + reg_row = await ContainerRegistry.load(ctx, reg_id) + reg_row = cast( + ContainerRegistryRow | None, + await db_session.scalar( + sa.select(ContainerRegistryRow).where(ContainerRegistryRow.id == reg_id) + ), + ) + if reg_row is None: + raise ValueError(f"Container registry not found (id:{reg_id})") + container_registry = ContainerRegistryNode.from_row(ctx, reg_row) + await db_session.execute( sa.delete(ContainerRegistryRow).where(ContainerRegistryRow.id == reg_id) ) diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index 9ac07449c97..f878a7bb6cf 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -19,6 +19,7 @@ from .container_registry import ( ContainerRegistry, ContainerRegistryConnection, + ContainerRegistryNode, CreateContainerRegistry, DeleteContainerRegistry, ModifyContainerRegistry, @@ -707,18 +708,12 @@ class Queries(graphene.ObjectType): quota_scope_id=graphene.String(required=True), ) - container_registry = graphene.Field( - ContainerRegistry, id=graphene.UUID(required=True), description="Added in 24.09.0." - ) + container_registry = graphene.Field(ContainerRegistry, hostname=graphene.String(required=True)) - container_registries = graphene.List( - ContainerRegistry, - registry_name=graphene.String(required=True), - description="Added in 24.09.0.", - ) + container_registries = graphene.List(ContainerRegistry) container_registry_node = graphene.Field( - ContainerRegistry, id=graphene.String(required=True), description="Added in 24.09.0." + ContainerRegistryNode, id=graphene.String(required=True), description="Added in 24.09.0." ) container_registry_nodes = PaginatedConnectionField( @@ -2210,10 +2205,10 @@ async def resolve_quota_scope( async def resolve_container_registry( root: Any, info: graphene.ResolveInfo, - id: graphene.UUID, + hostname: str, ) -> ContainerRegistry: ctx: GraphQueryContext = info.context - return await ContainerRegistry.load(ctx, id) + return await ContainerRegistry.load_by_hostname(ctx, hostname) @staticmethod @privileged_query(UserRole.SUPERADMIN) @@ -2223,7 +2218,7 @@ async def resolve_container_registries( registry_name: graphene.String, ) -> Sequence[ContainerRegistry]: ctx: GraphQueryContext = info.context - return await ContainerRegistry.list_by_registry_name(ctx, registry_name) + return await ContainerRegistry.load_all(ctx) @staticmethod @privileged_query(UserRole.SUPERADMIN) @@ -2231,8 +2226,8 @@ async def resolve_container_registry_node( root: Any, info: graphene.ResolveInfo, id: str, - ) -> ContainerRegistry: - return await ContainerRegistry.get_node(info, id) + ) -> ContainerRegistryNode: + return await ContainerRegistryNode.get_node(info, id) @staticmethod @privileged_query(UserRole.SUPERADMIN) @@ -2248,7 +2243,7 @@ async def resolve_container_registry_nodes( before: str | None = None, last: int | None = None, ) -> ConnectionResolverResult: - return await ContainerRegistry.get_connection( + return await ContainerRegistryNode.get_connection( info, filter, order,