From befcbc1db99c1feb5773638e9c8cd53d12a57ccf Mon Sep 17 00:00:00 2001 From: Doyoung Kim Date: Wed, 21 Aug 2024 04:18:56 +0000 Subject: [PATCH] refactor check for nsg creation --- .../azure/azure-config-template.json | 8 +- sky/provision/azure/config.py | 39 ++++++++-- sky/provision/azure/instance.py | 74 +++++++++++-------- 3 files changed, 84 insertions(+), 37 deletions(-) diff --git a/sky/provision/azure/azure-config-template.json b/sky/provision/azure/azure-config-template.json index 489783faf98..c743dd40215 100644 --- a/sky/provision/azure/azure-config-template.json +++ b/sky/provision/azure/azure-config-template.json @@ -13,6 +13,12 @@ "metadata": { "description": "Subnet parameters." } + }, + "nsgName": { + "type": "string", + "metadata": { + "description": "Name of the Network Security Group associated with the SkyPilot cluster." + } } }, "variables": { @@ -20,7 +26,7 @@ "location": "[resourceGroup().location]", "msiName": "[concat('sky-', parameters('clusterId'), '-msi')]", "roleAssignmentName": "[concat('sky-', parameters('clusterId'), '-ra')]", - "nsgName": "[concat('sky-', parameters('clusterId'), '-nsg')]", + "nsgName": "[parameters('nsgName')]", "nsg": "[resourceId('Microsoft.Network/networkSecurityGroups', variables('nsgName'))]", "vnetName": "[concat('sky-', parameters('clusterId'), '-vnet')]", "subnetName": "[concat('sky-', parameters('clusterId'), '-subnet')]" diff --git a/sky/provision/azure/config.py b/sky/provision/azure/config.py index 7b50c3d8c0f..97320aef329 100644 --- a/sky/provision/azure/config.py +++ b/sky/provision/azure/config.py @@ -39,6 +39,30 @@ def get_azure_sdk_function(client: Any, function_name: str) -> Callable: return func +def get_cluster_id(resource_group: str, cluster_name_on_cloud: str): + """Generate a unique cluster ID.""" + hasher = hashlib.md5(resource_group.encode('utf-8')) + unique_id = hasher.hexdigest()[:UNIQUE_ID_LEN] + # We use the cluster name + resource group hash as the + # unique ID for the cluster, as we need to make sure that + # the deployments have unique names during failover. + cluster_id = f'{cluster_name_on_cloud}-{unique_id}' + return cluster_id + + +def get_nsg_name(resource_group: str = '', + cluster_name_on_cloud: str = '', + cluster_id: str = '') -> str: + """Return the NSG name for a given cluster..""" + if not cluster_id: + assert resource_group + assert cluster_name_on_cloud + cluster_id = get_cluster_id(resource_group=resource_group, + cluster_name_on_cloud=cluster_name_on_cloud) + nsg_name = f'sky-{cluster_id}-nsg' + return nsg_name + + @common.log_function_start_end def bootstrap_instances( region: str, cluster_name_on_cloud: str, @@ -105,14 +129,15 @@ def bootstrap_instances( logger.info(f'Using cluster name: {cluster_name_on_cloud}') - hasher = hashlib.md5(provider_config['resource_group'].encode('utf-8')) - unique_id = hasher.hexdigest()[:UNIQUE_ID_LEN] + cluster_id = get_cluster_id(resource_group=provider_config['resource_group'], + cluster_name_on_cloud=cluster_name_on_cloud) subnet_mask = provider_config.get('subnet_mask') if subnet_mask is None: # choose a random subnet, skipping most common value of 0 - random.seed(unique_id) + random.seed(cluster_id) subnet_mask = f'10.{random.randint(1, 254)}.0.0/16' logger.info(f'Using subnet mask: {subnet_mask}') + nsg_name = get_nsg_name(cluster_id=cluster_id) parameters = { 'properties': { @@ -123,10 +148,10 @@ def bootstrap_instances( 'value': subnet_mask }, 'clusterId': { - # We use the cluster name + resource group hash as the - # unique ID for the cluster, as we need to make sure that - # the deployments have unique names during failover. - 'value': f'{cluster_name_on_cloud}-{unique_id}' + 'value': cluster_id + }, + 'nsgName': { + 'value': nsg_name }, }, } diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index 91593fd1160..c5dcfa52823 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -17,6 +17,7 @@ from sky.adaptors import azure from sky.provision import common from sky.provision import constants +from sky.provision.azure import config as config_lib from sky.utils import common_utils from sky.utils import ux_utils @@ -30,6 +31,8 @@ # https://github.com/Azure/azure-sdk-for-python/issues/9422 azure_logger = logging.getLogger('azure') azure_logger.setLevel(logging.WARNING) +Client = Any +NetworkSecurityGroup = Any _RESUME_INSTANCE_TIMEOUT = 480 # 8 minutes _RESUME_PER_INSTANCE_TIMEOUT = 120 # 2 minutes @@ -39,9 +42,8 @@ _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE = 'ResourceGroupNotFound' _POLL_INTERVAL = 1 -# Naming convention we use to create Network Security Group for provisioned -# VM at our ARM template, sky/provision/azure/azure-config-template.json. -_NSG_NAME = 'sky-{cluster_name_on_cloud}-nsg' +_LEGACY_NSG_NAME = 'ray-{cluster_name_on_cloud}-nsg' +_SECOND_LEGACY_NSG_NAME = 'sky-{cluster_name_on_cloud}-nsg' class AzureInstanceStatus(enum.Enum): @@ -707,6 +709,31 @@ def _fetch_and_map_status(node, resource_group: str) -> None: return statuses +def _get_cluster_nsg(network_client: Client, + resource_group: str, + cluster_name_on_cloud: str) -> NetworkSecurityGroup: + """Retrieve the NSG associated with the given name of the cluster.""" + list_network_security_groups = _get_azure_sdk_function( + client=network_client.network_security_groups, function_name='list') + legacy_nsg_name = _LEGACY_NSG_NAME.format( + cluster_name_on_cloud=cluster_name_on_cloud) + second_legacy_nsg_name = _SECOND_LEGACY_NSG_NAME.format( + cluster_name_on_cloud=cluster_name_on_cloud) + nsg_name = config_lib.get_nsg_name( + resource_group=resource_group, + cluster_name_on_cloud=cluster_name_on_cloud) + possible_nsg_names = [nsg_name, legacy_nsg_name, second_legacy_nsg_name] + for nsg in list_network_security_groups(resource_group): + if nsg.name in possible_nsg_names: + return nsg + + # Raise an error if no matching NSG is found + raise ValueError('Failed to find a matching NSG for cluster ' + f'{cluster_name_on_cloud!r} in resource group ' + f'{resource_group!r}. Expected NSG names were: ' + f'{possible_nsg_names}.') + + def open_ports( cluster_name_on_cloud: str, ports: List[str], @@ -717,28 +744,26 @@ def open_ports( subscription_id = provider_config['subscription_id'] resource_group = provider_config['resource_group'] network_client = azure.get_client('network', subscription_id) - nsg_name = _NSG_NAME.format(cluster_name_on_cloud=cluster_name_on_cloud) update_network_security_groups = _get_azure_sdk_function( client=network_client.network_security_groups, function_name='create_or_update') - list_network_security_groups = _get_azure_sdk_function( - client=network_client.network_security_groups, function_name='list') try: # Wait for the NSG creation to be finished before opening a port. The # cluster provisioning triggers the NSG creation, but it may not be # finished yet. - nsg_to_open_ports = None - nsg_created = False backoff = common_utils.Backoff(max_backoff_factor=1) start_time = time.time() - while not nsg_created: + while True: + nsg = _get_cluster_nsg(network_client, resource_group, cluster_name_on_cloud) + if nsg.provisioning_state not in ['Creating', 'Updating']: + break if time.time() - start_time > _WAIT_CREATION_TIMEOUT_SECONDS: with ux_utils.print_exception_no_traceback(): raise TimeoutError( f'Timed out while waiting for the Network ' - f'Security Group {nsg_name!r} to be ready for ' + f'Security Group {nsg.name!r} to be ready for ' f'cluster {cluster_name_on_cloud!r} in ' f'resource group {resource_group!r}. The NSG ' f'did not reach a stable state ' @@ -747,28 +772,19 @@ def open_ports( 'Consequently, the operation to open ports ' f'{ports} failed.') - nsg_list = list_network_security_groups(resource_group) - for nsg in nsg_list: - if nsg_name == nsg.name: - if nsg.provisioning_state not in ['Creating', 'Updating']: - nsg_to_open_ports = nsg - nsg_created = True - break - - backoff_time = backoff.current_backoff() - logger.info( - f'NSG {nsg_name} is not created yet. Waiting for ' - f'{backoff_time} seconds before checking again.') - time.sleep(backoff_time) + backoff_time = backoff.current_backoff() + logger.info( + f'NSG {nsg.name} is not created yet. Waiting for ' + f'{backoff_time} seconds before checking again.') + time.sleep(backoff_time) # Azure NSG rules have a priority field that determines the order # in which they are applied. The priority must be unique across # all inbound rules in one NSG. - assert nsg_to_open_ports is not None priority = max(rule.priority - for rule in nsg_to_open_ports.security_rules + for rule in nsg.security_rules if rule.direction == 'Inbound') + 1 - nsg_to_open_ports.security_rules.append( + nsg.security_rules.append( azure.create_security_rule( name=f'sky-ports-{cluster_name_on_cloud}-{priority}', priority=priority, @@ -781,13 +797,13 @@ def open_ports( destination_port_ranges=ports, )) poller = update_network_security_groups(resource_group, - nsg_to_open_ports.name, - nsg_to_open_ports) + nsg.name, + nsg) poller.wait() if poller.status() != 'Succeeded': with ux_utils.print_exception_no_traceback(): raise ValueError(f'Failed to open ports {ports} in NSG ' - f'{nsg_to_open_ports.name}: {poller.status()}') + f'{nsg.name}: {poller.status()}') except azure.exceptions().HttpResponseError as e: with ux_utils.print_exception_no_traceback():