diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index 4df1cd4a4bf..852af5c0c77 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -14,10 +14,8 @@ from sky import clouds from sky import exceptions from sky import sky_logging -from sky import status_lib from sky.adaptors import azure from sky.clouds import service_catalog -from sky.skylet import log_lib from sky.utils import common_utils from sky.utils import resources_utils from sky.utils import ux_utils @@ -70,6 +68,7 @@ class Azure(clouds.Cloud): _INDENT_PREFIX = ' ' * 4 PROVISIONER_VERSION = clouds.ProvisionerVersion.RAY_AUTOSCALER + STATUS_VERSION = clouds.StatusVersion.SKYPILOT @classmethod def _unsupported_features_for_resources( @@ -613,90 +612,3 @@ def _get_disk_type(cls, resources_utils.DiskTier.LOW: 'Standard_LRS', } return tier2name[tier] - - @classmethod - def query_status(cls, name: str, tag_filters: Dict[str, str], - region: Optional[str], zone: Optional[str], - **kwargs) -> List[status_lib.ClusterStatus]: - del zone # unused - status_map = { - 'VM starting': status_lib.ClusterStatus.INIT, - 'VM running': status_lib.ClusterStatus.UP, - # 'VM stopped' in Azure means Stopped (Allocated), which still bills - # for the VM. - 'VM stopping': status_lib.ClusterStatus.INIT, - 'VM stopped': status_lib.ClusterStatus.INIT, - # 'VM deallocated' in Azure means Stopped (Deallocated), which does not - # bill for the VM. - 'VM deallocating': status_lib.ClusterStatus.STOPPED, - 'VM deallocated': status_lib.ClusterStatus.STOPPED, - } - tag_filter_str = ' '.join( - f'tags.\\"{k}\\"==\'{v}\'' for k, v in tag_filters.items()) - - query_node_id = (f'az vm list --query "[?{tag_filter_str}].id" -o json') - returncode, stdout, stderr = log_lib.run_with_log(query_node_id, - '/dev/null', - require_outputs=True, - shell=True) - logger.debug(f'{query_node_id} returned {returncode}.\n' - '**** STDOUT ****\n' - f'{stdout}\n' - '**** STDERR ****\n' - f'{stderr}') - if returncode == 0: - if not stdout.strip(): - return [] - node_ids = json.loads(stdout.strip()) - if not node_ids: - return [] - state_str = '[].powerState' - if len(node_ids) == 1: - state_str = 'powerState' - node_ids_str = '\t'.join(node_ids) - query_cmd = ( - f'az vm show -d --ids {node_ids_str} --query "{state_str}" -o json' - ) - returncode, stdout, stderr = log_lib.run_with_log( - query_cmd, '/dev/null', require_outputs=True, shell=True) - logger.debug(f'{query_cmd} returned {returncode}.\n' - '**** STDOUT ****\n' - f'{stdout}\n' - '**** STDERR ****\n' - f'{stderr}') - - # NOTE: Azure cli should be handled carefully. The query command above - # takes about 1 second to run. - # An alternative is the following command, but it will take more than - # 20 seconds to run. - # query_cmd = ( - # f'az vm list --show-details --query "[' - # f'?tags.\\"ray-cluster-name\\" == \'{handle.cluster_name}\' ' - # '&& tags.\\"ray-node-type\\" == \'head\'].powerState" -o tsv' - # ) - - if returncode != 0: - with ux_utils.print_exception_no_traceback(): - raise exceptions.ClusterStatusFetchingError( - f'Failed to query Azure cluster {name!r} status: ' - f'{stdout + stderr}') - - assert stdout.strip(), f'No status returned for {name!r}' - - original_statuses_list = json.loads(stdout.strip()) - if not original_statuses_list: - # No nodes found. The original_statuses_list will be empty string. - # Return empty list. - return [] - if not isinstance(original_statuses_list, list): - original_statuses_list = [original_statuses_list] - statuses = [] - for s in original_statuses_list: - if s not in status_map: - with ux_utils.print_exception_no_traceback(): - raise exceptions.ClusterStatusFetchingError( - f'Failed to parse status from Azure response: {stdout}') - node_status = status_map[s] - if node_status is not None: - statuses.append(node_status) - return statuses diff --git a/sky/provision/azure/__init__.py b/sky/provision/azure/__init__.py index b83dbb462d9..b28c161a866 100644 --- a/sky/provision/azure/__init__.py +++ b/sky/provision/azure/__init__.py @@ -2,3 +2,4 @@ from sky.provision.azure.instance import cleanup_ports from sky.provision.azure.instance import open_ports +from sky.provision.azure.instance import query_instances diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index de5c7cbf0e9..6693427d8ff 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -1,11 +1,19 @@ """Azure instance provisioning.""" import logging +from multiprocessing import pool +import typing from typing import Any, Callable, Dict, List, Optional +from sky import exceptions from sky import sky_logging +from sky import status_lib from sky.adaptors import azure +from sky.utils import common_utils from sky.utils import ux_utils +if typing.TYPE_CHECKING: + from azure.mgmt import compute as azure_compute + logger = sky_logging.init_logger(__name__) # Suppress noisy logs from Azure SDK. Reference: @@ -17,6 +25,8 @@ TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' TAG_RAY_NODE_KIND = 'ray-node-type' +_RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE = 'ResourceGroupNotFound' + def get_azure_sdk_function(client: Any, function_name: str) -> Callable: """Retrieve a callable function from Azure SDK client object. @@ -93,3 +103,106 @@ def cleanup_ports( # Azure will automatically cleanup network security groups when cleanup # resource group. So we don't need to do anything here. del cluster_name_on_cloud, ports, provider_config # Unused. + + +def _get_vm_status(compute_client: 'azure_compute.ComputeManagementClient', + vm_name: str, resource_group: str) -> str: + instance = compute_client.virtual_machines.instance_view( + resource_group_name=resource_group, vm_name=vm_name).as_dict() + for status in instance['statuses']: + code_state = status['code'].split('/') + # It is possible that sometimes the 'code' is empty string, and we + # should skip them. + if len(code_state) != 2: + continue + code, state = code_state + # skip provisioning status + if code == 'PowerState': + return state + raise ValueError(f'Failed to get status for VM {vm_name}') + + +def _filter_instances( + compute_client: 'azure_compute.ComputeManagementClient', + filters: Dict[str, str], + resource_group: str) -> List['azure_compute.models.VirtualMachine']: + + def match_tags(vm): + for k, v in filters.items(): + if vm.tags.get(k) != v: + return False + return True + + try: + list_virtual_machines = get_azure_sdk_function( + client=compute_client.virtual_machines, function_name='list') + vms = list_virtual_machines(resource_group_name=resource_group) + nodes = list(filter(match_tags, vms)) + except azure.exceptions().ResourceNotFoundError as e: + if _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE in str(e): + return [] + raise + return nodes + + +@common_utils.retry +def query_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + non_terminated_only: bool = True, +) -> Dict[str, Optional[status_lib.ClusterStatus]]: + """See sky/provision/__init__.py""" + assert provider_config is not None, cluster_name_on_cloud + status_map = { + 'starting': status_lib.ClusterStatus.INIT, + 'running': status_lib.ClusterStatus.UP, + # 'stopped' in Azure means Stopped (Allocated), which still bills + # for the VM. + 'stopping': status_lib.ClusterStatus.INIT, + 'stopped': status_lib.ClusterStatus.INIT, + # 'VM deallocated' in Azure means Stopped (Deallocated), which does not + # bill for the VM. + 'deallocating': status_lib.ClusterStatus.STOPPED, + 'deallocated': status_lib.ClusterStatus.STOPPED, + } + provisioning_state_map = { + 'Creating': status_lib.ClusterStatus.INIT, + 'Updating': status_lib.ClusterStatus.INIT, + 'Failed': status_lib.ClusterStatus.INIT, + 'Migrating': status_lib.ClusterStatus.INIT, + 'Deleting': None, + # Succeeded in provisioning state means the VM is provisioned but not + # necessarily running. We exclude Succeeded state here, and the caller + # should determine the status of the VM based on the power state. + # 'Succeeded': status_lib.ClusterStatus.UP, + } + + subscription_id = provider_config['subscription_id'] + resource_group = provider_config['resource_group'] + compute_client = azure.get_client('compute', subscription_id) + filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + nodes = _filter_instances(compute_client, filters, resource_group) + statuses = {} + + def _fetch_and_map_status( + compute_client: 'azure_compute.ComputeManagementClient', node, + resource_group: str): + if node.provisioning_state in provisioning_state_map: + status = provisioning_state_map[node.provisioning_state] + else: + original_status = _get_vm_status(compute_client, node.name, + resource_group) + if original_status not in status_map: + with ux_utils.print_exception_no_traceback(): + raise exceptions.ClusterStatusFetchingError( + f'Failed to parse status from Azure response: {status}') + status = status_map[original_status] + if status is None and non_terminated_only: + return + statuses[node.name] = status + + with pool.ThreadPool() as p: + p.starmap(_fetch_and_map_status, + [(compute_client, node, resource_group) for node in nodes]) + + return statuses diff --git a/sky/provision/common.py b/sky/provision/common.py index 7c1bcb32652..e5df26a4c09 100644 --- a/sky/provision/common.py +++ b/sky/provision/common.py @@ -1,9 +1,11 @@ """Common data structures for provisioning""" import abc import dataclasses +import functools import os from typing import Any, Dict, List, Optional, Tuple +from sky import sky_logging from sky.utils import resources_utils # NOTE: we can use pydantic instead of dataclasses or namedtuples, because @@ -14,6 +16,10 @@ # -------------------- input data model -------------------- # InstanceId = str +_START_TITLE = '\n' + '-' * 20 + 'Start: {} ' + '-' * 20 +_END_TITLE = '-' * 20 + 'End: {} ' + '-' * 20 + '\n' + +logger = sky_logging.init_logger(__name__) class ProvisionerError(RuntimeError): @@ -268,3 +274,16 @@ def query_ports_passthrough( for port in ports: result[port] = [SocketEndpoint(port=port, host=head_ip)] return result + + +def log_function_start_end(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + logger.info(_START_TITLE.format(func.__name__)) + try: + return func(*args, **kwargs) + finally: + logger.info(_END_TITLE.format(func.__name__)) + + return wrapper diff --git a/sky/provision/instance_setup.py b/sky/provision/instance_setup.py index 1fb80ba542a..2d9ead3dc01 100644 --- a/sky/provision/instance_setup.py +++ b/sky/provision/instance_setup.py @@ -23,8 +23,6 @@ from sky.utils import ux_utils logger = sky_logging.init_logger(__name__) -_START_TITLE = '\n' + '-' * 20 + 'Start: {} ' + '-' * 20 -_END_TITLE = '-' * 20 + 'End: {} ' + '-' * 20 + '\n' _MAX_RETRY = 6 @@ -99,19 +97,6 @@ def retry(*args, **kwargs): return decorator -def _log_start_end(func): - - @functools.wraps(func) - def wrapper(*args, **kwargs): - logger.info(_START_TITLE.format(func.__name__)) - try: - return func(*args, **kwargs) - finally: - logger.info(_END_TITLE.format(func.__name__)) - - return wrapper - - def _hint_worker_log_path(cluster_name: str, cluster_info: common.ClusterInfo, stage_name: str): if cluster_info.num_instances > 1: @@ -153,7 +138,7 @@ def _parallel_ssh_with_cache(func, return [future.result() for future in results] -@_log_start_end +@common.log_function_start_end def initialize_docker(cluster_name: str, docker_config: Dict[str, Any], cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, Any]) -> Optional[str]: @@ -184,7 +169,7 @@ def _initialize_docker(runner: command_runner.CommandRunner, log_path: str): return docker_users[0] -@_log_start_end +@common.log_function_start_end def setup_runtime_on_cluster(cluster_name: str, setup_commands: List[str], cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, Any]) -> None: @@ -260,7 +245,7 @@ def _ray_gpu_options(custom_resource: str) -> str: return f' --num-gpus={acc_count}' -@_log_start_end +@common.log_function_start_end @_auto_retry() def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], cluster_info: common.ClusterInfo, @@ -320,7 +305,7 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], f'===== stderr ====={stderr}') -@_log_start_end +@common.log_function_start_end @_auto_retry() def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, custom_resource: Optional[str], ray_port: int, @@ -417,7 +402,7 @@ def _setup_ray_worker(runner_and_id: Tuple[command_runner.CommandRunner, f'===== stderr ====={stderr}') -@_log_start_end +@common.log_function_start_end @_auto_retry() def start_skylet_on_head_node(cluster_name: str, cluster_info: common.ClusterInfo, @@ -501,7 +486,7 @@ def _max_workers_for_file_mounts(common_file_mounts: Dict[str, str]) -> int: return max_workers -@_log_start_end +@common.log_function_start_end def internal_file_mounts(cluster_name: str, common_file_mounts: Dict[str, str], cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, str]) -> None: diff --git a/sky/skylet/providers/azure/config.py b/sky/skylet/providers/azure/config.py index 35008ef13d7..13ecd64a987 100644 --- a/sky/skylet/providers/azure/config.py +++ b/sky/skylet/providers/azure/config.py @@ -14,6 +14,7 @@ from sky.adaptors import azure from sky.utils import common_utils +from sky.provision import common UNIQUE_ID_LEN = 4 _WAIT_NSG_CREATION_NUM_TIMEOUT_SECONDS = 600 @@ -47,6 +48,7 @@ def bootstrap_azure(config): return config +@common.log_function_start_end def _configure_resource_group(config): # TODO: look at availability sets # https://docs.microsoft.com/en-us/azure/virtual-machines/windows/tutorial-availability-sets