Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: image configuration when create kernels #2646

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,7 @@ def __init__(
self.agent_id = agent_id
self.event_producer = event_producer
self.kernel_config = kernel_config
self.image_ref = ImageRef(
kernel_config["image"]["canonical"],
known_registries=[kernel_config["image"]["registry"]["name"]],
is_local=kernel_config["image"]["is_local"],
architecture=kernel_config["image"].get("architecture", get_arch_name()),
)
self.image_ref = ImageRef.from_image_config(kernel_config["image"])
self.distro = distro
self.internal_data = kernel_config["internal_data"] or {}
self.computers = computers
Expand Down
32 changes: 23 additions & 9 deletions src/ai/backend/common/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dataclasses import dataclass
from pathlib import Path, PurePath
from typing import (
TYPE_CHECKING,
Any,
Final,
Iterable,
Expand All @@ -20,6 +21,7 @@
Optional,
Sequence,
Union,
cast,
)

import aiohttp
Expand All @@ -37,6 +39,9 @@
from .exception import InvalidImageName, InvalidImageTag, UnknownImageRegistry
from .service_ports import parse_service_ports

if TYPE_CHECKING:
from .types import ImageConfig, ImageRegistry

__all__ = (
"arch_name_aliases",
"default_registry",
Expand Down Expand Up @@ -310,7 +315,7 @@ def is_known_registry(
return False


async def get_registry_info(etcd: AsyncEtcd, name: str) -> tuple[yarl.URL, dict]:
async def get_registry_info(etcd: AsyncEtcd, name: str) -> ImageRegistry:
reg_path = f"config/docker/registry/{etcd_quote(name)}"
item = await etcd.get_prefix(reg_path)
if not item:
Expand All @@ -319,14 +324,14 @@ async def get_registry_info(etcd: AsyncEtcd, name: str) -> tuple[yarl.URL, dict]
if not registry_addr:
raise UnknownImageRegistry(name)
assert isinstance(registry_addr, str)
creds = {}
username = item.get("username")
if username is not None:
creds["username"] = username
password = item.get("password")
if password is not None:
creds["password"] = password
return yarl.URL(registry_addr), creds
username = cast(str | None, item.get("username"))
password = cast(str | None, item.get("password"))
return {
"name": name,
"url": registry_addr,
"username": username,
"password": password,
}


def validate_image_labels(labels: dict[str, str]) -> dict[str, str]:
Expand Down Expand Up @@ -451,6 +456,15 @@ def __init__(
raise InvalidImageTag(self._tag, self._value)
self._update_tag_set()

@classmethod
def from_image_config(cls, config: ImageConfig) -> ImageRef:
return ImageRef(
config["canonical"],
known_registries=[config["registry"]["name"]],
is_local=config["is_local"],
architecture=config["architecture"],
)

@staticmethod
def _parse_image_tag(s: str, using_default_registry: bool = False) -> tuple[str, str]:
image_tag = s.rsplit(":", maxsplit=1)
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,7 @@ class ImageConfig(TypedDict):
registry: ImageRegistry
labels: Mapping[str, str]
is_local: bool
auto_pull: str # AutoPullBehavior value


class ServicePort(TypedDict):
Expand Down
67 changes: 56 additions & 11 deletions src/ai/backend/manager/models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
import enum
import functools
import logging
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
from decimal import Decimal
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
Union,
cast,
Expand All @@ -27,14 +25,21 @@
from graphql import Undefined
from redis.asyncio import Redis
from redis.asyncio.client import Pipeline
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
from sqlalchemy.orm import load_only, relationship, selectinload

from ai.backend.common import redis_helper
from ai.backend.common.docker import ImageRef
from ai.backend.common.docker import ImageRef, get_registry_info
from ai.backend.common.etcd import AsyncEtcd
from ai.backend.common.exception import UnknownImageReference
from ai.backend.common.types import BinarySize, ImageAlias, ResourceSlot
from ai.backend.common.types import (
AutoPullBehavior,
BinarySize,
ImageAlias,
ImageConfig,
ImageRegistry,
ResourceSlot,
)
from ai.backend.logging import BraceStyleAdapter

from ..api.exceptions import ImageNotFound, ObjectNotFound
Expand Down Expand Up @@ -250,6 +255,10 @@ def __init__(
self.labels = labels
self.resources = resources

@property
def trimmed_digest(self) -> str:
return self.config_digest.strip()

@property
def image_ref(self):
return ImageRef(self.name, [self.registry], self.architecture, self.is_local)
Expand Down Expand Up @@ -458,7 +467,7 @@ def _parse_row(self):
"tag": self.tag,
"architecture": self.architecture,
"registry": self.registry,
"digest": self.config_digest.strip() if self.config_digest else None,
"digest": self.trimmed_digest or None,
"labels": self.labels,
"size_bytes": self.size_bytes,
"resource_limits": res_limits,
Expand Down Expand Up @@ -486,6 +495,42 @@ def set_resource_limit(
self.resources = resources


async def bulk_get_image_configs(
image_refs: Iterable[ImageRef],
auto_pull: AutoPullBehavior = AutoPullBehavior.DIGEST,
*,
db: ExtendedAsyncSAEngine,
db_conn: AsyncConnection,
etcd: AsyncEtcd,
) -> list[ImageConfig]:
result: list[ImageConfig] = []
async with db.begin_readonly_session(db_conn) as db_session:
for ref in image_refs:
resolved_image_info = await ImageRow.resolve(db_session, [ref])
registry_info: ImageRegistry = {
"name": ref.registry,
"url": "http://127.0.0.1", # "http://localhost",
"username": None,
"password": None,
}
image_conf: ImageConfig = {
"architecture": ref.architecture,
"canonical": ref.canonical,
"is_local": resolved_image_info.image_ref.is_local,
"digest": resolved_image_info.trimmed_digest,
"labels": resolved_image_info.labels,
"repo_digest": None,
"registry": registry_info,
"auto_pull": auto_pull.value,
}
result.append(image_conf)
for conf in result:
if not conf["is_local"]:
registry_name = conf["registry"]["name"]
conf["registry"] = await get_registry_info(etcd, registry_name)
return result


class ImageAliasRow(Base):
__tablename__ = "image_aliases"
id = IDColumn("id")
Expand Down Expand Up @@ -560,7 +605,7 @@ def populate_row(
registry=row.registry,
architecture=row.architecture,
is_local=row.is_local,
digest=row.config_digest.strip() if row.config_digest else None,
digest=row.trimmed_digest or None,
labels=[KVPair(key=k, value=v) for k, v in row.labels.items()],
aliases=[alias_row.alias for alias_row in row.aliases],
size_bytes=row.size_bytes,
Expand All @@ -576,7 +621,7 @@ def populate_row(
installed=len(installed_agents) > 0,
installed_agents=installed_agents if not hide_agents else None,
# legacy
hash=row.config_digest.strip() if row.config_digest else None,
hash=row.trimmed_digest or None,
)
ret.raw_labels = row.labels
return ret
Expand Down Expand Up @@ -798,7 +843,7 @@ def from_row(cls, row: ImageRow | None) -> ImageNode | None:
registry=row.registry,
architecture=row.architecture,
is_local=row.is_local,
digest=row.config_digest.strip() if row.config_digest else None,
digest=row.trimmed_digest or None,
labels=[KVPair(key=k, value=v) for k, v in row.labels.items()],
size_bytes=row.size_bytes,
resource_limits=[
Expand All @@ -823,7 +868,7 @@ def from_legacy_image(cls, row: Image) -> ImageNode:
registry=row.registry,
architecture=row.architecture,
is_local=row.is_local,
digest=row.digest,
digest=row.trimmed_digest,
labels=row.labels,
size_bytes=row.size_bytes,
resource_limits=row.resource_limits,
Expand Down
Loading