diff --git a/sky/adaptors/azure.py b/sky/adaptors/azure.py index 61d8d14352e..2752129e305 100644 --- a/sky/adaptors/azure.py +++ b/sky/adaptors/azure.py @@ -69,6 +69,17 @@ def exceptions(): return azure_exceptions +@functools.lru_cache() +@common.load_lazy_modules(modules=_LAZY_MODULES) +def azure_mgmt_models(name: str): + if name == 'compute': + from azure.mgmt.compute import models + return models + elif name == 'network': + from azure.mgmt.network import models + return models + + # We should keep the order of the decorators having 'lru_cache' followed # by 'load_lazy_modules' as we need to make sure a caller can call # 'get_client.cache_clear', which is a function provided by 'lru_cache' diff --git a/sky/provision/azure/azure-vm-template.json b/sky/provision/azure/azure-vm-template.json deleted file mode 100644 index 52e82dc532c..00000000000 --- a/sky/provision/azure/azure-vm-template.json +++ /dev/null @@ -1,301 +0,0 @@ -{ - "$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#", - "contentVersion": "1.0.0.0", - "parameters": { - "vmName": { - "type": "string", - "metadata": { - "description": "The name of you Virtual Machine." - } - }, - "adminUsername": { - "type": "string", - "metadata": { - "description": "Username for the Virtual Machine." - } - }, - "publicKey": { - "type": "securestring", - "metadata": { - "description": "SSH Key for the Virtual Machine" - } - }, - "imagePublisher": { - "type": "string", - "metadata": { - "description": "The publisher of the VM image" - } - }, - "imageOffer": { - "type": "string", - "metadata": { - "description": "The offer of the VM image" - } - }, - "imageSku": { - "type": "string", - "metadata": { - "description": "The sku of the VM image" - } - }, - "imageVersion": { - "type": "string", - "metadata": { - "description": "The version of the VM image" - } - }, - "vmSize": { - "type": "string", - "metadata": { - "description": "The size of the VM" - } - }, - "vmTags": { - "type": "object", - "metadata": { - "description": "Tags for the VM" - } - }, - "vmCount": { - "type": "int", - "metadata": { - "description": "Number of VMs to deploy" - } - }, - "provisionPublicIp": { - "type": "bool", - "defaultValue": true, - "metadata": { - "description": "If true creates a public ip" - } - }, - "priority": { - "type": "string", - "defaultValue": "Regular", - "metadata": { - "description": "Specifies the priority for the virtual machine." - } - }, - "billingProfile": { - "type": "object", - "defaultValue": {}, - "metadata": { - "description": "Specifies the maximum price to pay for Azure Spot VM." - } - }, - "osDiskSizeGB": { - "type": "int", - "metadata": { - "description": "OS disk size in GBs." - } - }, - "msi": { - "type": "string", - "metadata": { - "description": "Managed service identity resource id." - } - }, - "nsg": { - "type": "string", - "metadata": { - "description": "Network security group resource id." - } - }, - "subnet": { - "type": "string", - "metadata": { - "descriptions": "Subnet resource id." - } - }, - "osDiskTier": { - "type": "string", - "allowedValues": [ - "Premium_LRS", - "StandardSSD_LRS", - "Standard_LRS" - ], - "metadata": { - "description": "OS disk tier." - } - }, - "cloudInitSetupCommands": { - "type": "string", - "metadata": { - "description": "Base64 encoded cloud-init setup commands." - } - } - }, - "variables": { - "location": "[resourceGroup().location]", - "networkInterfaceNamePrivate": "[concat(parameters('vmName'), '-nic')]", - "networkInterfaceNamePublic": "[concat(parameters('vmName'), '-nic-public')]", - "networkInterfaceName": "[if(parameters('provisionPublicIp'), variables('networkInterfaceNamePublic'), variables('networkInterfaceNamePrivate'))]", - "networkIpConfig": "[guid(resourceGroup().id, parameters('vmName'))]", - "publicIpAddressName": "[concat(parameters('vmName'), '-ip')]" - }, - "resources": [ - { - "type": "Microsoft.Network/networkInterfaces", - "apiVersion": "2020-06-01", - "name": "[concat(variables('networkInterfaceNamePublic'), copyIndex())]", - "location": "[variables('location')]", - "dependsOn": [ - "[resourceId('Microsoft.Network/publicIpAddresses/', concat(variables('publicIpAddressName'), copyIndex()))]" - ], - "copy": { - "name": "NICPublicCopy", - "count": "[parameters('vmCount')]" - }, - "properties": { - "ipConfigurations": [ - { - "name": "[variables('networkIpConfig')]", - "properties": { - "subnet": { - "id": "[parameters('subnet')]" - }, - "privateIPAllocationMethod": "Dynamic", - "publicIpAddress": { - "id": "[resourceId('Microsoft.Network/publicIPAddresses', concat(variables('publicIPAddressName'), copyIndex()))]" - } - } - } - ], - "networkSecurityGroup": { - "id": "[parameters('nsg')]" - } - }, - "condition": "[parameters('provisionPublicIp')]" - }, - { - "type": "Microsoft.Network/networkInterfaces", - "apiVersion": "2020-06-01", - "name": "[concat(variables('networkInterfaceNamePrivate'), copyIndex())]", - "location": "[variables('location')]", - "copy": { - "name": "NICPrivateCopy", - "count": "[parameters('vmCount')]" - }, - "properties": { - "ipConfigurations": [ - { - "name": "[variables('networkIpConfig')]", - "properties": { - "subnet": { - "id": "[parameters('subnet')]" - }, - "privateIPAllocationMethod": "Dynamic" - } - } - ], - "networkSecurityGroup": { - "id": "[parameters('nsg')]" - } - }, - "condition": "[not(parameters('provisionPublicIp'))]" - }, - { - "type": "Microsoft.Network/publicIpAddresses", - "apiVersion": "2019-02-01", - "name": "[concat(variables('publicIpAddressName'), copyIndex())]", - "location": "[variables('location')]", - "properties": { - "publicIpAllocationMethod": "Static", - "publicIPAddressVersion": "IPv4" - }, - "copy": { - "name": "PublicIpCopy", - "count": "[parameters('vmCount')]" - }, - "sku": { - "name": "Basic", - "tier": "Regional" - }, - "condition": "[parameters('provisionPublicIp')]" - }, - { - "type": "Microsoft.Compute/virtualMachines", - "apiVersion": "2019-03-01", - "name": "[concat(parameters('vmName'), copyIndex())]", - "location": "[variables('location')]", - "dependsOn": [ - "[resourceId('Microsoft.Network/networkInterfaces/', concat(variables('networkInterfaceName'), copyIndex()))]" - ], - "copy": { - "name": "VmCopy", - "count": "[parameters('vmCount')]" - }, - "tags": "[parameters('vmTags')]", - "properties": { - "hardwareProfile": { - "vmSize": "[parameters('vmSize')]" - }, - "storageProfile": { - "osDisk": { - "createOption": "fromImage", - "managedDisk": { - "storageAccountType": "[parameters('osDiskTier')]" - }, - "diskSizeGB": "[parameters('osDiskSizeGB')]" - }, - "imageReference": { - "publisher": "[parameters('imagePublisher')]", - "offer": "[parameters('imageOffer')]", - "sku": "[parameters('imageSku')]", - "version": "[parameters('imageVersion')]" - } - }, - "networkProfile": { - "networkInterfaces": [ - { - "id": "[resourceId('Microsoft.Network/networkInterfaces', concat(variables('networkInterfaceName'), copyIndex()))]" - } - ] - }, - "osProfile": { - "computerName": "[concat(parameters('vmName'), copyIndex())]", - "adminUsername": "[parameters('adminUsername')]", - "adminPassword": "[parameters('publicKey')]", - "linuxConfiguration": { - "disablePasswordAuthentication": true, - "ssh": { - "publicKeys": [ - { - "path": "[concat('/home/', parameters('adminUsername'), '/.ssh/authorized_keys')]", - "keyData": "[parameters('publicKey')]" - } - ] - } - }, - "customData": "[parameters('cloudInitSetupCommands')]" - }, - "priority": "[parameters('priority')]", - "billingProfile": "[parameters('billingProfile')]" - }, - "identity": { - "type": "UserAssigned", - "userAssignedIdentities": { - "[parameters('msi')]": { - } - } - } - } - ], - "outputs": { - "publicIp": { - "type": "array", - "copy": { - "count": "[parameters('vmCount')]", - "input": "[reference(concat(variables('publicIpAddressName'), copyIndex())).ipAddress]" - }, - "condition": "[parameters('provisionPublicIp')]" - }, - "privateIp": { - "type": "array", - "copy": { - "count": "[parameters('vmCount')]", - "input": "[reference(concat(variables('networkInterfaceName'), copyIndex())).ipConfigurations[0].properties.privateIPAddress]" - } - } - } -} diff --git a/sky/provision/azure/config.py b/sky/provision/azure/config.py index b3cb357512a..22982a99075 100644 --- a/sky/provision/azure/config.py +++ b/sky/provision/azure/config.py @@ -46,6 +46,7 @@ def bootstrap_instances( region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionConfig: """See sky/provision/__init__.py""" + # TODO: use new azure sdk instead of ARM deployment. del region # unused provider_config = config.provider_config subscription_id = provider_config.get('subscription_id') diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index 3c5ed8801a4..f6c865e29c8 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -2,10 +2,8 @@ import base64 import copy import enum -import json import logging from multiprocessing import pool -import pathlib import time import typing from typing import Any, Callable, Dict, List, Optional, Tuple @@ -23,7 +21,9 @@ if typing.TYPE_CHECKING: from azure.mgmt import compute as azure_compute - from azure.mgmt import resource as azure_resource + from azure.mgmt import network as azure_network + from azure.mgmt.compute import models as azure_compute_models + from azure.mgmt.network import models as azure_network_models logger = sky_logging.init_logger(__name__) @@ -184,14 +184,150 @@ def _get_head_instance_id(instances: List) -> Optional[str]: return head_instance_id -def _create_instances( - compute_client: 'azure_compute.ComputeManagementClient', - resource_client: 'azure_resource.ResourceManagementClient', - cluster_name_on_cloud: str, resource_group: str, - provider_config: Dict[str, Any], node_config: Dict[str, Any], - tags: Dict[str, str], count: int) -> List: +def _create_network_interface( + network_client: 'azure_network.NetworkManagementClient', vm_name: str, + provider_config: Dict[str, + Any]) -> 'azure_network_models.NetworkInterface': + network = azure.azure_mgmt_models('network') + compute = azure.azure_mgmt_models('compute') + logger.info(f'Start creating network interface for {vm_name}...') + if provider_config.get('use_internal_ips', False): + name = f'{vm_name}-nic-private' + ip_config = network.IPConfiguration( + name=f'ip-config-private-{vm_name}', + subnet=compute.SubResource(id=provider_config['subnet']), + private_ip_allocation_method=network.IPAllocationMethod.DYNAMIC) + else: + name = f'{vm_name}-nic-public' + public_ip_address = network.PublicIPAddress( + location=provider_config['location'], + public_ip_allocation_method='Static', + public_ip_address_version='IPv4', + sku=network.PublicIPAddressSku(name='Basic', tier='Regional')) + ip_poller = network_client.public_ip_addresses.begin_create_or_update( + resource_group_name=provider_config['resource_group'], + public_ip_address_name=f'{vm_name}-ip', + parameters=public_ip_address) + logger.info(f'Created public IP address {ip_poller.result().name} ' + f'with address {ip_poller.result().ip_address}.') + ip_config = network.IPConfiguration( + name=f'ip-config-public-{vm_name}', + subnet=compute.SubResource(id=provider_config['subnet']), + private_ip_allocation_method=network.IPAllocationMethod.DYNAMIC, + public_ip_address=network.PublicIPAddress(id=ip_poller.result().id)) + + ni_poller = network_client.network_interfaces.begin_create_or_update( + resource_group_name=provider_config['resource_group'], + network_interface_name=name, + parameters=network.NetworkInterface( + location=provider_config['location'], + ip_configurations=[ip_config], + network_security_group=network.NetworkSecurityGroup( + id=provider_config['nsg']))) + logger.info(f'Created network interface {ni_poller.result().name}.') + return ni_poller.result() + + +def _create_vm( + compute_client: 'azure_compute.ComputeManagementClient', vm_name: str, + node_tags: Dict[str, str], provider_config: Dict[str, Any], + node_config: Dict[str, Any], + network_interface_id: str) -> 'azure_compute_models.VirtualMachine': + compute = azure.azure_mgmt_models('compute') + logger.info(f'Start creating VM {vm_name}...') + hardware_profile = compute.HardwareProfile( + vm_size=node_config['azure_arm_parameters']['vmSize']) + network_profile = compute.NetworkProfile(network_interfaces=[ + compute.NetworkInterfaceReference(id=network_interface_id, primary=True) + ]) + public_key = node_config['azure_arm_parameters']['publicKey'] + username = node_config['azure_arm_parameters']['adminUsername'] + os_linux_custom_data = base64.b64encode( + node_config['azure_arm_parameters']['cloudInitSetupCommands'].encode( + 'utf-8')).decode('utf-8') + os_profile = compute.OSProfile( + admin_username=username, + computer_name=vm_name, + admin_password=public_key, + linux_configuration=compute.LinuxConfiguration( + disable_password_authentication=True, + ssh=compute.SshConfiguration(public_keys=[ + compute.SshPublicKey( + path=f'/home/{username}/.ssh/authorized_keys', + key_data=public_key) + ])), + custom_data=os_linux_custom_data) + community_image_id = node_config['azure_arm_parameters'].get( + 'communityGalleryImageId', None) + if community_image_id is not None: + # Prioritize using community gallery image if specified. + image_reference = compute.ImageReference( + community_gallery_image_id=community_image_id) + logger.info( + f'Used community_image_id: {community_image_id} for VM {vm_name}.') + else: + image_reference = compute.ImageReference( + publisher=node_config['azure_arm_parameters']['imagePublisher'], + offer=node_config['azure_arm_parameters']['imageOffer'], + sku=node_config['azure_arm_parameters']['imageSku'], + version=node_config['azure_arm_parameters']['imageVersion']) + storage_profile = compute.StorageProfile( + image_reference=image_reference, + os_disk=compute.OSDisk( + create_option=compute.DiskCreateOptionTypes.FROM_IMAGE, + managed_disk=compute.ManagedDiskParameters( + storage_account_type=node_config['azure_arm_parameters'] + ['osDiskTier']), + disk_size_gb=node_config['azure_arm_parameters']['osDiskSizeGB'])) + vm_instance = compute.VirtualMachine( + location=provider_config['location'], + tags=node_tags, + hardware_profile=hardware_profile, + os_profile=os_profile, + storage_profile=storage_profile, + network_profile=network_profile, + identity=compute.VirtualMachineIdentity( + type='UserAssigned', + user_assigned_identities={provider_config['msi']: {}})) + vm_poller = compute_client.virtual_machines.begin_create_or_update( + resource_group_name=provider_config['resource_group'], + vm_name=vm_name, + parameters=vm_instance, + ) + # poller.result() will block on async operation until it's done. + logger.info(f'Created VM {vm_poller.result().name}.') + # Configure driver extension for A10 GPUs. A10 GPUs requires a + # special type of drivers which is available at Microsoft HPC + # extension. Reference: + # https://forums.developer.nvidia.com/t/ubuntu-22-04-installation-driver-error-nvidia-a10/285195/2 + # This can take more than 20mins for setting up the A10 GPUs + if node_config.get('need_nvidia_driver_extension', False): + ext_poller = compute_client.virtual_machine_extensions.\ + begin_create_or_update( + resource_group_name=provider_config['resource_group'], + vm_name=vm_name, + vm_extension_name='NvidiaGpuDriverLinux', + extension_parameters=compute.VirtualMachineExtension( + location=provider_config['location'], + publisher='Microsoft.HpcCompute', + type_properties_type='NvidiaGpuDriverLinux', + type_handler_version='1.9', + auto_upgrade_minor_version=True, + settings='{}')) + logger.info( + f'Created VM extension {ext_poller.result().name} for VM {vm_name}.' + ) + return vm_poller.result() + + +def _create_instances(compute_client: 'azure_compute.ComputeManagementClient', + network_client: 'azure_network.NetworkManagementClient', + cluster_name_on_cloud: str, resource_group: str, + provider_config: Dict[str, Any], node_config: Dict[str, + Any], + tags: Dict[str, str], count: int) -> List: vm_id = uuid4().hex[:UNIQUE_ID_LEN] - tags = { + all_tags = { constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name_on_cloud, **constants.WORKER_NODE_TAGS, @@ -199,83 +335,19 @@ def _create_instances( **tags, } node_tags = node_config['tags'].copy() - node_tags.update(tags) - - # load the template file - current_path = pathlib.Path(__file__).parent - template_path = current_path.joinpath('azure-vm-template.json') - with open(template_path, 'r', encoding='utf-8') as template_fp: - template = json.load(template_fp) - - vm_name = f'{cluster_name_on_cloud}-{vm_id}' - use_internal_ips = provider_config.get('use_internal_ips', False) + node_tags.update(all_tags) - template_params = node_config['azure_arm_parameters'].copy() - # We don't include 'head' or 'worker' in the VM name as on Azure the VM - # name is immutable and we may change the node type for existing VM in the - # multi-node cluster, due to manual termination of the head node. - template_params['vmName'] = vm_name - template_params['provisionPublicIp'] = not use_internal_ips - template_params['vmTags'] = node_tags - template_params['vmCount'] = count - template_params['msi'] = provider_config['msi'] - template_params['nsg'] = provider_config['nsg'] - template_params['subnet'] = provider_config['subnet'] - # In Azure, cloud-init script must be encoded in base64. For more - # information, see: - # https://learn.microsoft.com/en-us/azure/virtual-machines/custom-data - template_params['cloudInitSetupCommands'] = (base64.b64encode( - template_params['cloudInitSetupCommands'].encode('utf-8')).decode( - 'utf-8')) + # Create VM instances in parallel. + def create_single_instance(vm_i): + vm_name = f'{cluster_name_on_cloud}-{vm_id}-{vm_i}' + network_interface = _create_network_interface(network_client, vm_name, + provider_config) + _create_vm(compute_client, vm_name, node_tags, provider_config, + node_config, network_interface.id) - if node_config.get('need_nvidia_driver_extension', False): - # pylint: disable=line-too-long - # Configure driver extension for A10 GPUs. A10 GPUs requires a - # special type of drivers which is available at Microsoft HPC - # extension. Reference: https://forums.developer.nvidia.com/t/ubuntu-22-04-installation-driver-error-nvidia-a10/285195/2 - for r in template['resources']: - if r['type'] == 'Microsoft.Compute/virtualMachines': - # Add a nested extension resource for A10 GPUs - r['resources'] = [ - { - 'type': 'extensions', - 'apiVersion': '2015-06-15', - 'location': '[variables(\'location\')]', - 'dependsOn': [ - '[concat(\'Microsoft.Compute/virtualMachines/\', parameters(\'vmName\'), copyIndex())]' - ], - 'name': 'NvidiaGpuDriverLinux', - 'properties': { - 'publisher': 'Microsoft.HpcCompute', - 'type': 'NvidiaGpuDriverLinux', - 'typeHandlerVersion': '1.9', - 'autoUpgradeMinorVersion': True, - 'settings': {}, - }, - }, - ] - break - - parameters = { - 'properties': { - 'mode': azure.deployment_mode().incremental, - 'template': template, - 'parameters': { - key: { - 'value': value - } for key, value in template_params.items() - }, - } - } - - create_or_update = _get_azure_sdk_function( - client=resource_client.deployments, function_name='create_or_update') - create_or_update( - resource_group_name=resource_group, - deployment_name=vm_name, - parameters=parameters, - ).wait() + subprocess_utils.run_in_parallel(create_single_instance, range(count)) + # Update disk performance tier performance_tier = node_config.get('disk_performance_tier', None) if performance_tier is not None: disks = compute_client.disks.list_by_resource_group(resource_group) @@ -286,12 +358,14 @@ def _create_instances( f'az disk update -n {name} -g {resource_group} ' f'--set tier={performance_tier}') + # Validation filters = { constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, _TAG_SKYPILOT_VM_ID: vm_id } instances = _filter_instances(compute_client, resource_group, filters) assert len(instances) == count, (len(instances), count) + return instances @@ -303,7 +377,7 @@ def run_instances(region: str, cluster_name_on_cloud: str, resource_group = provider_config['resource_group'] subscription_id = provider_config['subscription_id'] compute_client = azure.get_client('compute', subscription_id) - + network_client = azure.get_client('network', subscription_id) instances_to_resume = [] resumed_instance_ids: List[str] = [] created_instance_ids: List[str] = [] @@ -439,12 +513,11 @@ def _create_instance_tag(target_instance, is_head: bool = True) -> str: to_start_count -= len(resumed_instance_ids) if to_start_count > 0: - resource_client = azure.get_client('resource', subscription_id) logger.debug(f'run_instances: Creating {to_start_count} instances.') try: created_instances = _create_instances( compute_client=compute_client, - resource_client=resource_client, + network_client=network_client, cluster_name_on_cloud=cluster_name_on_cloud, resource_group=resource_group, provider_config=provider_config, diff --git a/sky/templates/azure-ray.yml.j2 b/sky/templates/azure-ray.yml.j2 index 77ddda6652f..b956530fccc 100644 --- a/sky/templates/azure-ray.yml.j2 +++ b/sky/templates/azure-ray.yml.j2 @@ -67,6 +67,8 @@ available_node_types: imageOffer: {{image_offer}} imageSku: "{{image_sku}}" imageVersion: {{image_version}} + # Community Gallery Image ID + communityGalleryImageId: {{community_gallery_image_id}} osDiskSizeGB: {{disk_size}} osDiskTier: {{disk_tier}} {%- if use_spot %}