diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index dde83e0d600..46d49e2d00a 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -1,6 +1,7 @@ """Azure instance provisioning.""" import logging from multiprocessing import pool +import typing from typing import Any, Callable, Dict, List, Optional from sky import exceptions @@ -10,6 +11,9 @@ from sky.utils import common_utils from sky.utils import ux_utils +if typing.TYPE_CHECKING: + from azure.mgmt import compute as azure_compute + logger = sky_logging.init_logger(__name__) # Suppress noisy logs from Azure SDK. Reference: @@ -21,6 +25,8 @@ TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' TAG_RAY_NODE_KIND = 'ray-node-type' +_RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE = 'ResourceGroupNotFound' + def get_azure_sdk_function(client: Any, function_name: str) -> Callable: """Retrieve a callable function from Azure SDK client object. @@ -156,32 +162,8 @@ def terminate_instances( delete_resource_group(resource_group, force_deletion_types=None) -# def _get_vm_ips(network_client, vm, resource_group: str, -# use_internal_ips: bool) -> Tuple[str, str]: -# nic_id = vm.network_profile.network_interfaces[0].id -# nic_name = nic_id.split("/")[-1] -# nic = network_client.network_interfaces.get( -# resource_group_name=resource_group, -# network_interface_name=nic_name, -# ) -# ip_config = nic.ip_configurations[0] - -# external_ip = None -# if not use_internal_ips: -# public_ip_id = ip_config.public_ip_address.id -# public_ip_name = public_ip_id.split("/")[-1] -# public_ip = network_client.public_ip_addresses.get( -# resource_group_name=resource_group, -# public_ip_address_name=public_ip_name, -# ) -# external_ip = public_ip.ip_address - -# internal_ip = ip_config.private_ip_address - -# return (external_ip, internal_ip) - - -def _get_vm_status(compute_client, vm_name: str, resource_group: str) -> str: +def _get_vm_status(compute_client: 'azure_compute.ComputeManagementClient', + vm_name: str, resource_group: str) -> str: instance = compute_client.virtual_machines.instance_view( resource_group_name=resource_group, vm_name=vm_name).as_dict() for status in instance['statuses']: @@ -197,8 +179,10 @@ def _get_vm_status(compute_client, vm_name: str, resource_group: str) -> str: raise ValueError(f'Failed to get status for VM {vm_name}') -def _filter_instances(compute_client, filters: Dict[str, str], - resource_group: str) -> List[Any]: +def _filter_instances( + compute_client: 'azure_compute.ComputeManagementClient', + filters: Dict[str, str], + resource_group: str) -> List['azure_compute.models.VirtualMachine']: def match_tags(vm): for k, v in filters.items(): @@ -212,7 +196,7 @@ def match_tags(vm): vms = list_virtual_machines(resource_group_name=resource_group) nodes = list(filter(match_tags, vms)) except azure.exceptions().ResourceNotFoundError as e: - if 'ResourceGroupNotFound' in str(e): + if _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE in str(e): return [] raise return nodes @@ -245,7 +229,8 @@ def query_instances( 'Migrating': status_lib.ClusterStatus.INIT, 'Deleting': None, # Succeeded in provisioning state means the VM is provisioned but not - # necessarily running. + # necessarily running. We exclude Succeeded state here, and the caller + # should determine the status of the VM based on the power state. # 'Succeeded': status_lib.ClusterStatus.UP, } @@ -256,7 +241,9 @@ def query_instances( nodes = _filter_instances(compute_client, filters, resource_group) statuses = {} - def _fetch_and_map_status(compute_client, node, resource_group: str): + def _fetch_and_map_status( + compute_client: 'azure_compute.ComputeManagementClient', node, + resource_group: str): if node.provisioning_state in provisioning_state_map: status = provisioning_state_map[node.provisioning_state] else: