diff --git a/sky/adaptors/azure.py b/sky/adaptors/azure.py index 2752129e305..0730b76ec88 100644 --- a/sky/adaptors/azure.py +++ b/sky/adaptors/azure.py @@ -131,6 +131,9 @@ def get_client(name: str, from azure.mgmt import authorization return authorization.AuthorizationManagementClient( credential, subscription_id) + elif name == 'msi': + from azure.mgmt import msi + return msi.ManagedServiceIdentityClient(credential, subscription_id) elif name == 'graph': import msgraph return msgraph.GraphServiceClient(credential) diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index fc9579d17c0..cc90f273dd9 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -12,6 +12,7 @@ from sky import clouds from sky import exceptions from sky import sky_logging +from sky import skypilot_config from sky.adaptors import azure from sky.clouds import service_catalog from sky.clouds.utils import azure_utils @@ -353,6 +354,13 @@ def make_deploy_resources_variables( need_nvidia_driver_extension = (acc_dict is not None and 'A10' in acc_dict) + # Determine resource group for deploying the instance. + resource_group_name = skypilot_config.get_nested( + ('azure', 'resource_group_vm'), None) + use_external_resource_group = resource_group_name is not None + if resource_group_name is None: + resource_group_name = f'{cluster_name.name_on_cloud}-{region_name}' + # Setup commands to eliminate the banner and restart sshd. # This script will modify /etc/ssh/sshd_config and add a bash script # into .bashrc. The bash script will restart sshd if it has not been @@ -409,7 +417,8 @@ def _failover_disk_tier() -> Optional[resources_utils.DiskTier]: 'disk_tier': Azure._get_disk_type(disk_tier), 'cloud_init_setup_commands': cloud_init_setup_commands, 'azure_subscription_id': self.get_project_id(dryrun), - 'resource_group': f'{cluster_name.name_on_cloud}-{region_name}', + 'resource_group': resource_group_name, + 'use_external_resource_group': use_external_resource_group, } # Setting disk performance tier for high disk tier. diff --git a/sky/provision/azure/azure-config-template.json b/sky/provision/azure/azure-config-template.json index c743dd40215..0c70c4d3999 100644 --- a/sky/provision/azure/azure-config-template.json +++ b/sky/provision/azure/azure-config-template.json @@ -14,6 +14,12 @@ "description": "Subnet parameters." } }, + "location": { + "type": "string", + "metadata": { + "description": "Location of where the resources are allocated." + } + }, "nsgName": { "type": "string", "metadata": { @@ -23,7 +29,7 @@ }, "variables": { "contributor": "[subscriptionResourceId('Microsoft.Authorization/roleDefinitions', 'b24988ac-6180-42a0-ab88-20f7382dd24c')]", - "location": "[resourceGroup().location]", + "location": "[parameters('location')]", "msiName": "[concat('sky-', parameters('clusterId'), '-msi')]", "roleAssignmentName": "[concat('sky-', parameters('clusterId'), '-ra')]", "nsgName": "[parameters('nsgName')]", diff --git a/sky/provision/azure/config.py b/sky/provision/azure/config.py index afa94b4adbe..e7ab59daa33 100644 --- a/sky/provision/azure/config.py +++ b/sky/provision/azure/config.py @@ -14,13 +14,12 @@ from sky import sky_logging from sky.adaptors import azure from sky.provision import common +from sky.provision import constants from sky.utils import common_utils logger = sky_logging.init_logger(__name__) UNIQUE_ID_LEN = 4 -_DEPLOYMENT_NAME = 'skypilot-config' -_LEGACY_DEPLOYMENT_NAME = 'ray-config' _RESOURCE_GROUP_WAIT_FOR_DELETION_TIMEOUT = 480 # 8 minutes _CLUSTER_ID = '{cluster_name_on_cloud}-{unique_id}' @@ -82,46 +81,55 @@ def bootstrap_instances( in provider_config), 'Provider config must include location field' params = {'location': provider_config['location']} + assert ('use_external_resource_group' + in provider_config), ('Provider config must include ' + 'use_external_resource_group field') + use_external_resource_group = provider_config['use_external_resource_group'] + if 'tags' in provider_config: params['tags'] = provider_config['tags'] - logger.info(f'Creating/Updating resource group: {resource_group}') - rg_create_or_update = get_azure_sdk_function( - client=resource_client.resource_groups, - function_name='create_or_update') - rg_creation_start = time.time() - retry = 0 - while (time.time() - rg_creation_start < - _RESOURCE_GROUP_WAIT_FOR_DELETION_TIMEOUT): - try: - rg_create_or_update(resource_group_name=resource_group, - parameters=params) - break - except azure.exceptions().ResourceExistsError as e: - if 'ResourceGroupBeingDeleted' in str(e): - if retry % 5 == 0: - logger.info( - f'Azure resource group {resource_group} of a recent ' - f'terminated cluster {cluster_name_on_cloud} is being ' - 'deleted. It can only be provisioned after it is fully ' - 'deleted. Waiting...') - time.sleep(1) - retry += 1 - continue - raise - except azure.exceptions().ClientAuthenticationError as e: + # When resource group is user specified, it already exists in certain + # region. + if not use_external_resource_group: + logger.info(f'Creating/Updating resource group: {resource_group}') + rg_create_or_update = get_azure_sdk_function( + client=resource_client.resource_groups, + function_name='create_or_update') + rg_creation_start = time.time() + retry = 0 + while (time.time() - rg_creation_start < + _RESOURCE_GROUP_WAIT_FOR_DELETION_TIMEOUT): + try: + rg_create_or_update(resource_group_name=resource_group, + parameters=params) + break + except azure.exceptions().ResourceExistsError as e: + if 'ResourceGroupBeingDeleted' in str(e): + if retry % 5 == 0: + logger.info( + f'Azure resource group {resource_group} of a ' + 'recent terminated cluster ' + f'{cluster_name_on_cloud} is being deleted. It can' + ' only be provisioned after it is fully deleted. ' + 'Waiting...') + time.sleep(1) + retry += 1 + continue + raise + except azure.exceptions().ClientAuthenticationError as e: + message = ( + 'Failed to authenticate with Azure. Please check your ' + 'Azure credentials. Error: ' + f'{common_utils.format_exception(e)}').replace('\n', ' ') + logger.error(message) + raise exceptions.NoClusterLaunchedError(message) from e + else: message = ( - 'Failed to authenticate with Azure. Please check your Azure ' - f'credentials. Error: {common_utils.format_exception(e)}' - ).replace('\n', ' ') + f'Timed out waiting for resource group {resource_group} to be ' + 'deleted.') logger.error(message) - raise exceptions.NoClusterLaunchedError(message) from e - else: - message = ( - f'Timed out waiting for resource group {resource_group} to be ' - 'deleted.') - logger.error(message) - raise TimeoutError(message) + raise TimeoutError(message) # load the template file current_path = Path(__file__).parent @@ -155,6 +163,9 @@ def bootstrap_instances( 'nsgName': { 'value': nsg_name }, + 'location': { + 'value': params['location'] + } }, } } @@ -164,11 +175,22 @@ def bootstrap_instances( get_deployment = get_azure_sdk_function(client=resource_client.deployments, function_name='get') deployment_exists = False - for deployment_name in [_DEPLOYMENT_NAME, _LEGACY_DEPLOYMENT_NAME]: + if use_external_resource_group: + deployment_name = ( + constants.EXTERNAL_RG_BOOTSTRAP_DEPLOYMENT_NAME.format( + cluster_name_on_cloud=cluster_name_on_cloud)) + deployment_list = [deployment_name] + else: + deployment_name = constants.DEPLOYMENT_NAME + deployment_list = [ + constants.DEPLOYMENT_NAME, constants.LEGACY_DEPLOYMENT_NAME + ] + + for deploy_name in deployment_list: try: deployment = get_deployment(resource_group_name=resource_group, - deployment_name=deployment_name) - logger.info(f'Deployment {deployment_name!r} already exists. ' + deployment_name=deploy_name) + logger.info(f'Deployment {deploy_name!r} already exists. ' 'Skipping deployment creation.') outputs = deployment.properties.outputs @@ -179,22 +201,20 @@ def bootstrap_instances( deployment_exists = False if not deployment_exists: - logger.info(f'Creating/Updating deployment: {_DEPLOYMENT_NAME}') + logger.info(f'Creating/Updating deployment: {deployment_name}') create_or_update = get_azure_sdk_function( client=resource_client.deployments, function_name='create_or_update') # TODO (skypilot): this takes a long time (> 40 seconds) to run. outputs = create_or_update( resource_group_name=resource_group, - deployment_name=_DEPLOYMENT_NAME, + deployment_name=deployment_name, parameters=parameters, ).result().properties.outputs - nsg_id = outputs['nsg']['value'] - # append output resource ids to be used with vm creation provider_config['msi'] = outputs['msi']['value'] - provider_config['nsg'] = nsg_id + provider_config['nsg'] = outputs['nsg']['value'] provider_config['subnet'] = outputs['subnet']['value'] return config diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index cc2dc692dec..24ba012ea9e 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -41,6 +41,15 @@ _TAG_SKYPILOT_VM_ID = 'skypilot-vm-id' _WAIT_CREATION_TIMEOUT_SECONDS = 600 +_RESOURCE_MANAGED_IDENTITY_TYPE = ( + 'Microsoft.ManagedIdentity/userAssignedIdentities') +_RESOURCE_NETWORK_SECURITY_GROUP_TYPE = ( + 'Microsoft.Network/networkSecurityGroups') +_RESOURCE_VIRTUAL_NETWORK_TYPE = 'Microsoft.Network/virtualNetworks' +_RESOURCE_PUBLIC_IP_ADDRESS_TYPE = 'Microsoft.Network/publicIPAddresses' +_RESOURCE_VIRTUAL_MACHINE_TYPE = 'Microsoft.Compute/virtualMachines' +_RESOURCE_NETWORK_INTERFACE_TYPE = 'Microsoft.Network/networkInterfaces' + _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE = 'ResourceGroupNotFound' _POLL_INTERVAL = 1 # TODO(Doyoung): _LEGACY_NSG_NAME can be remove this after 0.8.0 to ignore @@ -282,6 +291,7 @@ def _create_vm( image_reference=image_reference, os_disk=compute.OSDisk( create_option=compute.DiskCreateOptionTypes.FROM_IMAGE, + delete_option=compute.DiskDeleteOptionTypes.DELETE, managed_disk=compute.ManagedDiskParameters( storage_account_type=node_config['azure_arm_parameters'] ['osDiskTier']), @@ -697,18 +707,30 @@ def terminate_instances( assert provider_config is not None, cluster_name_on_cloud - resource_group_client = azure.get_client('resource', subscription_id) - delete_resource_group = _get_azure_sdk_function( - client=resource_group_client.resource_groups, function_name='delete') - - try: - delete_resource_group(resource_group, force_deletion_types=None) - except azure.exceptions().ResourceNotFoundError as e: - if 'ResourceGroupNotFound' in str(e): - logger.warning(f'Resource group {resource_group} not found. Skip ' - 'terminating it.') - return - raise + use_external_resource_group = provider_config.get( + 'use_external_resource_group', False) + # When user specified resource group through config.yaml to create a VM, we + # cannot remove the entire resource group as it may contain other resources + # unrelated to this VM being removed. + if use_external_resource_group: + delete_vm_and_attached_resources(subscription_id, resource_group, + cluster_name_on_cloud) + else: + # For SkyPilot default resource groups, delete entire resource group. + # This automatically terminates all resources within, including VMs + resource_group_client = azure.get_client('resource', subscription_id) + delete_resource_group = _get_azure_sdk_function( + client=resource_group_client.resource_groups, + function_name='delete') + try: + delete_resource_group(resource_group, force_deletion_types=None) + except azure.exceptions().ResourceNotFoundError as e: + if 'ResourceGroupNotFound' in str(e): + logger.warning( + f'Resource group {resource_group} not found. Skip ' + 'terminating it.') + return + raise def _get_instance_status( @@ -770,6 +792,188 @@ def match_tags(vm): return nodes +def _delete_nic_with_retries(network_client, + resource_group, + nic_name, + max_retries=15, + retry_interval=20): + """Delete a NIC with retries. + + When a VM is created, its NIC is reserved for 180 seconds, preventing its + immediate deletion. If the NIC is in this reserved state, we must retry + deletion with intervals until the reservation expires. This situation + commonly arises if a VM termination is followed by a failover to another + region due to provisioning failures. + """ + delete_network_interfaces = _get_azure_sdk_function( + client=network_client.network_interfaces, function_name='begin_delete') + for _ in range(max_retries): + try: + delete_network_interfaces(resource_group_name=resource_group, + network_interface_name=nic_name).result() + return + except azure.exceptions().HttpResponseError as e: + if 'NicReservedForAnotherVm' in str(e): + # Retry when deletion fails with reserved NIC. + logger.warning(f'NIC {nic_name} is reserved. ' + f'Retrying in {retry_interval} seconds...') + time.sleep(retry_interval) + else: + raise e + logger.error( + f'Failed to delete NIC {nic_name} after {max_retries} attempts.') + + +def delete_vm_and_attached_resources(subscription_id: str, resource_group: str, + cluster_name_on_cloud: str) -> None: + """Removes VM with attached resources and Deployments. + + This function deletes a virtual machine and its associated resources + (public IP addresses, virtual networks, managed identities, network + interface and network security groups) that match cluster_name_on_cloud. + There is one attached resources that is not removed within this + method: OS disk. It is configured to be deleted when VM is terminated while + setting up storage profile from _create_vm. + + Args: + subscription_id: The Azure subscription ID. + resource_group: The name of the resource group. + cluster_name_on_cloud: The name of the cluster to filter resources. + """ + resource_client = azure.get_client('resource', subscription_id) + try: + list_resources = _get_azure_sdk_function( + client=resource_client.resources, + function_name='list_by_resource_group') + resources = list(list_resources(resource_group)) + except azure.exceptions().ResourceNotFoundError as e: + if _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE in str(e): + return + raise + + filtered_resources: Dict[str, List[str]] = { + _RESOURCE_VIRTUAL_MACHINE_TYPE: [], + _RESOURCE_MANAGED_IDENTITY_TYPE: [], + _RESOURCE_NETWORK_SECURITY_GROUP_TYPE: [], + _RESOURCE_VIRTUAL_NETWORK_TYPE: [], + _RESOURCE_PUBLIC_IP_ADDRESS_TYPE: [], + _RESOURCE_NETWORK_INTERFACE_TYPE: [] + } + + for resource in resources: + if (resource.type in filtered_resources and + cluster_name_on_cloud in resource.name): + filtered_resources[resource.type].append(resource.name) + + network_client = azure.get_client('network', subscription_id) + msi_client = azure.get_client('msi', subscription_id) + compute_client = azure.get_client('compute', subscription_id) + auth_client = azure.get_client('authorization', subscription_id) + + delete_virtual_machine = _get_azure_sdk_function( + client=compute_client.virtual_machines, function_name='delete') + delete_public_ip_addresses = _get_azure_sdk_function( + client=network_client.public_ip_addresses, function_name='begin_delete') + delete_virtual_networks = _get_azure_sdk_function( + client=network_client.virtual_networks, function_name='begin_delete') + delete_managed_identity = _get_azure_sdk_function( + client=msi_client.user_assigned_identities, function_name='delete') + delete_network_security_group = _get_azure_sdk_function( + client=network_client.network_security_groups, + function_name='begin_delete') + delete_role_assignment = _get_azure_sdk_function( + client=auth_client.role_assignments, function_name='delete') + + for vm_name in filtered_resources[_RESOURCE_VIRTUAL_MACHINE_TYPE]: + try: + # Before removing Network Interface, we need to wait for the VM to + # be completely removed with .result() so the dependency of VM on + # Network Interface is disassociated. This takes abour ~30s. + delete_virtual_machine(resource_group_name=resource_group, + vm_name=vm_name).result() + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete VM: {}'.format(e)) + + for nic_name in filtered_resources[_RESOURCE_NETWORK_INTERFACE_TYPE]: + try: + # Before removing Public IP Address, we need to wait for the + # Network Interface to be completely removed with .result() so the + # dependency of Network Interface on Public IP Address is + # disassociated. This takes about ~1s. + _delete_nic_with_retries(network_client, resource_group, nic_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete nic: {}'.format(e)) + + for public_ip_name in filtered_resources[_RESOURCE_PUBLIC_IP_ADDRESS_TYPE]: + try: + delete_public_ip_addresses(resource_group_name=resource_group, + public_ip_address_name=public_ip_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete public ip: {}'.format(e)) + + for vnet_name in filtered_resources[_RESOURCE_VIRTUAL_NETWORK_TYPE]: + try: + delete_virtual_networks(resource_group_name=resource_group, + virtual_network_name=vnet_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete vnet: {}'.format(e)) + + for msi_name in filtered_resources[_RESOURCE_MANAGED_IDENTITY_TYPE]: + user_assigned_identities = ( + msi_client.user_assigned_identities.list_by_resource_group( + resource_group_name=resource_group)) + for identity in user_assigned_identities: + if msi_name == identity.name: + # We use the principal_id to find the correct guid converted + # role assignment name because each managed identity has a + # unique principal_id, and role assignments are associated + # with security principals (like managed identities) via this + # principal_id. + target_principal_id = identity.principal_id + scope = (f'/subscriptions/{subscription_id}' + f'/resourceGroups/{resource_group}') + role_assignments = auth_client.role_assignments.list_for_scope( + scope) + for assignment in role_assignments: + if target_principal_id == assignment.principal_id: + guid_role_assignment_name = assignment.name + try: + delete_role_assignment( + scope=scope, + role_assignment_name=guid_role_assignment_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete role ' + 'assignment: {}'.format(e)) + break + try: + delete_managed_identity(resource_group_name=resource_group, + resource_name=msi_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete msi: {}'.format(e)) + + for nsg_name in filtered_resources[_RESOURCE_NETWORK_SECURITY_GROUP_TYPE]: + try: + delete_network_security_group(resource_group_name=resource_group, + network_security_group_name=nsg_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete nsg: {}'.format(e)) + + delete_deployment = _get_azure_sdk_function( + client=resource_client.deployments, function_name='begin_delete') + deployment_names = [ + constants.EXTERNAL_RG_BOOTSTRAP_DEPLOYMENT_NAME.format( + cluster_name_on_cloud=cluster_name_on_cloud), + constants.EXTERNAL_RG_VM_DEPLOYMENT_NAME.format( + cluster_name_on_cloud=cluster_name_on_cloud) + ] + for deployment_name in deployment_names: + try: + delete_deployment(resource_group_name=resource_group, + deployment_name=deployment_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to delete deployment: {}'.format(e)) + + @common_utils.retry def query_instances( cluster_name_on_cloud: str, @@ -842,66 +1046,67 @@ def open_ports( update_network_security_groups = _get_azure_sdk_function( client=network_client.network_security_groups, function_name='create_or_update') + list_network_security_groups = _get_azure_sdk_function( + client=network_client.network_security_groups, function_name='list') - try: - # Wait for the NSG creation to be finished before opening a port. The - # cluster provisioning triggers the NSG creation, but it may not be - # finished yet. - backoff = common_utils.Backoff(max_backoff_factor=1) - start_time = time.time() - while True: - nsg = _get_cluster_nsg(network_client, resource_group, - cluster_name_on_cloud) - if nsg.provisioning_state not in ['Creating', 'Updating']: - break - if time.time() - start_time > _WAIT_CREATION_TIMEOUT_SECONDS: - with ux_utils.print_exception_no_traceback(): - raise TimeoutError( - f'Timed out while waiting for the Network ' - f'Security Group {nsg.name!r} to be ready for ' - f'cluster {cluster_name_on_cloud!r} in ' - f'resource group {resource_group!r}. The NSG ' - f'did not reach a stable state ' - '(Creating/Updating) within the allocated ' - f'{_WAIT_CREATION_TIMEOUT_SECONDS} seconds. ' - 'Consequently, the operation to open ports ' - f'{ports} failed.') - - backoff_time = backoff.current_backoff() - logger.info(f'NSG {nsg.name} is not created yet. Waiting for ' + for nsg in list_network_security_groups(resource_group): + # Given resource group can contain network security groups that are + # irrelevant to this provisioning especially with user specified + # resource group at ~/.sky/config. So we make sure to check for the + # completion of nsg relevant to the VM being provisioned. + if cluster_name_on_cloud in nsg.name: + try: + # Wait the NSG creation to be finished before opening a port. + # The cluster provisioning triggers the NSG creation, but it + # may not be finished yet. + backoff = common_utils.Backoff(max_backoff_factor=1) + start_time = time.time() + while True: + if nsg.provisioning_state not in ['Creating', 'Updating']: + break + if time.time( + ) - start_time > _WAIT_CREATION_TIMEOUT_SECONDS: + logger.warning( + f'Fails to wait for the creation of NSG {nsg.name}' + f' in {resource_group} within ' + f'{_WAIT_CREATION_TIMEOUT_SECONDS} seconds. ' + 'Skip this NSG.') + backoff_time = backoff.current_backoff() + logger.info( + f'NSG {nsg.name} is not created yet. Waiting for ' f'{backoff_time} seconds before checking again.') - time.sleep(backoff_time) - - # Azure NSG rules have a priority field that determines the order - # in which they are applied. The priority must be unique across - # all inbound rules in one NSG. - priority = max(rule.priority - for rule in nsg.security_rules - if rule.direction == 'Inbound') + 1 - nsg.security_rules.append( - azure.create_security_rule( - name=f'sky-ports-{cluster_name_on_cloud}-{priority}', - priority=priority, - protocol='Tcp', - access='Allow', - direction='Inbound', - source_address_prefix='*', - source_port_range='*', - destination_address_prefix='*', - destination_port_ranges=ports, - )) - poller = update_network_security_groups(resource_group, nsg.name, nsg) - poller.wait() - if poller.status() != 'Succeeded': - with ux_utils.print_exception_no_traceback(): - raise ValueError(f'Failed to open ports {ports} in NSG ' - f'{nsg.name}: {poller.status()}') - - except azure.exceptions().HttpResponseError as e: - with ux_utils.print_exception_no_traceback(): - raise ValueError(f'Failed to open ports {ports} in NSG for cluster ' - f'{cluster_name_on_cloud!r} within resource group ' - f'{resource_group!r}.') from e + time.sleep(backoff_time) + + # Azure NSG rules have a priority field that determines the + # order in which they are applied. The priority must be unique + # across all inbound rules in one NSG. + priority = max(rule.priority + for rule in nsg.security_rules + if rule.direction == 'Inbound') + 1 + nsg.security_rules.append( + azure.create_security_rule( + name=f'sky-ports-{cluster_name_on_cloud}-{priority}', + priority=priority, + protocol='Tcp', + access='Allow', + direction='Inbound', + source_address_prefix='*', + source_port_range='*', + destination_address_prefix='*', + destination_port_ranges=ports, + )) + poller = update_network_security_groups(resource_group, + nsg.name, nsg) + poller.wait() + if poller.status() != 'Succeeded': + with ux_utils.print_exception_no_traceback(): + raise ValueError(f'Failed to open ports {ports} in NSG ' + f'{nsg.name}: {poller.status()}') + except azure.exceptions().HttpResponseError as e: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Failed to open ports {ports} in NSG {nsg.name}.' + ) from e def cleanup_ports( diff --git a/sky/provision/constants.py b/sky/provision/constants.py index 760abc4861a..8e8ad5ddf1b 100644 --- a/sky/provision/constants.py +++ b/sky/provision/constants.py @@ -16,3 +16,10 @@ TAG_RAY_NODE_KIND: 'worker', TAG_SKYPILOT_HEAD_NODE: '0', } + +# Names for Azure Deployments. +DEPLOYMENT_NAME = 'skypilot-config' +LEGACY_DEPLOYMENT_NAME = 'ray-config' +EXTERNAL_RG_BOOTSTRAP_DEPLOYMENT_NAME = ( + 'skypilot-bootstrap-{cluster_name_on_cloud}') +EXTERNAL_RG_VM_DEPLOYMENT_NAME = 'skypilot-vm-{cluster_name_on_cloud}' diff --git a/sky/templates/azure-ray.yml.j2 b/sky/templates/azure-ray.yml.j2 index b956530fccc..36bf9468b23 100644 --- a/sky/templates/azure-ray.yml.j2 +++ b/sky/templates/azure-ray.yml.j2 @@ -34,6 +34,7 @@ provider: # instead of the cluster_name. This ensures that ray creates new instances # for different cluster_name. resource_group: {{resource_group}} + use_external_resource_group: {{use_external_resource_group}} # Keep (otherwise cannot reuse when re-provisioning). # teardown(terminate=True) will override this. cache_stopped_nodes: True diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index d9f105db8b0..71c51dded0f 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -758,6 +758,9 @@ def get_config_schema(): 'storage_account': { 'type': 'string', }, + 'resource_group_vm': { + 'type': 'string', + }, } }, 'kubernetes': { diff --git a/tests/test_smoke.py b/tests/test_smoke.py index ed86f93ca27..4fdeefd12ce 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -5566,7 +5566,7 @@ def test_multiple_accelerators_unordered(): def test_multiple_accelerators_unordered_with_default(): name = _get_cluster_name() test = Test( - 'multiple-accelerators-unordered', + 'multiple-accelerators-unordered-with-default', [ f'sky launch -y -c {name} tests/test_yamls/test_multiple_accelerators_unordered_with_default.yaml', f'sky logs {name} 1 --status', # Ensure the job succeeded.