From e1fa87d8dbed067341145b0b9a292f60d6745e9d Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Thu, 19 Sep 2024 07:19:22 +0000 Subject: [PATCH] fix: broken `start_session` --- src/ai/backend/manager/registry.py | 148 ++++++++++++++++------------- 1 file changed, 82 insertions(+), 66 deletions(-) diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index fc4326f263..ecc55347cd 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -1385,18 +1385,22 @@ async def start_session( idle_timeout = cast(int, resource_policy.idle_timeout) auto_pull = cast(str, self.shared_config["docker"]["image"]["auto_pull"]) - # Aggregate image registry information - image_refs: set[ImageRef] = set() - - for binding in kernel_agent_bindings: - image_refs.add( - ( - await ImageRow.resolve( - db_sess, - [ImageIdentifier(binding.kernel.image, binding.kernel.architecture)], - ) - ).image_ref - ) + # Aggregate image registry information + image_refs: set[ImageRef] = set() + + for binding in kernel_agent_bindings: + image_refs.add( + ( + await ImageRow.resolve( + db_sess, + [ + ImageIdentifier( + binding.kernel.image, binding.kernel.architecture + ) + ], + ) + ).image_ref + ) _log_msg = ",".join([f"image ref => {ref} ({ref.architecture})" for ref in image_refs]) log.debug(f"start_session(): {_log_msg}") @@ -1619,68 +1623,80 @@ async def _update_kernel() -> None: await execute_with_retry(_update_kernel) try: - kernel_image_refs: dict[KernelId, ImageRef] = {} - async with self.agent_cache.rpc_context( agent_alloc_ctx.agent_id, order_key=str(scheduled_session.id), ) as rpc: def get_image_conf(kernel: KernelRow) -> ImageConfig: - return image_configs[str(kernel.image_ref)] - - for binding in items: - raw_kernel_ids = [str(binding.kernel.id) for binding in items] + return image_configs[kernel.image] + + kernel_image_refs: dict[KernelId, ImageRef] = {} + + async with self.db.begin_readonly_session() as db_sess: + for binding in items: + kernel_image_refs[binding.kernel.id] = ( + await ImageRow.resolve( + db_sess, + [ + ImageIdentifier( + binding.kernel.image, binding.kernel.architecture + ) + ], + ) + ).image_ref - raw_configs = [ - { - "image": { - # TODO: refactor registry and is_local to be specified per kernel. - "registry": get_image_conf(binding.kernel)["registry"], - "digest": get_image_conf(binding.kernel)["digest"], - "repo_digest": get_image_conf(binding.kernel)["repo_digest"], - "canonical": get_image_conf(binding.kernel)["canonical"], - "architecture": get_image_conf(binding.kernel)["architecture"], - "labels": get_image_conf(binding.kernel)["labels"], - "is_local": get_image_conf(binding.kernel)["is_local"], - }, - "session_type": scheduled_session.session_type.value, - "cluster_role": binding.kernel.cluster_role, - "cluster_idx": binding.kernel.cluster_idx, - "local_rank": binding.kernel.local_rank, - "cluster_hostname": binding.kernel.cluster_hostname, - "idle_timeout": idle_timeout, - "mounts": [item.to_json() for item in scheduled_session.vfolder_mounts], - "environ": { - # inherit per-session environment variables - **scheduled_session.environ, - # set per-kernel environment variables - "BACKENDAI_KERNEL_ID": str(binding.kernel.id), - "BACKENDAI_KERNEL_IMAGE": get_image_conf(binding.kernel)[ - "canonical" + raw_configs = [ + { + "image": { + # TODO: refactor registry and is_local to be specified per kernel. + "registry": get_image_conf(binding.kernel)["registry"], + "digest": get_image_conf(binding.kernel)["digest"], + "repo_digest": get_image_conf(binding.kernel)["repo_digest"], + "canonical": get_image_conf(binding.kernel)["canonical"], + "architecture": get_image_conf(binding.kernel)["architecture"], + "labels": get_image_conf(binding.kernel)["labels"], + "is_local": get_image_conf(binding.kernel)["is_local"], + }, + "session_type": scheduled_session.session_type.value, + "cluster_role": binding.kernel.cluster_role, + "cluster_idx": binding.kernel.cluster_idx, + "local_rank": binding.kernel.local_rank, + "cluster_hostname": binding.kernel.cluster_hostname, + "idle_timeout": idle_timeout, + "mounts": [ + item.to_json() for item in scheduled_session.vfolder_mounts ], - "BACKENDAI_CLUSTER_ROLE": binding.kernel.cluster_role, - "BACKENDAI_CLUSTER_IDX": str(binding.kernel.cluster_idx), - "BACKENDAI_CLUSTER_LOCAL_RANK": str(binding.kernel.local_rank), - "BACKENDAI_CLUSTER_HOST": str(binding.kernel.cluster_hostname), - "BACKENDAI_SERVICE_PORTS": str( - get_image_conf(binding.kernel)["labels"].get( - "ai.backend.service-ports" - ) - ), - }, - "resource_slots": binding.kernel.requested_slots.to_json(), - "resource_opts": binding.kernel.resource_opts, - "bootstrap_script": binding.kernel.bootstrap_script, - "startup_command": binding.kernel.startup_command, - "internal_data": scheduled_session.main_kernel.internal_data, - "auto_pull": get_image_conf(binding.kernel)["auto_pull"], - "preopen_ports": scheduled_session.main_kernel.preopen_ports, - "allocated_host_ports": list(binding.allocated_host_ports), - "agent_addr": binding.agent_alloc_ctx.agent_addr, - "scaling_group": binding.agent_alloc_ctx.scaling_group, - } - ] + "environ": { + # inherit per-session environment variables + **scheduled_session.environ, + # set per-kernel environment variables + "BACKENDAI_KERNEL_ID": str(binding.kernel.id), + "BACKENDAI_KERNEL_IMAGE": get_image_conf(binding.kernel)[ + "canonical" + ], + "BACKENDAI_CLUSTER_ROLE": binding.kernel.cluster_role, + "BACKENDAI_CLUSTER_IDX": str(binding.kernel.cluster_idx), + "BACKENDAI_CLUSTER_LOCAL_RANK": str(binding.kernel.local_rank), + "BACKENDAI_CLUSTER_HOST": str(binding.kernel.cluster_hostname), + "BACKENDAI_SERVICE_PORTS": str( + get_image_conf(binding.kernel)["labels"].get( + "ai.backend.service-ports" + ) + ), + }, + "resource_slots": binding.kernel.requested_slots.to_json(), + "resource_opts": binding.kernel.resource_opts, + "bootstrap_script": binding.kernel.bootstrap_script, + "startup_command": binding.kernel.startup_command, + "internal_data": scheduled_session.main_kernel.internal_data, + "auto_pull": get_image_conf(binding.kernel)["auto_pull"], + "preopen_ports": scheduled_session.main_kernel.preopen_ports, + "allocated_host_ports": list(binding.allocated_host_ports), + "agent_addr": binding.agent_alloc_ctx.agent_addr, + "scaling_group": binding.agent_alloc_ctx.scaling_group, + } + ] raw_kernel_ids = [str(binding.kernel.id) for binding in items]