Skip to content

Commit

Permalink
Merge branch 'master' of github.com:skypilot-org/skypilot into azure-…
Browse files Browse the repository at this point in the history
…termination
  • Loading branch information
Michaelvll committed Jul 1, 2024
2 parents 8f25311 + 3d9c6ca commit 4a7e0f9
Showing 1 changed file with 18 additions and 31 deletions.
49 changes: 18 additions & 31 deletions sky/provision/azure/instance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Azure instance provisioning."""
import logging
from multiprocessing import pool
import typing
from typing import Any, Callable, Dict, List, Optional

from sky import exceptions
Expand All @@ -10,6 +11,9 @@
from sky.utils import common_utils
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
from azure.mgmt import compute as azure_compute

logger = sky_logging.init_logger(__name__)

# Suppress noisy logs from Azure SDK. Reference:
Expand All @@ -21,6 +25,8 @@
TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
TAG_RAY_NODE_KIND = 'ray-node-type'

_RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE = 'ResourceGroupNotFound'


def get_azure_sdk_function(client: Any, function_name: str) -> Callable:
"""Retrieve a callable function from Azure SDK client object.
Expand Down Expand Up @@ -156,32 +162,8 @@ def terminate_instances(
delete_resource_group(resource_group, force_deletion_types=None)


# def _get_vm_ips(network_client, vm, resource_group: str,
# use_internal_ips: bool) -> Tuple[str, str]:
# nic_id = vm.network_profile.network_interfaces[0].id
# nic_name = nic_id.split("/")[-1]
# nic = network_client.network_interfaces.get(
# resource_group_name=resource_group,
# network_interface_name=nic_name,
# )
# ip_config = nic.ip_configurations[0]

# external_ip = None
# if not use_internal_ips:
# public_ip_id = ip_config.public_ip_address.id
# public_ip_name = public_ip_id.split("/")[-1]
# public_ip = network_client.public_ip_addresses.get(
# resource_group_name=resource_group,
# public_ip_address_name=public_ip_name,
# )
# external_ip = public_ip.ip_address

# internal_ip = ip_config.private_ip_address

# return (external_ip, internal_ip)


def _get_vm_status(compute_client, vm_name: str, resource_group: str) -> str:
def _get_vm_status(compute_client: 'azure_compute.ComputeManagementClient',
vm_name: str, resource_group: str) -> str:
instance = compute_client.virtual_machines.instance_view(
resource_group_name=resource_group, vm_name=vm_name).as_dict()
for status in instance['statuses']:
Expand All @@ -197,8 +179,10 @@ def _get_vm_status(compute_client, vm_name: str, resource_group: str) -> str:
raise ValueError(f'Failed to get status for VM {vm_name}')


def _filter_instances(compute_client, filters: Dict[str, str],
resource_group: str) -> List[Any]:
def _filter_instances(
compute_client: 'azure_compute.ComputeManagementClient',
filters: Dict[str, str],
resource_group: str) -> List['azure_compute.models.VirtualMachine']:

def match_tags(vm):
for k, v in filters.items():
Expand All @@ -212,7 +196,7 @@ def match_tags(vm):
vms = list_virtual_machines(resource_group_name=resource_group)
nodes = list(filter(match_tags, vms))
except azure.exceptions().ResourceNotFoundError as e:
if 'ResourceGroupNotFound' in str(e):
if _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE in str(e):
return []
raise
return nodes
Expand Down Expand Up @@ -245,7 +229,8 @@ def query_instances(
'Migrating': status_lib.ClusterStatus.INIT,
'Deleting': None,
# Succeeded in provisioning state means the VM is provisioned but not
# necessarily running.
# necessarily running. We exclude Succeeded state here, and the caller
# should determine the status of the VM based on the power state.
# 'Succeeded': status_lib.ClusterStatus.UP,
}

Expand All @@ -256,7 +241,9 @@ def query_instances(
nodes = _filter_instances(compute_client, filters, resource_group)
statuses = {}

def _fetch_and_map_status(compute_client, node, resource_group: str):
def _fetch_and_map_status(
compute_client: 'azure_compute.ComputeManagementClient', node,
resource_group: str):
if node.provisioning_state in provisioning_state_map:
status = provisioning_state_map[node.provisioning_state]
else:
Expand Down

0 comments on commit 4a7e0f9

Please sign in to comment.