Skip to content

Commit

Permalink
refactor check for nsg creation
Browse files Browse the repository at this point in the history
  • Loading branch information
landscapepainter committed Aug 21, 2024
1 parent b6edd1e commit befcbc1
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 37 deletions.
8 changes: 7 additions & 1 deletion sky/provision/azure/azure-config-template.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@
"metadata": {
"description": "Subnet parameters."
}
},
"nsgName": {
"type": "string",
"metadata": {
"description": "Name of the Network Security Group associated with the SkyPilot cluster."
}
}
},
"variables": {
"contributor": "[subscriptionResourceId('Microsoft.Authorization/roleDefinitions', 'b24988ac-6180-42a0-ab88-20f7382dd24c')]",
"location": "[resourceGroup().location]",
"msiName": "[concat('sky-', parameters('clusterId'), '-msi')]",
"roleAssignmentName": "[concat('sky-', parameters('clusterId'), '-ra')]",
"nsgName": "[concat('sky-', parameters('clusterId'), '-nsg')]",
"nsgName": "[parameters('nsgName')]",
"nsg": "[resourceId('Microsoft.Network/networkSecurityGroups', variables('nsgName'))]",
"vnetName": "[concat('sky-', parameters('clusterId'), '-vnet')]",
"subnetName": "[concat('sky-', parameters('clusterId'), '-subnet')]"
Expand Down
39 changes: 32 additions & 7 deletions sky/provision/azure/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,30 @@ def get_azure_sdk_function(client: Any, function_name: str) -> Callable:
return func


def get_cluster_id(resource_group: str, cluster_name_on_cloud: str):
"""Generate a unique cluster ID."""
hasher = hashlib.md5(resource_group.encode('utf-8'))
unique_id = hasher.hexdigest()[:UNIQUE_ID_LEN]
# We use the cluster name + resource group hash as the
# unique ID for the cluster, as we need to make sure that
# the deployments have unique names during failover.
cluster_id = f'{cluster_name_on_cloud}-{unique_id}'
return cluster_id


def get_nsg_name(resource_group: str = '',
cluster_name_on_cloud: str = '',
cluster_id: str = '') -> str:
"""Return the NSG name for a given cluster.."""
if not cluster_id:
assert resource_group
assert cluster_name_on_cloud
cluster_id = get_cluster_id(resource_group=resource_group,
cluster_name_on_cloud=cluster_name_on_cloud)
nsg_name = f'sky-{cluster_id}-nsg'
return nsg_name


@common.log_function_start_end
def bootstrap_instances(
region: str, cluster_name_on_cloud: str,
Expand Down Expand Up @@ -105,14 +129,15 @@ def bootstrap_instances(

logger.info(f'Using cluster name: {cluster_name_on_cloud}')

hasher = hashlib.md5(provider_config['resource_group'].encode('utf-8'))
unique_id = hasher.hexdigest()[:UNIQUE_ID_LEN]
cluster_id = get_cluster_id(resource_group=provider_config['resource_group'],
cluster_name_on_cloud=cluster_name_on_cloud)
subnet_mask = provider_config.get('subnet_mask')
if subnet_mask is None:
# choose a random subnet, skipping most common value of 0
random.seed(unique_id)
random.seed(cluster_id)
subnet_mask = f'10.{random.randint(1, 254)}.0.0/16'
logger.info(f'Using subnet mask: {subnet_mask}')
nsg_name = get_nsg_name(cluster_id=cluster_id)

parameters = {
'properties': {
Expand All @@ -123,10 +148,10 @@ def bootstrap_instances(
'value': subnet_mask
},
'clusterId': {
# We use the cluster name + resource group hash as the
# unique ID for the cluster, as we need to make sure that
# the deployments have unique names during failover.
'value': f'{cluster_name_on_cloud}-{unique_id}'
'value': cluster_id
},
'nsgName': {
'value': nsg_name
},
},
}
Expand Down
74 changes: 45 additions & 29 deletions sky/provision/azure/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sky.adaptors import azure
from sky.provision import common
from sky.provision import constants
from sky.provision.azure import config as config_lib
from sky.utils import common_utils
from sky.utils import ux_utils

Expand All @@ -30,6 +31,8 @@
# https://github.com/Azure/azure-sdk-for-python/issues/9422
azure_logger = logging.getLogger('azure')
azure_logger.setLevel(logging.WARNING)
Client = Any
NetworkSecurityGroup = Any

_RESUME_INSTANCE_TIMEOUT = 480 # 8 minutes
_RESUME_PER_INSTANCE_TIMEOUT = 120 # 2 minutes
Expand All @@ -39,9 +42,8 @@

_RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE = 'ResourceGroupNotFound'
_POLL_INTERVAL = 1
# Naming convention we use to create Network Security Group for provisioned
# VM at our ARM template, sky/provision/azure/azure-config-template.json.
_NSG_NAME = 'sky-{cluster_name_on_cloud}-nsg'
_LEGACY_NSG_NAME = 'ray-{cluster_name_on_cloud}-nsg'
_SECOND_LEGACY_NSG_NAME = 'sky-{cluster_name_on_cloud}-nsg'


class AzureInstanceStatus(enum.Enum):
Expand Down Expand Up @@ -707,6 +709,31 @@ def _fetch_and_map_status(node, resource_group: str) -> None:
return statuses


def _get_cluster_nsg(network_client: Client,
resource_group: str,
cluster_name_on_cloud: str) -> NetworkSecurityGroup:
"""Retrieve the NSG associated with the given name of the cluster."""
list_network_security_groups = _get_azure_sdk_function(
client=network_client.network_security_groups, function_name='list')
legacy_nsg_name = _LEGACY_NSG_NAME.format(
cluster_name_on_cloud=cluster_name_on_cloud)
second_legacy_nsg_name = _SECOND_LEGACY_NSG_NAME.format(
cluster_name_on_cloud=cluster_name_on_cloud)
nsg_name = config_lib.get_nsg_name(
resource_group=resource_group,
cluster_name_on_cloud=cluster_name_on_cloud)
possible_nsg_names = [nsg_name, legacy_nsg_name, second_legacy_nsg_name]
for nsg in list_network_security_groups(resource_group):
if nsg.name in possible_nsg_names:
return nsg

# Raise an error if no matching NSG is found
raise ValueError('Failed to find a matching NSG for cluster '
f'{cluster_name_on_cloud!r} in resource group '
f'{resource_group!r}. Expected NSG names were: '
f'{possible_nsg_names}.')


def open_ports(
cluster_name_on_cloud: str,
ports: List[str],
Expand All @@ -717,28 +744,26 @@ def open_ports(
subscription_id = provider_config['subscription_id']
resource_group = provider_config['resource_group']
network_client = azure.get_client('network', subscription_id)
nsg_name = _NSG_NAME.format(cluster_name_on_cloud=cluster_name_on_cloud)

update_network_security_groups = _get_azure_sdk_function(
client=network_client.network_security_groups,
function_name='create_or_update')
list_network_security_groups = _get_azure_sdk_function(
client=network_client.network_security_groups, function_name='list')

try:
# Wait for the NSG creation to be finished before opening a port. The
# cluster provisioning triggers the NSG creation, but it may not be
# finished yet.
nsg_to_open_ports = None
nsg_created = False
backoff = common_utils.Backoff(max_backoff_factor=1)
start_time = time.time()
while not nsg_created:
while True:
nsg = _get_cluster_nsg(network_client, resource_group, cluster_name_on_cloud)
if nsg.provisioning_state not in ['Creating', 'Updating']:
break
if time.time() - start_time > _WAIT_CREATION_TIMEOUT_SECONDS:
with ux_utils.print_exception_no_traceback():
raise TimeoutError(
f'Timed out while waiting for the Network '
f'Security Group {nsg_name!r} to be ready for '
f'Security Group {nsg.name!r} to be ready for '
f'cluster {cluster_name_on_cloud!r} in '
f'resource group {resource_group!r}. The NSG '
f'did not reach a stable state '
Expand All @@ -747,28 +772,19 @@ def open_ports(
'Consequently, the operation to open ports '
f'{ports} failed.')

nsg_list = list_network_security_groups(resource_group)
for nsg in nsg_list:
if nsg_name == nsg.name:
if nsg.provisioning_state not in ['Creating', 'Updating']:
nsg_to_open_ports = nsg
nsg_created = True
break

backoff_time = backoff.current_backoff()
logger.info(
f'NSG {nsg_name} is not created yet. Waiting for '
f'{backoff_time} seconds before checking again.')
time.sleep(backoff_time)
backoff_time = backoff.current_backoff()
logger.info(
f'NSG {nsg.name} is not created yet. Waiting for '
f'{backoff_time} seconds before checking again.')
time.sleep(backoff_time)

# Azure NSG rules have a priority field that determines the order
# in which they are applied. The priority must be unique across
# all inbound rules in one NSG.
assert nsg_to_open_ports is not None
priority = max(rule.priority
for rule in nsg_to_open_ports.security_rules
for rule in nsg.security_rules
if rule.direction == 'Inbound') + 1
nsg_to_open_ports.security_rules.append(
nsg.security_rules.append(
azure.create_security_rule(
name=f'sky-ports-{cluster_name_on_cloud}-{priority}',
priority=priority,
Expand All @@ -781,13 +797,13 @@ def open_ports(
destination_port_ranges=ports,
))
poller = update_network_security_groups(resource_group,
nsg_to_open_ports.name,
nsg_to_open_ports)
nsg.name,
nsg)
poller.wait()
if poller.status() != 'Succeeded':
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Failed to open ports {ports} in NSG '
f'{nsg_to_open_ports.name}: {poller.status()}')
f'{nsg.name}: {poller.status()}')

except azure.exceptions().HttpResponseError as e:
with ux_utils.print_exception_no_traceback():
Expand Down

0 comments on commit befcbc1

Please sign in to comment.