Skip to content

Commit

Permalink
[Azure] Use SkyPilot provisioner for status query (#3696)
Browse files Browse the repository at this point in the history
* Use SkyPilot for status query

* format

* Avoid reconfig

* Add todo

* Fix filtering for autodown clusters

* remove comment

* Address comments

* typing
  • Loading branch information
Michaelvll authored Jul 1, 2024
1 parent d3c1f8c commit 3d9c6ca
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 110 deletions.
90 changes: 1 addition & 89 deletions sky/clouds/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
from sky import clouds
from sky import exceptions
from sky import sky_logging
from sky import status_lib
from sky.adaptors import azure
from sky.clouds import service_catalog
from sky.skylet import log_lib
from sky.utils import common_utils
from sky.utils import resources_utils
from sky.utils import ux_utils
Expand Down Expand Up @@ -70,6 +68,7 @@ class Azure(clouds.Cloud):
_INDENT_PREFIX = ' ' * 4

PROVISIONER_VERSION = clouds.ProvisionerVersion.RAY_AUTOSCALER
STATUS_VERSION = clouds.StatusVersion.SKYPILOT

@classmethod
def _unsupported_features_for_resources(
Expand Down Expand Up @@ -613,90 +612,3 @@ def _get_disk_type(cls,
resources_utils.DiskTier.LOW: 'Standard_LRS',
}
return tier2name[tier]

@classmethod
def query_status(cls, name: str, tag_filters: Dict[str, str],
region: Optional[str], zone: Optional[str],
**kwargs) -> List[status_lib.ClusterStatus]:
del zone # unused
status_map = {
'VM starting': status_lib.ClusterStatus.INIT,
'VM running': status_lib.ClusterStatus.UP,
# 'VM stopped' in Azure means Stopped (Allocated), which still bills
# for the VM.
'VM stopping': status_lib.ClusterStatus.INIT,
'VM stopped': status_lib.ClusterStatus.INIT,
# 'VM deallocated' in Azure means Stopped (Deallocated), which does not
# bill for the VM.
'VM deallocating': status_lib.ClusterStatus.STOPPED,
'VM deallocated': status_lib.ClusterStatus.STOPPED,
}
tag_filter_str = ' '.join(
f'tags.\\"{k}\\"==\'{v}\'' for k, v in tag_filters.items())

query_node_id = (f'az vm list --query "[?{tag_filter_str}].id" -o json')
returncode, stdout, stderr = log_lib.run_with_log(query_node_id,
'/dev/null',
require_outputs=True,
shell=True)
logger.debug(f'{query_node_id} returned {returncode}.\n'
'**** STDOUT ****\n'
f'{stdout}\n'
'**** STDERR ****\n'
f'{stderr}')
if returncode == 0:
if not stdout.strip():
return []
node_ids = json.loads(stdout.strip())
if not node_ids:
return []
state_str = '[].powerState'
if len(node_ids) == 1:
state_str = 'powerState'
node_ids_str = '\t'.join(node_ids)
query_cmd = (
f'az vm show -d --ids {node_ids_str} --query "{state_str}" -o json'
)
returncode, stdout, stderr = log_lib.run_with_log(
query_cmd, '/dev/null', require_outputs=True, shell=True)
logger.debug(f'{query_cmd} returned {returncode}.\n'
'**** STDOUT ****\n'
f'{stdout}\n'
'**** STDERR ****\n'
f'{stderr}')

# NOTE: Azure cli should be handled carefully. The query command above
# takes about 1 second to run.
# An alternative is the following command, but it will take more than
# 20 seconds to run.
# query_cmd = (
# f'az vm list --show-details --query "['
# f'?tags.\\"ray-cluster-name\\" == \'{handle.cluster_name}\' '
# '&& tags.\\"ray-node-type\\" == \'head\'].powerState" -o tsv'
# )

if returncode != 0:
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterStatusFetchingError(
f'Failed to query Azure cluster {name!r} status: '
f'{stdout + stderr}')

assert stdout.strip(), f'No status returned for {name!r}'

original_statuses_list = json.loads(stdout.strip())
if not original_statuses_list:
# No nodes found. The original_statuses_list will be empty string.
# Return empty list.
return []
if not isinstance(original_statuses_list, list):
original_statuses_list = [original_statuses_list]
statuses = []
for s in original_statuses_list:
if s not in status_map:
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterStatusFetchingError(
f'Failed to parse status from Azure response: {stdout}')
node_status = status_map[s]
if node_status is not None:
statuses.append(node_status)
return statuses
1 change: 1 addition & 0 deletions sky/provision/azure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

from sky.provision.azure.instance import cleanup_ports
from sky.provision.azure.instance import open_ports
from sky.provision.azure.instance import query_instances
113 changes: 113 additions & 0 deletions sky/provision/azure/instance.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
"""Azure instance provisioning."""
import logging
from multiprocessing import pool
import typing
from typing import Any, Callable, Dict, List, Optional

from sky import exceptions
from sky import sky_logging
from sky import status_lib
from sky.adaptors import azure
from sky.utils import common_utils
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
from azure.mgmt import compute as azure_compute

logger = sky_logging.init_logger(__name__)

# Suppress noisy logs from Azure SDK. Reference:
Expand All @@ -17,6 +25,8 @@
TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
TAG_RAY_NODE_KIND = 'ray-node-type'

_RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE = 'ResourceGroupNotFound'


def get_azure_sdk_function(client: Any, function_name: str) -> Callable:
"""Retrieve a callable function from Azure SDK client object.
Expand Down Expand Up @@ -93,3 +103,106 @@ def cleanup_ports(
# Azure will automatically cleanup network security groups when cleanup
# resource group. So we don't need to do anything here.
del cluster_name_on_cloud, ports, provider_config # Unused.


def _get_vm_status(compute_client: 'azure_compute.ComputeManagementClient',
vm_name: str, resource_group: str) -> str:
instance = compute_client.virtual_machines.instance_view(
resource_group_name=resource_group, vm_name=vm_name).as_dict()
for status in instance['statuses']:
code_state = status['code'].split('/')
# It is possible that sometimes the 'code' is empty string, and we
# should skip them.
if len(code_state) != 2:
continue
code, state = code_state
# skip provisioning status
if code == 'PowerState':
return state
raise ValueError(f'Failed to get status for VM {vm_name}')


def _filter_instances(
compute_client: 'azure_compute.ComputeManagementClient',
filters: Dict[str, str],
resource_group: str) -> List['azure_compute.models.VirtualMachine']:

def match_tags(vm):
for k, v in filters.items():
if vm.tags.get(k) != v:
return False
return True

try:
list_virtual_machines = get_azure_sdk_function(
client=compute_client.virtual_machines, function_name='list')
vms = list_virtual_machines(resource_group_name=resource_group)
nodes = list(filter(match_tags, vms))
except azure.exceptions().ResourceNotFoundError as e:
if _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE in str(e):
return []
raise
return nodes


@common_utils.retry
def query_instances(
cluster_name_on_cloud: str,
provider_config: Optional[Dict[str, Any]] = None,
non_terminated_only: bool = True,
) -> Dict[str, Optional[status_lib.ClusterStatus]]:
"""See sky/provision/__init__.py"""
assert provider_config is not None, cluster_name_on_cloud
status_map = {
'starting': status_lib.ClusterStatus.INIT,
'running': status_lib.ClusterStatus.UP,
# 'stopped' in Azure means Stopped (Allocated), which still bills
# for the VM.
'stopping': status_lib.ClusterStatus.INIT,
'stopped': status_lib.ClusterStatus.INIT,
# 'VM deallocated' in Azure means Stopped (Deallocated), which does not
# bill for the VM.
'deallocating': status_lib.ClusterStatus.STOPPED,
'deallocated': status_lib.ClusterStatus.STOPPED,
}
provisioning_state_map = {
'Creating': status_lib.ClusterStatus.INIT,
'Updating': status_lib.ClusterStatus.INIT,
'Failed': status_lib.ClusterStatus.INIT,
'Migrating': status_lib.ClusterStatus.INIT,
'Deleting': None,
# Succeeded in provisioning state means the VM is provisioned but not
# necessarily running. We exclude Succeeded state here, and the caller
# should determine the status of the VM based on the power state.
# 'Succeeded': status_lib.ClusterStatus.UP,
}

subscription_id = provider_config['subscription_id']
resource_group = provider_config['resource_group']
compute_client = azure.get_client('compute', subscription_id)
filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
nodes = _filter_instances(compute_client, filters, resource_group)
statuses = {}

def _fetch_and_map_status(
compute_client: 'azure_compute.ComputeManagementClient', node,
resource_group: str):
if node.provisioning_state in provisioning_state_map:
status = provisioning_state_map[node.provisioning_state]
else:
original_status = _get_vm_status(compute_client, node.name,
resource_group)
if original_status not in status_map:
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterStatusFetchingError(
f'Failed to parse status from Azure response: {status}')
status = status_map[original_status]
if status is None and non_terminated_only:
return
statuses[node.name] = status

with pool.ThreadPool() as p:
p.starmap(_fetch_and_map_status,
[(compute_client, node, resource_group) for node in nodes])

return statuses
19 changes: 19 additions & 0 deletions sky/provision/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Common data structures for provisioning"""
import abc
import dataclasses
import functools
import os
from typing import Any, Dict, List, Optional, Tuple

from sky import sky_logging
from sky.utils import resources_utils

# NOTE: we can use pydantic instead of dataclasses or namedtuples, because
Expand All @@ -14,6 +16,10 @@
# -------------------- input data model -------------------- #

InstanceId = str
_START_TITLE = '\n' + '-' * 20 + 'Start: {} ' + '-' * 20
_END_TITLE = '-' * 20 + 'End: {} ' + '-' * 20 + '\n'

logger = sky_logging.init_logger(__name__)


class ProvisionerError(RuntimeError):
Expand Down Expand Up @@ -268,3 +274,16 @@ def query_ports_passthrough(
for port in ports:
result[port] = [SocketEndpoint(port=port, host=head_ip)]
return result


def log_function_start_end(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
logger.info(_START_TITLE.format(func.__name__))
try:
return func(*args, **kwargs)
finally:
logger.info(_END_TITLE.format(func.__name__))

return wrapper
27 changes: 6 additions & 21 deletions sky/provision/instance_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from sky.utils import ux_utils

logger = sky_logging.init_logger(__name__)
_START_TITLE = '\n' + '-' * 20 + 'Start: {} ' + '-' * 20
_END_TITLE = '-' * 20 + 'End: {} ' + '-' * 20 + '\n'

_MAX_RETRY = 6

Expand Down Expand Up @@ -99,19 +97,6 @@ def retry(*args, **kwargs):
return decorator


def _log_start_end(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
logger.info(_START_TITLE.format(func.__name__))
try:
return func(*args, **kwargs)
finally:
logger.info(_END_TITLE.format(func.__name__))

return wrapper


def _hint_worker_log_path(cluster_name: str, cluster_info: common.ClusterInfo,
stage_name: str):
if cluster_info.num_instances > 1:
Expand Down Expand Up @@ -153,7 +138,7 @@ def _parallel_ssh_with_cache(func,
return [future.result() for future in results]


@_log_start_end
@common.log_function_start_end
def initialize_docker(cluster_name: str, docker_config: Dict[str, Any],
cluster_info: common.ClusterInfo,
ssh_credentials: Dict[str, Any]) -> Optional[str]:
Expand Down Expand Up @@ -184,7 +169,7 @@ def _initialize_docker(runner: command_runner.CommandRunner, log_path: str):
return docker_users[0]


@_log_start_end
@common.log_function_start_end
def setup_runtime_on_cluster(cluster_name: str, setup_commands: List[str],
cluster_info: common.ClusterInfo,
ssh_credentials: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -260,7 +245,7 @@ def _ray_gpu_options(custom_resource: str) -> str:
return f' --num-gpus={acc_count}'


@_log_start_end
@common.log_function_start_end
@_auto_retry()
def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str],
cluster_info: common.ClusterInfo,
Expand Down Expand Up @@ -320,7 +305,7 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str],
f'===== stderr ====={stderr}')


@_log_start_end
@common.log_function_start_end
@_auto_retry()
def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool,
custom_resource: Optional[str], ray_port: int,
Expand Down Expand Up @@ -417,7 +402,7 @@ def _setup_ray_worker(runner_and_id: Tuple[command_runner.CommandRunner,
f'===== stderr ====={stderr}')


@_log_start_end
@common.log_function_start_end
@_auto_retry()
def start_skylet_on_head_node(cluster_name: str,
cluster_info: common.ClusterInfo,
Expand Down Expand Up @@ -501,7 +486,7 @@ def _max_workers_for_file_mounts(common_file_mounts: Dict[str, str]) -> int:
return max_workers


@_log_start_end
@common.log_function_start_end
def internal_file_mounts(cluster_name: str, common_file_mounts: Dict[str, str],
cluster_info: common.ClusterInfo,
ssh_credentials: Dict[str, str]) -> None:
Expand Down
2 changes: 2 additions & 0 deletions sky/skylet/providers/azure/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from sky.adaptors import azure
from sky.utils import common_utils
from sky.provision import common

UNIQUE_ID_LEN = 4
_WAIT_NSG_CREATION_NUM_TIMEOUT_SECONDS = 600
Expand Down Expand Up @@ -47,6 +48,7 @@ def bootstrap_azure(config):
return config


@common.log_function_start_end
def _configure_resource_group(config):
# TODO: look at availability sets
# https://docs.microsoft.com/en-us/azure/virtual-machines/windows/tutorial-availability-sets
Expand Down

0 comments on commit 3d9c6ca

Please sign in to comment.