diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index a92d13fd214..89f9dcdc695 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -3888,22 +3888,8 @@ def teardown_no_lock(self, self.post_teardown_cleanup(handle, terminate, purge) return - if terminate and isinstance(cloud, clouds.Azure): - # Here we handle termination of Azure by ourselves instead of Ray - # autoscaler. - resource_group = config['provider']['resource_group'] - terminate_cmd = f'az group delete -y --name {resource_group}' - with rich_utils.safe_status(f'[bold cyan]Terminating ' - f'[green]{cluster_name}'): - returncode, stdout, stderr = log_lib.run_with_log( - terminate_cmd, - log_abs_path, - shell=True, - stream_logs=False, - require_outputs=True) - - elif (isinstance(cloud, clouds.IBM) and terminate and - prev_cluster_status == status_lib.ClusterStatus.STOPPED): + if (isinstance(cloud, clouds.IBM) and terminate and + prev_cluster_status == status_lib.ClusterStatus.STOPPED): # pylint: disable= W0622 W0703 C0415 from sky.adaptors import ibm from sky.skylet.providers.ibm.vpc_provider import IBMVPCProvider @@ -4021,14 +4007,8 @@ def teardown_no_lock(self, # never launched and the errors are related to pre-launch # configurations (such as VPC not found). So it's safe & good UX # to not print a failure message. - # - # '(ResourceGroupNotFound)': this indicates the resource group on - # Azure is not found. That means the cluster is already deleted - # on the cloud. So it's safe & good UX to not print a failure - # message. elif ('TPU must be specified.' not in stderr and - 'SKYPILOT_ERROR_NO_NODES_LAUNCHED: ' not in stderr and - '(ResourceGroupNotFound)' not in stderr): + 'SKYPILOT_ERROR_NO_NODES_LAUNCHED: ' not in stderr): raise RuntimeError( _TEARDOWN_FAILURE_MESSAGE.format( extra_reason='', diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index 852af5c0c77..b75f9207856 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -67,7 +67,7 @@ class Azure(clouds.Cloud): _INDENT_PREFIX = ' ' * 4 - PROVISIONER_VERSION = clouds.ProvisionerVersion.RAY_AUTOSCALER + PROVISIONER_VERSION = clouds.ProvisionerVersion.RAY_PROVISIONER_SKYPILOT_TERMINATOR STATUS_VERSION = clouds.StatusVersion.SKYPILOT @classmethod diff --git a/sky/provision/azure/__init__.py b/sky/provision/azure/__init__.py index b28c161a866..2152728ba6e 100644 --- a/sky/provision/azure/__init__.py +++ b/sky/provision/azure/__init__.py @@ -3,3 +3,5 @@ from sky.provision.azure.instance import cleanup_ports from sky.provision.azure.instance import open_ports from sky.provision.azure.instance import query_instances +from sky.provision.azure.instance import stop_instances +from sky.provision.azure.instance import terminate_instances diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index 6693427d8ff..19c1ba3f3da 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -105,6 +105,63 @@ def cleanup_ports( del cluster_name_on_cloud, ports, provider_config # Unused. +def stop_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + worker_only: bool = False, +) -> None: + """See sky/provision/__init__.py""" + assert provider_config is not None, (cluster_name_on_cloud, provider_config) + + subscription_id = provider_config['subscription_id'] + resource_group = provider_config['resource_group'] + compute_client = azure.get_client('compute', subscription_id) + tag_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + if worker_only: + tag_filters[TAG_RAY_NODE_KIND] = 'worker' + + nodes = _filter_instances(compute_client, tag_filters, resource_group) + stop_virtual_machine = get_azure_sdk_function( + client=compute_client.virtual_machines, function_name='deallocate') + with pool.ThreadPool() as p: + p.starmap(stop_virtual_machine, + [(resource_group, node.name) for node in nodes]) + + +def terminate_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + worker_only: bool = False, +) -> None: + """See sky/provision/__init__.py""" + assert provider_config is not None, (cluster_name_on_cloud, provider_config) + # TODO(zhwu): check the following. Also, seems we can directly force + # delete a resource group. + subscription_id = provider_config['subscription_id'] + resource_group = provider_config['resource_group'] + if worker_only: + compute_client = azure.get_client('compute', subscription_id) + delete_virtual_machine = get_azure_sdk_function( + client=compute_client.virtual_machines, function_name='delete') + filters = { + TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, + TAG_RAY_NODE_KIND: 'worker' + } + nodes = _filter_instances(compute_client, filters, resource_group) + with pool.ThreadPool() as p: + p.starmap(delete_virtual_machine, + [(resource_group, node.name) for node in nodes]) + return + + assert provider_config is not None, cluster_name_on_cloud + + resource_group_client = azure.get_client('resource', subscription_id) + delete_resource_group = get_azure_sdk_function( + client=resource_group_client.resource_groups, function_name='delete') + + delete_resource_group(resource_group, force_deletion_types=None) + + def _get_vm_status(compute_client: 'azure_compute.ComputeManagementClient', vm_name: str, resource_group: str) -> str: instance = compute_client.virtual_machines.instance_view( @@ -119,7 +176,7 @@ def _get_vm_status(compute_client: 'azure_compute.ComputeManagementClient', # skip provisioning status if code == 'PowerState': return state - raise ValueError(f'Failed to get status for VM {vm_name}') + raise ValueError(f'Failed to get power state for VM {vm_name}: {instance}') def _filter_instances( @@ -185,8 +242,9 @@ def query_instances( statuses = {} def _fetch_and_map_status( - compute_client: 'azure_compute.ComputeManagementClient', node, - resource_group: str): + compute_client: 'azure_compute.ComputeManagementClient', + node: 'azure_compute.models.VirtualMachine', + resource_group: str) -> None: if node.provisioning_state in provisioning_state_map: status = provisioning_state_map[node.provisioning_state] else: