Skip to content

Commit

Permalink
Add gql WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Mar 17, 2024
1 parent b5c82b7 commit 7505e1e
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 9 deletions.
10 changes: 5 additions & 5 deletions python.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4214,19 +4214,19 @@
"artifacts": [
{
"algorithm": "sha256",
"hash": "ec5c8db0af0207aabbea93f9c25b859103e1ce4303f0dc70b0c7f4f927417310",
"url": "https://files.pythonhosted.org/packages/62/b0/9493d0429c293eb2d8a1318d739fd18608dfcdc9e69257c4e1161248ee34/types_setuptools-69.2.0.20240316-py3-none-any.whl"
"hash": "cf91ff7c87ab7bf0625c3f0d4d90427c9da68561f3b0feab77977aaf0bbf7531",
"url": "https://files.pythonhosted.org/packages/1f/22/904934a3344fa5f332ecab887003f3f033c1272432a4af877007b75b0bd3/types_setuptools-69.2.0.20240317-py3-none-any.whl"
},
{
"algorithm": "sha256",
"hash": "58c3172aef76a368d6434bc406c08adec91779770b6bdda59ae56af43f29a61e",
"url": "https://files.pythonhosted.org/packages/37/f1/72d90c354f15ba5d465293e59aaf80cc11fe070c9609ace09be3b1a97a87/types-setuptools-69.2.0.20240316.tar.gz"
"hash": "b607c4c48842ef3ee49dc0c7fe9c1bad75700b071e1018bb4d7e3ac492d47048",
"url": "https://files.pythonhosted.org/packages/2d/06/0de7b539346aaa8758b3c80375c4841dc2764ef92c5e743f1ebe9789da54/types-setuptools-69.2.0.20240317.tar.gz"
}
],
"project_name": "types-setuptools",
"requires_dists": [],
"requires_python": ">=3.8",
"version": "69.2.0.20240316"
"version": "69.2.0.20240317"
},
{
"artifacts": [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
from __future__ import annotations

import logging
import uuid
from typing import Sequence
from typing import TYPE_CHECKING, Any, Sequence

import graphene
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.exc import NoResultFound

from ai.backend.common.logging_utils import BraceStyleAdapter
from ai.backend.manager.models.base import GUID, Base, IDColumn

if TYPE_CHECKING:
from .gql import GraphQueryContext

from .base import GUID, Base, IDColumn, privileged_mutation
from .gql_relay import AsyncNode
from .user import UserRole

log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore

__all__: Sequence[str] = ("AssociationContainerRegistriesUsers",)
__all__: Sequence[str] = (
"AssociationContainerRegistriesUsersRow",
"AssociationContainerRegistriesUsers",
"CreateAssociationContainerRegistriesUsersInput",
"CreateAssociationContainerRegistriesUsers",
"DeleteAssociationContainerRegistriesUsersInput",
"DeleteAssociationContainerRegistriesUsers",
)


class AssociationContainerRegistriesUsers(Base):
class AssociationContainerRegistriesUsersRow(Base):
__tablename__ = "association_container_registries_users"
id = IDColumn()
container_registry_id = sa.Column(
Expand All @@ -31,3 +49,192 @@ class AssociationContainerRegistriesUsers(Base):
def __init__(self, container_registry_id: uuid.UUID, user_id: uuid.UUID):
self.container_registry_id = container_registry_id
self.user_id = user_id

@classmethod
async def get(
cls,
session: AsyncSession,
id: str | uuid.UUID,
) -> "AssociationContainerRegistriesUsers":
query = sa.select(AssociationContainerRegistriesUsers).where(
AssociationContainerRegistriesUsers.id == id
)
result = await session.execute(query)
row = result.scalar()
if row is None:
raise NoResultFound
return row

@classmethod
async def get_by_registry_id_and_user_id(
cls,
session: AsyncSession,
registry_id: str | uuid.UUID,
user_id: str | uuid.UUID,
) -> "AssociationContainerRegistriesUsers":
query = sa.select(AssociationContainerRegistriesUsers).where(
AssociationContainerRegistriesUsers.user_id == user_id
and AssociationContainerRegistriesUsers.container_registry_id == registry_id
)
result = await session.execute(query)
row = result.scalar()
if row is None:
raise NoResultFound
return row


class AssociationContainerRegistriesUsers(graphene.ObjectType):
class Meta:
interfaces = (AsyncNode,)

id = graphene.ID()
container_registry_id = graphene.UUID(required=True)
user_id = graphene.UUID(required=True)

@classmethod
def from_row(
cls, ctx: GraphQueryContext, row: AssociationContainerRegistriesUsersRow
) -> "AssociationContainerRegistriesUsers":
return cls(
id=row.id,
container_registry_id=row.container_registry_id,
user_id=row.user_id,
)

@classmethod
async def load(
cls, ctx: GraphQueryContext, id: str | uuid.UUID
) -> "AssociationContainerRegistriesUsers":
async with ctx.db.begin_readonly_session() as session:
return cls.from_row(
ctx,
await AssociationContainerRegistriesUsersRow.get(
session,
id,
),
)

@classmethod
async def load_by_registry_id_and_user_id(
cls, ctx: GraphQueryContext, registry_id: str | uuid.UUID, user_id: str | uuid.UUID
) -> "AssociationContainerRegistriesUsers":
async with ctx.db.begin_readonly_session() as session:
return cls.from_row(
ctx,
await AssociationContainerRegistriesUsersRow.get_by_registry_id_and_user_id(
session,
registry_id,
user_id,
),
)

@classmethod
async def list_by_registry_id(
cls, ctx: GraphQueryContext, registry_id: str | uuid.UUID
) -> list["AssociationContainerRegistriesUsers"]:
async with ctx.db.begin_readonly_session() as session:
query = sa.select(AssociationContainerRegistriesUsers).where(
AssociationContainerRegistriesUsers.container_registry_id == registry_id
)
result = await session.execute(query)
return [cls.from_row(ctx, row) for row in result.scalars().all()]

# @classmethod
# async def list_by_project(
# cls, ctx: GraphQueryContext, project: str
# ) -> list["AssociationContainerRegistriesUsers"]:
# async with ctx.db.begin_readonly_session() as session:
# query = sa.select(AssociationContainerRegistriesUsers).where(
# AssociationContainerRegistriesUsers.registry_id == registry_id
# )
# result = await session.execute(query)
# return [cls.from_row(ctx, row) for row in result.scalars().all()]

@classmethod
async def list_by_user_id(
cls, ctx: GraphQueryContext, user_id: str | uuid.UUID
) -> list["AssociationContainerRegistriesUsers"]:
async with ctx.db.begin_readonly_session() as session:
query = sa.select(AssociationContainerRegistriesUsers).where(
AssociationContainerRegistriesUsers.user_id == user_id
)
result = await session.execute(query)
return [cls.from_row(ctx, row) for row in result.scalars().all()]


class CreateAssociationContainerRegistriesUsersInput(graphene.InputObjectType):
container_registry_id = graphene.UUID(required=True)
user_id = graphene.UUID(required=True)


class CreateAssociationContainerRegistriesUsers(graphene.Mutation):
allowed_roles = (UserRole.SUPERADMIN,)
id = graphene.UUID(required=True)
association = graphene.Field(AssociationContainerRegistriesUsers)

class Arguments:
props = CreateAssociationContainerRegistriesUsersInput(required=True)

@classmethod
@privileged_mutation(
UserRole.SUPERADMIN,
lambda id, **kwargs: (None, id),
)
async def mutate(
cls, root, info: graphene.ResolveInfo, props: CreateAssociationContainerRegistriesUsersInput
) -> "CreateAssociationContainerRegistriesUsers":
ctx: GraphQueryContext = info.context

input_config: dict[str, Any] = {
"container_registry_id": props.container_registry_id,
"user_id": props.user_id,
}

async with ctx.db.begin_session() as db_session:
association_row = AssociationContainerRegistriesUsersRow(**input_config)
db_session.add(association_row)
await db_session.flush()
await db_session.refresh(association_row)

return cls(
id=association_row.id,
association=AssociationContainerRegistriesUsers.from_row(ctx, association_row),
)


class DeleteAssociationContainerRegistriesUsersInput(graphene.InputObjectType):
id = graphene.String(
required=True, description="Object id. Can be either global id or object id"
)


class DeleteAssociationContainerRegistriesUsers(graphene.Mutation):
allowed_roles = (UserRole.SUPERADMIN,)
id = graphene.UUID(required=True)

class Arguments:
props = DeleteAssociationContainerRegistriesUsersInput(required=True)

@classmethod
@privileged_mutation(
UserRole.SUPERADMIN,
lambda id, **kwargs: (None, id),
)
async def mutate(
cls, root, info: graphene.ResolveInfo, props: DeleteAssociationContainerRegistriesUsersInput
) -> "DeleteAssociationContainerRegistriesUsers":
ctx: GraphQueryContext = info.context

async with ctx.db.begin_session() as db_session:
_, _id = AsyncNode.resolve_global_id(info, props.id)
association_id = uuid.UUID(_id) if _id else uuid.UUID(props.id)
association_row = await AssociationContainerRegistriesUsers.load(ctx, association_id)
await db_session.execute(
sa.delete(AssociationContainerRegistriesUsers).where(
AssociationContainerRegistriesUsers.id == association_id
)
)

return cls(
association=AssociationContainerRegistriesUsers.from_row(ctx, association_row),
)
61 changes: 61 additions & 0 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,20 @@
from graphene.types.inputobjecttype import set_input_object_type_default_value
from graphql import OperationType, Undefined

from ai.backend.manager.models.association_container_registries_users import (
AssociationContainerRegistriesUsers,
)

set_input_object_type_default_value(Undefined)

from ai.backend.common.types import QuotaScopeID
from ai.backend.manager.defs import DEFAULT_IMAGE_ARCH
from ai.backend.manager.models.gql_relay import AsyncNode, ConnectionResolverResult

from .association_container_registries_users import (
CreateAssociationContainerRegistriesUsers,
DeleteAssociationContainerRegistriesUsers,
)
from .container_registry import (
ContainerRegistry,
ContainerRegistryConnection,
Expand Down Expand Up @@ -247,6 +255,9 @@ class Mutations(graphene.ObjectType):
modify_container_registry = ModifyContainerRegistry.Field()
delete_container_registry = DeleteContainerRegistry.Field()

create_container_registries_users = CreateAssociationContainerRegistriesUsers.Field()
delete_container_registries_users = DeleteAssociationContainerRegistriesUsers.Field()

modify_endpoint = ModifyEndpoint.Field()


Expand Down Expand Up @@ -1013,11 +1024,31 @@ async def resolve_image(
info: graphene.ResolveInfo,
reference: str,
architecture: str,
user_id: str | None = None,
# project: str | None = None,
) -> Image:
ctx: GraphQueryContext = info.context
client_role = ctx.user["role"]
client_domain = ctx.user["domain_name"]
item = await Image.load_item(info.context, reference, architecture)

if registry_id := item.registry_id:
cr = await ContainerRegistry.load(ctx, registry_id)
if not cr.config.is_global:
if user_id:
association = (
await AssociationContainerRegistriesUsers.load_by_registry_id_and_user_id(
ctx, registry_id, user_id
)
)
if not association:
raise ImageNotFound

# if project:
# associations = await AssociationContainerRegistriesUsers.list_by_registry_id(
# project
# )

if client_role == UserRole.SUPERADMIN:
pass
elif client_role in (UserRole.ADMIN, UserRole.USER):
Expand Down Expand Up @@ -2057,6 +2088,36 @@ async def resolve_quota_scope(
storage_host_name=storage_host_name,
)

@staticmethod
@privileged_query(UserRole.SUPERADMIN)
async def resolve_assocation_container_registries_users(
root: Any,
info: graphene.ResolveInfo,
id: graphene.UUID,
) -> AssociationContainerRegistriesUsers:
ctx: GraphQueryContext = info.context
return await AssociationContainerRegistriesUsers.load(ctx, id)

@staticmethod
@privileged_query(UserRole.SUPERADMIN)
async def resolve_assocation_container_registries_users_by_registry_id(
root: Any,
info: graphene.ResolveInfo,
registry_id: graphene.UUID,
) -> Sequence[AssociationContainerRegistriesUsers]:
ctx: GraphQueryContext = info.context
return await AssociationContainerRegistriesUsers.list_by_registry_id(ctx, registry_id)

@staticmethod
@privileged_query(UserRole.SUPERADMIN)
async def resolve_assocation_container_registries_users_by_user_id(
root: Any,
info: graphene.ResolveInfo,
user_id: graphene.UUID,
) -> Sequence[AssociationContainerRegistriesUsers]:
ctx: GraphQueryContext = info.context
return await AssociationContainerRegistriesUsers.list_by_registry_id(ctx, user_id)

@staticmethod
@privileged_query(UserRole.SUPERADMIN)
async def resolve_container_registry(
Expand Down
3 changes: 3 additions & 0 deletions src/ai/backend/manager/models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,8 @@ class Image(graphene.ObjectType):
supported_accelerators = graphene.List(graphene.String)
installed = graphene.Boolean()
installed_agents = graphene.List(graphene.String)
registry_id = graphene.UUID()

# legacy field
hash = graphene.String()

Expand Down Expand Up @@ -526,6 +528,7 @@ def populate_row(
supported_accelerators=(row.accelerators or "").split(","),
installed=len(installed_agents) > 0,
installed_agents=installed_agents if not hide_agents else None,
registry_id=row.registry_id,
# legacy
hash=row.config_digest,
)
Expand Down

0 comments on commit 7505e1e

Please sign in to comment.