Skip to content

Commit

Permalink
refactor: Add Image batch loader function (#3048)
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa authored Nov 7, 2024
1 parent 988471c commit 6e542a1
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion src/ai/backend/manager/models/gql_models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from ...api.exceptions import ImageNotFound, ObjectNotFound
from ...defs import DEFAULT_IMAGE_ARCH
from ..base import set_if_set
from ..base import batch_multiresult_in_scalar_stream, set_if_set
from ..gql_relay import AsyncNode
from ..image import (
ImageAliasRow,
Expand Down Expand Up @@ -344,6 +344,36 @@ class Meta:
graphene.String, description="Added in 24.03.4. The array of image aliases."
)

@classmethod
async def batch_load_by_name_and_arch(
cls,
graph_ctx: GraphQueryContext,
name_and_arch: Sequence[tuple[str, str]],
) -> Sequence[Sequence[ImageNode]]:
query = (
sa.select(ImageRow)
.where(sa.tuple_(ImageRow.name, ImageRow.architecture).in_(name_and_arch))
.options(selectinload(ImageRow.aliases))
)
async with graph_ctx.db.begin_readonly_session() as db_session:
return await batch_multiresult_in_scalar_stream(
graph_ctx,
db_session,
query,
cls,
name_and_arch,
lambda row: (row.name, row.architecture),
)

@classmethod
async def batch_load_by_image_identifier(
cls,
graph_ctx: GraphQueryContext,
image_ids: Sequence[ImageIdentifier],
) -> Sequence[Sequence[ImageNode]]:
name_and_arch_tuples = [(img.canonical, img.architecture) for img in image_ids]
return await cls.batch_load_by_name_and_arch(graph_ctx, name_and_arch_tuples)

@overload
@classmethod
def from_row(cls, row: ImageRow) -> ImageNode: ...
Expand Down

0 comments on commit 6e542a1

Please sign in to comment.