Skip to content

Commit

Permalink
feat: expose email address of model service owner (#1831)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyujin-cho authored Jan 11, 2024
1 parent b835aff commit fbea3be
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 5 deletions.
1 change: 1 addition & 0 deletions changes/1831.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Introduce `endpoint.created_user_email` and `endpoint.session_owner_email` GQL field
2 changes: 2 additions & 0 deletions src/ai/backend/manager/api/admin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import traceback
from typing import TYPE_CHECKING, Any, Iterable, Tuple

import aiohttp_cors
Expand Down Expand Up @@ -96,6 +97,7 @@ async def _handle_gql_common(request: web.Request, params: Any) -> ExecutionResu
else:
errmsg = {"message": str(e)}
log.error("ADMIN.GQL Exception: {}", errmsg)
log.debug("{}", "".join(traceback.format_exception(e)))
return result


Expand Down
50 changes: 45 additions & 5 deletions src/ai/backend/manager/models/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ class EndpointRow(Base):
routings = relationship("RoutingRow", back_populates="endpoint_row")
tokens = relationship("EndpointTokenRow", back_populates="endpoint_row")
image_row = relationship("ImageRow", back_populates="endpoints")
created_user_row = relationship(
"UserRow", back_populates="created_endpoints", foreign_keys="EndpointRow.created_user"
)
session_owner_row = relationship(
"UserRow", back_populates="owned_endpoints", foreign_keys="EndpointRow.session_owner"
)

def __init__(
self,
Expand Down Expand Up @@ -204,6 +210,8 @@ async def get(
load_routes=False,
load_tokens=False,
load_image=False,
load_created_user=False,
load_session_owner=False,
) -> "EndpointRow":
"""
:raises: sqlalchemy.orm.exc.NoResultFound
Expand All @@ -215,6 +223,10 @@ async def get(
query = query.options(selectinload(EndpointRow.tokens))
if load_image:
query = query.options(selectinload(EndpointRow.image_row))
if load_created_user:
query = query.options(selectinload(EndpointRow.created_user_row))
if load_session_owner:
query = query.options(selectinload(EndpointRow.session_owner_row))
if project:
query = query.filter(EndpointRow.project == project)
if domain:
Expand All @@ -237,6 +249,8 @@ async def list(
load_routes=False,
load_image=False,
load_tokens=False,
load_created_user=False,
load_session_owner=False,
status_filter=[EndpointLifecycle.CREATED],
) -> List["EndpointRow"]:
query = (
Expand All @@ -250,6 +264,10 @@ async def list(
query = query.options(selectinload(EndpointRow.tokens))
if load_image:
query = query.options(selectinload(EndpointRow.image_row))
if load_created_user:
query = query.options(selectinload(EndpointRow.created_user_row))
if load_session_owner:
query = query.options(selectinload(EndpointRow.session_owner_row))
if project:
query = query.filter(EndpointRow.project == project)
if domain:
Expand Down Expand Up @@ -368,8 +386,16 @@ class Meta:
url = graphene.String()
model = graphene.UUID()
model_mount_destiation = graphene.String()
created_user = graphene.UUID()
session_owner = graphene.UUID()
created_user = graphene.UUID(
deprecation_reason="Deprecated since 23.09.8; use `created_user_id`"
)
created_user_email = graphene.String(description="Added at 23.09.8")
created_user_id = graphene.UUID(description="Added at 23.09.8")
session_owner = graphene.UUID(
deprecation_reason="Deprecated since 23.09.8; use `session_owner_id`"
)
session_owner_email = graphene.String(description="Added at 23.09.8")
session_owner_id = graphene.UUID(description="Added at 23.09.8")
tag = graphene.String()
startup_command = graphene.String()
bootstrap_script = graphene.String()
Expand Down Expand Up @@ -410,7 +436,11 @@ async def from_row(
model=row.model,
model_mount_destiation=row.model_mount_destiation,
created_user=row.created_user,
created_user_id=row.created_user,
created_user_email=row.created_user_row.email,
session_owner=row.session_owner,
session_owner_id=row.session_owner,
session_owner_email=row.session_owner_row.email,
tag=row.tag,
startup_command=row.startup_command,
bootstrap_script=row.bootstrap_script,
Expand Down Expand Up @@ -477,6 +507,8 @@ async def load_slice(
.offset(offset)
.options(selectinload(EndpointRow.image_row))
.options(selectinload(EndpointRow.routings))
.options(selectinload(EndpointRow.created_user_row))
.options(selectinload(EndpointRow.session_owner_row))
.order_by(sa.desc(EndpointRow.created_at))
.filter(
EndpointRow.lifecycle_stage.in_([
Expand Down Expand Up @@ -514,9 +546,15 @@ async def load_all(
) -> Sequence["Endpoint"]:
async with ctx.db.begin_readonly_session() as session:
rows = await EndpointRow.list(
session, project=project, domain=domain_name, user_uuid=user_uuid, load_image=True
session,
project=project,
domain=domain_name,
user_uuid=user_uuid,
load_image=True,
load_created_user=True,
load_session_owner=True,
)
return [await Endpoint.from_row(ctx, row) for row in rows]
return [await Endpoint.from_row(ctx, row) for row in rows]

@classmethod
async def load_item(
Expand All @@ -541,10 +579,12 @@ async def load_item(
project=project,
load_image=True,
load_routes=True,
load_created_user=True,
load_session_owner=True,
)
return await Endpoint.from_row(ctx, row)
except NoResultFound:
raise EndpointNotFound
return await Endpoint.from_row(ctx, row)

async def resolve_status(self, info: graphene.ResolveInfo) -> str:
if self.retries > SERVICE_MAX_RETRIES:
Expand Down
7 changes: 7 additions & 0 deletions src/ai/backend/manager/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@ class UserRow(Base):
resource_policy_row = relationship("UserResourcePolicyRow", back_populates="users")
keypairs = relationship("KeyPairRow", back_populates="user_row", foreign_keys="KeyPairRow.user")

created_endpoints = relationship(
"EndpointRow", back_populates="created_user_row", foreign_keys="EndpointRow.created_user"
)
owned_endpoints = relationship(
"EndpointRow", back_populates="session_owner_row", foreign_keys="EndpointRow.session_owner"
)

main_keypair = relationship("KeyPairRow", foreign_keys=users.c.main_access_key)


Expand Down

0 comments on commit fbea3be

Please sign in to comment.