From 94f7c610e4ee3a9fb535712dfe162e967afc586f Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 1 Jul 2024 04:40:51 +0000 Subject: [PATCH] Address comments --- sky/provision/azure/instance.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index 3d4b05da795..f50f874a4e3 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -1,6 +1,7 @@ """Azure instance provisioning.""" import logging from multiprocessing import pool +import typing from typing import Any, Callable, Dict, List, Optional from sky import exceptions @@ -10,6 +11,9 @@ from sky.utils import common_utils from sky.utils import ux_utils +if typing.TYPE_CHECKING: + from azure.mgmt import compute as azure_compute + logger = sky_logging.init_logger(__name__) # Suppress noisy logs from Azure SDK. Reference: @@ -21,6 +25,8 @@ TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' TAG_RAY_NODE_KIND = 'ray-node-type' +_RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE = 'ResourceGroupNotFound' + def get_azure_sdk_function(client: Any, function_name: str) -> Callable: """Retrieve a callable function from Azure SDK client object. @@ -99,7 +105,8 @@ def cleanup_ports( del cluster_name_on_cloud, ports, provider_config # Unused. -def _get_vm_status(compute_client, vm_name: str, resource_group: str) -> str: +def _get_vm_status(compute_client: 'azure_compute.ComputeManagementClient', + vm_name: str, resource_group: str) -> str: instance = compute_client.virtual_machines.instance_view( resource_group_name=resource_group, vm_name=vm_name).as_dict() for status in instance['statuses']: @@ -115,8 +122,9 @@ def _get_vm_status(compute_client, vm_name: str, resource_group: str) -> str: raise ValueError(f'Failed to get status for VM {vm_name}') -def _filter_instances(compute_client, filters: Dict[str, str], - resource_group: str) -> List[Any]: +def _filter_instances(compute_client: 'azure_compute.ComputeManagementClient', + filters: Dict[str, + str], resource_group: str) -> List[Any]: def match_tags(vm): for k, v in filters.items(): @@ -130,7 +138,7 @@ def match_tags(vm): vms = list_virtual_machines(resource_group_name=resource_group) nodes = list(filter(match_tags, vms)) except azure.exceptions().ResourceNotFoundError as e: - if 'ResourceGroupNotFound' in str(e): + if _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE in str(e): return [] raise return nodes @@ -163,7 +171,8 @@ def query_instances( 'Migrating': status_lib.ClusterStatus.INIT, 'Deleting': None, # Succeeded in provisioning state means the VM is provisioned but not - # necessarily running. + # necessarily running. We exclude Succeeded state here, and the caller + # should determine the status of the VM based on the power state. # 'Succeeded': status_lib.ClusterStatus.UP, }