diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 23cd493e8a3..a19bdcd020d 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -35,12 +35,10 @@ jobs: - name: Running yapf run: | yapf --diff --recursive ./ --exclude 'sky/skylet/ray_patches/**' \ - --exclude 'sky/skylet/providers/azure/**' \ --exclude 'sky/skylet/providers/ibm/**' - name: Running black run: | - black --diff --check sky/skylet/providers/azure/ \ - sky/skylet/providers/ibm/ + black --diff --check sky/skylet/providers/ibm/ - name: Running isort for black formatted files run: | isort --diff --check --profile black -l 88 -m 3 \ @@ -48,5 +46,4 @@ jobs: - name: Running isort for yapf formatted files run: | isort --diff --check ./ --sg 'sky/skylet/ray_patches/**' \ - --sg 'sky/skylet/providers/azure/**' \ --sg 'sky/skylet/providers/ibm/**' diff --git a/format.sh b/format.sh index e3bcfde0f18..66b966c3029 100755 --- a/format.sh +++ b/format.sh @@ -48,18 +48,15 @@ YAPF_FLAGS=( YAPF_EXCLUDES=( '--exclude' 'build/**' - '--exclude' 'sky/skylet/providers/azure/**' '--exclude' 'sky/skylet/providers/ibm/**' ) ISORT_YAPF_EXCLUDES=( '--sg' 'build/**' - '--sg' 'sky/skylet/providers/azure/**' '--sg' 'sky/skylet/providers/ibm/**' ) BLACK_INCLUDES=( - 'sky/skylet/providers/azure' 'sky/skylet/providers/ibm' ) diff --git a/sky/adaptors/azure.py b/sky/adaptors/azure.py index 6bd57bc6bec..9ec58dbcbc0 100644 --- a/sky/adaptors/azure.py +++ b/sky/adaptors/azure.py @@ -82,3 +82,10 @@ def get_client(name: str, subscription_id: str): def create_security_rule(**kwargs): from azure.mgmt.network.models import SecurityRule return SecurityRule(**kwargs) + + +@common.load_lazy_modules(modules=_LAZY_MODULES) +def deployment_mode(): + """Azure deployment mode.""" + from azure.mgmt.resource.resources.models import DeploymentMode + return DeploymentMode diff --git a/sky/authentication.py b/sky/authentication.py index c61e0ce36c8..7eeb0e0ec9c 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -19,7 +19,6 @@ is an exception, due to the limitation of the cloud provider. See the comments in setup_lambda_authentication) """ -import base64 import copy import functools import os @@ -270,36 +269,6 @@ def setup_gcp_authentication(config: Dict[str, Any]) -> Dict[str, Any]: return configure_ssh_info(config) -# In Azure, cloud-init script must be encoded in base64. See -# https://learn.microsoft.com/en-us/azure/virtual-machines/custom-data -# for more information. Here we decode it and replace the ssh user -# and public key content, then encode it back. -def setup_azure_authentication(config: Dict[str, Any]) -> Dict[str, Any]: - _, public_key_path = get_or_generate_keys() - with open(public_key_path, 'r', encoding='utf-8') as f: - public_key = f.read().strip() - for node_type in config['available_node_types']: - node_config = config['available_node_types'][node_type]['node_config'] - cloud_init = ( - node_config['azure_arm_parameters']['cloudInitSetupCommands']) - cloud_init = base64.b64decode(cloud_init).decode('utf-8') - cloud_init = cloud_init.replace('skypilot:ssh_user', - config['auth']['ssh_user']) - cloud_init = cloud_init.replace('skypilot:ssh_public_key_content', - public_key) - cloud_init = base64.b64encode( - cloud_init.encode('utf-8')).decode('utf-8') - node_config['azure_arm_parameters']['cloudInitSetupCommands'] = ( - cloud_init) - config_str = common_utils.dump_yaml_str(config) - config_str = config_str.replace('skypilot:ssh_user', - config['auth']['ssh_user']) - config_str = config_str.replace('skypilot:ssh_public_key_content', - public_key) - config = yaml.safe_load(config_str) - return config - - def setup_lambda_authentication(config: Dict[str, Any]) -> Dict[str, Any]: get_or_generate_keys() diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index c726f21a597..e0829b3cd64 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -158,7 +158,8 @@ ('available_node_types', 'ray.head.default', 'node_config', 'IamInstanceProfile'), ('available_node_types', 'ray.head.default', 'node_config', 'UserData'), - ('available_node_types', 'ray.worker.default', 'node_config', 'UserData'), + ('available_node_types', 'ray.head.default', 'node_config', + 'azure_arm_parameters', 'cloudInitSetupCommands'), ] @@ -1019,13 +1020,18 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str): """ config = common_utils.read_yaml(cluster_config_file) # Check the availability of the cloud type. - if isinstance(cloud, (clouds.AWS, clouds.OCI, clouds.SCP, clouds.Vsphere, - clouds.Cudo, clouds.Paperspace)): + if isinstance(cloud, ( + clouds.AWS, + clouds.OCI, + clouds.SCP, + clouds.Vsphere, + clouds.Cudo, + clouds.Paperspace, + clouds.Azure, + )): config = auth.configure_ssh_info(config) elif isinstance(cloud, clouds.GCP): config = auth.setup_gcp_authentication(config) - elif isinstance(cloud, clouds.Azure): - config = auth.setup_azure_authentication(config) elif isinstance(cloud, clouds.Lambda): config = auth.setup_lambda_authentication(config) elif isinstance(cloud, clouds.Kubernetes): diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 2fdef130afb..e391e1fecd8 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -18,7 +18,8 @@ import threading import time import typing -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import (Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, + Union) import colorama import filelock @@ -702,56 +703,38 @@ class FailoverCloudErrorHandlerV1: """ @staticmethod - def _azure_handler(blocked_resources: Set['resources_lib.Resources'], - launchable_resources: 'resources_lib.Resources', - region: 'clouds.Region', - zones: Optional[List['clouds.Zone']], stdout: str, - stderr: str): - del zones # Unused. - # The underlying ray autoscaler will try all zones of a region at once. - style = colorama.Style + def _handle_errors(stdout: str, stderr: str, + is_error_str_known: Callable[[str], bool]) -> List[str]: stdout_splits = stdout.split('\n') stderr_splits = stderr.split('\n') errors = [ s.strip() for s in stdout_splits + stderr_splits - if ('Exception Details:' in s.strip() or 'InvalidTemplateDeployment' - in s.strip() or '(ReadOnlyDisabledSubscription)' in s.strip()) + if is_error_str_known(s.strip()) ] - if not errors: - if 'Head node fetch timed out' in stderr: - # Example: click.exceptions.ClickException: Head node fetch - # timed out. Failed to create head node. - # This is a transient error, but we have retried in need_ray_up - # and failed. So we skip this region. - logger.info('Got \'Head node fetch timed out\' in ' - f'{region.name}.') - _add_to_blocked_resources( - blocked_resources, - launchable_resources.copy(region=region.name)) - elif 'rsync: command not found' in stderr: - with ux_utils.print_exception_no_traceback(): - raise RuntimeError(_RSYNC_NOT_FOUND_MESSAGE) - logger.info('====== stdout ======') - for s in stdout_splits: - print(s) - logger.info('====== stderr ======') - for s in stderr_splits: - print(s) + if errors: + return errors + if 'rsync: command not found' in stderr: with ux_utils.print_exception_no_traceback(): - raise RuntimeError('Errors occurred during provision; ' - 'check logs above.') - - logger.warning(f'Got error(s) in {region.name}:') - messages = '\n\t'.join(errors) - logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') - if any('(ReadOnlyDisabledSubscription)' in s for s in errors): - _add_to_blocked_resources( - blocked_resources, - resources_lib.Resources(cloud=clouds.Azure())) - else: - _add_to_blocked_resources(blocked_resources, - launchable_resources.copy(zone=None)) + e = RuntimeError(_RSYNC_NOT_FOUND_MESSAGE) + setattr(e, 'detailed_reason', + f'stdout: {stdout}\nstderr: {stderr}') + raise e + detailed_reason = textwrap.dedent(f"""\ + ====== stdout ====== + {stdout} + ====== stderr ====== + {stderr} + """) + logger.info('====== stdout ======') + print(stdout) + logger.info('====== stderr ======') + print(stderr) + with ux_utils.print_exception_no_traceback(): + e = RuntimeError('Errors occurred during provision; ' + 'check logs above.') + setattr(e, 'detailed_reason', detailed_reason) + raise e @staticmethod def _lambda_handler(blocked_resources: Set['resources_lib.Resources'], @@ -760,30 +743,13 @@ def _lambda_handler(blocked_resources: Set['resources_lib.Resources'], zones: Optional[List['clouds.Zone']], stdout: str, stderr: str): del zones # Unused. - style = colorama.Style - stdout_splits = stdout.split('\n') - stderr_splits = stderr.split('\n') - errors = [ - s.strip() - for s in stdout_splits + stderr_splits - if 'LambdaCloudError:' in s.strip() - ] - if not errors: - if 'rsync: command not found' in stderr: - with ux_utils.print_exception_no_traceback(): - raise RuntimeError(_RSYNC_NOT_FOUND_MESSAGE) - logger.info('====== stdout ======') - for s in stdout_splits: - print(s) - logger.info('====== stderr ======') - for s in stderr_splits: - print(s) - with ux_utils.print_exception_no_traceback(): - raise RuntimeError('Errors occurred during provision; ' - 'check logs above.') - + errors = FailoverCloudErrorHandlerV1._handle_errors( + stdout, + stderr, + is_error_str_known=lambda x: 'LambdaCloudError:' in x.strip()) logger.warning(f'Got error(s) in {region.name}:') messages = '\n\t'.join(errors) + style = colorama.Style logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') _add_to_blocked_resources(blocked_resources, launchable_resources.copy(zone=None)) @@ -797,65 +763,21 @@ def _lambda_handler(blocked_resources: Set['resources_lib.Resources'], blocked_resources, launchable_resources.copy(region=r.name, zone=None)) - @staticmethod - def _kubernetes_handler(blocked_resources: Set['resources_lib.Resources'], - launchable_resources: 'resources_lib.Resources', - region, zones, stdout, stderr): - del zones # Unused. - style = colorama.Style - stdout_splits = stdout.split('\n') - stderr_splits = stderr.split('\n') - errors = [ - s.strip() - for s in stdout_splits + stderr_splits - if 'KubernetesError:' in s.strip() - ] - if not errors: - logger.info('====== stdout ======') - for s in stdout_splits: - print(s) - logger.info('====== stderr ======') - for s in stderr_splits: - print(s) - with ux_utils.print_exception_no_traceback(): - raise RuntimeError('Errors occurred during provisioning; ' - 'check logs above.') - - logger.warning(f'Got error(s) in {region.name}:') - messages = '\n\t'.join(errors) - logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') - _add_to_blocked_resources(blocked_resources, - launchable_resources.copy(zone=None)) - @staticmethod def _scp_handler(blocked_resources: Set['resources_lib.Resources'], - launchable_resources: 'resources_lib.Resources', region, - zones, stdout, stderr): + launchable_resources: 'resources_lib.Resources', + region: 'clouds.Region', + zones: Optional[List['clouds.Zone']], stdout: str, + stderr: str): del zones # Unused. - style = colorama.Style - stdout_splits = stdout.split('\n') - stderr_splits = stderr.split('\n') - errors = [ - s.strip() - for s in stdout_splits + stderr_splits - if 'SCPError:' in s.strip() - ] - if not errors: - if 'rsync: command not found' in stderr: - with ux_utils.print_exception_no_traceback(): - raise RuntimeError(_RSYNC_NOT_FOUND_MESSAGE) - logger.info('====== stdout ======') - for s in stdout_splits: - print(s) - logger.info('====== stderr ======') - for s in stderr_splits: - print(s) - with ux_utils.print_exception_no_traceback(): - raise RuntimeError('Errors occurred during provision; ' - 'check logs above.') + errors = FailoverCloudErrorHandlerV1._handle_errors( + stdout, + stderr, + is_error_str_known=lambda x: 'SCPError:' in x.strip()) logger.warning(f'Got error(s) in {region.name}:') messages = '\n\t'.join(errors) + style = colorama.Style logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') _add_to_blocked_resources(blocked_resources, launchable_resources.copy(zone=None)) @@ -876,29 +798,13 @@ def _ibm_handler(blocked_resources: Set['resources_lib.Resources'], zones: Optional[List['clouds.Zone']], stdout: str, stderr: str): - style = colorama.Style - stdout_splits = stdout.split('\n') - stderr_splits = stderr.split('\n') - errors = [ - s.strip() - for s in stdout_splits + stderr_splits - if 'ERR' in s.strip() or 'PANIC' in s.strip() - ] - if not errors: - if 'rsync: command not found' in stderr: - with ux_utils.print_exception_no_traceback(): - raise RuntimeError(_RSYNC_NOT_FOUND_MESSAGE) - logger.info('====== stdout ======') - for s in stdout_splits: - print(s) - logger.info('====== stderr ======') - for s in stderr_splits: - print(s) - with ux_utils.print_exception_no_traceback(): - raise RuntimeError('Errors occurred during provision; ' - 'check logs above.') + errors = FailoverCloudErrorHandlerV1._handle_errors( + stdout, stderr, + lambda x: 'ERR' in x.strip() or 'PANIC' in x.strip()) + logger.warning(f'Got error(s) on IBM cluster, in {region.name}:') messages = '\n\t'.join(errors) + style = colorama.Style logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') for zone in zones: # type: ignore[union-attr] @@ -912,35 +818,17 @@ def _oci_handler(blocked_resources: Set['resources_lib.Resources'], region: 'clouds.Region', zones: Optional[List['clouds.Zone']], stdout: str, stderr: str): - - style = colorama.Style - stdout_splits = stdout.split('\n') - stderr_splits = stderr.split('\n') - errors = [ - s.strip() - for s in stdout_splits + stderr_splits - if ('VcnSubnetNotFound' in s.strip()) or - ('oci.exceptions.ServiceError' in s.strip() and - ('NotAuthorizedOrNotFound' in s.strip() or 'CannotParseRequest' in - s.strip() or 'InternalError' in s.strip() or - 'LimitExceeded' in s.strip() or 'NotAuthenticated' in s.strip())) + known_service_errors = [ + 'NotAuthorizedOrNotFound', 'CannotParseRequest', 'InternalError', + 'LimitExceeded', 'NotAuthenticated' ] - if not errors: - if 'rsync: command not found' in stderr: - with ux_utils.print_exception_no_traceback(): - raise RuntimeError(_RSYNC_NOT_FOUND_MESSAGE) - logger.info('====== stdout ======') - for s in stdout_splits: - print(s) - logger.info('====== stderr ======') - for s in stderr_splits: - print(s) - with ux_utils.print_exception_no_traceback(): - raise RuntimeError('Errors occurred during provision; ' - 'check logs above.') - + errors = FailoverCloudErrorHandlerV1._handle_errors( + stdout, stderr, lambda x: 'VcnSubnetNotFound' in x.strip() or + ('oci.exceptions.ServiceError' in x.strip() and any( + known_err in x.strip() for known_err in known_service_errors))) logger.warning(f'Got error(s) in {region.name}:') messages = '\n\t'.join(errors) + style = colorama.Style logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') if zones is not None: @@ -1022,6 +910,25 @@ class FailoverCloudErrorHandlerV2: stdout and stderr. """ + @staticmethod + def _azure_handler(blocked_resources: Set['resources_lib.Resources'], + launchable_resources: 'resources_lib.Resources', + region: 'clouds.Region', zones: List['clouds.Zone'], + err: Exception): + del region, zones # Unused. + if '(ReadOnlyDisabledSubscription)' in str(err): + logger.info( + f'{colorama.Style.DIM}Azure subscription is read-only. ' + 'Skip provisioning on Azure. Please check the subscription set ' + 'with az account set -s .' + f'{colorama.Style.RESET_ALL}') + _add_to_blocked_resources( + blocked_resources, + resources_lib.Resources(cloud=clouds.Azure())) + else: + _add_to_blocked_resources(blocked_resources, + launchable_resources.copy(zone=None)) + @staticmethod def _gcp_handler(blocked_resources: Set['resources_lib.Resources'], launchable_resources: 'resources_lib.Resources', @@ -1826,19 +1733,6 @@ def need_ray_up( if returncode == 0: return False - if isinstance(to_provision_cloud, clouds.Azure): - if 'Failed to invoke the Azure CLI' in stderr: - logger.info( - 'Retrying head node provisioning due to Azure CLI ' - 'issues.') - return True - if ('Head node fetch timed out. Failed to create head node.' - in stderr): - logger.info( - 'Retrying head node provisioning due to head fetching ' - 'timeout.') - return True - if isinstance(to_provision_cloud, clouds.Lambda): if 'Your API requests are being rate limited.' in stderr: logger.info( @@ -2446,8 +2340,20 @@ def get_command_runners(self, self.cluster_yaml, self.docker_user, self.ssh_user) if avoid_ssh_control: ssh_credentials.pop('ssh_control_name', None) + updated_to_skypilot_provisioner_after_provisioned = ( + self.launched_resources.cloud.PROVISIONER_VERSION >= + clouds.ProvisionerVersion.SKYPILOT and + self.cached_external_ips is not None and + self.cached_cluster_info is None) + if updated_to_skypilot_provisioner_after_provisioned: + logger.debug( + f'{self.launched_resources.cloud} has been updated to the new ' + f'provisioner after cluster {self.cluster_name} was ' + f'provisioned. Cached IPs are used for connecting to the ' + 'cluster.') if (clouds.ProvisionerVersion.RAY_PROVISIONER_SKYPILOT_TERMINATOR >= - self.launched_resources.cloud.PROVISIONER_VERSION): + self.launched_resources.cloud.PROVISIONER_VERSION or + updated_to_skypilot_provisioner_after_provisioned): ip_list = (self.cached_external_ips if force_cached else self.external_ips()) if ip_list is None: @@ -2460,7 +2366,15 @@ def get_command_runners(self, zip(ip_list, port_list), **ssh_credentials) return runners if self.cached_cluster_info is None: - assert not force_cached, 'cached_cluster_info is None.' + # We have `or self.cached_external_ips is None` here, because + # when a cluster's cloud is just upgraded to the new provsioner, + # although it has the cached_external_ips, the cached_cluster_info + # can be None. We need to update it here, even when force_cached is + # set to True. + # TODO: We can remove `self.cached_external_ips is None` after + # version 0.8.0. + assert not force_cached or self.cached_external_ips is not None, ( + force_cached, self.cached_external_ips) self._update_cluster_info() assert self.cached_cluster_info is not None, self runners = provision_lib.get_command_runners( @@ -3292,8 +3206,8 @@ def _exec_code_on_head( '--address=http://127.0.0.1:$RAY_DASHBOARD_PORT ' f'--submission-id {job_id}-$(whoami) --no-wait ' # Redirect stderr to /dev/null to avoid distracting error from ray. - f'"{constants.SKY_PYTHON_CMD} -u {script_path} > {remote_log_path} 2> /dev/null"' - ) + f'"{constants.SKY_PYTHON_CMD} -u {script_path} > {remote_log_path} ' + '2> /dev/null"') code = job_lib.JobLibCodeGen.queue_job(job_id, job_submit_cmd) job_submit_cmd = ' && '.join([mkdir_code, create_script_code, code]) diff --git a/sky/benchmark/benchmark_utils.py b/sky/benchmark/benchmark_utils.py index 23a37c573ae..ffbcc8b6279 100644 --- a/sky/benchmark/benchmark_utils.py +++ b/sky/benchmark/benchmark_utils.py @@ -262,10 +262,16 @@ def _delete_remote_dir(remote_dir: str, bucket_type: data.StoreType) -> None: check=True) elif bucket_type == data.StoreType.GCS: remote_dir = f'gs://{remote_dir}' - subprocess.run(['gsutil', '-m', 'rm', '-r', remote_dir], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - check=True) + proc = subprocess.run(['gsutil', '-m', 'rm', '-r', remote_dir], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False) + if proc.returncode != 0: + stderr = proc.stderr.decode('utf-8') + if 'BucketNotFoundException: 404' in stderr: + logger.warning(f'Bucket {remote_dir} does not exist. Skip') + else: + raise RuntimeError(f'Failed to delete {remote_dir}: {stderr}') else: raise RuntimeError('Azure Blob Storage is not supported yet.') diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index 4b159bb0035..638a01dcf6b 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -1,5 +1,4 @@ """Azure.""" -import base64 import functools import json import os @@ -69,7 +68,7 @@ class Azure(clouds.Cloud): _INDENT_PREFIX = ' ' * 4 - PROVISIONER_VERSION = clouds.ProvisionerVersion.RAY_PROVISIONER_SKYPILOT_TERMINATOR + PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT STATUS_VERSION = clouds.StatusVersion.SKYPILOT @classmethod @@ -327,8 +326,7 @@ def make_deploy_resources_variables( # restarted, identified by a file /tmp/__restarted is existing. # Also, add default user to docker group. # pylint: disable=line-too-long - cloud_init_setup_commands = base64.b64encode( - textwrap.dedent("""\ + cloud_init_setup_commands = textwrap.dedent("""\ #cloud-config runcmd: - sed -i 's/#Banner none/Banner none/' /etc/ssh/sshd_config @@ -344,7 +342,7 @@ def make_deploy_resources_variables( - path: /etc/apt/apt.conf.d/10cloudinit-disable content: | APT::Periodic::Enable "0"; - """).encode('utf-8')).decode('utf-8') + """).split('\n') def _failover_disk_tier() -> Optional[resources_utils.DiskTier]: if (r.disk_tier is not None and diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index aadf5a64684..524a0cb0478 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -843,4 +843,5 @@ def set_pending(cls, job_id: int, managed_job_dag: 'dag_lib.Dag') -> str: @classmethod def _build(cls, code: str) -> str: generated_code = cls._PREFIX + '\n' + code + return f'{constants.SKY_PYTHON_CMD} -u -c {shlex.quote(generated_code)}' diff --git a/sky/provision/aws/instance.py b/sky/provision/aws/instance.py index 23cc69b664d..647d23ef225 100644 --- a/sky/provision/aws/instance.py +++ b/sky/provision/aws/instance.py @@ -15,6 +15,7 @@ from sky.adaptors import aws from sky.clouds import aws as aws_cloud from sky.provision import common +from sky.provision import constants from sky.provision.aws import utils from sky.utils import common_utils from sky.utils import resources_utils @@ -25,11 +26,6 @@ _T = TypeVar('_T') -# Tag uniquely identifying all nodes of a cluster -TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' -TAG_SKYPILOT_CLUSTER_NAME = 'skypilot-cluster-name' -TAG_RAY_NODE_KIND = 'ray-node-type' # legacy tag for backward compatibility -TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node' # Max retries for general AWS API calls. BOTO_MAX_RETRIES = 12 # Max retries for creating an instance. @@ -103,7 +99,7 @@ def _default_ec2_resource(region: str) -> Any: def _cluster_name_filter(cluster_name_on_cloud: str) -> List[Dict[str, Any]]: return [{ - 'Name': f'tag:{TAG_RAY_CLUSTER_NAME}', + 'Name': f'tag:{constants.TAG_RAY_CLUSTER_NAME}', 'Values': [cluster_name_on_cloud], }] @@ -181,8 +177,8 @@ def _create_instances(ec2_fail_fast, cluster_name: str, count: int, associate_public_ip_address: bool) -> List: tags = { 'Name': cluster_name, - TAG_RAY_CLUSTER_NAME: cluster_name, - TAG_SKYPILOT_CLUSTER_NAME: cluster_name, + constants.TAG_RAY_CLUSTER_NAME: cluster_name, + constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name, **tags } conf = node_config.copy() @@ -250,10 +246,8 @@ def _create_instances(ec2_fail_fast, cluster_name: str, def _get_head_instance_id(instances: List) -> Optional[str]: head_instance_id = None - head_node_markers = ( - (TAG_SKYPILOT_HEAD_NODE, '1'), - (TAG_RAY_NODE_KIND, 'head'), # backward compat with Ray - ) + head_node_markers = tuple(constants.HEAD_NODE_TAGS.items()) + for inst in instances: for t in inst.tags: if (t['Key'], t['Value']) in head_node_markers: @@ -288,7 +282,7 @@ def run_instances(region: str, cluster_name_on_cloud: str, 'Name': 'instance-state-name', 'Values': ['pending', 'running', 'stopping', 'stopped'], }, { - 'Name': f'tag:{TAG_RAY_CLUSTER_NAME}', + 'Name': f'tag:{constants.TAG_RAY_CLUSTER_NAME}', 'Values': [cluster_name_on_cloud], }] exist_instances = list(ec2.instances.filter(Filters=filters)) @@ -314,28 +308,19 @@ def run_instances(region: str, cluster_name_on_cloud: str, raise RuntimeError(f'Impossible state "{state}".') def _create_node_tag(target_instance, is_head: bool = True) -> str: + node_type_tags = (constants.HEAD_NODE_TAGS + if is_head else constants.WORKER_NODE_TAGS) + node_tag = [{'Key': k, 'Value': v} for k, v in node_type_tags.items()] if is_head: - node_tag = [{ - 'Key': TAG_SKYPILOT_HEAD_NODE, - 'Value': '1' - }, { - 'Key': TAG_RAY_NODE_KIND, - 'Value': 'head' - }, { + node_tag.append({ 'Key': 'Name', 'Value': f'sky-{cluster_name_on_cloud}-head' - }] + }) else: - node_tag = [{ - 'Key': TAG_SKYPILOT_HEAD_NODE, - 'Value': '0' - }, { - 'Key': TAG_RAY_NODE_KIND, - 'Value': 'worker' - }, { + node_tag.append({ 'Key': 'Name', 'Value': f'sky-{cluster_name_on_cloud}-worker' - }] + }) ec2.meta.client.create_tags( Resources=[target_instance.id], Tags=target_instance.tags + node_tag, @@ -563,7 +548,7 @@ def stop_instances( ] if worker_only: filters.append({ - 'Name': f'tag:{TAG_RAY_NODE_KIND}', + 'Name': f'tag:{constants.TAG_RAY_NODE_KIND}', 'Values': ['worker'], }) instances = _filter_instances(ec2, @@ -601,7 +586,7 @@ def terminate_instances( ] if worker_only: filters.append({ - 'Name': f'tag:{TAG_RAY_NODE_KIND}', + 'Name': f'tag:{constants.TAG_RAY_NODE_KIND}', 'Values': ['worker'], }) # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html#EC2.Instance @@ -814,7 +799,7 @@ def wait_instances(region: str, cluster_name_on_cloud: str, filters = [ { - 'Name': f'tag:{TAG_RAY_CLUSTER_NAME}', + 'Name': f'tag:{constants.TAG_RAY_CLUSTER_NAME}', 'Values': [cluster_name_on_cloud], }, ] @@ -865,7 +850,7 @@ def get_cluster_info( 'Values': ['running'], }, { - 'Name': f'tag:{TAG_RAY_CLUSTER_NAME}', + 'Name': f'tag:{constants.TAG_RAY_CLUSTER_NAME}', 'Values': [cluster_name_on_cloud], }, ] diff --git a/sky/provision/azure/__init__.py b/sky/provision/azure/__init__.py index 2152728ba6e..378bda0e112 100644 --- a/sky/provision/azure/__init__.py +++ b/sky/provision/azure/__init__.py @@ -1,7 +1,11 @@ """Azure provisioner for SkyPilot.""" +from sky.provision.azure.config import bootstrap_instances from sky.provision.azure.instance import cleanup_ports +from sky.provision.azure.instance import get_cluster_info from sky.provision.azure.instance import open_ports from sky.provision.azure.instance import query_instances +from sky.provision.azure.instance import run_instances from sky.provision.azure.instance import stop_instances from sky.provision.azure.instance import terminate_instances +from sky.provision.azure.instance import wait_instances diff --git a/sky/skylet/providers/azure/azure-config-template.json b/sky/provision/azure/azure-config-template.json similarity index 91% rename from sky/skylet/providers/azure/azure-config-template.json rename to sky/provision/azure/azure-config-template.json index 1a13a67a121..489783faf98 100644 --- a/sky/skylet/providers/azure/azure-config-template.json +++ b/sky/provision/azure/azure-config-template.json @@ -5,7 +5,7 @@ "clusterId": { "type": "string", "metadata": { - "description": "Unique string appended to resource names to isolate resources from different ray clusters." + "description": "Unique string appended to resource names to isolate resources from different SkyPilot clusters." } }, "subnet": { @@ -18,12 +18,12 @@ "variables": { "contributor": "[subscriptionResourceId('Microsoft.Authorization/roleDefinitions', 'b24988ac-6180-42a0-ab88-20f7382dd24c')]", "location": "[resourceGroup().location]", - "msiName": "[concat('ray-', parameters('clusterId'), '-msi')]", - "roleAssignmentName": "[concat('ray-', parameters('clusterId'), '-ra')]", - "nsgName": "[concat('ray-', parameters('clusterId'), '-nsg')]", + "msiName": "[concat('sky-', parameters('clusterId'), '-msi')]", + "roleAssignmentName": "[concat('sky-', parameters('clusterId'), '-ra')]", + "nsgName": "[concat('sky-', parameters('clusterId'), '-nsg')]", "nsg": "[resourceId('Microsoft.Network/networkSecurityGroups', variables('nsgName'))]", - "vnetName": "[concat('ray-', parameters('clusterId'), '-vnet')]", - "subnetName": "[concat('ray-', parameters('clusterId'), '-subnet')]" + "vnetName": "[concat('sky-', parameters('clusterId'), '-vnet')]", + "subnetName": "[concat('sky-', parameters('clusterId'), '-subnet')]" }, "resources": [ { diff --git a/sky/skylet/providers/azure/azure-vm-template.json b/sky/provision/azure/azure-vm-template.json similarity index 100% rename from sky/skylet/providers/azure/azure-vm-template.json rename to sky/provision/azure/azure-vm-template.json diff --git a/sky/provision/azure/config.py b/sky/provision/azure/config.py new file mode 100644 index 00000000000..5d9385bd73c --- /dev/null +++ b/sky/provision/azure/config.py @@ -0,0 +1,169 @@ +"""Azure configuration bootstrapping. + +Creates the resource group and deploys the configuration template to Azure for +a cluster to be launched. +""" +import json +import logging +from pathlib import Path +import random +import time +from typing import Any, Callable + +from sky.adaptors import azure +from sky.provision import common + +logger = logging.getLogger(__name__) + +_DEPLOYMENT_NAME = 'skypilot-config' +_LEGACY_DEPLOYMENT_NAME = 'ray-config' +_RESOURCE_GROUP_WAIT_FOR_DELETION_TIMEOUT = 480 # 8 minutes + + +def get_azure_sdk_function(client: Any, function_name: str) -> Callable: + """Retrieve a callable function from Azure SDK client object. + + Newer versions of the various client SDKs renamed function names to + have a begin_ prefix. This function supports both the old and new + versions of the SDK by first trying the old name and falling back to + the prefixed new name. + """ + func = getattr(client, function_name, + getattr(client, f'begin_{function_name}', None)) + if func is None: + raise AttributeError( + f'{client.__name__!r} object has no {function_name} or ' + f'begin_{function_name} attribute') + return func + + +@common.log_function_start_end +def bootstrap_instances( + region: str, cluster_name_on_cloud: str, + config: common.ProvisionConfig) -> common.ProvisionConfig: + """See sky/provision/__init__.py""" + del region # unused + provider_config = config.provider_config + subscription_id = provider_config.get('subscription_id') + if subscription_id is None: + subscription_id = azure.get_subscription_id() + # Increase the timeout to fix the Azure get-access-token (used by ray azure + # node_provider) timeout issue. + # Tracked in https://github.com/Azure/azure-cli/issues/20404#issuecomment-1249575110 # pylint: disable=line-too-long + resource_client = azure.get_client('resource', subscription_id) + provider_config['subscription_id'] = subscription_id + logger.info(f'Using subscription id: {subscription_id}') + + assert ( + 'resource_group' + in provider_config), 'Provider config must include resource_group field' + resource_group = provider_config['resource_group'] + + assert ('location' + in provider_config), 'Provider config must include location field' + params = {'location': provider_config['location']} + + if 'tags' in provider_config: + params['tags'] = provider_config['tags'] + + logger.info(f'Creating/Updating resource group: {resource_group}') + rg_create_or_update = get_azure_sdk_function( + client=resource_client.resource_groups, + function_name='create_or_update') + rg_creation_start = time.time() + retry = 0 + while (time.time() - rg_creation_start < + _RESOURCE_GROUP_WAIT_FOR_DELETION_TIMEOUT): + try: + rg_create_or_update(resource_group_name=resource_group, + parameters=params) + break + except azure.exceptions().ResourceExistsError as e: + if 'ResourceGroupBeingDeleted' in str(e): + if retry % 5 == 0: + logger.info( + f'Azure resource group {resource_group} of a recent ' + f'terminated cluster {cluster_name_on_cloud} is being ' + 'deleted. It can only be provisioned after it is fully' + 'deleted. Waiting...') + time.sleep(1) + retry += 1 + continue + raise + else: + raise TimeoutError( + f'Timed out waiting for resource group {resource_group} to be ' + 'deleted.') + + # load the template file + current_path = Path(__file__).parent + template_path = current_path.joinpath('azure-config-template.json') + with open(template_path, 'r', encoding='utf-8') as template_fp: + template = json.load(template_fp) + + logger.info(f'Using cluster name: {cluster_name_on_cloud}') + + subnet_mask = provider_config.get('subnet_mask') + if subnet_mask is None: + # choose a random subnet, skipping most common value of 0 + random.seed(cluster_name_on_cloud) + subnet_mask = f'10.{random.randint(1, 254)}.0.0/16' + logger.info(f'Using subnet mask: {subnet_mask}') + + parameters = { + 'properties': { + 'mode': azure.deployment_mode().incremental, + 'template': template, + 'parameters': { + 'subnet': { + 'value': subnet_mask + }, + 'clusterId': { + # We use the cluster name as the unique ID for the cluster, + # as we have already appended the user hash to the cluster + # name. + 'value': cluster_name_on_cloud + }, + }, + } + } + + # Skip creating or updating the deployment if the deployment already exists + # and the cluster name is the same. + get_deployment = get_azure_sdk_function(client=resource_client.deployments, + function_name='get') + deployment_exists = False + for deployment_name in [_DEPLOYMENT_NAME, _LEGACY_DEPLOYMENT_NAME]: + try: + deployment = get_deployment(resource_group_name=resource_group, + deployment_name=deployment_name) + logger.info(f'Deployment {deployment_name!r} already exists. ' + 'Skipping deployment creation.') + + outputs = deployment.properties.outputs + if outputs is not None: + deployment_exists = True + break + except azure.exceptions().ResourceNotFoundError: + deployment_exists = False + + if not deployment_exists: + logger.info(f'Creating/Updating deployment: {_DEPLOYMENT_NAME}') + create_or_update = get_azure_sdk_function( + client=resource_client.deployments, + function_name='create_or_update') + # TODO (skypilot): this takes a long time (> 40 seconds) to run. + outputs = create_or_update( + resource_group_name=resource_group, + deployment_name=_DEPLOYMENT_NAME, + parameters=parameters, + ).result().properties.outputs + + nsg_id = outputs['nsg']['value'] + + # append output resource ids to be used with vm creation + provider_config['msi'] = outputs['msi']['value'] + provider_config['nsg'] = nsg_id + provider_config['subnet'] = outputs['subnet']['value'] + + return config diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index 19c1ba3f3da..2a8d54273c2 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -1,18 +1,28 @@ """Azure instance provisioning.""" +import base64 +import copy +import enum +import json import logging from multiprocessing import pool +import pathlib +import time import typing -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple +from uuid import uuid4 from sky import exceptions from sky import sky_logging from sky import status_lib from sky.adaptors import azure +from sky.provision import common +from sky.provision import constants from sky.utils import common_utils from sky.utils import ux_utils if typing.TYPE_CHECKING: from azure.mgmt import compute as azure_compute + from azure.mgmt import resource as azure_resource logger = sky_logging.init_logger(__name__) @@ -21,14 +31,100 @@ azure_logger = logging.getLogger('azure') azure_logger.setLevel(logging.WARNING) -# Tag uniquely identifying all nodes of a cluster -TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' -TAG_RAY_NODE_KIND = 'ray-node-type' +_RESUME_INSTANCE_TIMEOUT = 480 # 8 minutes +_RESUME_PER_INSTANCE_TIMEOUT = 120 # 2 minutes +UNIQUE_ID_LEN = 4 +_TAG_SKYPILOT_VM_ID = 'skypilot-vm-id' +_WAIT_CREATION_TIMEOUT_SECONDS = 600 _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE = 'ResourceGroupNotFound' +_POLL_INTERVAL = 1 + + +class AzureInstanceStatus(enum.Enum): + """Statuses enum for Azure instances with power and provisioning states.""" + PENDING = 'pending' + RUNNING = 'running' + STOPPING = 'stopping' + STOPPED = 'stopped' + DELETING = 'deleting' + + @classmethod + def power_state_map(cls) -> Dict[str, 'AzureInstanceStatus']: + return { + 'starting': cls.PENDING, + 'running': cls.RUNNING, + # 'stopped' in Azure means Stopped (Allocated), which still bills + # for the VM. + 'stopping': cls.STOPPING, + 'stopped': cls.STOPPED, + # 'VM deallocated' in Azure means Stopped (Deallocated), which does + # not bill for the VM. + 'deallocating': cls.STOPPING, + 'deallocated': cls.STOPPED, + } + + @classmethod + def provisioning_state_map(cls) -> Dict[str, 'AzureInstanceStatus']: + return { + 'Creating': cls.PENDING, + 'Updating': cls.PENDING, + 'Failed': cls.PENDING, + 'Migrating': cls.PENDING, + 'Deleting': cls.DELETING, + # Succeeded in provisioning state means the VM is provisioned but + # not necessarily running. The caller should further check the + # power state to determine the actual VM status. + 'Succeeded': cls.RUNNING, + } + + @classmethod + def cluster_status_map( + cls + ) -> Dict['AzureInstanceStatus', Optional[status_lib.ClusterStatus]]: + return { + cls.PENDING: status_lib.ClusterStatus.INIT, + cls.STOPPING: status_lib.ClusterStatus.INIT, + cls.RUNNING: status_lib.ClusterStatus.UP, + cls.STOPPED: status_lib.ClusterStatus.STOPPED, + cls.DELETING: None, + } + + @classmethod + def from_raw_states(cls, provisioning_state: str, + power_state: Optional[str]) -> 'AzureInstanceStatus': + provisioning_state_map = cls.provisioning_state_map() + power_state_map = cls.power_state_map() + status = None + if power_state is None: + if provisioning_state not in provisioning_state_map: + with ux_utils.print_exception_no_traceback(): + raise exceptions.ClusterStatusFetchingError( + 'Failed to parse status from Azure response: ' + f'{provisioning_state}') + status = provisioning_state_map[provisioning_state] + if status is None or status == cls.RUNNING: + # We should further check the power state to determine the actual + # VM status. + if power_state not in power_state_map: + with ux_utils.print_exception_no_traceback(): + raise exceptions.ClusterStatusFetchingError( + 'Failed to parse status from Azure response: ' + f'{power_state}.') + status = power_state_map[power_state] + if status is None: + with ux_utils.print_exception_no_traceback(): + raise exceptions.ClusterStatusFetchingError( + 'Failed to parse status from Azure response: ' + f'provisioning state ({provisioning_state}), ' + f'power state ({power_state})') + return status + + def to_cluster_status(self) -> Optional[status_lib.ClusterStatus]: + return self.cluster_status_map().get(self) -def get_azure_sdk_function(client: Any, function_name: str) -> Callable: +def _get_azure_sdk_function(client: Any, function_name: str) -> Callable: """Retrieve a callable function from Azure SDK client object. Newer versions of the various client SDKs renamed function names to @@ -45,64 +141,412 @@ def get_azure_sdk_function(client: Any, function_name: str) -> Callable: return func -def open_ports( - cluster_name_on_cloud: str, - ports: List[str], - provider_config: Optional[Dict[str, Any]] = None, -) -> None: +def _get_instance_ips(network_client, vm, resource_group: str, + use_internal_ips: bool) -> Tuple[str, Optional[str]]: + nic_id = vm.network_profile.network_interfaces[0].id + nic_name = nic_id.split('/')[-1] + nic = network_client.network_interfaces.get( + resource_group_name=resource_group, + network_interface_name=nic_name, + ) + ip_config = nic.ip_configurations[0] + + external_ip = None + if not use_internal_ips: + public_ip_id = ip_config.public_ip_address.id + public_ip_name = public_ip_id.split('/')[-1] + public_ip = network_client.public_ip_addresses.get( + resource_group_name=resource_group, + public_ip_address_name=public_ip_name, + ) + external_ip = public_ip.ip_address + + internal_ip = ip_config.private_ip_address + + return (internal_ip, external_ip) + + +def _get_head_instance_id(instances: List) -> Optional[str]: + head_instance_id = None + head_node_tags = tuple(constants.HEAD_NODE_TAGS.items()) + for inst in instances: + for k, v in inst.tags.items(): + if (k, v) in head_node_tags: + if head_instance_id is not None: + logger.warning( + 'There are multiple head nodes in the cluster ' + f'(current head instance id: {head_instance_id}, ' + f'newly discovered id: {inst.name}). It is likely ' + f'that something goes wrong.') + head_instance_id = inst.name + break + return head_instance_id + + +def _create_instances( + compute_client: 'azure_compute.ComputeManagementClient', + resource_client: 'azure_resource.ResourceManagementClient', + cluster_name_on_cloud: str, resource_group: str, + provider_config: Dict[str, Any], node_config: Dict[str, Any], + tags: Dict[str, str], count: int) -> List: + vm_id = uuid4().hex[:UNIQUE_ID_LEN] + tags = { + constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, + constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name_on_cloud, + **constants.WORKER_NODE_TAGS, + _TAG_SKYPILOT_VM_ID: vm_id, + **tags, + } + node_tags = node_config['tags'].copy() + node_tags.update(tags) + + # load the template file + current_path = pathlib.Path(__file__).parent + template_path = current_path.joinpath('azure-vm-template.json') + with open(template_path, 'r', encoding='utf-8') as template_fp: + template = json.load(template_fp) + + vm_name = f'{cluster_name_on_cloud}-{vm_id}' + use_internal_ips = provider_config.get('use_internal_ips', False) + + template_params = node_config['azure_arm_parameters'].copy() + # We don't include 'head' or 'worker' in the VM name as on Azure the VM + # name is immutable and we may change the node type for existing VM in the + # multi-node cluster, due to manual termination of the head node. + template_params['vmName'] = vm_name + template_params['provisionPublicIp'] = not use_internal_ips + template_params['vmTags'] = node_tags + template_params['vmCount'] = count + template_params['msi'] = provider_config['msi'] + template_params['nsg'] = provider_config['nsg'] + template_params['subnet'] = provider_config['subnet'] + # In Azure, cloud-init script must be encoded in base64. For more + # information, see: + # https://learn.microsoft.com/en-us/azure/virtual-machines/custom-data + template_params['cloudInitSetupCommands'] = (base64.b64encode( + template_params['cloudInitSetupCommands'].encode('utf-8')).decode( + 'utf-8')) + + if node_config.get('need_nvidia_driver_extension', False): + # pylint: disable=line-too-long + # Configure driver extension for A10 GPUs. A10 GPUs requires a + # special type of drivers which is available at Microsoft HPC + # extension. Reference: https://forums.developer.nvidia.com/t/ubuntu-22-04-installation-driver-error-nvidia-a10/285195/2 + for r in template['resources']: + if r['type'] == 'Microsoft.Compute/virtualMachines': + # Add a nested extension resource for A10 GPUs + r['resources'] = [ + { + 'type': 'extensions', + 'apiVersion': '2015-06-15', + 'location': '[variables(\'location\')]', + 'dependsOn': [ + '[concat(\'Microsoft.Compute/virtualMachines/\', parameters(\'vmName\'), copyIndex())]' + ], + 'name': 'NvidiaGpuDriverLinux', + 'properties': { + 'publisher': 'Microsoft.HpcCompute', + 'type': 'NvidiaGpuDriverLinux', + 'typeHandlerVersion': '1.9', + 'autoUpgradeMinorVersion': True, + 'settings': {}, + }, + }, + ] + break + + parameters = { + 'properties': { + 'mode': azure.deployment_mode().incremental, + 'template': template, + 'parameters': { + key: { + 'value': value + } for key, value in template_params.items() + }, + } + } + + create_or_update = _get_azure_sdk_function( + client=resource_client.deployments, function_name='create_or_update') + create_or_update( + resource_group_name=resource_group, + deployment_name=vm_name, + parameters=parameters, + ).wait() + filters = { + constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, + _TAG_SKYPILOT_VM_ID: vm_id + } + instances = _filter_instances(compute_client, resource_group, filters) + assert len(instances) == count, (len(instances), count) + return instances + + +def run_instances(region: str, cluster_name_on_cloud: str, + config: common.ProvisionConfig) -> common.ProvisionRecord: """See sky/provision/__init__.py""" - assert provider_config is not None, cluster_name_on_cloud - subscription_id = provider_config['subscription_id'] + # TODO(zhwu): This function is too long. We should refactor it. + provider_config = config.provider_config resource_group = provider_config['resource_group'] - network_client = azure.get_client('network', subscription_id) - # The NSG should have been created by the cluster provisioning. - update_network_security_groups = get_azure_sdk_function( - client=network_client.network_security_groups, - function_name='create_or_update') - list_network_security_groups = get_azure_sdk_function( - client=network_client.network_security_groups, function_name='list') - for nsg in list_network_security_groups(resource_group): - try: - # Azure NSG rules have a priority field that determines the order - # in which they are applied. The priority must be unique across - # all inbound rules in one NSG. - priority = max(rule.priority - for rule in nsg.security_rules - if rule.direction == 'Inbound') + 1 - nsg.security_rules.append( - azure.create_security_rule( - name=f'sky-ports-{cluster_name_on_cloud}-{priority}', - priority=priority, - protocol='Tcp', - access='Allow', - direction='Inbound', - source_address_prefix='*', - source_port_range='*', - destination_address_prefix='*', - destination_port_ranges=ports, - )) - poller = update_network_security_groups(resource_group, nsg.name, - nsg) - poller.wait() - if poller.status() != 'Succeeded': - with ux_utils.print_exception_no_traceback(): - raise ValueError(f'Failed to open ports {ports} in NSG ' - f'{nsg.name}: {poller.status()}') - except azure.exceptions().HttpResponseError as e: - with ux_utils.print_exception_no_traceback(): - raise ValueError( - f'Failed to open ports {ports} in NSG {nsg.name}.') from e + subscription_id = provider_config['subscription_id'] + compute_client = azure.get_client('compute', subscription_id) + instances_to_resume = [] + resumed_instance_ids: List[str] = [] + created_instance_ids: List[str] = [] + + # sort tags by key to support deterministic unit test stubbing + tags = dict(sorted(copy.deepcopy(config.tags).items())) + filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + + non_deleting_states = (set(AzureInstanceStatus) - + {AzureInstanceStatus.DELETING}) + existing_instances = _filter_instances( + compute_client, + tag_filters=filters, + resource_group=resource_group, + status_filters=list(non_deleting_states), + ) + logger.debug( + f'run_instances: Found {[inst.name for inst in existing_instances]} ' + 'existing instances in cluster.') + existing_instances.sort(key=lambda x: x.name) + + pending_instances = [] + running_instances = [] + stopping_instances = [] + stopped_instances = [] + + for instance in existing_instances: + status = _get_instance_status(compute_client, instance, resource_group) + logger.debug( + f'run_instances: Instance {instance.name} has status {status}.') + + if status == AzureInstanceStatus.RUNNING: + running_instances.append(instance) + elif status == AzureInstanceStatus.STOPPED: + stopped_instances.append(instance) + elif status == AzureInstanceStatus.STOPPING: + stopping_instances.append(instance) + elif status == AzureInstanceStatus.PENDING: + pending_instances.append(instance) + + def _create_instance_tag(target_instance, is_head: bool = True) -> str: + new_instance_tags = (constants.HEAD_NODE_TAGS + if is_head else constants.WORKER_NODE_TAGS) + + tags = target_instance.tags + tags.update(new_instance_tags) + + update = _get_azure_sdk_function(compute_client.virtual_machines, + 'update') + update(resource_group, target_instance.name, parameters={'tags': tags}) + return target_instance.name + + head_instance_id = _get_head_instance_id(existing_instances) + if head_instance_id is None: + if running_instances: + head_instance_id = _create_instance_tag(running_instances[0]) + elif pending_instances: + head_instance_id = _create_instance_tag(pending_instances[0]) + + if config.resume_stopped_nodes and len(existing_instances) > config.count: + raise RuntimeError( + 'The number of pending/running/stopped/stopping ' + f'instances combined ({len(existing_instances)}) in ' + f'cluster "{cluster_name_on_cloud}" is greater than the ' + f'number requested by the user ({config.count}). ' + 'This is likely a resource leak. ' + 'Use "sky down" to terminate the cluster.') + + to_start_count = config.count - len(pending_instances) - len( + running_instances) + + if to_start_count < 0: + raise RuntimeError( + 'The number of running+pending instances ' + f'({config.count - to_start_count}) in cluster ' + f'"{cluster_name_on_cloud}" is greater than the number ' + f'requested by the user ({config.count}). ' + 'This is likely a resource leak. ' + 'Use "sky down" to terminate the cluster.') + + if config.resume_stopped_nodes and to_start_count > 0 and ( + stopping_instances or stopped_instances): + time_start = time.time() + if stopping_instances: + plural = 's' if len(stopping_instances) > 1 else '' + verb = 'are' if len(stopping_instances) > 1 else 'is' + # TODO(zhwu): double check the correctness of the following on Azure + logger.warning( + f'Instance{plural} {[inst.name for inst in stopping_instances]}' + f' {verb} still in STOPPING state on Azure. It can only be ' + 'resumed after it is fully STOPPED. Waiting ...') + while (stopping_instances and + to_start_count > len(stopped_instances) and + time.time() - time_start < _RESUME_INSTANCE_TIMEOUT): + inst = stopping_instances.pop(0) + per_instance_time_start = time.time() + while (time.time() - per_instance_time_start < + _RESUME_PER_INSTANCE_TIMEOUT): + status = _get_instance_status(compute_client, inst, + resource_group) + if status == AzureInstanceStatus.STOPPED: + break + time.sleep(1) + else: + logger.warning( + f'Instance {inst.name} is still in stopping state ' + f'(Timeout: {_RESUME_PER_INSTANCE_TIMEOUT}). ' + 'Retrying ...') + stopping_instances.append(inst) + time.sleep(5) + continue + stopped_instances.append(inst) + if stopping_instances and to_start_count > len(stopped_instances): + msg = ('Timeout for waiting for existing instances ' + f'{stopping_instances} in STOPPING state to ' + 'be STOPPED before restarting them. Please try again later.') + logger.error(msg) + raise RuntimeError(msg) + + instances_to_resume = stopped_instances[:to_start_count] + instances_to_resume.sort(key=lambda x: x.name) + instances_to_resume_ids = [t.name for t in instances_to_resume] + logger.debug('run_instances: Resuming stopped instances ' + f'{instances_to_resume_ids}.') + start_virtual_machine = _get_azure_sdk_function( + compute_client.virtual_machines, 'start') + with pool.ThreadPool() as p: + p.starmap( + start_virtual_machine, + [(resource_group, inst.name) for inst in instances_to_resume]) + resumed_instance_ids = instances_to_resume_ids + + to_start_count -= len(resumed_instance_ids) + + if to_start_count > 0: + resource_client = azure.get_client('resource', subscription_id) + logger.debug(f'run_instances: Creating {to_start_count} instances.') + created_instances = _create_instances( + compute_client=compute_client, + resource_client=resource_client, + cluster_name_on_cloud=cluster_name_on_cloud, + resource_group=resource_group, + provider_config=provider_config, + node_config=config.node_config, + tags=tags, + count=to_start_count) + created_instance_ids = [inst.name for inst in created_instances] + + non_running_instance_statuses = list( + set(AzureInstanceStatus) - {AzureInstanceStatus.RUNNING}) + start = time.time() + while True: + # Wait for all instances to be in running state + instances = _filter_instances( + compute_client, + resource_group, + filters, + status_filters=non_running_instance_statuses, + included_instances=created_instance_ids + resumed_instance_ids) + if not instances: + break + if time.time() - start > _WAIT_CREATION_TIMEOUT_SECONDS: + raise TimeoutError( + 'run_instances: Timed out waiting for Azure instances to be ' + f'running: {instances}') + logger.debug(f'run_instances: Waiting for {len(instances)} instances ' + 'in PENDING status.') + time.sleep(_POLL_INTERVAL) + + running_instances = _filter_instances( + compute_client, + resource_group, + filters, + status_filters=[AzureInstanceStatus.RUNNING]) + head_instance_id = _get_head_instance_id(running_instances) + instances_to_tag = copy.copy(running_instances) + if head_instance_id is None: + head_instance_id = _create_instance_tag(instances_to_tag[0]) + instances_to_tag = instances_to_tag[1:] + else: + instances_to_tag = [ + inst for inst in instances_to_tag if inst.name != head_instance_id + ] + + if instances_to_tag: + # Tag the instances in case the old resumed instances are not correctly + # tagged. + with pool.ThreadPool() as p: + p.starmap( + _create_instance_tag, + # is_head=False for all wokers. + [(inst, False) for inst in instances_to_tag]) + + assert head_instance_id is not None, head_instance_id + return common.ProvisionRecord( + provider_name='azure', + region=region, + zone=None, + cluster_name=cluster_name_on_cloud, + head_instance_id=head_instance_id, + created_instance_ids=created_instance_ids, + resumed_instance_ids=resumed_instance_ids, + ) + + +def wait_instances(region: str, cluster_name_on_cloud: str, + state: Optional[status_lib.ClusterStatus]) -> None: + """See sky/provision/__init__.py""" + del region, cluster_name_on_cloud, state + # We already wait for the instances to be running in run_instances. + # So we don't need to wait here. -def cleanup_ports( - cluster_name_on_cloud: str, - ports: List[str], - provider_config: Optional[Dict[str, Any]] = None, -) -> None: + +def get_cluster_info( + region: str, + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: """See sky/provision/__init__.py""" - # Azure will automatically cleanup network security groups when cleanup - # resource group. So we don't need to do anything here. - del cluster_name_on_cloud, ports, provider_config # Unused. + del region + filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + assert provider_config is not None, (cluster_name_on_cloud, provider_config) + resource_group = provider_config['resource_group'] + subscription_id = provider_config.get('subscription_id', + azure.get_subscription_id()) + compute_client = azure.get_client('compute', subscription_id) + network_client = azure.get_client('network', subscription_id) + + running_instances = _filter_instances( + compute_client, + resource_group, + filters, + status_filters=[AzureInstanceStatus.RUNNING]) + head_instance_id = _get_head_instance_id(running_instances) + + instances = {} + use_internal_ips = provider_config.get('use_internal_ips', False) + for inst in running_instances: + internal_ip, external_ip = _get_instance_ips(network_client, inst, + resource_group, + use_internal_ips) + instances[inst.name] = [ + common.InstanceInfo( + instance_id=inst.name, + internal_ip=internal_ip, + external_ip=external_ip, + tags=inst.tags, + ) + ] + instances = dict(sorted(instances.items(), key=lambda x: x[0])) + return common.ClusterInfo( + provider_name='azure', + head_instance_id=head_instance_id, + instances=instances, + provider_config=provider_config, + ) def stop_instances( @@ -116,12 +560,12 @@ def stop_instances( subscription_id = provider_config['subscription_id'] resource_group = provider_config['resource_group'] compute_client = azure.get_client('compute', subscription_id) - tag_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + tag_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} if worker_only: - tag_filters[TAG_RAY_NODE_KIND] = 'worker' + tag_filters[constants.TAG_RAY_NODE_KIND] = 'worker' - nodes = _filter_instances(compute_client, tag_filters, resource_group) - stop_virtual_machine = get_azure_sdk_function( + nodes = _filter_instances(compute_client, resource_group, tag_filters) + stop_virtual_machine = _get_azure_sdk_function( client=compute_client.virtual_machines, function_name='deallocate') with pool.ThreadPool() as p: p.starmap(stop_virtual_machine, @@ -141,13 +585,13 @@ def terminate_instances( resource_group = provider_config['resource_group'] if worker_only: compute_client = azure.get_client('compute', subscription_id) - delete_virtual_machine = get_azure_sdk_function( + delete_virtual_machine = _get_azure_sdk_function( client=compute_client.virtual_machines, function_name='delete') filters = { - TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, - TAG_RAY_NODE_KIND: 'worker' + constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, + constants.TAG_RAY_NODE_KIND: 'worker' } - nodes = _filter_instances(compute_client, filters, resource_group) + nodes = _filter_instances(compute_client, resource_group, filters) with pool.ThreadPool() as p: p.starmap(delete_virtual_machine, [(resource_group, node.name) for node in nodes]) @@ -156,17 +600,32 @@ def terminate_instances( assert provider_config is not None, cluster_name_on_cloud resource_group_client = azure.get_client('resource', subscription_id) - delete_resource_group = get_azure_sdk_function( + delete_resource_group = _get_azure_sdk_function( client=resource_group_client.resource_groups, function_name='delete') - delete_resource_group(resource_group, force_deletion_types=None) + try: + delete_resource_group(resource_group, force_deletion_types=None) + except azure.exceptions().ResourceNotFoundError as e: + if 'ResourceGroupNotFound' in str(e): + logger.warning(f'Resource group {resource_group} not found. Skip ' + 'terminating it.') + return + raise -def _get_vm_status(compute_client: 'azure_compute.ComputeManagementClient', - vm_name: str, resource_group: str) -> str: - instance = compute_client.virtual_machines.instance_view( - resource_group_name=resource_group, vm_name=vm_name).as_dict() - for status in instance['statuses']: +def _get_instance_status( + compute_client: 'azure_compute.ComputeManagementClient', vm, + resource_group: str) -> Optional[AzureInstanceStatus]: + try: + instance = compute_client.virtual_machines.instance_view( + resource_group_name=resource_group, vm_name=vm.name) + except azure.exceptions().ResourceNotFoundError as e: + if 'ResourceNotFound' in str(e): + return None + raise + provisioning_state = vm.provisioning_state + instance_dict = instance.as_dict() + for status in instance_dict['statuses']: code_state = status['code'].split('/') # It is possible that sometimes the 'code' is empty string, and we # should skip them. @@ -175,23 +634,27 @@ def _get_vm_status(compute_client: 'azure_compute.ComputeManagementClient', code, state = code_state # skip provisioning status if code == 'PowerState': - return state - raise ValueError(f'Failed to get power state for VM {vm_name}: {instance}') + return AzureInstanceStatus.from_raw_states(provisioning_state, + state) + return AzureInstanceStatus.from_raw_states(provisioning_state, None) def _filter_instances( - compute_client: 'azure_compute.ComputeManagementClient', - filters: Dict[str, str], - resource_group: str) -> List['azure_compute.models.VirtualMachine']: + compute_client: 'azure_compute.ComputeManagementClient', + resource_group: str, + tag_filters: Dict[str, str], + status_filters: Optional[List[AzureInstanceStatus]] = None, + included_instances: Optional[List[str]] = None, +) -> List['azure_compute.models.VirtualMachine']: def match_tags(vm): - for k, v in filters.items(): + for k, v in tag_filters.items(): if vm.tags.get(k) != v: return False return True try: - list_virtual_machines = get_azure_sdk_function( + list_virtual_machines = _get_azure_sdk_function( client=compute_client.virtual_machines, function_name='list') vms = list_virtual_machines(resource_group_name=resource_group) nodes = list(filter(match_tags, vms)) @@ -199,6 +662,13 @@ def match_tags(vm): if _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE in str(e): return [] raise + if status_filters is not None: + nodes = [ + node for node in nodes if _get_instance_status( + compute_client, node, resource_group) in status_filters + ] + if included_instances: + nodes = [node for node in nodes if node.name in included_instances] return nodes @@ -210,57 +680,104 @@ def query_instances( ) -> Dict[str, Optional[status_lib.ClusterStatus]]: """See sky/provision/__init__.py""" assert provider_config is not None, cluster_name_on_cloud - status_map = { - 'starting': status_lib.ClusterStatus.INIT, - 'running': status_lib.ClusterStatus.UP, - # 'stopped' in Azure means Stopped (Allocated), which still bills - # for the VM. - 'stopping': status_lib.ClusterStatus.INIT, - 'stopped': status_lib.ClusterStatus.INIT, - # 'VM deallocated' in Azure means Stopped (Deallocated), which does not - # bill for the VM. - 'deallocating': status_lib.ClusterStatus.STOPPED, - 'deallocated': status_lib.ClusterStatus.STOPPED, - } - provisioning_state_map = { - 'Creating': status_lib.ClusterStatus.INIT, - 'Updating': status_lib.ClusterStatus.INIT, - 'Failed': status_lib.ClusterStatus.INIT, - 'Migrating': status_lib.ClusterStatus.INIT, - 'Deleting': None, - # Succeeded in provisioning state means the VM is provisioned but not - # necessarily running. We exclude Succeeded state here, and the caller - # should determine the status of the VM based on the power state. - # 'Succeeded': status_lib.ClusterStatus.UP, - } subscription_id = provider_config['subscription_id'] resource_group = provider_config['resource_group'] + filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} compute_client = azure.get_client('compute', subscription_id) - filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} - nodes = _filter_instances(compute_client, filters, resource_group) - statuses = {} - - def _fetch_and_map_status( - compute_client: 'azure_compute.ComputeManagementClient', - node: 'azure_compute.models.VirtualMachine', - resource_group: str) -> None: - if node.provisioning_state in provisioning_state_map: - status = provisioning_state_map[node.provisioning_state] - else: - original_status = _get_vm_status(compute_client, node.name, - resource_group) - if original_status not in status_map: - with ux_utils.print_exception_no_traceback(): - raise exceptions.ClusterStatusFetchingError( - f'Failed to parse status from Azure response: {status}') - status = status_map[original_status] + nodes = _filter_instances(compute_client, resource_group, filters) + statuses: Dict[str, Optional[status_lib.ClusterStatus]] = {} + + def _fetch_and_map_status(node, resource_group: str) -> None: + compute_client = azure.get_client('compute', subscription_id) + status = _get_instance_status(compute_client, node, resource_group) + if status is None and non_terminated_only: return - statuses[node.name] = status + statuses[node.name] = (None if status is None else + status.to_cluster_status()) with pool.ThreadPool() as p: p.starmap(_fetch_and_map_status, - [(compute_client, node, resource_group) for node in nodes]) + [(node, resource_group) for node in nodes]) return statuses + + +def open_ports( + cluster_name_on_cloud: str, + ports: List[str], + provider_config: Optional[Dict[str, Any]] = None, +) -> None: + """See sky/provision/__init__.py""" + assert provider_config is not None, cluster_name_on_cloud + subscription_id = provider_config['subscription_id'] + resource_group = provider_config['resource_group'] + network_client = azure.get_client('network', subscription_id) + + update_network_security_groups = _get_azure_sdk_function( + client=network_client.network_security_groups, + function_name='create_or_update') + list_network_security_groups = _get_azure_sdk_function( + client=network_client.network_security_groups, function_name='list') + for nsg in list_network_security_groups(resource_group): + try: + # Wait the NSG creation to be finished before opening a port. The + # cluster provisioning triggers the NSG creation, but it may not be + # finished yet. + backoff = common_utils.Backoff(max_backoff_factor=1) + start_time = time.time() + while True: + if nsg.provisioning_state not in ['Creating', 'Updating']: + break + if time.time() - start_time > _WAIT_CREATION_TIMEOUT_SECONDS: + logger.warning( + f'Fails to wait for the creation of NSG {nsg.name} in ' + f'{resource_group} within ' + f'{_WAIT_CREATION_TIMEOUT_SECONDS} seconds. ' + 'Skip this NSG.') + backoff_time = backoff.current_backoff() + logger.info(f'NSG {nsg.name} is not created yet. Waiting for ' + f'{backoff_time} seconds before checking again.') + time.sleep(backoff_time) + + # Azure NSG rules have a priority field that determines the order + # in which they are applied. The priority must be unique across + # all inbound rules in one NSG. + priority = max(rule.priority + for rule in nsg.security_rules + if rule.direction == 'Inbound') + 1 + nsg.security_rules.append( + azure.create_security_rule( + name=f'sky-ports-{cluster_name_on_cloud}-{priority}', + priority=priority, + protocol='Tcp', + access='Allow', + direction='Inbound', + source_address_prefix='*', + source_port_range='*', + destination_address_prefix='*', + destination_port_ranges=ports, + )) + poller = update_network_security_groups(resource_group, nsg.name, + nsg) + poller.wait() + if poller.status() != 'Succeeded': + with ux_utils.print_exception_no_traceback(): + raise ValueError(f'Failed to open ports {ports} in NSG ' + f'{nsg.name}: {poller.status()}') + except azure.exceptions().HttpResponseError as e: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Failed to open ports {ports} in NSG {nsg.name}.') from e + + +def cleanup_ports( + cluster_name_on_cloud: str, + ports: List[str], + provider_config: Optional[Dict[str, Any]] = None, +) -> None: + """See sky/provision/__init__.py""" + # Azure will automatically cleanup network security groups when cleanup + # resource group. So we don't need to do anything here. + del cluster_name_on_cloud, ports, provider_config # Unused. diff --git a/sky/provision/common.py b/sky/provision/common.py index e5df26a4c09..a588fbe94e8 100644 --- a/sky/provision/common.py +++ b/sky/provision/common.py @@ -129,7 +129,8 @@ def get_head_instance(self) -> Optional[InstanceInfo]: if self.head_instance_id is None: return None if self.head_instance_id not in self.instances: - raise ValueError('Head instance ID not in the cluster metadata.') + raise ValueError('Head instance ID not in the cluster metadata. ' + f'ClusterInfo: {self.__dict__}') return self.instances[self.head_instance_id][0] def get_worker_instances(self) -> List[InstanceInfo]: diff --git a/sky/provision/constants.py b/sky/provision/constants.py new file mode 100644 index 00000000000..760abc4861a --- /dev/null +++ b/sky/provision/constants.py @@ -0,0 +1,18 @@ +"""Constants used in the SkyPilot provisioner.""" + +# Tag uniquely identifying all nodes of a cluster +TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' +TAG_SKYPILOT_CLUSTER_NAME = 'skypilot-cluster-name' +# Legacy tag for backward compatibility to distinguish head and worker nodes. +TAG_RAY_NODE_KIND = 'ray-node-type' +TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node' + +HEAD_NODE_TAGS = { + TAG_RAY_NODE_KIND: 'head', + TAG_SKYPILOT_HEAD_NODE: '1', +} + +WORKER_NODE_TAGS = { + TAG_RAY_NODE_KIND: 'worker', + TAG_SKYPILOT_HEAD_NODE: '0', +} diff --git a/sky/provision/docker_utils.py b/sky/provision/docker_utils.py index 9fbc19c2959..aa29a3666a3 100644 --- a/sky/provision/docker_utils.py +++ b/sky/provision/docker_utils.py @@ -12,9 +12,6 @@ logger = sky_logging.init_logger(__name__) -DOCKER_PERMISSION_DENIED_STR = ('permission denied while trying to connect to ' - 'the Docker daemon socket') - # Configure environment variables. A docker image can have environment variables # set in the Dockerfile with `ENV``. We need to export these variables to the # shell environment, so that our ssh session can access them. @@ -26,6 +23,13 @@ '$(prefix_cmd) mv ~/container_env_var.sh /etc/profile.d/container_env_var.sh' ) +# Docker daemon may not be ready when the machine is firstly started. The error +# message starts with the following string. We should wait for a while and retry +# the command. +DOCKER_PERMISSION_DENIED_STR = ('permission denied while trying to connect to ' + 'the Docker daemon socket') +_DOCKER_SOCKET_WAIT_TIMEOUT_SECONDS = 30 + @dataclasses.dataclass class DockerLoginConfig: @@ -140,7 +144,8 @@ def _run(self, cmd, run_env='host', wait_for_docker_daemon: bool = False, - separate_stderr: bool = False) -> str: + separate_stderr: bool = False, + log_err_when_fail: bool = True) -> str: if run_env == 'docker': cmd = self._docker_expand_user(cmd, any_char=True) @@ -153,8 +158,7 @@ def _run(self, f' {shlex.quote(cmd)} ') logger.debug(f'+ {cmd}') - cnt = 0 - retry = 3 + start = time.time() while True: rc, stdout, stderr = self.runner.run( cmd, @@ -162,24 +166,30 @@ def _run(self, stream_logs=False, separate_stderr=separate_stderr, log_path=self.log_path) - if (not wait_for_docker_daemon or - DOCKER_PERMISSION_DENIED_STR not in stdout + stderr): - break - - cnt += 1 - if cnt > retry: - break - logger.info( - 'Failed to run docker command, retrying in 10 seconds... ' - f'({cnt}/{retry})') - time.sleep(10) + if (DOCKER_PERMISSION_DENIED_STR in stdout + stderr and + wait_for_docker_daemon): + if time.time() - start > _DOCKER_SOCKET_WAIT_TIMEOUT_SECONDS: + if rc == 0: + # Set returncode to 1 if failed to connect to docker + # daemon after timeout. + rc = 1 + break + # Close the cached connection to make the permission update of + # ssh user take effect, e.g. usermod -aG docker $USER, called + # by cloud-init of Azure. + self.runner.close_cached_connection() + logger.info('Failed to connect to docker daemon. It might be ' + 'initializing, retrying in 5 seconds...') + time.sleep(5) + continue + break subprocess_utils.handle_returncode( rc, cmd, error_msg='Failed to run docker setup commands.', stderr=stdout + stderr, # Print out the error message if the command failed. - stream_logs=True) + stream_logs=log_err_when_fail) return stdout.strip() def initialize(self) -> str: @@ -370,7 +380,7 @@ def _configure_runtime(self, run_options: List[str]) -> List[str]: 'info -f "{{.Runtimes}}"')) if 'nvidia-container-runtime' in runtime_output: try: - self._run('nvidia-smi') + self._run('nvidia-smi', log_err_when_fail=False) return run_options + ['--runtime=nvidia'] except Exception as e: # pylint: disable=broad-except logger.debug( diff --git a/sky/provision/gcp/constants.py b/sky/provision/gcp/constants.py index 8f9341bd342..4f442709b0c 100644 --- a/sky/provision/gcp/constants.py +++ b/sky/provision/gcp/constants.py @@ -215,12 +215,6 @@ # Stopping instances can take several minutes, so we increase the timeout MAX_POLLS_STOP = MAX_POLLS * 8 -TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node' -# Tag uniquely identifying all nodes of a cluster -TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' -TAG_RAY_NODE_KIND = 'ray-node-type' -TAG_SKYPILOT_CLUSTER_NAME = 'skypilot-cluster-name' - # MIG constants MANAGED_INSTANCE_GROUP_CONFIG = 'managed-instance-group' DEFAULT_MANAGED_INSTANCE_GROUP_PROVISION_TIMEOUT = 900 # 15 minutes diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 967e673ee70..3e4cee65e4c 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -9,6 +9,7 @@ from sky import sky_logging from sky.adaptors import gcp from sky.provision import common +from sky.provision import constants as provision_constants from sky.provision.gcp import constants from sky.provision.gcp import instance_utils from sky.utils import common_utils @@ -61,7 +62,9 @@ def query_instances( assert provider_config is not None, (cluster_name_on_cloud, provider_config) zone = provider_config['availability_zone'] project_id = provider_config['project_id'] - label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + label_filters = { + provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud + } handler: Type[ instance_utils.GCPInstance] = instance_utils.GCPComputeInstance @@ -126,8 +129,8 @@ def _get_head_instance_id(instances: List) -> Optional[str]: head_instance_id = None for inst in instances: labels = inst.get('labels', {}) - if (labels.get(constants.TAG_RAY_NODE_KIND) == 'head' or - labels.get(constants.TAG_SKYPILOT_HEAD_NODE) == '1'): + if (labels.get(provision_constants.TAG_RAY_NODE_KIND) == 'head' or + labels.get(provision_constants.TAG_SKYPILOT_HEAD_NODE) == '1'): head_instance_id = inst['name'] break return head_instance_id @@ -160,7 +163,9 @@ def _run_instances(region: str, cluster_name_on_cloud: str, else: raise ValueError(f'Unknown node type {node_type}') - filter_labels = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + filter_labels = { + provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud + } # wait until all stopping instances are stopped/terminated while True: @@ -393,7 +398,9 @@ def get_cluster_info( assert provider_config is not None, cluster_name_on_cloud zone = provider_config['availability_zone'] project_id = provider_config['project_id'] - label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + label_filters = { + provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud + } handlers: List[Type[instance_utils.GCPInstance]] = [ instance_utils.GCPComputeInstance @@ -421,7 +428,7 @@ def get_cluster_info( project_id, zone, { - **label_filters, constants.TAG_RAY_NODE_KIND: 'head' + **label_filters, provision_constants.TAG_RAY_NODE_KIND: 'head' }, lambda h: [h.RUNNING_STATE], ) @@ -447,14 +454,16 @@ def stop_instances( assert provider_config is not None, cluster_name_on_cloud zone = provider_config['availability_zone'] project_id = provider_config['project_id'] - label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + label_filters = { + provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud + } tpu_node = provider_config.get('tpu_node') if tpu_node is not None: instance_utils.delete_tpu_node(project_id, zone, tpu_node) if worker_only: - label_filters[constants.TAG_RAY_NODE_KIND] = 'worker' + label_filters[provision_constants.TAG_RAY_NODE_KIND] = 'worker' handlers: List[Type[instance_utils.GCPInstance]] = [ instance_utils.GCPComputeInstance @@ -523,9 +532,11 @@ def terminate_instances( project_id, zone, cluster_name_on_cloud) return - label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + label_filters = { + provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud + } if worker_only: - label_filters[constants.TAG_RAY_NODE_KIND] = 'worker' + label_filters[provision_constants.TAG_RAY_NODE_KIND] = 'worker' handlers: List[Type[instance_utils.GCPInstance]] = [ instance_utils.GCPComputeInstance @@ -568,7 +579,9 @@ def open_ports( project_id = provider_config['project_id'] firewall_rule_name = provider_config['firewall_rule'] - label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + label_filters = { + provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud + } handlers: List[Type[instance_utils.GCPInstance]] = [ instance_utils.GCPComputeInstance, instance_utils.GCPTPUVMInstance, diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index e1e72a25d6c..933df5e08a1 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -13,6 +13,7 @@ from sky.adaptors import gcp from sky.clouds import gcp as gcp_cloud from sky.provision import common +from sky.provision import constants as provision_constants from sky.provision.gcp import constants from sky.provision.gcp import mig_utils from sky.utils import common_utils @@ -21,8 +22,6 @@ # Tag for the name of the node INSTANCE_NAME_MAX_LEN = 64 INSTANCE_NAME_UUID_LEN = 8 -TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node' -TAG_RAY_NODE_KIND = 'ray-node-type' TPU_NODE_CREATION_FAILURE = 'Failed to provision TPU node.' @@ -284,15 +283,9 @@ def create_node_tag(cls, target_instance_id: str, is_head: bool = True) -> str: if is_head: - node_tag = { - TAG_SKYPILOT_HEAD_NODE: '1', - TAG_RAY_NODE_KIND: 'head', - } + node_tag = provision_constants.HEAD_NODE_TAGS else: - node_tag = { - TAG_SKYPILOT_HEAD_NODE: '0', - TAG_RAY_NODE_KIND: 'worker', - } + node_tag = provision_constants.WORKER_NODE_TAGS cls.set_labels(project_id=project_id, availability_zone=availability_zone, node_id=target_instance_id, @@ -676,8 +669,8 @@ def create_instances( config.update({ 'labels': dict( labels, **{ - constants.TAG_RAY_CLUSTER_NAME: cluster_name, - constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name + provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name, + provision_constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name }), }) @@ -999,11 +992,11 @@ def create_instances( 'labels': dict( labels, **{ - constants.TAG_RAY_CLUSTER_NAME: cluster_name, + provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name, # Assume all nodes are workers, we can update the head node # once the instances are created. - constants.TAG_RAY_NODE_KIND: 'worker', - constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name, + **provision_constants.WORKER_NODE_TAGS, + provision_constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name, }), }) cls._convert_selflinks_in_config(config) @@ -1021,17 +1014,18 @@ def create_instances( project_id, zone, managed_instance_group_name) label_filters = { - constants.TAG_RAY_CLUSTER_NAME: cluster_name, + provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name, } potential_head_instances = [] if mig_exists: - instances = cls.filter(project_id, - zone, - label_filters={ - constants.TAG_RAY_NODE_KIND: 'head', - **label_filters, - }, - status_filters=cls.NEED_TO_TERMINATE_STATES) + instances = cls.filter( + project_id, + zone, + label_filters={ + provision_constants.TAG_RAY_NODE_KIND: 'head', + **label_filters, + }, + status_filters=cls.NEED_TO_TERMINATE_STATES) potential_head_instances = list(instances.keys()) config['labels'] = { @@ -1165,7 +1159,7 @@ def _add_labels_and_find_head( pending_running_instances = cls.filter( project_id, zone, - {constants.TAG_RAY_CLUSTER_NAME: cluster_name}, + {provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name}, # Find all provisioning and running instances. status_filters=cls.NEED_TO_STOP_STATES) for running_instance_name in pending_running_instances.keys(): @@ -1452,8 +1446,8 @@ def create_instances( config.update({ 'labels': dict( labels, **{ - constants.TAG_RAY_CLUSTER_NAME: cluster_name, - constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name + provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name, + provision_constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name }), }) @@ -1479,11 +1473,10 @@ def create_instances( for i, name in enumerate(names): node_config = config.copy() if i == 0: - node_config['labels'][TAG_SKYPILOT_HEAD_NODE] = '1' - node_config['labels'][TAG_RAY_NODE_KIND] = 'head' + node_config['labels'].update(provision_constants.HEAD_NODE_TAGS) else: - node_config['labels'][TAG_SKYPILOT_HEAD_NODE] = '0' - node_config['labels'][TAG_RAY_NODE_KIND] = 'worker' + node_config['labels'].update( + provision_constants.WORKER_NODE_TAGS) try: logger.debug('Launching GCP TPU VM ...') request = ( diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index ff9a59d85ad..ac442702d55 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -9,6 +9,7 @@ from sky import skypilot_config from sky.adaptors import kubernetes from sky.provision import common +from sky.provision import constants from sky.provision import docker_utils from sky.provision.kubernetes import config as config_lib from sky.provision.kubernetes import network_utils @@ -25,7 +26,6 @@ logger = sky_logging.init_logger(__name__) TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' TAG_SKYPILOT_CLUSTER_NAME = 'skypilot-cluster-name' -TAG_RAY_NODE_KIND = 'ray-node-type' # legacy tag for backward compatibility TAG_POD_INITIALIZED = 'skypilot-initialized' POD_STATUSES = { @@ -74,7 +74,7 @@ def _filter_pods(namespace: str, tag_filters: Dict[str, str], def _get_head_pod_name(pods: Dict[str, Any]) -> Optional[str]: head_pod_name = None for pod_name, pod in pods.items(): - if pod.metadata.labels[TAG_RAY_NODE_KIND] == 'head': + if pod.metadata.labels[constants.TAG_RAY_NODE_KIND] == 'head': head_pod_name = pod_name break return head_pod_name @@ -455,12 +455,12 @@ def _create_pods(region: str, cluster_name_on_cloud: str, f'(count={to_start_count}).') for _ in range(to_start_count): if head_pod_name is None: - pod_spec['metadata']['labels'][TAG_RAY_NODE_KIND] = 'head' + pod_spec['metadata']['labels'].update(constants.HEAD_NODE_TAGS) head_selector = head_service_selector(cluster_name_on_cloud) pod_spec['metadata']['labels'].update(head_selector) pod_spec['metadata']['name'] = f'{cluster_name_on_cloud}-head' else: - pod_spec['metadata']['labels'][TAG_RAY_NODE_KIND] = 'worker' + pod_spec['metadata']['labels'].update(constants.WORKER_NODE_TAGS) pod_uuid = str(uuid.uuid4())[:4] pod_name = f'{cluster_name_on_cloud}-{pod_uuid}' pod_spec['metadata']['name'] = f'{pod_name}-worker' @@ -636,7 +636,7 @@ def terminate_instances( pods = _filter_pods(namespace, tag_filters, None) def _is_head(pod) -> bool: - return pod.metadata.labels[TAG_RAY_NODE_KIND] == 'head' + return pod.metadata.labels[constants.TAG_RAY_NODE_KIND] == 'head' for pod_name, pod in pods.items(): logger.debug(f'Terminating instance {pod_name}: {pod}') @@ -685,7 +685,7 @@ def get_cluster_info( tags=pod.metadata.labels, ) ] - if pod.metadata.labels[TAG_RAY_NODE_KIND] == 'head': + if pod.metadata.labels[constants.TAG_RAY_NODE_KIND] == 'head': head_pod_name = pod_name head_spec = pod.spec assert head_spec is not None, pod diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index 2b0a8bbcc05..ebea79476a8 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -426,10 +426,11 @@ def _post_provision_setup( head_instance = cluster_info.get_head_instance() if head_instance is None: - raise RuntimeError( - f'Provision failed for cluster {cluster_name!r}. ' - 'Could not find any head instance. To fix: refresh ' - 'status with: sky status -r; and retry provisioning.') + e = RuntimeError(f'Provision failed for cluster {cluster_name!r}. ' + 'Could not find any head instance. To fix: refresh ' + f'status with: sky status -r; and retry provisioning.') + setattr(e, 'detailed_reason', str(cluster_info)) + raise e # TODO(suquark): Move wheel build here in future PRs. # We don't set docker_user here, as we are configuring the VM itself. diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index 6b4759e18bd..2189d4f1bc6 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -414,7 +414,7 @@ def terminate_services(service_names: Optional[List[str]], purge: bool) -> str: for service_name in service_names: service_status = _get_service_status(service_name, with_replica_info=False) - assert service_status is not None + assert service_status is not None, service_name if service_status['status'] == serve_state.ServiceStatus.SHUTTING_DOWN: # Already scheduled to be terminated. continue diff --git a/sky/setup_files/MANIFEST.in b/sky/setup_files/MANIFEST.in index ad0163a2e22..54ab3b55a32 100644 --- a/sky/setup_files/MANIFEST.in +++ b/sky/setup_files/MANIFEST.in @@ -1,13 +1,11 @@ include sky/backends/monkey_patches/*.py exclude sky/clouds/service_catalog/data_fetchers/analyze.py include sky/provision/kubernetes/manifests/* +include sky/provision/azure/* include sky/setup_files/* include sky/skylet/*.sh include sky/skylet/LICENSE -include sky/skylet/providers/azure/* -include sky/skylet/providers/gcp/* include sky/skylet/providers/ibm/* -include sky/skylet/providers/kubernetes/* include sky/skylet/providers/lambda_cloud/* include sky/skylet/providers/oci/* include sky/skylet/providers/scp/* diff --git a/sky/skylet/providers/azure/__init__.py b/sky/skylet/providers/azure/__init__.py deleted file mode 100644 index dfe4805dfa1..00000000000 --- a/sky/skylet/providers/azure/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Azure node provider""" -from sky.skylet.providers.azure.node_provider import AzureNodeProvider diff --git a/sky/skylet/providers/azure/config.py b/sky/skylet/providers/azure/config.py deleted file mode 100644 index 4c6322f00e5..00000000000 --- a/sky/skylet/providers/azure/config.py +++ /dev/null @@ -1,218 +0,0 @@ -import json -import logging -import random -from hashlib import sha256 -from pathlib import Path -import time -from typing import Any, Callable - -from azure.common.credentials import get_cli_profile -from azure.identity import AzureCliCredential -from azure.mgmt.network import NetworkManagementClient -from azure.mgmt.resource import ResourceManagementClient -from azure.mgmt.resource.resources.models import DeploymentMode - -from sky.adaptors import azure -from sky.utils import common_utils -from sky.provision import common - -UNIQUE_ID_LEN = 4 -_WAIT_NSG_CREATION_NUM_TIMEOUT_SECONDS = 600 -_WAIT_FOR_RESOURCE_GROUP_DELETION_TIMEOUT_SECONDS = 480 # 8 minutes - - -logger = logging.getLogger(__name__) - - -def get_azure_sdk_function(client: Any, function_name: str) -> Callable: - """Retrieve a callable function from Azure SDK client object. - - Newer versions of the various client SDKs renamed function names to - have a begin_ prefix. This function supports both the old and new - versions of the SDK by first trying the old name and falling back to - the prefixed new name. - """ - func = getattr( - client, function_name, getattr(client, f"begin_{function_name}", None) - ) - if func is None: - raise AttributeError( - "'{obj}' object has no {func} or begin_{func} attribute".format( - obj={client.__name__}, func=function_name - ) - ) - return func - - -def bootstrap_azure(config): - config = _configure_key_pair(config) - config = _configure_resource_group(config) - return config - - -@common.log_function_start_end -def _configure_resource_group(config): - # TODO: look at availability sets - # https://docs.microsoft.com/en-us/azure/virtual-machines/windows/tutorial-availability-sets - subscription_id = config["provider"].get("subscription_id") - if subscription_id is None: - subscription_id = get_cli_profile().get_subscription_id() - # Increase the timeout to fix the Azure get-access-token (used by ray azure - # node_provider) timeout issue. - # Tracked in https://github.com/Azure/azure-cli/issues/20404#issuecomment-1249575110 - credentials = AzureCliCredential(process_timeout=30) - resource_client = ResourceManagementClient(credentials, subscription_id) - config["provider"]["subscription_id"] = subscription_id - logger.info("Using subscription id: %s", subscription_id) - - assert ( - "resource_group" in config["provider"] - ), "Provider config must include resource_group field" - resource_group = config["provider"]["resource_group"] - - assert ( - "location" in config["provider"] - ), "Provider config must include location field" - params = {"location": config["provider"]["location"]} - - if "tags" in config["provider"]: - params["tags"] = config["provider"]["tags"] - - logger.info("Creating/Updating resource group: %s", resource_group) - rg_create_or_update = get_azure_sdk_function( - client=resource_client.resource_groups, function_name="create_or_update" - ) - rg_creation_start = time.time() - retry = 0 - while ( - time.time() - rg_creation_start - < _WAIT_FOR_RESOURCE_GROUP_DELETION_TIMEOUT_SECONDS - ): - try: - rg_create_or_update(resource_group_name=resource_group, parameters=params) - break - except azure.exceptions().ResourceExistsError as e: - if "ResourceGroupBeingDeleted" in str(e): - if retry % 5 == 0: - # TODO(zhwu): This should be shown in terminal for better - # UX, which will be achieved after we move Azure to use - # SkyPilot provisioner. - logger.warning( - f"Azure resource group {resource_group} of a recent " - "terminated cluster {config['cluster_name']} is being " - "deleted. It can only be provisioned after it is fully" - "deleted. Waiting..." - ) - time.sleep(1) - retry += 1 - continue - raise - - # load the template file - current_path = Path(__file__).parent - template_path = current_path.joinpath("azure-config-template.json") - with open(template_path, "r") as template_fp: - template = json.load(template_fp) - - logger.info("Using cluster name: %s", config["cluster_name"]) - - # set unique id for resources in this cluster - unique_id = config["provider"].get("unique_id") - if unique_id is None: - hasher = sha256() - hasher.update(config["provider"]["resource_group"].encode("utf-8")) - unique_id = hasher.hexdigest()[:UNIQUE_ID_LEN] - else: - unique_id = str(unique_id) - config["provider"]["unique_id"] = unique_id - logger.info("Using unique id: %s", unique_id) - cluster_id = "{}-{}".format(config["cluster_name"], unique_id) - - subnet_mask = config["provider"].get("subnet_mask") - if subnet_mask is None: - # choose a random subnet, skipping most common value of 0 - random.seed(unique_id) - subnet_mask = "10.{}.0.0/16".format(random.randint(1, 254)) - logger.info("Using subnet mask: %s", subnet_mask) - - parameters = { - "properties": { - "mode": DeploymentMode.incremental, - "template": template, - "parameters": { - "subnet": {"value": subnet_mask}, - "clusterId": {"value": cluster_id}, - }, - } - } - - create_or_update = get_azure_sdk_function( - client=resource_client.deployments, function_name="create_or_update" - ) - # Skip creating or updating the deployment if the deployment already exists - # and the cluster name is the same. - get_deployment = get_azure_sdk_function( - client=resource_client.deployments, function_name="get" - ) - deployment_exists = False - try: - deployment = get_deployment( - resource_group_name=resource_group, deployment_name="ray-config" - ) - logger.info("Deployment already exists. Skipping deployment creation.") - - outputs = deployment.properties.outputs - if outputs is not None: - deployment_exists = True - except azure.exceptions().ResourceNotFoundError: - deployment_exists = False - - if not deployment_exists: - # This takes a long time (> 40 seconds), we should be careful calling - # this function. - outputs = ( - create_or_update( - resource_group_name=resource_group, - deployment_name="ray-config", - parameters=parameters, - ) - .result() - .properties.outputs - ) - - # We should wait for the NSG to be created before opening any ports - # to avoid overriding the newly-added NSG rules. - nsg_id = outputs["nsg"]["value"] - nsg_name = nsg_id.split("/")[-1] - network_client = NetworkManagementClient(credentials, subscription_id) - backoff = common_utils.Backoff(max_backoff_factor=1) - start_time = time.time() - while True: - nsg = network_client.network_security_groups.get(resource_group, nsg_name) - if nsg.provisioning_state == "Succeeded": - break - if time.time() - start_time > _WAIT_NSG_CREATION_NUM_TIMEOUT_SECONDS: - raise RuntimeError( - f"Fails to create NSG {nsg_name} in {resource_group} within " - f"{_WAIT_NSG_CREATION_NUM_TIMEOUT_SECONDS} seconds." - ) - backoff_time = backoff.current_backoff() - logger.info( - f"NSG {nsg_name} is not created yet. Waiting for " - f"{backoff_time} seconds before checking again." - ) - time.sleep(backoff_time) - - # append output resource ids to be used with vm creation - config["provider"]["msi"] = outputs["msi"]["value"] - config["provider"]["nsg"] = nsg_id - config["provider"]["subnet"] = outputs["subnet"]["value"] - - return config - - -def _configure_key_pair(config): - # SkyPilot: The original checks and configurations are no longer - # needed, since we have already set them up in the upper level - # SkyPilot codes. See sky/templates/azure-ray.yml.j2 - return config diff --git a/sky/skylet/providers/azure/node_provider.py b/sky/skylet/providers/azure/node_provider.py deleted file mode 100644 index 5f87e57245e..00000000000 --- a/sky/skylet/providers/azure/node_provider.py +++ /dev/null @@ -1,488 +0,0 @@ -import copy -import json -import logging -from pathlib import Path -from threading import RLock -from uuid import uuid4 - -from azure.identity import AzureCliCredential -from azure.mgmt.compute import ComputeManagementClient -from azure.mgmt.network import NetworkManagementClient -from azure.mgmt.resource import ResourceManagementClient -from azure.mgmt.resource.resources.models import DeploymentMode - -from sky.adaptors import azure -from sky.skylet.providers.azure.config import ( - bootstrap_azure, - get_azure_sdk_function, -) -from sky.skylet.providers.command_runner import SkyDockerCommandRunner -from sky.provision import docker_utils - -from ray.autoscaler._private.command_runner import SSHCommandRunner -from ray.autoscaler.node_provider import NodeProvider -from ray.autoscaler.tags import ( - TAG_RAY_CLUSTER_NAME, - TAG_RAY_LAUNCH_CONFIG, - TAG_RAY_NODE_KIND, - TAG_RAY_NODE_NAME, - TAG_RAY_USER_NODE_TYPE, -) - -VM_NAME_MAX_LEN = 64 -UNIQUE_ID_LEN = 4 - -logger = logging.getLogger(__name__) -azure_logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy") -azure_logger.setLevel(logging.WARNING) - - -def synchronized(f): - def wrapper(self, *args, **kwargs): - self.lock.acquire() - try: - return f(self, *args, **kwargs) - finally: - self.lock.release() - - return wrapper - - -class AzureNodeProvider(NodeProvider): - """Node Provider for Azure - - This provider assumes Azure credentials are set by running ``az login`` - and the default subscription is configured through ``az account`` - or set in the ``provider`` field of the autoscaler configuration. - - Nodes may be in one of three states: {pending, running, terminated}. Nodes - appear immediately once started by ``create_node``, and transition - immediately to terminated when ``terminate_node`` is called. - """ - - def __init__(self, provider_config, cluster_name): - NodeProvider.__init__(self, provider_config, cluster_name) - - subscription_id = provider_config["subscription_id"] - self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes", True) - # Sky only supports Azure CLI credential for now. - # Increase the timeout to fix the Azure get-access-token (used by ray azure - # node_provider) timeout issue. - # Tracked in https://github.com/Azure/azure-cli/issues/20404#issuecomment-1249575110 - credential = AzureCliCredential(process_timeout=30) - self.compute_client = ComputeManagementClient(credential, subscription_id) - self.network_client = NetworkManagementClient(credential, subscription_id) - self.resource_client = ResourceManagementClient(credential, subscription_id) - - self.lock = RLock() - - # cache node objects - self.cached_nodes = {} - - @synchronized - def _get_filtered_nodes(self, tag_filters): - # add cluster name filter to only get nodes from this cluster - cluster_tag_filters = {**tag_filters, TAG_RAY_CLUSTER_NAME: self.cluster_name} - - def match_tags(vm): - for k, v in cluster_tag_filters.items(): - if vm.tags.get(k) != v: - return False - return True - - try: - vms = list( - self.compute_client.virtual_machines.list( - resource_group_name=self.provider_config["resource_group"] - ) - ) - except azure.exceptions().ResourceNotFoundError as e: - if "Code: ResourceGroupNotFound" in e.exc_msg: - logger.debug( - "Resource group not found. VMs should have been terminated." - ) - vms = [] - else: - raise - - nodes = [self._extract_metadata(vm) for vm in filter(match_tags, vms)] - self.cached_nodes = {node["name"]: node for node in nodes} - return self.cached_nodes - - def _extract_metadata(self, vm): - # get tags - metadata = {"name": vm.name, "tags": vm.tags, "status": ""} - - # get status - resource_group = self.provider_config["resource_group"] - instance = self.compute_client.virtual_machines.instance_view( - resource_group_name=resource_group, vm_name=vm.name - ).as_dict() - for status in instance["statuses"]: - code_state = status["code"].split("/") - # It is possible that sometimes the 'code' is empty string, and we - # should skip them. - if len(code_state) != 2: - continue - code, state = code_state - # skip provisioning status - if code == "PowerState": - metadata["status"] = state - break - - # get ip data - nic_id = vm.network_profile.network_interfaces[0].id - metadata["nic_name"] = nic_id.split("/")[-1] - nic = self.network_client.network_interfaces.get( - resource_group_name=resource_group, - network_interface_name=metadata["nic_name"], - ) - ip_config = nic.ip_configurations[0] - - if not self.provider_config.get("use_internal_ips", False): - public_ip_id = ip_config.public_ip_address.id - metadata["public_ip_name"] = public_ip_id.split("/")[-1] - public_ip = self.network_client.public_ip_addresses.get( - resource_group_name=resource_group, - public_ip_address_name=metadata["public_ip_name"], - ) - metadata["external_ip"] = public_ip.ip_address - - metadata["internal_ip"] = ip_config.private_ip_address - - return metadata - - def stopped_nodes(self, tag_filters): - """Return a list of stopped node ids filtered by the specified tags dict.""" - nodes = self._get_filtered_nodes(tag_filters=tag_filters) - return [k for k, v in nodes.items() if v["status"].startswith("deallocat")] - - def non_terminated_nodes(self, tag_filters): - """Return a list of node ids filtered by the specified tags dict. - - This list must not include terminated nodes. For performance reasons, - providers are allowed to cache the result of a call to nodes() to - serve single-node queries (e.g. is_running(node_id)). This means that - nodes() must be called again to refresh results. - - Examples: - >>> from ray.autoscaler.tags import TAG_RAY_NODE_KIND - >>> provider = ... # doctest: +SKIP - >>> provider.non_terminated_nodes( # doctest: +SKIP - ... {TAG_RAY_NODE_KIND: "worker"}) - ["node-1", "node-2"] - """ - nodes = self._get_filtered_nodes(tag_filters=tag_filters) - return [k for k, v in nodes.items() if not v["status"].startswith("deallocat")] - - def is_running(self, node_id): - """Return whether the specified node is running.""" - # always get current status - node = self._get_node(node_id=node_id) - return node["status"] == "running" - - def is_terminated(self, node_id): - """Return whether the specified node is terminated.""" - # always get current status - node = self._get_node(node_id=node_id) - return node["status"].startswith("deallocat") - - def node_tags(self, node_id): - """Returns the tags of the given node (string dict).""" - return self._get_cached_node(node_id=node_id)["tags"] - - def external_ip(self, node_id): - """Returns the external ip of the given node.""" - ip = ( - self._get_cached_node(node_id=node_id)["external_ip"] - or self._get_node(node_id=node_id)["external_ip"] - ) - return ip - - def internal_ip(self, node_id): - """Returns the internal ip (Ray ip) of the given node.""" - ip = ( - self._get_cached_node(node_id=node_id)["internal_ip"] - or self._get_node(node_id=node_id)["internal_ip"] - ) - return ip - - def create_node(self, node_config, tags, count): - resource_group = self.provider_config["resource_group"] - - if self.cache_stopped_nodes: - VALIDITY_TAGS = [ - TAG_RAY_CLUSTER_NAME, - TAG_RAY_NODE_KIND, - TAG_RAY_USER_NODE_TYPE, - ] - filters = {tag: tags[tag] for tag in VALIDITY_TAGS if tag in tags} - filters_with_launch_config = copy.copy(filters) - if TAG_RAY_LAUNCH_CONFIG in tags: - filters_with_launch_config[TAG_RAY_LAUNCH_CONFIG] = tags[ - TAG_RAY_LAUNCH_CONFIG - ] - - # SkyPilot: We try to use the instances with the same matching launch_config first. If - # there is not enough instances with matching launch_config, we then use all the - # instances with the same matching launch_config plus some instances with wrong - # launch_config. - nodes_matching_launch_config = self.stopped_nodes( - filters_with_launch_config - ) - nodes_matching_launch_config.sort(reverse=True) - if len(nodes_matching_launch_config) >= count: - reuse_nodes = nodes_matching_launch_config[:count] - else: - nodes_all = self.stopped_nodes(filters) - nodes_non_matching_launch_config = [ - n for n in nodes_all if n not in nodes_matching_launch_config - ] - # This sort is for backward compatibility, where the user already has - # leaked stopped nodes with the different launch config before update - # to #1671, and the total number of the leaked nodes is greater than - # the number of nodes to be created. With this, we make sure the nodes - # are reused in a deterministic order (sorting by str IDs; we cannot - # get the launch time info here; otherwise, sort by the launch time - # is more accurate.) - # This can be removed in the future when we are sure all the users - # have updated to #1671. - nodes_non_matching_launch_config.sort(reverse=True) - reuse_nodes = ( - nodes_matching_launch_config + nodes_non_matching_launch_config - ) - # The total number of reusable nodes can be less than the number of nodes to be created. - # This `[:count]` is fine, as it will get all the reusable nodes, even if there are - # less nodes. - reuse_nodes = reuse_nodes[:count] - - logger.info( - f"Reusing nodes {list(reuse_nodes)}. " - "To disable reuse, set `cache_stopped_nodes: False` " - "under `provider` in the cluster configuration.", - ) - start = get_azure_sdk_function( - client=self.compute_client.virtual_machines, function_name="start" - ) - for node_id in reuse_nodes: - start(resource_group_name=resource_group, vm_name=node_id).wait() - self.set_node_tags(node_id, tags) - count -= len(reuse_nodes) - - if count: - self._create_node(node_config, tags, count) - - def _create_node(self, node_config, tags, count): - """Creates a number of nodes within the namespace.""" - resource_group = self.provider_config["resource_group"] - - # load the template file - current_path = Path(__file__).parent - template_path = current_path.joinpath("azure-vm-template.json") - with open(template_path, "r") as template_fp: - template = json.load(template_fp) - - # get the tags - config_tags = node_config.get("tags", {}).copy() - config_tags.update(tags) - config_tags[TAG_RAY_CLUSTER_NAME] = self.cluster_name - - vm_name = "{node}-{unique_id}-{vm_id}".format( - node=config_tags.get(TAG_RAY_NODE_NAME, "node"), - unique_id=self.provider_config["unique_id"], - vm_id=uuid4().hex[:UNIQUE_ID_LEN], - )[:VM_NAME_MAX_LEN] - use_internal_ips = self.provider_config.get("use_internal_ips", False) - - template_params = node_config["azure_arm_parameters"].copy() - template_params["vmName"] = vm_name - template_params["provisionPublicIp"] = not use_internal_ips - template_params["vmTags"] = config_tags - template_params["vmCount"] = count - template_params["msi"] = self.provider_config["msi"] - template_params["nsg"] = self.provider_config["nsg"] - template_params["subnet"] = self.provider_config["subnet"] - - if node_config.get("need_nvidia_driver_extension", False): - # Configure driver extension for A10 GPUs. A10 GPUs requires a - # special type of drivers which is available at Microsoft HPC - # extension. Reference: https://forums.developer.nvidia.com/t/ubuntu-22-04-installation-driver-error-nvidia-a10/285195/2 - for r in template["resources"]: - if r["type"] == "Microsoft.Compute/virtualMachines": - # Add a nested extension resource for A10 GPUs - r["resources"] = [ - { - "type": "extensions", - "apiVersion": "2015-06-15", - "location": "[variables('location')]", - "dependsOn": [ - "[concat('Microsoft.Compute/virtualMachines/', parameters('vmName'), copyIndex())]" - ], - "name": "NvidiaGpuDriverLinux", - "properties": { - "publisher": "Microsoft.HpcCompute", - "type": "NvidiaGpuDriverLinux", - "typeHandlerVersion": "1.9", - "autoUpgradeMinorVersion": True, - "settings": {}, - }, - }, - ] - break - - parameters = { - "properties": { - "mode": DeploymentMode.incremental, - "template": template, - "parameters": { - key: {"value": value} for key, value in template_params.items() - }, - } - } - - # TODO: we could get the private/public ips back directly - create_or_update = get_azure_sdk_function( - client=self.resource_client.deployments, function_name="create_or_update" - ) - create_or_update( - resource_group_name=resource_group, - deployment_name=vm_name, - parameters=parameters, - ).wait() - - @synchronized - def set_node_tags(self, node_id, tags): - """Sets the tag values (string dict) for the specified node.""" - node_tags = self._get_cached_node(node_id)["tags"] - node_tags.update(tags) - update = get_azure_sdk_function( - client=self.compute_client.virtual_machines, function_name="update" - ) - update( - resource_group_name=self.provider_config["resource_group"], - vm_name=node_id, - parameters={"tags": node_tags}, - ) - self.cached_nodes[node_id]["tags"] = node_tags - - def terminate_node(self, node_id): - """Terminates the specified node. This will delete the VM and - associated resources (NIC, IP, Storage) for the specified node.""" - - resource_group = self.provider_config["resource_group"] - try: - # get metadata for node - metadata = self._get_node(node_id) - except KeyError: - # node no longer exists - return - - if self.cache_stopped_nodes: - try: - # stop machine and leave all resources - logger.info( - f"Stopping instance {node_id}" - "(to fully terminate instead, " - "set `cache_stopped_nodes: False` " - "under `provider` in the cluster configuration)" - ) - stop = get_azure_sdk_function( - client=self.compute_client.virtual_machines, - function_name="deallocate", - ) - stop(resource_group_name=resource_group, vm_name=node_id) - except Exception as e: - logger.warning("Failed to stop VM: {}".format(e)) - else: - vm = self.compute_client.virtual_machines.get( - resource_group_name=resource_group, vm_name=node_id - ) - disks = {d.name for d in vm.storage_profile.data_disks} - disks.add(vm.storage_profile.os_disk.name) - - try: - # delete machine, must wait for this to complete - delete = get_azure_sdk_function( - client=self.compute_client.virtual_machines, function_name="delete" - ) - delete(resource_group_name=resource_group, vm_name=node_id).wait() - except Exception as e: - logger.warning("Failed to delete VM: {}".format(e)) - - try: - # delete nic - delete = get_azure_sdk_function( - client=self.network_client.network_interfaces, - function_name="delete", - ) - delete( - resource_group_name=resource_group, - network_interface_name=metadata["nic_name"], - ) - except Exception as e: - logger.warning("Failed to delete nic: {}".format(e)) - - # delete ip address - if "public_ip_name" in metadata: - try: - delete = get_azure_sdk_function( - client=self.network_client.public_ip_addresses, - function_name="delete", - ) - delete( - resource_group_name=resource_group, - public_ip_address_name=metadata["public_ip_name"], - ) - except Exception as e: - logger.warning("Failed to delete public ip: {}".format(e)) - - # delete disks - for disk in disks: - try: - delete = get_azure_sdk_function( - client=self.compute_client.disks, function_name="delete" - ) - delete(resource_group_name=resource_group, disk_name=disk) - except Exception as e: - logger.warning("Failed to delete disk: {}".format(e)) - - def _get_node(self, node_id): - self._get_filtered_nodes({}) # Side effect: updates cache - return self.cached_nodes[node_id] - - def _get_cached_node(self, node_id): - if node_id in self.cached_nodes: - return self.cached_nodes[node_id] - return self._get_node(node_id=node_id) - - @staticmethod - def bootstrap_config(cluster_config): - return bootstrap_azure(cluster_config) - - def get_command_runner( - self, - log_prefix, - node_id, - auth_config, - cluster_name, - process_runner, - use_internal_ip, - docker_config=None, - ): - common_args = { - "log_prefix": log_prefix, - "node_id": node_id, - "provider": self, - "auth_config": auth_config, - "cluster_name": cluster_name, - "process_runner": process_runner, - "use_internal_ip": use_internal_ip, - } - if docker_config and docker_config["container_name"] != "": - if "docker_login_config" in self.provider_config: - docker_config["docker_login_config"] = docker_utils.DockerLoginConfig( - **self.provider_config["docker_login_config"] - ) - return SkyDockerCommandRunner(docker_config, **common_args) - else: - return SSHCommandRunner(**common_args) diff --git a/sky/templates/azure-ray.yml.j2 b/sky/templates/azure-ray.yml.j2 index e8c388e1879..16eb1d9dd23 100644 --- a/sky/templates/azure-ray.yml.j2 +++ b/sky/templates/azure-ray.yml.j2 @@ -21,7 +21,7 @@ docker: provider: type: external - module: sky.skylet.providers.azure.AzureNodeProvider + module: sky.provision.azure location: {{region}} # Ref: https://github.com/ray-project/ray/blob/2367a2cb9033913b68b1230316496ae273c25b54/python/ray/autoscaler/_private/_azure/node_provider.py#L87 # For Azure, ray distinguishes different instances by the resource_group, @@ -72,45 +72,19 @@ available_node_types: imageVersion: {{image_version}} osDiskSizeGB: {{disk_size}} osDiskTier: {{disk_tier}} - cloudInitSetupCommands: {{cloud_init_setup_commands}} - # optionally set priority to use Spot instances {%- if use_spot %} + # optionally set priority to use Spot instances priority: Spot # set a maximum price for spot instances if desired # billingProfile: # maxPrice: -1 {%- endif %} + cloudInitSetupCommands: |- + {%- for cmd in cloud_init_setup_commands %} + {{ cmd }} + {%- endfor %} need_nvidia_driver_extension: {{need_nvidia_driver_extension}} # TODO: attach disk -{% if num_nodes > 1 %} - ray.worker.default: - min_workers: {{num_nodes - 1}} - max_workers: {{num_nodes - 1}} - resources: {} - node_config: - tags: - skypilot-user: {{ user }} - azure_arm_parameters: - adminUsername: skypilot:ssh_user - publicKey: | - skypilot:ssh_public_key_content - vmSize: {{instance_type}} - # List images https://docs.microsoft.com/en-us/azure/virtual-machines/linux/cli-ps-findimage - imagePublisher: {{image_publisher}} - imageOffer: {{image_offer}} - imageSku: "{{image_sku}}" - imageVersion: {{image_version}} - osDiskSizeGB: {{disk_size}} - osDiskTier: {{disk_tier}} - cloudInitSetupCommands: {{cloud_init_setup_commands}} - {%- if use_spot %} - priority: Spot - # set a maximum price for spot instances if desired - # billingProfile: - # maxPrice: -1 - {%- endif %} - need_nvidia_driver_extension: {{need_nvidia_driver_extension}} -{%- endif %} head_node_type: ray.head.default @@ -123,9 +97,6 @@ file_mounts: { {%- endfor %} } -rsync_exclude: [] - -initialization_commands: [] # List of shell commands to run to set up nodes. # NOTE: these are very performance-sensitive. Each new item opens/closes an SSH @@ -159,34 +130,3 @@ setup_commands: mkdir -p ~/.ssh; (grep -Pzo -q "Host \*\n StrictHostKeyChecking no" ~/.ssh/config) || printf "Host *\n StrictHostKeyChecking no\n" >> ~/.ssh/config; [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); sudo mv /etc/nccl.conf /etc/nccl.conf.bak || true; - -# Command to start ray on the head node. You don't need to change this. -# NOTE: these are very performance-sensitive. Each new item opens/closes an SSH -# connection, which is expensive. Try your best to co-locate commands into fewer -# items! The same comment applies for worker_start_ray_commands. -# -# Increment the following for catching performance bugs easier: -# current num items (num SSH connections): 2 -head_start_ray_commands: - # NOTE: --disable-usage-stats in `ray start` saves 10 seconds of idle wait. - - {{ sky_activate_python_env }}; {{ sky_ray_cmd }} stop; RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 {{ sky_ray_cmd }} start --disable-usage-stats --head --port={{ray_port}} --dashboard-port={{ray_dashboard_port}} --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml {{"--num-gpus=%s" % num_gpus if num_gpus}} {{"--resources='%s'" % custom_resources if custom_resources}} --temp-dir {{ray_temp_dir}} || exit 1; - which prlimit && for id in $(pgrep -f raylet/raylet); do sudo prlimit --nofile=1048576:1048576 --pid=$id || true; done; - {{dump_port_command}}; - {{ray_head_wait_initialized_command}} - -{%- if num_nodes > 1 %} -worker_start_ray_commands: - - {{ sky_activate_python_env }}; {{ sky_ray_cmd }} stop; RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 {{ sky_ray_cmd }} start --disable-usage-stats --address=$RAY_HEAD_IP:{{ray_port}} --object-manager-port=8076 {{"--num-gpus=%s" % num_gpus if num_gpus}} {{"--resources='%s'" % custom_resources if custom_resources}} --temp-dir {{ray_temp_dir}} || exit 1; - which prlimit && for id in $(pgrep -f raylet/raylet); do sudo prlimit --nofile=1048576:1048576 --pid=$id || true; done; -{%- else %} -worker_start_ray_commands: [] -{%- endif %} - -head_node: {} -worker_nodes: {} - -# These fields are required for external cloud providers. -head_setup_commands: [] -worker_setup_commands: [] -cluster_synced_files: [] -file_mounts_sync_continuously: False diff --git a/sky/utils/command_runner.py b/sky/utils/command_runner.py index dce5ee22ba7..8529874092a 100644 --- a/sky/utils/command_runner.py +++ b/sky/utils/command_runner.py @@ -384,6 +384,10 @@ def check_connection(self) -> bool: returncode = self.run('true', connect_timeout=5, stream_logs=False) return returncode == 0 + def close_cached_connection(self) -> None: + """Close the cached connection to the remote machine.""" + pass + class SSHCommandRunner(CommandRunner): """Runner for SSH commands.""" @@ -482,6 +486,26 @@ def _ssh_base_command(self, *, ssh_mode: SshMode, f'{self.ssh_user}@{self.ip}' ] + def close_cached_connection(self) -> None: + """Close the cached connection to the remote machine. + + This is useful when we need to make the permission update effective of a + ssh user, e.g. usermod -aG docker $USER. + """ + if self.ssh_control_name is not None: + control_path = _ssh_control_path(self.ssh_control_name) + if control_path is not None: + cmd = (f'ssh -O exit -S {control_path}/%C ' + f'{self.ssh_user}@{self.ip}') + logger.debug(f'Closing cached connection {control_path!r} with ' + f'cmd: {cmd}') + log_lib.run_with_log(cmd, + log_path=os.devnull, + require_outputs=False, + stream_logs=False, + process_stream=False, + shell=True) + @timeline.event def run( self, @@ -683,6 +707,7 @@ def run( SkyPilot but we still want to get rid of some warning messages, such as SSH warnings. + Returns: returncode or diff --git a/sky/utils/command_runner.pyi b/sky/utils/command_runner.pyi index 077447e1d5c..45dfc77a167 100644 --- a/sky/utils/command_runner.pyi +++ b/sky/utils/command_runner.pyi @@ -114,6 +114,9 @@ class CommandRunner: def check_connection(self) -> bool: ... + def close_cached_connection(self) -> None: + ... + class SSHCommandRunner(CommandRunner): ip: str diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 477ebe8d1ba..5df8e25ad9e 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -238,7 +238,8 @@ def _get_cloud_dependencies_installation_commands( '! command -v curl &> /dev/null || ' '! command -v socat &> /dev/null || ' '! command -v netcat &> /dev/null; ' - 'then apt update && apt install curl socat netcat -y; ' + 'then apt update && apt install curl socat netcat -y ' + '&> /dev/null; ' 'fi" && ' # Install kubectl '(command -v kubectl &>/dev/null || ' diff --git a/tests/backward_compatibility_tests.sh b/tests/backward_compatibility_tests.sh index 2156057953c..4f83c379ccf 100644 --- a/tests/backward_compatibility_tests.sh +++ b/tests/backward_compatibility_tests.sh @@ -52,10 +52,10 @@ conda activate sky-back-compat-master rm -r ~/.sky/wheels || true which sky # Job 1 -sky launch --cloud ${CLOUD} -y --cpus 2 -c ${CLUSTER_NAME} examples/minimal.yaml +sky launch --cloud ${CLOUD} -y --cpus 2 --num-nodes 2 -c ${CLUSTER_NAME} examples/minimal.yaml sky autostop -i 10 -y ${CLUSTER_NAME} # Job 2 -sky exec -d --cloud ${CLOUD} ${CLUSTER_NAME} sleep 100 +sky exec -d --cloud ${CLOUD} --num-nodes 2 ${CLUSTER_NAME} sleep 100 conda activate sky-back-compat-current sky status -r ${CLUSTER_NAME} | grep ${CLUSTER_NAME} | grep UP @@ -84,21 +84,21 @@ fi if [ "$start_from" -le 2 ]; then conda activate sky-back-compat-master rm -r ~/.sky/wheels || true -sky launch --cloud ${CLOUD} -y --cpus 2 -c ${CLUSTER_NAME}-2 examples/minimal.yaml +sky launch --cloud ${CLOUD} -y --cpus 2 --num-nodes 2 -c ${CLUSTER_NAME}-2 examples/minimal.yaml conda activate sky-back-compat-current rm -r ~/.sky/wheels || true sky stop -y ${CLUSTER_NAME}-2 sky start -y ${CLUSTER_NAME}-2 s=$(sky exec --cloud ${CLOUD} -d ${CLUSTER_NAME}-2 examples/minimal.yaml) -echo $s -echo $s | sed -r "s/\x1B\[([0-9]{1,3}(;[0-9]{1,2})?)?[mGK]//g" | grep "Job ID: 2" || exit 1 +echo "$s" +echo "$s" | sed -r "s/\x1B\[([0-9]{1,3}(;[0-9]{1,2})?)?[mGK]//g" | grep "Job ID: 2" || exit 1 fi # `sky autostop` + `sky status -r` if [ "$start_from" -le 3 ]; then conda activate sky-back-compat-master rm -r ~/.sky/wheels || true -sky launch --cloud ${CLOUD} -y --cpus 2 -c ${CLUSTER_NAME}-3 examples/minimal.yaml +sky launch --cloud ${CLOUD} -y --cpus 2 --num-nodes 2 -c ${CLUSTER_NAME}-3 examples/minimal.yaml conda activate sky-back-compat-current rm -r ~/.sky/wheels || true sky autostop -y -i0 ${CLUSTER_NAME}-3 @@ -111,11 +111,11 @@ fi if [ "$start_from" -le 4 ]; then conda activate sky-back-compat-master rm -r ~/.sky/wheels || true -sky launch --cloud ${CLOUD} -y --cpus 2 -c ${CLUSTER_NAME}-4 examples/minimal.yaml +sky launch --cloud ${CLOUD} -y --cpus 2 --num-nodes 2 -c ${CLUSTER_NAME}-4 examples/minimal.yaml sky stop -y ${CLUSTER_NAME}-4 conda activate sky-back-compat-current rm -r ~/.sky/wheels || true -sky launch --cloud ${CLOUD} -y -c ${CLUSTER_NAME}-4 examples/minimal.yaml +sky launch --cloud ${CLOUD} -y --num-nodes 2 -c ${CLUSTER_NAME}-4 examples/minimal.yaml sky queue ${CLUSTER_NAME}-4 sky logs ${CLUSTER_NAME}-4 1 --status sky logs ${CLUSTER_NAME}-4 2 --status @@ -127,7 +127,7 @@ fi if [ "$start_from" -le 5 ]; then conda activate sky-back-compat-master rm -r ~/.sky/wheels || true -sky launch --cloud ${CLOUD} -y --cpus 2 -c ${CLUSTER_NAME}-5 examples/minimal.yaml +sky launch --cloud ${CLOUD} -y --cpus 2 --num-nodes 2 -c ${CLUSTER_NAME}-5 examples/minimal.yaml sky stop -y ${CLUSTER_NAME}-5 conda activate sky-back-compat-current rm -r ~/.sky/wheels || true @@ -145,7 +145,7 @@ fi if [ "$start_from" -le 6 ]; then conda activate sky-back-compat-master rm -r ~/.sky/wheels || true -sky launch --cloud ${CLOUD} -y --cpus 2 -c ${CLUSTER_NAME}-6 examples/multi_hostname.yaml +sky launch --cloud ${CLOUD} -y --cpus 2 --num-nodes 2 -c ${CLUSTER_NAME}-6 examples/multi_hostname.yaml sky stop -y ${CLUSTER_NAME}-6 conda activate sky-back-compat-current rm -r ~/.sky/wheels || true @@ -167,15 +167,15 @@ MANAGED_JOB_JOB_NAME=${CLUSTER_NAME}-${uuid:0:4} if [ "$start_from" -le 7 ]; then conda activate sky-back-compat-master rm -r ~/.sky/wheels || true -sky spot launch -d --cloud ${CLOUD} -y --cpus 2 -n ${MANAGED_JOB_JOB_NAME}-7-0 "echo hi; sleep 1000" -sky spot launch -d --cloud ${CLOUD} -y --cpus 2 -n ${MANAGED_JOB_JOB_NAME}-7-1 "echo hi; sleep 300" +sky spot launch -d --cloud ${CLOUD} -y --cpus 2 --num-nodes 2 -n ${MANAGED_JOB_JOB_NAME}-7-0 "echo hi; sleep 1000" +sky spot launch -d --cloud ${CLOUD} -y --cpus 2 --num-nodes 2 -n ${MANAGED_JOB_JOB_NAME}-7-1 "echo hi; sleep 400" conda activate sky-back-compat-current rm -r ~/.sky/wheels || true s=$(sky jobs queue | grep ${MANAGED_JOB_JOB_NAME}-7 | grep "RUNNING" | wc -l) s=$(sky jobs logs --no-follow -n ${MANAGED_JOB_JOB_NAME}-7-1) echo "$s" echo "$s" | grep " hi" || exit 1 -sky jobs launch -d --cloud ${CLOUD} -y -n ${MANAGED_JOB_JOB_NAME}-7-2 "echo hi; sleep 40" +sky jobs launch -d --cloud ${CLOUD} --num-nodes 2 -y -n ${MANAGED_JOB_JOB_NAME}-7-2 "echo hi; sleep 40" s=$(sky jobs logs --no-follow -n ${MANAGED_JOB_JOB_NAME}-7-2) echo "$s" echo "$s" | grep " hi" || exit 1 @@ -183,7 +183,7 @@ s=$(sky jobs queue | grep ${MANAGED_JOB_JOB_NAME}-7) echo "$s" echo "$s" | grep "RUNNING" | wc -l | grep 3 || exit 1 sky jobs cancel -y -n ${MANAGED_JOB_JOB_NAME}-7-0 -sky jobs logs -n "${MANAGED_JOB_JOB_NAME}-7-1" +sky jobs logs -n "${MANAGED_JOB_JOB_NAME}-7-1" || exit 1 s=$(sky jobs queue | grep ${MANAGED_JOB_JOB_NAME}-7) echo "$s" echo "$s" | grep "SUCCEEDED" | wc -l | grep 2 || exit 1 diff --git a/tests/skyserve/readiness_timeout/task.yaml b/tests/skyserve/readiness_timeout/task.yaml index f618ee730cb..335c949e9de 100644 --- a/tests/skyserve/readiness_timeout/task.yaml +++ b/tests/skyserve/readiness_timeout/task.yaml @@ -11,4 +11,6 @@ resources: cpus: 2+ ports: 8081 +setup: pip install fastapi uvicorn + run: python3 server.py --port 8081 diff --git a/tests/skyserve/readiness_timeout/task_large_timeout.yaml b/tests/skyserve/readiness_timeout/task_large_timeout.yaml index 3039b438d5e..e797d09e059 100644 --- a/tests/skyserve/readiness_timeout/task_large_timeout.yaml +++ b/tests/skyserve/readiness_timeout/task_large_timeout.yaml @@ -12,4 +12,6 @@ resources: cpus: 2+ ports: 8081 +setup: pip install fastapi uvicorn + run: python3 server.py --port 8081 diff --git a/tests/skyserve/update/new_autoscaler_after.yaml b/tests/skyserve/update/new_autoscaler_after.yaml index 64cf3b01772..f5a2e552f67 100644 --- a/tests/skyserve/update/new_autoscaler_after.yaml +++ b/tests/skyserve/update/new_autoscaler_after.yaml @@ -1,7 +1,7 @@ service: readiness_probe: path: /health - initial_delay_seconds: 100 + initial_delay_seconds: 150 replica_policy: min_replicas: 5 max_replicas: 5 @@ -20,6 +20,6 @@ run: | # Sleep for the last replica in the test_skyserve_new_autoscaler_update # so that we can check the behavior difference between rolling and # blue-green update. - sleep 60 + sleep 120 fi python3 server.py diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 9616ef26482..c5e2becff3a 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -3075,7 +3075,6 @@ def test_kubernetes_custom_image(image_id): run_one_test(test) -@pytest.mark.slow def test_azure_start_stop_two_nodes(): name = _get_cluster_name() test = Test( @@ -3862,8 +3861,15 @@ def test_skyserve_new_autoscaler_update(mode: str, generic_cloud: str): """Test skyserve with update that changes autoscaler""" name = _get_service_name() + mode + wait_until_no_pending = ( + f's=$(sky serve status {name}); echo "$s"; ' + 'until ! echo "$s" | grep PENDING; do ' + ' echo "Waiting for replica to be out of pending..."; ' + f' sleep 5; s=$(sky serve status {name}); ' + ' echo "$s"; ' + 'done') four_spot_up_cmd = _check_replica_in_status(name, [(4, True, 'READY')]) - update_check = [f'until ({four_spot_up_cmd}); do sleep 5; done; sleep 10;'] + update_check = [f'until ({four_spot_up_cmd}); do sleep 5; done; sleep 15;'] if mode == 'rolling': # Check rolling update, it will terminate one of the old on-demand # instances, once there are 4 spot instance ready. @@ -3892,7 +3898,8 @@ def test_skyserve_new_autoscaler_update(mode: str, generic_cloud: str): 's=$(curl http://$endpoint); echo "$s"; echo "$s" | grep "Hi, SkyPilot here"', f'sky serve update {name} --cloud {generic_cloud} --mode {mode} -y tests/skyserve/update/new_autoscaler_after.yaml', # Wait for update to be registered - f'sleep 120', + f'sleep 90', + wait_until_no_pending, _check_replica_in_status( name, [(4, True, _SERVICE_LAUNCHING_STATUS_REGEX + '\|READY'), (1, False, _SERVICE_LAUNCHING_STATUS_REGEX),