From ae344ad431a1197ee785cc800a2256e2bf5ce222 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 7 Jul 2024 22:26:32 +0000 Subject: [PATCH] make cloud init more readable --- sky/authentication.py | 31 ------------------------------- sky/backends/backend_utils.py | 16 +++++++++++----- sky/clouds/azure.py | 6 ++---- sky/provision/azure/instance.py | 4 ++++ sky/provision/docker_utils.py | 22 +++++++++++++--------- sky/templates/azure-ray.yml.j2 | 7 +++++-- 6 files changed, 35 insertions(+), 51 deletions(-) diff --git a/sky/authentication.py b/sky/authentication.py index c61e0ce36c8..7eeb0e0ec9c 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -19,7 +19,6 @@ is an exception, due to the limitation of the cloud provider. See the comments in setup_lambda_authentication) """ -import base64 import copy import functools import os @@ -270,36 +269,6 @@ def setup_gcp_authentication(config: Dict[str, Any]) -> Dict[str, Any]: return configure_ssh_info(config) -# In Azure, cloud-init script must be encoded in base64. See -# https://learn.microsoft.com/en-us/azure/virtual-machines/custom-data -# for more information. Here we decode it and replace the ssh user -# and public key content, then encode it back. -def setup_azure_authentication(config: Dict[str, Any]) -> Dict[str, Any]: - _, public_key_path = get_or_generate_keys() - with open(public_key_path, 'r', encoding='utf-8') as f: - public_key = f.read().strip() - for node_type in config['available_node_types']: - node_config = config['available_node_types'][node_type]['node_config'] - cloud_init = ( - node_config['azure_arm_parameters']['cloudInitSetupCommands']) - cloud_init = base64.b64decode(cloud_init).decode('utf-8') - cloud_init = cloud_init.replace('skypilot:ssh_user', - config['auth']['ssh_user']) - cloud_init = cloud_init.replace('skypilot:ssh_public_key_content', - public_key) - cloud_init = base64.b64encode( - cloud_init.encode('utf-8')).decode('utf-8') - node_config['azure_arm_parameters']['cloudInitSetupCommands'] = ( - cloud_init) - config_str = common_utils.dump_yaml_str(config) - config_str = config_str.replace('skypilot:ssh_user', - config['auth']['ssh_user']) - config_str = config_str.replace('skypilot:ssh_public_key_content', - public_key) - config = yaml.safe_load(config_str) - return config - - def setup_lambda_authentication(config: Dict[str, Any]) -> Dict[str, Any]: get_or_generate_keys() diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index a1c86fdb624..84fbd399dc0 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -156,7 +156,8 @@ ('provider', 'tpu_node'), ('provider', 'security_group', 'GroupName'), ('available_node_types', 'ray.head.default', 'node_config', 'UserData'), - ('available_node_types', 'ray.worker.default', 'node_config', 'UserData'), + ('available_node_types', 'ray.head.default', 'node_config', + 'azure_arm_parameters', 'cloudInitSetupCommands'), ] @@ -1029,13 +1030,18 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str): """ config = common_utils.read_yaml(cluster_config_file) # Check the availability of the cloud type. - if isinstance(cloud, (clouds.AWS, clouds.OCI, clouds.SCP, clouds.Vsphere, - clouds.Cudo, clouds.Paperspace)): + if isinstance(cloud, ( + clouds.AWS, + clouds.OCI, + clouds.SCP, + clouds.Vsphere, + clouds.Cudo, + clouds.Paperspace, + clouds.Azure, + )): config = auth.configure_ssh_info(config) elif isinstance(cloud, clouds.GCP): config = auth.setup_gcp_authentication(config) - elif isinstance(cloud, clouds.Azure): - config = auth.setup_azure_authentication(config) elif isinstance(cloud, clouds.Lambda): config = auth.setup_lambda_authentication(config) elif isinstance(cloud, clouds.Kubernetes): diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index d8a52a184e8..1a08b2cca39 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -1,5 +1,4 @@ """Azure.""" -import base64 import functools import json import os @@ -324,8 +323,7 @@ def make_deploy_resources_variables(self, # restarted, identified by a file /tmp/__restarted is existing. # Also, add default user to docker group. # pylint: disable=line-too-long - cloud_init_setup_commands = base64.b64encode( - textwrap.dedent("""\ + cloud_init_setup_commands = textwrap.dedent("""\ #cloud-config runcmd: - sed -i 's/#Banner none/Banner none/' /etc/ssh/sshd_config @@ -341,7 +339,7 @@ def make_deploy_resources_variables(self, - path: /etc/apt/apt.conf.d/10cloudinit-disable content: | APT::Periodic::Enable "0"; - """).encode('utf-8')).decode('utf-8') + """).split('\n') def _failover_disk_tier() -> Optional[resources_utils.DiskTier]: if (r.disk_tier is not None and diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index b5f43991635..f96110e754e 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -1,4 +1,5 @@ """Azure instance provisioning.""" +import base64 import copy import enum import json @@ -223,6 +224,9 @@ def _create_instances( template_params['msi'] = provider_config['msi'] template_params['nsg'] = provider_config['nsg'] template_params['subnet'] = provider_config['subnet'] + template_params['cloudInitSetupCommands'] = base64.b64encode( + template_params['cloudInitSetupCommands'].encode('utf-8')).decode( + 'utf-8') if node_config.get('need_nvidia_driver_extension', False): # pylint: disable=line-too-long diff --git a/sky/provision/docker_utils.py b/sky/provision/docker_utils.py index 6ebec2577ad..4dec5f96f1d 100644 --- a/sky/provision/docker_utils.py +++ b/sky/provision/docker_utils.py @@ -166,15 +166,19 @@ def _run(self, stream_logs=False, separate_stderr=separate_stderr, log_path=self.log_path) - if (not wait_for_docker_daemon or - DOCKER_PERMISSION_DENIED_STR not in stdout + stderr): - break - - if time.time() - start > _DOCKER_WAIT_FOR_SOCKET_TIMEOUT_SECONDS: - break - logger.info( - 'Failed to run docker command, retrying in 15 seconds...') - time.sleep(15) + if (DOCKER_PERMISSION_DENIED_STR in stdout + stderr and + wait_for_docker_daemon): + if time.time() - start > _DOCKER_SOCKET_WAIT_TIMEOUT_SECONDS: + if rc == 0: + # Set returncode to 1 if failed to connect to docker + # daemon after timeout. + rc = 1 + break + logger.info('Failed to connect to docker daemon. It might be ' + 'initializing, retrying in 30 seconds...') + time.sleep(30) + continue + break subprocess_utils.handle_returncode( rc, cmd, diff --git a/sky/templates/azure-ray.yml.j2 b/sky/templates/azure-ray.yml.j2 index 48e4511a8e7..16eb1d9dd23 100644 --- a/sky/templates/azure-ray.yml.j2 +++ b/sky/templates/azure-ray.yml.j2 @@ -72,14 +72,17 @@ available_node_types: imageVersion: {{image_version}} osDiskSizeGB: {{disk_size}} osDiskTier: {{disk_tier}} - cloudInitSetupCommands: {{cloud_init_setup_commands}} - # optionally set priority to use Spot instances {%- if use_spot %} + # optionally set priority to use Spot instances priority: Spot # set a maximum price for spot instances if desired # billingProfile: # maxPrice: -1 {%- endif %} + cloudInitSetupCommands: |- + {%- for cmd in cloud_init_setup_commands %} + {{ cmd }} + {%- endfor %} need_nvidia_driver_extension: {{need_nvidia_driver_extension}} # TODO: attach disk