Skip to content

Commit

Permalink
refactor: Clean image configuration when create kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Sep 14, 2024
1 parent 97f2390 commit 1a5302a
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 94 deletions.
7 changes: 1 addition & 6 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,7 @@ def __init__(
self.agent_id = agent_id
self.event_producer = event_producer
self.kernel_config = kernel_config
self.image_ref = ImageRef(
kernel_config["image"]["canonical"],
known_registries=[kernel_config["image"]["registry"]["name"]],
is_local=kernel_config["image"]["is_local"],
architecture=kernel_config["image"].get("architecture", get_arch_name()),
)
self.image_ref = ImageRef.from_image_config(kernel_config["image"])
self.distro = distro
self.internal_data = kernel_config["internal_data"] or {}
self.computers = computers
Expand Down
32 changes: 23 additions & 9 deletions src/ai/backend/common/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dataclasses import dataclass
from pathlib import Path, PurePath
from typing import (
TYPE_CHECKING,
Any,
Final,
Iterable,
Expand All @@ -20,6 +21,7 @@
Optional,
Sequence,
Union,
cast,
)

import aiohttp
Expand All @@ -37,6 +39,9 @@
from .exception import InvalidImageName, InvalidImageTag, UnknownImageRegistry
from .service_ports import parse_service_ports

if TYPE_CHECKING:
from .types import ImageConfig, ImageRegistry

__all__ = (
"arch_name_aliases",
"default_registry",
Expand Down Expand Up @@ -310,7 +315,7 @@ def is_known_registry(
return False


async def get_registry_info(etcd: AsyncEtcd, name: str) -> tuple[yarl.URL, dict]:
async def get_registry_info(etcd: AsyncEtcd, name: str) -> ImageRegistry:
reg_path = f"config/docker/registry/{etcd_quote(name)}"
item = await etcd.get_prefix(reg_path)
if not item:
Expand All @@ -319,14 +324,14 @@ async def get_registry_info(etcd: AsyncEtcd, name: str) -> tuple[yarl.URL, dict]
if not registry_addr:
raise UnknownImageRegistry(name)
assert isinstance(registry_addr, str)
creds = {}
username = item.get("username")
if username is not None:
creds["username"] = username
password = item.get("password")
if password is not None:
creds["password"] = password
return yarl.URL(registry_addr), creds
username = cast(str | None, item.get("username"))
password = cast(str | None, item.get("password"))
return {
"name": name,
"url": registry_addr,
"username": username,
"password": password,
}


def validate_image_labels(labels: dict[str, str]) -> dict[str, str]:
Expand Down Expand Up @@ -451,6 +456,15 @@ def __init__(
raise InvalidImageTag(self._tag, self._value)
self._update_tag_set()

@classmethod
def from_image_config(cls, config: ImageConfig) -> ImageRef:
return ImageRef(
config["canonical"],
known_registries=[config["registry"]["name"]],
is_local=config["is_local"],
architecture=config["architecture"],
)

@staticmethod
def _parse_image_tag(s: str, using_default_registry: bool = False) -> tuple[str, str]:
image_tag = s.rsplit(":", maxsplit=1)
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,7 @@ class ImageConfig(TypedDict):
registry: ImageRegistry
labels: Mapping[str, str]
is_local: bool
auto_pull: str # AutoPullBehavior value


class ServicePort(TypedDict):
Expand Down
67 changes: 56 additions & 11 deletions src/ai/backend/manager/models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
import enum
import functools
import logging
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
from decimal import Decimal
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
Union,
cast,
Expand All @@ -27,14 +25,21 @@
from graphql import Undefined
from redis.asyncio import Redis
from redis.asyncio.client import Pipeline
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
from sqlalchemy.orm import load_only, relationship, selectinload

from ai.backend.common import redis_helper
from ai.backend.common.docker import ImageRef
from ai.backend.common.docker import ImageRef, get_registry_info
from ai.backend.common.etcd import AsyncEtcd
from ai.backend.common.exception import UnknownImageReference
from ai.backend.common.types import BinarySize, ImageAlias, ResourceSlot
from ai.backend.common.types import (
AutoPullBehavior,
BinarySize,
ImageAlias,
ImageConfig,
ImageRegistry,
ResourceSlot,
)
from ai.backend.logging import BraceStyleAdapter

from ..api.exceptions import ImageNotFound, ObjectNotFound
Expand Down Expand Up @@ -250,6 +255,10 @@ def __init__(
self.labels = labels
self.resources = resources

@property
def trimmed_digest(self) -> str:
return self.config_digest.strip()

@property
def image_ref(self):
return ImageRef(self.name, [self.registry], self.architecture, self.is_local)
Expand Down Expand Up @@ -458,7 +467,7 @@ def _parse_row(self):
"tag": self.tag,
"architecture": self.architecture,
"registry": self.registry,
"digest": self.config_digest.strip() if self.config_digest else None,
"digest": self.trimmed_digest or None,
"labels": self.labels,
"size_bytes": self.size_bytes,
"resource_limits": res_limits,
Expand Down Expand Up @@ -486,6 +495,42 @@ def set_resource_limit(
self.resources = resources


async def bulk_get_image_configs(
image_refs: Iterable[ImageRef],
auto_pull: AutoPullBehavior = AutoPullBehavior.DIGEST,
*,
db: ExtendedAsyncSAEngine,
db_conn: AsyncConnection,
etcd: AsyncEtcd,
) -> list[ImageConfig]:
result: list[ImageConfig] = []
async with db.begin_readonly_session(db_conn) as db_session:
for ref in image_refs:
resolved_image_info = await ImageRow.resolve(db_session, [ref])
registry_info: ImageRegistry = {
"name": ref.registry,
"url": "http://127.0.0.1", # "http://localhost",
"username": None,
"password": None,
}
image_conf: ImageConfig = {
"architecture": ref.architecture,
"canonical": ref.canonical,
"is_local": resolved_image_info.image_ref.is_local,
"digest": resolved_image_info.trimmed_digest,
"labels": resolved_image_info.labels,
"repo_digest": None,
"registry": registry_info,
"auto_pull": auto_pull.value,
}
result.append(image_conf)
for conf in result:
if not conf["is_local"]:
registry_name = conf["registry"]["name"]
conf["registry"] = await get_registry_info(etcd, registry_name)
return result


class ImageAliasRow(Base):
__tablename__ = "image_aliases"
id = IDColumn("id")
Expand Down Expand Up @@ -560,7 +605,7 @@ def populate_row(
registry=row.registry,
architecture=row.architecture,
is_local=row.is_local,
digest=row.config_digest.strip() if row.config_digest else None,
digest=row.trimmed_digest or None,
labels=[KVPair(key=k, value=v) for k, v in row.labels.items()],
aliases=[alias_row.alias for alias_row in row.aliases],
size_bytes=row.size_bytes,
Expand All @@ -576,7 +621,7 @@ def populate_row(
installed=len(installed_agents) > 0,
installed_agents=installed_agents if not hide_agents else None,
# legacy
hash=row.config_digest.strip() if row.config_digest else None,
hash=row.trimmed_digest or None,
)
ret.raw_labels = row.labels
return ret
Expand Down Expand Up @@ -798,7 +843,7 @@ def from_row(cls, row: ImageRow | None) -> ImageNode | None:
registry=row.registry,
architecture=row.architecture,
is_local=row.is_local,
digest=row.config_digest.strip() if row.config_digest else None,
digest=row.trimmed_digest or None,
labels=[KVPair(key=k, value=v) for k, v in row.labels.items()],
size_bytes=row.size_bytes,
resource_limits=[
Expand All @@ -823,7 +868,7 @@ def from_legacy_image(cls, row: Image) -> ImageNode:
registry=row.registry,
architecture=row.architecture,
is_local=row.is_local,
digest=row.digest,
digest=row.trimmed_digest,
labels=row.labels,
size_bytes=row.size_bytes,
resource_limits=row.resource_limits,
Expand Down
Loading

0 comments on commit 1a5302a

Please sign in to comment.