Skip to content

Commit

Permalink
Azure finished
Browse files Browse the repository at this point in the history
  • Loading branch information
cblmemo committed Sep 3, 2023
1 parent 356b02f commit 08ce689
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 36 deletions.
37 changes: 34 additions & 3 deletions sky/adaptors/azure.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Azure cli adaptor"""

# pylint: disable=import-outside-toplevel
from functools import wraps
import functools
import threading

azure = None
_session_creation_lock = threading.RLock()


def import_package(func):

@wraps(func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
global azure
if azure is None:
Expand All @@ -35,3 +36,33 @@ def get_current_account_user() -> str:
"""Get the default account user."""
from azure.common import credentials
return credentials.get_cli_profile().get_current_account_user()


@import_package
def http_error_exception():
"""HttpError exception."""
from azure.core import exceptions
return exceptions.HttpResponseError


@functools.lru_cache()
@import_package
def get_client(name: str, subscription_id: str):
# Sky only supports Azure CLI credential for now.
# Increase the timeout to fix the Azure get-access-token timeout issue.
# Tracked in
# https://github.com/Azure/azure-cli/issues/20404#issuecomment-1249575110
from azure.identity import AzureCliCredential
from azure.mgmt.network import NetworkManagementClient
from azure.mgmt.resource import ResourceManagementClient
with _session_creation_lock:
credential = AzureCliCredential(process_timeout=30)
if name == 'compute':
from azure.mgmt.compute import ComputeManagementClient
return ComputeManagementClient(credential, subscription_id)
elif name == 'network':
return NetworkManagementClient(credential, subscription_id)
elif name == 'resource':
return ResourceManagementClient(credential, subscription_id)
else:
raise ValueError(f'Client not supported: "{name}"')
9 changes: 2 additions & 7 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2877,7 +2877,7 @@ def _get_zone(runner):
def _open_inexistent_ports(self, handle: CloudVmRayResourceHandle,
ports_to_open: List[Union[int, str]]) -> None:
cloud = handle.launched_resources.cloud
if not isinstance(cloud, (clouds.AWS, clouds.GCP)):
if not isinstance(cloud, (clouds.AWS, clouds.GCP, clouds.Azure)):
logger.warning(f'Cannot open ports for {cloud} that not support '
'new provisioner API.')
return
Expand Down Expand Up @@ -4011,12 +4011,7 @@ def post_teardown_cleanup(self,
if terminate:
cloud = handle.launched_resources.cloud
config = common_utils.read_yaml(handle.cluster_yaml)
if isinstance(cloud, (clouds.AWS, clouds.GCP)):
# Clean up AWS SGs or GCP firewall rules
# We don't need to clean up on Azure since it is done by
# our sky node provider.
# TODO(tian): Adding a no-op cleanup_ports API after #2286
# merged.
if isinstance(cloud, (clouds.AWS, clouds.GCP, clouds.Azure)):
provision_lib.cleanup_ports(repr(cloud), cluster_name_on_cloud,
config['provider'])

Expand Down
4 changes: 4 additions & 0 deletions sky/provision/azure/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Azure provisioner for SkyPilot."""

from sky.provision.aws.instance import cleanup_ports
from sky.provision.azure.instance import open_ports
83 changes: 83 additions & 0 deletions sky/provision/azure/instance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Azure instance provisioning."""
from typing import Any, Callable, Dict, List, Optional, Union

from sky import sky_logging
from sky.adaptors import azure

logger = sky_logging.init_logger(__name__)

# Tag uniquely identifying all nodes of a cluster
TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
TAG_RAY_NODE_KIND = 'ray-node-type'


def get_azure_sdk_function(client: Any, function_name: str) -> Callable:
"""Retrieve a callable function from Azure SDK client object.
Newer versions of the various client SDKs renamed function names to
have a begin_ prefix. This function supports both the old and new
versions of the SDK by first trying the old name and falling back to
the prefixed new name.
"""
func = getattr(client, function_name,
getattr(client, f'begin_{function_name}', None))
if func is None:
raise AttributeError(
'"{obj}" object has no {func} or begin_{func} attribute'.format(
obj={client.__name__}, func=function_name))
return func


def open_ports(
cluster_name_on_cloud: str,
ports: List[Union[int, str]],
provider_config: Optional[Dict[str, Any]] = None,
) -> None:
"""See sky/provision/__init__.py"""
assert provider_config is not None, cluster_name_on_cloud
subscription_id = provider_config['subscription_id']
resource_group = provider_config['resource_group']
network_client = azure.get_client('network', subscription_id)
create_or_update = get_azure_sdk_function(
client=network_client.security_rules, function_name='create_or_update')
ports = [str(port) for port in ports if port != 22]
rule_name = f'user-ports-{"-".join(ports)}'

def security_rule_parameters(priority: int) -> Dict[str, Any]:
return {
"priority": priority,
"protocol": "TCP",
"access": "Allow",
"direction": "Inbound",
"sourceAddressPrefix": "*",
"sourcePortRange": "*",
"destinationAddressPrefix": "*",
"destinationPortRanges": ports,
}

list_nsg = get_azure_sdk_function(
client=network_client.network_security_groups, function_name='list')
for nsg in list_nsg(resource_group):
try:
# Azure NSG rules have a priority field that determines the order
# in which they are applied. The priority must be unique across
# all inbound rules in one NSG.
max_inbound_priority = max([
rule.priority
for rule in nsg.security_rules
if rule.direction == 'Inbound'
])
create_or_update(resource_group, nsg.name, rule_name,
security_rule_parameters(max_inbound_priority + 1))
except azure.http_error_exception() as e:
logger.warning(
f'Failed to open ports {ports} in NSG {nsg.name}: {e}')


def cleanup_ports(
cluster_name_on_cloud: str,
provider_config: Optional[Dict[str, Any]] = None,
) -> None:
"""See sky/provision/__init__.py"""
# Azure will automatically cleanup network security groups when cleanup
# resource group. So we don't need to do anything here.
del cluster_name_on_cloud, provider_config # Unused.
26 changes: 0 additions & 26 deletions sky/skylet/providers/azure/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,32 +81,6 @@ def _configure_resource_group(config):
with open(template_path, "r") as template_fp:
template = json.load(template_fp)

# Setup firewall rules for ports
nsg_resource = None
for resource in template["resources"]:
if resource["type"] == "Microsoft.Network/networkSecurityGroups":
nsg_resource = resource
break
assert nsg_resource is not None, "Could not find NSG resource in template"
ports = config["provider"].get("ports", None)
if ports is not None:
ports = [str(port) for port in ports if port != 22]
nsg_resource["properties"]["securityRules"].append(
{
"name": "user-ports",
"properties": {
"priority": 1001,
"protocol": "TCP",
"access": "Allow",
"direction": "Inbound",
"sourceAddressPrefix": "*",
"sourcePortRange": "*",
"destinationAddressPrefix": "*",
"destinationPortRanges": ports,
},
}
)

logger.info("Using cluster name: %s", config["cluster_name"])

# set unique id for resources in this cluster
Expand Down

0 comments on commit 08ce689

Please sign in to comment.