diff --git a/src/ai/backend/agent/docker/agent.py b/src/ai/backend/agent/docker/agent.py index f5456d6035b..6bdc7e3349d 100644 --- a/src/ai/backend/agent/docker/agent.py +++ b/src/ai/backend/agent/docker/agent.py @@ -911,24 +911,6 @@ async def _rollback_container_creation() -> None: for k, v in kvpairs.items(): await writer.write(f"{k}={v}\n") - await container.start() - - if self.internal_data.get("sudo_session_enabled", False): - exec = await container.exec( - [ - # file ownership is guaranteed to be set as root:root since command is executed on behalf of root user - "sh", - "-c", - 'mkdir -p /etc/sudoers.d && echo "work ALL=(ALL:ALL) NOPASSWD:ALL" > /etc/sudoers.d/01-bai-work', - ], - user="root", - ) - shell_response = await exec.start(detach=True) - if shell_response: - raise ContainerCreationError( - container_id=cid, - message=f"sudoers provision failed: {shell_response.decode()}", - ) except asyncio.CancelledError: if container is not None: raise ContainerCreationError( @@ -943,6 +925,33 @@ async def _rollback_container_creation() -> None: ) raise + try: + await container.start() + except asyncio.CancelledError: + await _rollback_container_creation() + raise ContainerCreationError(container_id=cid, message="Container start cancelled") + except Exception as e: + await _rollback_container_creation() + raise ContainerCreationError(container_id=cid, message=f"unknown. {repr(e)}") + + if self.internal_data.get("sudo_session_enabled", False): + exec = await container.exec( + [ + # file ownership is guaranteed to be set as root:root since command is executed on behalf of root user + "sh", + "-c", + 'mkdir -p /etc/sudoers.d && echo "work ALL=(ALL:ALL) NOPASSWD:ALL" > /etc/sudoers.d/01-bai-work', + ], + user="root", + ) + shell_response = await exec.start(detach=True) + if shell_response: + await _rollback_container_creation() + raise ContainerCreationError( + container_id=cid, + message=f"sudoers provision failed: {shell_response.decode()}", + ) + additional_network_names: Set[str] = set() for dev_name, device_alloc in resource_spec.allocations.items(): n = await self.computers[dev_name].instance.get_docker_networks(device_alloc)