Skip to content

Commit

Permalink
[Azure] Fix to sync NSG status while opening ports (#3844)
Browse files Browse the repository at this point in the history
* fix to update NSG status while opening ports

* nit

* format

* refactor check for nsg creation

* format

* nit

* format

* Update sky/provision/azure/config.py

Co-authored-by: Zhanghao Wu <[email protected]>

* Update sky/provision/azure/instance.py

Co-authored-by: Zhanghao Wu <[email protected]>

* Update sky/provision/azure/config.py

Co-authored-by: Zhanghao Wu <[email protected]>

* Update sky/provision/azure/config.py

Co-authored-by: Zhanghao Wu <[email protected]>

* format

* additional TODO comments

---------

Co-authored-by: Zhanghao Wu <[email protected]>
  • Loading branch information
landscapepainter and Michaelvll authored Oct 25, 2024
1 parent 7c5b7e0 commit 057bc4b
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 59 deletions.
8 changes: 7 additions & 1 deletion sky/provision/azure/azure-config-template.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@
"metadata": {
"description": "Subnet parameters."
}
},
"nsgName": {
"type": "string",
"metadata": {
"description": "Name of the Network Security Group associated with the SkyPilot cluster."
}
}
},
"variables": {
"contributor": "[subscriptionResourceId('Microsoft.Authorization/roleDefinitions', 'b24988ac-6180-42a0-ab88-20f7382dd24c')]",
"location": "[resourceGroup().location]",
"msiName": "[concat('sky-', parameters('clusterId'), '-msi')]",
"roleAssignmentName": "[concat('sky-', parameters('clusterId'), '-ra')]",
"nsgName": "[concat('sky-', parameters('clusterId'), '-nsg')]",
"nsgName": "[parameters('nsgName')]",
"nsg": "[resourceId('Microsoft.Network/networkSecurityGroups', variables('nsgName'))]",
"vnetName": "[concat('sky-', parameters('clusterId'), '-vnet')]",
"subnetName": "[concat('sky-', parameters('clusterId'), '-subnet')]"
Expand Down
31 changes: 23 additions & 8 deletions sky/provision/azure/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path
import random
import time
from typing import Any, Callable
from typing import Any, Callable, Tuple

from sky import exceptions
from sky import sky_logging
Expand All @@ -22,6 +22,7 @@
_DEPLOYMENT_NAME = 'skypilot-config'
_LEGACY_DEPLOYMENT_NAME = 'ray-config'
_RESOURCE_GROUP_WAIT_FOR_DELETION_TIMEOUT = 480 # 8 minutes
_CLUSTER_ID = '{cluster_name_on_cloud}-{unique_id}'


def get_azure_sdk_function(client: Any, function_name: str) -> Callable:
Expand All @@ -41,6 +42,19 @@ def get_azure_sdk_function(client: Any, function_name: str) -> Callable:
return func


def get_cluster_id_and_nsg_name(resource_group: str,
cluster_name_on_cloud: str) -> Tuple[str, str]:
hasher = hashlib.md5(resource_group.encode('utf-8'))
unique_id = hasher.hexdigest()[:UNIQUE_ID_LEN]
# We use the cluster name + resource group hash as the
# unique ID for the cluster, as we need to make sure that
# the deployments have unique names during failover.
cluster_id = _CLUSTER_ID.format(cluster_name_on_cloud=cluster_name_on_cloud,
unique_id=unique_id)
nsg_name = f'sky-{cluster_id}-nsg'
return cluster_id, nsg_name


@common.log_function_start_end
def bootstrap_instances(
region: str, cluster_name_on_cloud: str,
Expand Down Expand Up @@ -117,12 +131,13 @@ def bootstrap_instances(

logger.info(f'Using cluster name: {cluster_name_on_cloud}')

hasher = hashlib.md5(provider_config['resource_group'].encode('utf-8'))
unique_id = hasher.hexdigest()[:UNIQUE_ID_LEN]
cluster_id, nsg_name = get_cluster_id_and_nsg_name(
resource_group=provider_config['resource_group'],
cluster_name_on_cloud=cluster_name_on_cloud)
subnet_mask = provider_config.get('subnet_mask')
if subnet_mask is None:
# choose a random subnet, skipping most common value of 0
random.seed(unique_id)
random.seed(cluster_id)
subnet_mask = f'10.{random.randint(1, 254)}.0.0/16'
logger.info(f'Using subnet mask: {subnet_mask}')

Expand All @@ -135,10 +150,10 @@ def bootstrap_instances(
'value': subnet_mask
},
'clusterId': {
# We use the cluster name + resource group hash as the
# unique ID for the cluster, as we need to make sure that
# the deployments have unique names during failover.
'value': f'{cluster_name_on_cloud}-{unique_id}'
'value': cluster_id
},
'nsgName': {
'value': nsg_name
},
},
}
Expand Down
141 changes: 91 additions & 50 deletions sky/provision/azure/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sky.adaptors import azure
from sky.provision import common
from sky.provision import constants
from sky.provision.azure import config as config_lib
from sky.utils import common_utils
from sky.utils import subprocess_utils
from sky.utils import ux_utils
Expand All @@ -31,6 +32,8 @@
# https://github.com/Azure/azure-sdk-for-python/issues/9422
azure_logger = logging.getLogger('azure')
azure_logger.setLevel(logging.WARNING)
Client = Any
NetworkSecurityGroup = Any

_RESUME_INSTANCE_TIMEOUT = 480 # 8 minutes
_RESUME_PER_INSTANCE_TIMEOUT = 120 # 2 minutes
Expand All @@ -40,6 +43,10 @@

_RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE = 'ResourceGroupNotFound'
_POLL_INTERVAL = 1
# TODO(Doyoung): _LEGACY_NSG_NAME can be remove this after 0.8.0 to ignore
# legacy nsg names.
_LEGACY_NSG_NAME = 'ray-{cluster_name_on_cloud}-nsg'
_SECOND_LEGACY_NSG_NAME = 'sky-{cluster_name_on_cloud}-nsg'


class AzureInstanceStatus(enum.Enum):
Expand Down Expand Up @@ -795,6 +802,32 @@ def _fetch_and_map_status(node, resource_group: str) -> None:
return statuses


# TODO(Doyoung): _get_cluster_nsg can be remove this after 0.8.0 to ignore
# legacy nsg names.
def _get_cluster_nsg(network_client: Client, resource_group: str,
cluster_name_on_cloud: str) -> NetworkSecurityGroup:
"""Retrieve the NSG associated with the given name of the cluster."""
list_network_security_groups = _get_azure_sdk_function(
client=network_client.network_security_groups, function_name='list')
legacy_nsg_name = _LEGACY_NSG_NAME.format(
cluster_name_on_cloud=cluster_name_on_cloud)
second_legacy_nsg_name = _SECOND_LEGACY_NSG_NAME.format(
cluster_name_on_cloud=cluster_name_on_cloud)
_, nsg_name = config_lib.get_cluster_id_and_nsg_name(
resource_group=resource_group,
cluster_name_on_cloud=cluster_name_on_cloud)
possible_nsg_names = [nsg_name, legacy_nsg_name, second_legacy_nsg_name]
for nsg in list_network_security_groups(resource_group):
if nsg.name in possible_nsg_names:
return nsg

# Raise an error if no matching NSG is found
raise ValueError('Failed to find a matching NSG for cluster '
f'{cluster_name_on_cloud!r} in resource group '
f'{resource_group!r}. Expected NSG names were: '
f'{possible_nsg_names}.')


def open_ports(
cluster_name_on_cloud: str,
ports: List[str],
Expand All @@ -809,58 +842,66 @@ def open_ports(
update_network_security_groups = _get_azure_sdk_function(
client=network_client.network_security_groups,
function_name='create_or_update')
list_network_security_groups = _get_azure_sdk_function(
client=network_client.network_security_groups, function_name='list')
for nsg in list_network_security_groups(resource_group):
try:
# Wait the NSG creation to be finished before opening a port. The
# cluster provisioning triggers the NSG creation, but it may not be
# finished yet.
backoff = common_utils.Backoff(max_backoff_factor=1)
start_time = time.time()
while True:
if nsg.provisioning_state not in ['Creating', 'Updating']:
break
if time.time() - start_time > _WAIT_CREATION_TIMEOUT_SECONDS:
logger.warning(
f'Fails to wait for the creation of NSG {nsg.name} in '
f'{resource_group} within '
f'{_WAIT_CREATION_TIMEOUT_SECONDS} seconds. '
'Skip this NSG.')
backoff_time = backoff.current_backoff()
logger.info(f'NSG {nsg.name} is not created yet. Waiting for '
f'{backoff_time} seconds before checking again.')
time.sleep(backoff_time)

# Azure NSG rules have a priority field that determines the order
# in which they are applied. The priority must be unique across
# all inbound rules in one NSG.
priority = max(rule.priority
for rule in nsg.security_rules
if rule.direction == 'Inbound') + 1
nsg.security_rules.append(
azure.create_security_rule(
name=f'sky-ports-{cluster_name_on_cloud}-{priority}',
priority=priority,
protocol='Tcp',
access='Allow',
direction='Inbound',
source_address_prefix='*',
source_port_range='*',
destination_address_prefix='*',
destination_port_ranges=ports,
))
poller = update_network_security_groups(resource_group, nsg.name,
nsg)
poller.wait()
if poller.status() != 'Succeeded':

try:
# Wait for the NSG creation to be finished before opening a port. The
# cluster provisioning triggers the NSG creation, but it may not be
# finished yet.
backoff = common_utils.Backoff(max_backoff_factor=1)
start_time = time.time()
while True:
nsg = _get_cluster_nsg(network_client, resource_group,
cluster_name_on_cloud)
if nsg.provisioning_state not in ['Creating', 'Updating']:
break
if time.time() - start_time > _WAIT_CREATION_TIMEOUT_SECONDS:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Failed to open ports {ports} in NSG '
f'{nsg.name}: {poller.status()}')
except azure.exceptions().HttpResponseError as e:
raise TimeoutError(
f'Timed out while waiting for the Network '
f'Security Group {nsg.name!r} to be ready for '
f'cluster {cluster_name_on_cloud!r} in '
f'resource group {resource_group!r}. The NSG '
f'did not reach a stable state '
'(Creating/Updating) within the allocated '
f'{_WAIT_CREATION_TIMEOUT_SECONDS} seconds. '
'Consequently, the operation to open ports '
f'{ports} failed.')

backoff_time = backoff.current_backoff()
logger.info(f'NSG {nsg.name} is not created yet. Waiting for '
f'{backoff_time} seconds before checking again.')
time.sleep(backoff_time)

# Azure NSG rules have a priority field that determines the order
# in which they are applied. The priority must be unique across
# all inbound rules in one NSG.
priority = max(rule.priority
for rule in nsg.security_rules
if rule.direction == 'Inbound') + 1
nsg.security_rules.append(
azure.create_security_rule(
name=f'sky-ports-{cluster_name_on_cloud}-{priority}',
priority=priority,
protocol='Tcp',
access='Allow',
direction='Inbound',
source_address_prefix='*',
source_port_range='*',
destination_address_prefix='*',
destination_port_ranges=ports,
))
poller = update_network_security_groups(resource_group, nsg.name, nsg)
poller.wait()
if poller.status() != 'Succeeded':
with ux_utils.print_exception_no_traceback():
raise ValueError(
f'Failed to open ports {ports} in NSG {nsg.name}.') from e
raise ValueError(f'Failed to open ports {ports} in NSG '
f'{nsg.name}: {poller.status()}')

except azure.exceptions().HttpResponseError as e:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Failed to open ports {ports} in NSG for cluster '
f'{cluster_name_on_cloud!r} within resource group '
f'{resource_group!r}.') from e


def cleanup_ports(
Expand Down

0 comments on commit 057bc4b

Please sign in to comment.