diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py index 67f5e14159..442f7c783b 100644 --- a/src/ai/backend/agent/agent.py +++ b/src/ai/backend/agent/agent.py @@ -228,12 +228,7 @@ def __init__( self.agent_id = agent_id self.event_producer = event_producer self.kernel_config = kernel_config - self.image_ref = ImageRef( - kernel_config["image"]["canonical"], - known_registries=[kernel_config["image"]["registry"]["name"]], - is_local=kernel_config["image"]["is_local"], - architecture=kernel_config["image"].get("architecture", get_arch_name()), - ) + self.image_ref = ImageRef.from_image_config(kernel_config["image"]) self.distro = distro self.internal_data = kernel_config["internal_data"] or {} self.computers = computers diff --git a/src/ai/backend/common/docker.py b/src/ai/backend/common/docker.py index 092202c1e6..b3038a5633 100644 --- a/src/ai/backend/common/docker.py +++ b/src/ai/backend/common/docker.py @@ -12,6 +12,7 @@ from dataclasses import dataclass from pathlib import Path, PurePath from typing import ( + TYPE_CHECKING, Any, Final, Iterable, @@ -20,6 +21,7 @@ Optional, Sequence, Union, + cast, ) import aiohttp @@ -36,6 +38,9 @@ from .logging import BraceStyleAdapter from .service_ports import parse_service_ports +if TYPE_CHECKING: + from .types import ImageConfig, ImageRegistry + __all__ = ( "arch_name_aliases", "default_registry", @@ -309,7 +314,7 @@ def is_known_registry( return False -async def get_registry_info(etcd: AsyncEtcd, name: str) -> tuple[yarl.URL, dict]: +async def get_registry_info(etcd: AsyncEtcd, name: str) -> ImageRegistry: reg_path = f"config/docker/registry/{etcd_quote(name)}" item = await etcd.get_prefix(reg_path) if not item: @@ -318,14 +323,14 @@ async def get_registry_info(etcd: AsyncEtcd, name: str) -> tuple[yarl.URL, dict] if not registry_addr: raise UnknownImageRegistry(name) assert isinstance(registry_addr, str) - creds = {} - username = item.get("username") - if username is not None: - creds["username"] = username - password = item.get("password") - if password is not None: - creds["password"] = password - return yarl.URL(registry_addr), creds + username = cast(str | None, item.get("username")) + password = cast(str | None, item.get("password")) + return { + "name": name, + "url": registry_addr, + "username": username, + "password": password, + } def validate_image_labels(labels: dict[str, str]) -> dict[str, str]: @@ -450,6 +455,15 @@ def __init__( raise InvalidImageTag(self._tag, self._value) self._update_tag_set() + @classmethod + def from_image_config(cls, config: ImageConfig) -> ImageRef: + return ImageRef( + config["canonical"], + known_registries=[config["registry"]["name"]], + is_local=config["is_local"], + architecture=config["architecture"], + ) + @staticmethod def _parse_image_tag(s: str, using_default_registry: bool = False) -> tuple[str, str]: image_tag = s.rsplit(":", maxsplit=1) diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index f8b53f8b34..5c860783f6 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -1044,6 +1044,7 @@ class ImageConfig(TypedDict): registry: ImageRegistry labels: Mapping[str, str] is_local: bool + auto_pull: str # AutoPullBehavior value class ServicePort(TypedDict): diff --git a/src/ai/backend/manager/models/image.py b/src/ai/backend/manager/models/image.py index 4ac0960215..b79d461f02 100644 --- a/src/ai/backend/manager/models/image.py +++ b/src/ai/backend/manager/models/image.py @@ -3,16 +3,14 @@ import enum import functools import logging +from collections.abc import Iterable, Mapping, MutableMapping, Sequence from decimal import Decimal from typing import ( TYPE_CHECKING, Any, AsyncIterator, List, - Mapping, - MutableMapping, Optional, - Sequence, Tuple, Union, cast, @@ -31,11 +29,18 @@ from sqlalchemy.orm import load_only, relationship, selectinload from ai.backend.common import redis_helper -from ai.backend.common.docker import ImageRef +from ai.backend.common.docker import ImageRef, get_registry_info from ai.backend.common.etcd import AsyncEtcd from ai.backend.common.exception import UnknownImageReference from ai.backend.common.logging import BraceStyleAdapter -from ai.backend.common.types import BinarySize, ImageAlias, ResourceSlot +from ai.backend.common.types import ( + AutoPullBehavior, + BinarySize, + ImageAlias, + ImageConfig, + ImageRegistry, + ResourceSlot, +) from ..api.exceptions import ImageNotFound, ObjectNotFound from ..container_registry import get_container_registry_cls @@ -250,6 +255,10 @@ def __init__( self.labels = labels self.resources = resources + @property + def trimmed_digest(self) -> str: + return self.config_digest.strip() + @property def image_ref(self): return ImageRef(self.name, [self.registry], self.architecture, self.is_local) @@ -458,7 +467,7 @@ def _parse_row(self): "tag": self.tag, "architecture": self.architecture, "registry": self.registry, - "digest": self.config_digest.strip() if self.config_digest else None, + "digest": self.trimmed_digest or None, "labels": self.labels, "size_bytes": self.size_bytes, "resource_limits": res_limits, @@ -486,6 +495,40 @@ def set_resource_limit( self.resources = resources +async def bulk_get_image_configs( + db_session: AsyncSession, + etcd: AsyncEtcd, + image_refs: Iterable[ImageRef], + auto_pull: AutoPullBehavior = AutoPullBehavior.DIGEST, +) -> list[ImageConfig]: + result = [] + for ref in image_refs: + resolved_image_info = await ImageRow.resolve(db_session, [ref]) + if resolved_image_info.image_ref.is_local: + is_local = True + registry_conf: ImageRegistry = { + "name": ref.registry, + "url": "http://127.0.0.1", # "http://localhost", + "username": None, + "password": None, + } + else: + is_local = False + registry_conf = await get_registry_info(etcd, ref.registry) + image_conf: ImageConfig = { + "architecture": ref.architecture, + "canonical": ref.canonical, + "is_local": is_local, + "digest": resolved_image_info.digest, + "labels": resolved_image_info.labels, + "repo_digest": None, + "registry": registry_conf, + "auto_pull": auto_pull.value, + } + result.append(image_conf) + return result + + class ImageAliasRow(Base): __tablename__ = "image_aliases" id = IDColumn("id") @@ -560,7 +603,7 @@ def populate_row( registry=row.registry, architecture=row.architecture, is_local=row.is_local, - digest=row.config_digest.strip() if row.config_digest else None, + digest=row.trimmed_digest or None, labels=[KVPair(key=k, value=v) for k, v in row.labels.items()], aliases=[alias_row.alias for alias_row in row.aliases], size_bytes=row.size_bytes, @@ -576,7 +619,7 @@ def populate_row( installed=len(installed_agents) > 0, installed_agents=installed_agents if not hide_agents else None, # legacy - hash=row.config_digest.strip() if row.config_digest else None, + hash=row.trimmed_digest or None, ) ret.raw_labels = row.labels return ret @@ -798,7 +841,7 @@ def from_row(cls, row: ImageRow | None) -> ImageNode | None: registry=row.registry, architecture=row.architecture, is_local=row.is_local, - digest=row.config_digest.strip() if row.config_digest else None, + digest=row.trimmed_digest or None, labels=[KVPair(key=k, value=v) for k, v in row.labels.items()], size_bytes=row.size_bytes, resource_limits=[ @@ -823,7 +866,7 @@ def from_legacy_image(cls, row: Image) -> ImageNode: registry=row.registry, architecture=row.architecture, is_local=row.is_local, - digest=row.digest, + digest=row.trimmed_digest, labels=row.labels, size_bytes=row.size_bytes, resource_limits=row.resource_limits, diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 37a3586008..68aafb2b90 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -52,7 +52,7 @@ from ai.backend.common import msgpack, redis_helper from ai.backend.common.asyncio import cancel_tasks -from ai.backend.common.docker import ImageRef, get_known_registries, get_registry_info +from ai.backend.common.docker import ImageRef, get_known_registries from ai.backend.common.events import ( AgentHeartbeatEvent, AgentStartedEvent, @@ -88,6 +88,7 @@ AbuseReport, AccessKey, AgentId, + AutoPullBehavior, BinarySize, ClusterInfo, ClusterMode, @@ -97,6 +98,7 @@ DeviceId, HardwareMetadata, ImageAlias, + ImageConfig, ImageRegistry, KernelEnqueueingConfig, KernelId, @@ -166,6 +168,7 @@ scaling_groups, verify_vfolder_name, ) +from .models.image import bulk_get_image_configs from .models.session import ( COMPUTE_CONCURRENCY_USED_KEY_PREFIX, SESSION_KERNEL_STATUS_MAPPING, @@ -1356,40 +1359,18 @@ async def start_session( ) result = await db_sess.execute(query) resource_policy = result.scalars().first() - auto_pull = self.shared_config["docker"]["image"]["auto_pull"] + idle_timeout = cast(int, resource_policy.idle_timeout) + auto_pull = cast(str, self.shared_config["docker"]["image"]["auto_pull"]) # Aggregate image registry information - keyfunc = lambda item: item.kernel.image_ref - image_infos: MutableMapping[str, ImageRow] = {} - is_local_image = True - registry_url = URL("http://localhost") - registry_creds: dict[str, str] = {} - async with self.db.begin_readonly_session() as session: - for image_ref, _ in itertools.groupby( - sorted(kernel_agent_bindings, key=keyfunc), - key=keyfunc, - ): - # img_query = sa.select(ImageRow).where(ImageRow.id == image_id) - # img_row: ImageRow = (await session.execute(img_query)).scalars().first() - # image_ref = img_row.image_ref - log.debug( - "start_session(): image ref => {} ({})", image_ref, image_ref.architecture - ) - resolved_image_info = await ImageRow.resolve(session, [image_ref]) - image_infos[str(image_ref)] = resolved_image_info - if not resolved_image_info.image_ref.is_local: - is_local_image = False - registry_url, registry_creds = await get_registry_info( - self.shared_config.etcd, image_ref.registry - ) - image_info = { - "image_infos": image_infos, - "registry_url": registry_url, - "registry_creds": registry_creds, - "resource_policy": resource_policy, - "auto_pull": auto_pull, - "is_local": is_local_image, - } + _image_refs: set[ImageRef] = set([item.kernel.image_ref for item in kernel_agent_bindings]) + _log_msg = ",".join([f"image ref => {ref} ({ref.architecture})" for ref in _image_refs]) + log.debug(f"start_session(): {_log_msg}") + async with self.db.begin_readonly_session() as db_session: + configs = await bulk_get_image_configs( + db_session, self.shared_config.etcd, _image_refs, AutoPullBehavior(auto_pull) + ) + image_ref_config_map = {item["canonical"]: item for item in configs} network_name: Optional[str] = None cluster_ssh_port_mapping: Optional[Dict[str, Tuple[str, int]]] = None @@ -1525,8 +1506,9 @@ async def start_session( agent_alloc_ctx, scheduled_session, items, - image_info, + image_ref_config_map, cluster_info, + idle_timeout, ), ), ) @@ -1577,15 +1559,10 @@ async def _create_kernels_in_one_agent( agent_alloc_ctx: AgentAllocationContext, scheduled_session: SessionRow, items: Sequence[KernelAgentBinding], - image_info: Mapping[str, Any], + image_configs: Mapping[str, ImageConfig], cluster_info, + idle_timeout: float | int, ) -> None: - registry_url = image_info["registry_url"] - registry_creds = image_info["registry_creds"] - image_infos = image_info["image_infos"] - is_local = image_info["is_local"] - resource_policy: KeyPairResourcePolicyRow = image_info["resource_policy"] - auto_pull = image_info["auto_pull"] assert agent_alloc_ctx.agent_id is not None assert scheduled_session.id is not None @@ -1609,7 +1586,10 @@ async def _update_kernel() -> None: agent_alloc_ctx.agent_id, order_key=str(scheduled_session.id), ) as rpc: - get_image_ref = lambda k: image_infos[str(k.image_ref)].image_ref + + def get_image_conf(kernel: KernelRow) -> ImageConfig: + return image_configs[str(kernel.image_ref)] + # Issue a batched RPC call to create kernels on this agent # created_infos = await rpc.call.create_kernels( await rpc.call.create_kernels( @@ -1619,37 +1599,35 @@ async def _update_kernel() -> None: { "image": { # TODO: refactor registry and is_local to be specified per kernel. - "registry": { - "name": get_image_ref(binding.kernel).registry, - "url": str(registry_url), - **registry_creds, # type: ignore - }, - "digest": image_infos[binding.kernel.image].config_digest, - "repo_digest": None, - "canonical": get_image_ref(binding.kernel).canonical, - "architecture": get_image_ref(binding.kernel).architecture, - "labels": image_infos[binding.kernel.image].labels, - "is_local": is_local, + "registry": get_image_conf(binding.kernel)["registry"], + "digest": get_image_conf(binding.kernel)["digest"], + "repo_digest": get_image_conf(binding.kernel)["repo_digest"], + "canonical": get_image_conf(binding.kernel)["canonical"], + "architecture": get_image_conf(binding.kernel)["architecture"], + "labels": get_image_conf(binding.kernel)["labels"], + "is_local": get_image_conf(binding.kernel)["is_local"], }, "session_type": scheduled_session.session_type.value, "cluster_role": binding.kernel.cluster_role, "cluster_idx": binding.kernel.cluster_idx, "local_rank": binding.kernel.local_rank, "cluster_hostname": binding.kernel.cluster_hostname, - "idle_timeout": resource_policy.idle_timeout, + "idle_timeout": idle_timeout, "mounts": [item.to_json() for item in scheduled_session.vfolder_mounts], "environ": { # inherit per-session environment variables **scheduled_session.environ, # set per-kernel environment variables "BACKENDAI_KERNEL_ID": str(binding.kernel.id), - "BACKENDAI_KERNEL_IMAGE": str(get_image_ref(binding.kernel)), + "BACKENDAI_KERNEL_IMAGE": get_image_conf(binding.kernel)[ + "canonical" + ], "BACKENDAI_CLUSTER_ROLE": binding.kernel.cluster_role, "BACKENDAI_CLUSTER_IDX": str(binding.kernel.cluster_idx), "BACKENDAI_CLUSTER_LOCAL_RANK": str(binding.kernel.local_rank), "BACKENDAI_CLUSTER_HOST": str(binding.kernel.cluster_hostname), "BACKENDAI_SERVICE_PORTS": str( - image_infos[binding.kernel.image].labels.get( + get_image_conf(binding.kernel)["labels"].get( "ai.backend.service-ports" ) ), @@ -1659,7 +1637,7 @@ async def _update_kernel() -> None: "bootstrap_script": binding.kernel.bootstrap_script, "startup_command": binding.kernel.startup_command, "internal_data": scheduled_session.main_kernel.internal_data, - "auto_pull": auto_pull, + "auto_pull": get_image_conf(binding.kernel)["auto_pull"], "preopen_ports": scheduled_session.main_kernel.preopen_ports, "allocated_host_ports": list(binding.allocated_host_ports), "agent_addr": binding.agent_alloc_ctx.agent_addr,