From e6ee397c00e579cf0227c14b6837efa56abf7df3 Mon Sep 17 00:00:00 2001 From: Gurcan Gercek <111535545+gurcangercek@users.noreply.github.com> Date: Tue, 11 Jun 2024 11:35:30 +0300 Subject: [PATCH] [GCP] GCE DWS Support (#3574) * [GCP] initial take for dws support with migs * fix lint errors * dependency and format fix * refactor mig instance creation * fix * remove unecessary instance creation code for mig * Fix deletion * Fix instance template logic * Restart * format * format * move to REST APIs instead of python APIs * add multi-node back * Fix multi-node * Avoid spot * format * format * fix scheduling * fix cancel * Add smoke test * revert some changes * fix smoke * Fix * fix * Fix smoke * [GCP] Changing the config name for DWS support and fix for resize request cancellation (#5) * Fix config fields * fix cancel * Add loggings * remove useless codes --------- Co-authored-by: Zhanghao Wu Co-authored-by: Zhanghao Wu --- docs/source/reference/config.rst | 24 ++ sky/clouds/gcp.py | 34 ++- sky/provision/gcp/constants.py | 12 + sky/provision/gcp/instance.py | 65 +++--- sky/provision/gcp/instance_utils.py | 329 ++++++++++++++++++++++++--- sky/provision/gcp/mig_utils.py | 209 +++++++++++++++++ sky/templates/gcp-ray.yml.j2 | 7 + sky/utils/schemas.py | 13 ++ tests/test_smoke.py | 36 +++ tests/test_yamls/use_mig_config.yaml | 4 + 10 files changed, 662 insertions(+), 71 deletions(-) create mode 100644 sky/provision/gcp/mig_utils.py create mode 100644 tests/test_yamls/use_mig_config.yaml diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index dce0ce1f643..74cd2c01092 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -247,6 +247,30 @@ Available fields and semantics: - projects/my-project/reservations/my-reservation2 + # Managed instance group / DWS (optional). + # + # SkyPilot supports launching instances in a managed instance group (MIG) + # which schedules the GPU instance creation through DWS, offering a better + # availability. This feature is only applied when a resource request + # contains GPU instances. + managed_instance_group: + # Duration for a created instance to be kept alive (in seconds, required). + # + # This is required for the DWS to work properly. After the + # specified duration, the instance will be terminated. + run_duration: 3600 + # Timeout for provisioning an instance by DWS (in seconds, optional). + # + # This timeout determines how long SkyPilot will wait for a managed + # instance group to create the requested resources before giving up, + # deleting the MIG and failing over to other locations. Larger timeouts + # may increase the chance for getting a resource, but will blcok failover + # to go to other zones/regions/clouds. + # + # Default: 900 + provision_timeout: 900 + + # Identity to use for all GCP instances (optional). # # LOCAL_CREDENTIALS: The user's local credential files will be uploaded to diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index fd88045dc12..7e7dacc539f 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -14,6 +14,7 @@ from sky import clouds from sky import exceptions from sky import sky_logging +from sky import skypilot_config from sky.adaptors import gcp from sky.clouds import service_catalog from sky.clouds.utils import gcp_utils @@ -179,20 +180,31 @@ class GCP(clouds.Cloud): def _unsupported_features_for_resources( cls, resources: 'resources.Resources' ) -> Dict[clouds.CloudImplementationFeatures, str]: + unsupported = {} if gcp_utils.is_tpu_vm_pod(resources): - return { + unsupported = { clouds.CloudImplementationFeatures.STOP: ( - 'TPU VM pods cannot be stopped. Please refer to: https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#stopping_your_resources' + 'TPU VM pods cannot be stopped. Please refer to: ' + 'https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#stopping_your_resources' ) } if gcp_utils.is_tpu(resources) and not gcp_utils.is_tpu_vm(resources): # TPU node does not support multi-node. - return { - clouds.CloudImplementationFeatures.MULTI_NODE: - ('TPU node does not support multi-node. Please set ' - 'num_nodes to 1.') - } - return {} + unsupported[clouds.CloudImplementationFeatures.MULTI_NODE] = ( + 'TPU node does not support multi-node. Please set ' + 'num_nodes to 1.') + # TODO(zhwu): We probably need to store the MIG requirement in resources + # because `skypilot_config` may change for an existing cluster. + # Clusters created with MIG (only GPU clusters) cannot be stopped. + if (skypilot_config.get_nested( + ('gcp', 'managed_instance_group'), None) is not None and + resources.accelerators): + unsupported[clouds.CloudImplementationFeatures.STOP] = ( + 'Managed Instance Group (MIG) does not support stopping yet.') + unsupported[clouds.CloudImplementationFeatures.SPOT_INSTANCE] = ( + 'Managed Instance Group with DWS does not support ' + 'spot instances.') + return unsupported @classmethod def max_cluster_name_length(cls) -> Optional[int]: @@ -493,6 +505,12 @@ def make_deploy_resources_variables( resources_vars['tpu_node_name'] = tpu_node_name + managed_instance_group_config = skypilot_config.get_nested( + ('gcp', 'managed_instance_group'), None) + use_mig = managed_instance_group_config is not None + resources_vars['gcp_use_managed_instance_group'] = use_mig + if use_mig: + resources_vars.update(managed_instance_group_config) return resources_vars def _get_feasible_launchable_resources( diff --git a/sky/provision/gcp/constants.py b/sky/provision/gcp/constants.py index 7ed8d3da6e0..8f9341bd342 100644 --- a/sky/provision/gcp/constants.py +++ b/sky/provision/gcp/constants.py @@ -214,3 +214,15 @@ MAX_POLLS = 60 // POLL_INTERVAL # Stopping instances can take several minutes, so we increase the timeout MAX_POLLS_STOP = MAX_POLLS * 8 + +TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node' +# Tag uniquely identifying all nodes of a cluster +TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' +TAG_RAY_NODE_KIND = 'ray-node-type' +TAG_SKYPILOT_CLUSTER_NAME = 'skypilot-cluster-name' + +# MIG constants +MANAGED_INSTANCE_GROUP_CONFIG = 'managed-instance-group' +DEFAULT_MANAGED_INSTANCE_GROUP_PROVISION_TIMEOUT = 900 # 15 minutes +MIG_NAME_PREFIX = 'sky-mig-' +INSTANCE_TEMPLATE_NAME_PREFIX = 'sky-it-' diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index a4996fc4d4b..62f234725dd 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -16,11 +16,6 @@ logger = sky_logging.init_logger(__name__) -TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node' -# Tag uniquely identifying all nodes of a cluster -TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' -TAG_RAY_NODE_KIND = 'ray-node-type' - _INSTANCE_RESOURCE_NOT_FOUND_PATTERN = re.compile( r'The resource \'projects/.*/zones/.*/instances/.*\' was not found') @@ -66,7 +61,7 @@ def query_instances( assert provider_config is not None, (cluster_name_on_cloud, provider_config) zone = provider_config['availability_zone'] project_id = provider_config['project_id'] - label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} handler: Type[ instance_utils.GCPInstance] = instance_utils.GCPComputeInstance @@ -124,15 +119,15 @@ def _wait_for_operations( logger.debug( f'wait_for_compute_{op_type}_operation: ' f'Waiting for operation {operation["name"]} to finish...') - handler.wait_for_operation(operation, project_id, zone) + handler.wait_for_operation(operation, project_id, zone=zone) def _get_head_instance_id(instances: List) -> Optional[str]: head_instance_id = None for inst in instances: labels = inst.get('labels', {}) - if (labels.get(TAG_RAY_NODE_KIND) == 'head' or - labels.get(TAG_SKYPILOT_HEAD_NODE) == '1'): + if (labels.get(constants.TAG_RAY_NODE_KIND) == 'head' or + labels.get(constants.TAG_SKYPILOT_HEAD_NODE) == '1'): head_instance_id = inst['name'] break return head_instance_id @@ -158,12 +153,14 @@ def _run_instances(region: str, cluster_name_on_cloud: str, resource: Type[instance_utils.GCPInstance] if node_type == instance_utils.GCPNodeType.COMPUTE: resource = instance_utils.GCPComputeInstance + elif node_type == instance_utils.GCPNodeType.MIG: + resource = instance_utils.GCPManagedInstanceGroup elif node_type == instance_utils.GCPNodeType.TPU: resource = instance_utils.GCPTPUVMInstance else: raise ValueError(f'Unknown node type {node_type}') - filter_labels = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + filter_labels = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} # wait until all stopping instances are stopped/terminated while True: @@ -264,12 +261,16 @@ def get_order_key(node): if config.resume_stopped_nodes and to_start_count > 0 and stopped_instances: resumed_instance_ids = [n['name'] for n in stopped_instances] if resumed_instance_ids: - for instance_id in resumed_instance_ids: - resource.start_instance(instance_id, project_id, - availability_zone) - resource.set_labels(project_id, availability_zone, instance_id, - labels) - to_start_count -= len(resumed_instance_ids) + resumed_instance_ids = resource.start_instances( + cluster_name_on_cloud, project_id, availability_zone, + resumed_instance_ids, labels) + # In MIG case, the resumed_instance_ids will include the previously + # PENDING and RUNNING instances. To avoid double counting, we need to + # remove them from the resumed_instance_ids. + ready_instances = set(resumed_instance_ids) + ready_instances |= set([n['name'] for n in running_instances]) + ready_instances |= set([n['name'] for n in pending_instances]) + to_start_count = config.count - len(ready_instances) if head_instance_id is None: head_instance_id = resource.create_node_tag( @@ -281,9 +282,14 @@ def get_order_key(node): if to_start_count > 0: errors, created_instance_ids = resource.create_instances( - cluster_name_on_cloud, project_id, availability_zone, - config.node_config, labels, to_start_count, - head_instance_id is None) + cluster_name_on_cloud, + project_id, + availability_zone, + config.node_config, + labels, + to_start_count, + total_count=config.count, + include_head_node=head_instance_id is None) if errors: error = common.ProvisionerError('Failed to launch instances.') error.errors = errors @@ -387,7 +393,7 @@ def get_cluster_info( assert provider_config is not None, cluster_name_on_cloud zone = provider_config['availability_zone'] project_id = provider_config['project_id'] - label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} handlers: List[Type[instance_utils.GCPInstance]] = [ instance_utils.GCPComputeInstance @@ -415,7 +421,7 @@ def get_cluster_info( project_id, zone, { - **label_filters, TAG_RAY_NODE_KIND: 'head' + **label_filters, constants.TAG_RAY_NODE_KIND: 'head' }, lambda h: [h.RUNNING_STATE], ) @@ -441,14 +447,14 @@ def stop_instances( assert provider_config is not None, cluster_name_on_cloud zone = provider_config['availability_zone'] project_id = provider_config['project_id'] - label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} tpu_node = provider_config.get('tpu_node') if tpu_node is not None: instance_utils.delete_tpu_node(project_id, zone, tpu_node) if worker_only: - label_filters[TAG_RAY_NODE_KIND] = 'worker' + label_filters[constants.TAG_RAY_NODE_KIND] = 'worker' handlers: List[Type[instance_utils.GCPInstance]] = [ instance_utils.GCPComputeInstance @@ -510,9 +516,16 @@ def terminate_instances( if tpu_node is not None: instance_utils.delete_tpu_node(project_id, zone, tpu_node) - label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + use_mig = provider_config.get('use_managed_instance_group', False) + if use_mig: + # Deleting the MIG will also delete the instances. + instance_utils.GCPManagedInstanceGroup.delete_mig( + project_id, zone, cluster_name_on_cloud) + return + + label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} if worker_only: - label_filters[TAG_RAY_NODE_KIND] = 'worker' + label_filters[constants.TAG_RAY_NODE_KIND] = 'worker' handlers: List[Type[instance_utils.GCPInstance]] = [ instance_utils.GCPComputeInstance @@ -555,7 +568,7 @@ def open_ports( project_id = provider_config['project_id'] firewall_rule_name = provider_config['firewall_rule'] - label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} handlers: List[Type[instance_utils.GCPInstance]] = [ instance_utils.GCPComputeInstance, instance_utils.GCPTPUVMInstance, diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index be17861e9f8..e1e72a25d6c 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -14,12 +14,10 @@ from sky.clouds import gcp as gcp_cloud from sky.provision import common from sky.provision.gcp import constants +from sky.provision.gcp import mig_utils from sky.utils import common_utils from sky.utils import ux_utils -# Tag uniquely identifying all nodes of a cluster -TAG_SKYPILOT_CLUSTER_NAME = 'skypilot-cluster-name' -TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' # Tag for the name of the node INSTANCE_NAME_MAX_LEN = 64 INSTANCE_NAME_UUID_LEN = 8 @@ -134,6 +132,8 @@ def instance_to_handler(instance: str): return GCPComputeInstance elif instance_type == 'tpu': return GCPTPUVMInstance + elif instance.startswith(constants.MIG_NAME_PREFIX): + return GCPManagedInstanceGroup else: raise ValueError(f'Unknown instance type: {instance_type}') @@ -177,8 +177,11 @@ def terminate( raise NotImplementedError @classmethod - def wait_for_operation(cls, operation: dict, project_id: str, - zone: Optional[str]) -> None: + def wait_for_operation(cls, + operation: dict, + project_id: str, + region: Optional[str] = None, + zone: Optional[str] = None) -> None: raise NotImplementedError @classmethod @@ -240,6 +243,7 @@ def create_instances( node_config: dict, labels: dict, count: int, + total_count: int, include_head_node: bool, ) -> Tuple[Optional[List], List[str]]: """Creates multiple instances and returns result. @@ -248,6 +252,21 @@ def create_instances( """ raise NotImplementedError + @classmethod + def start_instances(cls, cluster_name: str, project_id: str, zone: str, + instances: List[str], labels: Dict[str, + str]) -> List[str]: + """Start multiple instances. + + Returns: + List of instance names that are started. + """ + del cluster_name # Unused + for instance_id in instances: + cls.start_instance(instance_id, project_id, zone) + cls.set_labels(project_id, zone, instance_id, labels) + return instances + @classmethod def start_instance(cls, node_id: str, project_id: str, zone: str) -> None: """Start a stopped instance.""" @@ -401,11 +420,18 @@ def filter( return instances @classmethod - def wait_for_operation(cls, operation: dict, project_id: str, - zone: Optional[str]) -> None: + def wait_for_operation(cls, + operation: dict, + project_id: str, + region: Optional[str] = None, + zone: Optional[str] = None, + timeout: int = GCP_TIMEOUT) -> None: if zone is not None: kwargs = {'zone': zone} operation_caller = cls.load_resource().zoneOperations() + elif region is not None: + kwargs = {'region': region} + operation_caller = cls.load_resource().regionOperations() else: kwargs = {} operation_caller = cls.load_resource().globalOperations() @@ -424,13 +450,13 @@ def call_operation(fn, timeout: int): return request.execute(num_retries=GCP_MAX_RETRIES) wait_start = time.time() - while time.time() - wait_start < GCP_TIMEOUT: + while time.time() - wait_start < timeout: # Retry the wait() call until it succeeds or times out. # This is because the wait() call is only best effort, and does not # guarantee that the operation is done when it returns. # Reference: https://cloud.google.com/workflows/docs/reference/googleapis/compute/v1/zoneOperations/wait # pylint: disable=line-too-long - timeout = max(GCP_TIMEOUT - (time.time() - wait_start), 1) - result = call_operation(operation_caller.wait, timeout) + remaining_timeout = max(timeout - (time.time() - wait_start), 1) + result = call_operation(operation_caller.wait, remaining_timeout) if result['status'] == 'DONE': # NOTE: Error example: # { @@ -454,9 +480,10 @@ def call_operation(fn, timeout: int): else: logger.warning('wait_for_operation: Timeout waiting for creation ' 'operation, cancelling the operation ...') - timeout = max(GCP_TIMEOUT - (time.time() - wait_start), 1) + remaining_timeout = max(timeout - (time.time() - wait_start), 1) try: - result = call_operation(operation_caller.delete, timeout) + result = call_operation(operation_caller.delete, + remaining_timeout) except gcp.http_error_exception() as e: logger.debug('wait_for_operation: failed to cancel operation ' f'due to error: {e}') @@ -611,7 +638,7 @@ def set_labels(cls, project_id: str, availability_zone: str, node_id: str, body=body, ).execute(num_retries=GCP_CREATE_MAX_RETRIES)) - cls.wait_for_operation(operation, project_id, availability_zone) + cls.wait_for_operation(operation, project_id, zone=availability_zone) @classmethod def create_instances( @@ -622,6 +649,7 @@ def create_instances( node_config: dict, labels: dict, count: int, + total_count: int, include_head_node: bool, ) -> Tuple[Optional[List], List[str]]: # NOTE: The syntax for bulkInsert() is different from insert(). @@ -648,8 +676,8 @@ def create_instances( config.update({ 'labels': dict( labels, **{ - TAG_RAY_CLUSTER_NAME: cluster_name, - TAG_SKYPILOT_CLUSTER_NAME: cluster_name + constants.TAG_RAY_CLUSTER_NAME: cluster_name, + constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name }), }) @@ -744,6 +772,19 @@ def _insert(cls, names: List[str], project_id: str, zone: str, logger.debug('"insert" operation requested ...') return operations + @classmethod + def _convert_selflinks_in_config(cls, config: dict) -> None: + """Convert selflinks to names in the config.""" + for disk in config.get('disks', []): + disk_type = disk.get('initializeParams', {}).get('diskType') + if disk_type is not None: + disk['initializeParams']['diskType'] = selflink_to_name( + disk_type) + config['machineType'] = selflink_to_name(config['machineType']) + for accelerator in config.get('guestAccelerators', []): + accelerator['acceleratorType'] = selflink_to_name( + accelerator['acceleratorType']) + @classmethod def _bulk_insert(cls, names: List[str], project_id: str, zone: str, config: dict) -> List[dict]: @@ -757,15 +798,7 @@ def _bulk_insert(cls, names: List[str], project_id: str, zone: str, k: v for d in config['scheduling'] for k, v in d.items() } - for disk in config.get('disks', []): - disk_type = disk.get('initializeParams', {}).get('diskType') - if disk_type is not None: - disk['initializeParams']['diskType'] = selflink_to_name( - disk_type) - config['machineType'] = selflink_to_name(config['machineType']) - for accelerator in config.get('guestAccelerators', []): - accelerator['acceleratorType'] = selflink_to_name( - accelerator['acceleratorType']) + cls._convert_selflinks_in_config(config) body = { 'count': len(names), @@ -860,7 +893,7 @@ def _handle_http_error(e): logger.debug('Waiting GCP instances to be ready ...') try: for operation in operations: - cls.wait_for_operation(operation, project_id, zone) + cls.wait_for_operation(operation, project_id, zone=zone) except common.ProvisionerError as e: return e.errors except gcp.http_error_exception() as e: @@ -881,7 +914,7 @@ def start_instance(cls, node_id: str, project_id: str, zone: str) -> None: instance=node_id, ).execute()) - cls.wait_for_operation(operation, project_id, zone) + cls.wait_for_operation(operation, project_id, zone=zone) @classmethod def get_instance_info(cls, project_id: str, availability_zone: str, @@ -940,7 +973,219 @@ def resize_disk(cls, project_id: str, availability_zone: str, logger.warning(f'googleapiclient.errors.HttpError: {e.reason}') return - cls.wait_for_operation(operation, project_id, availability_zone) + cls.wait_for_operation(operation, project_id, zone=availability_zone) + + +class GCPManagedInstanceGroup(GCPComputeInstance): + """Handler for GCP Managed Instance Group.""" + + @classmethod + def create_instances( + cls, + cluster_name: str, + project_id: str, + zone: str, + node_config: dict, + labels: dict, + count: int, + total_count: int, + include_head_node: bool, + ) -> Tuple[Optional[List], List[str]]: + logger.debug(f'Creating cluster with MIG: {cluster_name!r}') + config = copy.deepcopy(node_config) + labels = dict(config.get('labels', {}), **labels) + + config.update({ + 'labels': dict( + labels, + **{ + constants.TAG_RAY_CLUSTER_NAME: cluster_name, + # Assume all nodes are workers, we can update the head node + # once the instances are created. + constants.TAG_RAY_NODE_KIND: 'worker', + constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name, + }), + }) + cls._convert_selflinks_in_config(config) + + # Convert label values to string and lowercase per MIG API requirement. + region = zone.rpartition('-')[0] + instance_template_name = mig_utils.get_instance_template_name( + cluster_name) + managed_instance_group_name = mig_utils.get_managed_instance_group_name( + cluster_name) + + instance_template_exists = mig_utils.check_instance_template_exits( + project_id, region, instance_template_name) + mig_exists = mig_utils.check_managed_instance_group_exists( + project_id, zone, managed_instance_group_name) + + label_filters = { + constants.TAG_RAY_CLUSTER_NAME: cluster_name, + } + potential_head_instances = [] + if mig_exists: + instances = cls.filter(project_id, + zone, + label_filters={ + constants.TAG_RAY_NODE_KIND: 'head', + **label_filters, + }, + status_filters=cls.NEED_TO_TERMINATE_STATES) + potential_head_instances = list(instances.keys()) + + config['labels'] = { + k: str(v).lower() for k, v in config['labels'].items() + } + if instance_template_exists: + if mig_exists: + logger.debug( + f'Instance template {instance_template_name} already ' + 'exists. Skip creating it.') + else: + logger.debug( + f'Instance template {instance_template_name!r} ' + 'exists and no instance group is using it. This is a ' + 'leftover of a previous autodown. Delete it and recreate ' + 'it.') + # TODO(zhwu): this is a bit hacky as we cannot delete instance + # template during an autodown, we can only defer the deletion + # to the next launch of a cluster with the same name. We should + # find a better way to handle this. + cls._delete_instance_template(project_id, zone, + instance_template_name) + instance_template_exists = False + + if not instance_template_exists: + operation = mig_utils.create_region_instance_template( + cluster_name, project_id, region, instance_template_name, + config) + cls.wait_for_operation(operation, project_id, region=region) + # create managed instance group + instance_template_url = (f'projects/{project_id}/regions/{region}/' + f'instanceTemplates/{instance_template_name}') + if not mig_exists: + # Create a new MIG with size 0 and resize it later for triggering + # DWS, according to the doc: https://cloud.google.com/compute/docs/instance-groups/create-mig-with-gpu-vms # pylint: disable=line-too-long + operation = mig_utils.create_managed_instance_group( + project_id, + zone, + managed_instance_group_name, + instance_template_url, + size=0) + cls.wait_for_operation(operation, project_id, zone=zone) + + managed_instance_group_config = config[ + constants.MANAGED_INSTANCE_GROUP_CONFIG] + if count > 0: + # Use resize to trigger DWS for creating VMs. + operation = mig_utils.resize_managed_instance_group( + project_id, + zone, + managed_instance_group_name, + count, + run_duration=managed_instance_group_config['run_duration']) + cls.wait_for_operation(operation, project_id, zone=zone) + + # This will block the provisioning until the nodes are ready, which + # makes the failover not effective. We rely on the request timeout set + # by user to trigger failover. + mig_utils.wait_for_managed_group_to_be_stable( + project_id, + zone, + managed_instance_group_name, + timeout=managed_instance_group_config.get( + 'provision_timeout', + constants.DEFAULT_MANAGED_INSTANCE_GROUP_PROVISION_TIMEOUT)) + + pending_running_instance_names = cls._add_labels_and_find_head( + cluster_name, project_id, zone, labels, potential_head_instances) + assert len(pending_running_instance_names) == total_count, ( + pending_running_instance_names, total_count) + cls.create_node_tag( + project_id, + zone, + pending_running_instance_names[0], + is_head=True, + ) + return None, pending_running_instance_names + + @classmethod + def _delete_instance_template(cls, project_id: str, zone: str, + instance_template_name: str) -> None: + logger.debug(f'Deleting instance template {instance_template_name}...') + region = zone.rpartition('-')[0] + try: + operation = cls.load_resource().regionInstanceTemplates().delete( + project=project_id, + region=region, + instanceTemplate=instance_template_name).execute() + cls.wait_for_operation(operation, project_id, region=region) + except gcp.http_error_exception() as e: + if re.search(mig_utils.IT_RESOURCE_NOT_FOUND_PATTERN, + str(e)) is None: + raise + logger.warning( + f'Instance template {instance_template_name!r} does not exist. ' + 'Skip deletion.') + + @classmethod + def delete_mig(cls, project_id: str, zone: str, cluster_name: str) -> None: + mig_name = mig_utils.get_managed_instance_group_name(cluster_name) + # Get all resize request of the MIG and cancel them. + mig_utils.cancel_all_resize_request_for_mig(project_id, zone, mig_name) + logger.debug(f'Deleting MIG {mig_name!r} ...') + try: + operation = cls.load_resource().instanceGroupManagers().delete( + project=project_id, zone=zone, + instanceGroupManager=mig_name).execute() + cls.wait_for_operation(operation, project_id, zone=zone) + except gcp.http_error_exception() as e: + if re.search(mig_utils.MIG_RESOURCE_NOT_FOUND_PATTERN, + str(e)) is None: + raise + logger.warning(f'MIG {mig_name!r} does not exist. Skip ' + 'deletion.') + + # In the autostop case, the following deletion of instance template + # will not be executed as the instance that runs the deletion will be + # terminated with the managed instance group. It is ok to leave the + # instance template there as when a user creates a new cluster with the + # same name, the instance template will be updated in our + # create_instances method. + cls._delete_instance_template( + project_id, zone, + mig_utils.get_instance_template_name(cluster_name)) + + @classmethod + def _add_labels_and_find_head( + cls, cluster_name: str, project_id: str, zone: str, + labels: Dict[str, str], + potential_head_instances: List[str]) -> List[str]: + pending_running_instances = cls.filter( + project_id, + zone, + {constants.TAG_RAY_CLUSTER_NAME: cluster_name}, + # Find all provisioning and running instances. + status_filters=cls.NEED_TO_STOP_STATES) + for running_instance_name in pending_running_instances.keys(): + if running_instance_name in potential_head_instances: + head_instance_name = running_instance_name + break + else: + head_instance_name = list(pending_running_instances.keys())[0] + # We need to update the node's label if mig already exists, as the + # config is not updated during the resize operation. + for instance_name in pending_running_instances.keys(): + cls.set_labels(project_id=project_id, + availability_zone=zone, + node_id=instance_name, + labels=labels) + + pending_running_instance_names = list(pending_running_instances.keys()) + pending_running_instance_names.remove(head_instance_name) + # Label for head node type will be set by caller + return [head_instance_name] + pending_running_instance_names class GCPTPUVMInstance(GCPInstance): @@ -964,10 +1209,13 @@ def load_resource(cls): discoveryServiceUrl='https://tpu.googleapis.com/$discovery/rest') @classmethod - def wait_for_operation(cls, operation: dict, project_id: str, - zone: Optional[str]) -> None: + def wait_for_operation(cls, + operation: dict, + project_id: str, + region: Optional[str] = None, + zone: Optional[str] = None) -> None: """Poll for TPU operation until finished.""" - del project_id, zone # unused + del project_id, region, zone # unused @_retry_on_http_exception( f'Failed to wait for operation {operation["name"]}') @@ -1181,6 +1429,7 @@ def create_instances( node_config: dict, labels: dict, count: int, + total_count: int, include_head_node: bool, ) -> Tuple[Optional[List], List[str]]: config = copy.deepcopy(node_config) @@ -1203,8 +1452,8 @@ def create_instances( config.update({ 'labels': dict( labels, **{ - TAG_RAY_CLUSTER_NAME: cluster_name, - TAG_SKYPILOT_CLUSTER_NAME: cluster_name + constants.TAG_RAY_CLUSTER_NAME: cluster_name, + constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name }), }) @@ -1411,10 +1660,11 @@ class GCPNodeType(enum.Enum): """Enum for GCP node types (compute & tpu)""" COMPUTE = 'compute' + MIG = 'mig' TPU = 'tpu' -def get_node_type(node: dict) -> GCPNodeType: +def get_node_type(config: Dict[str, Any]) -> GCPNodeType: """Returns node type based on the keys in ``node``. This is a very simple check. If we have a ``machineType`` key, @@ -1424,17 +1674,22 @@ def get_node_type(node: dict) -> GCPNodeType: This works for both node configs and API returned nodes. """ - - if 'machineType' not in node and 'acceleratorType' not in node: + if ('machineType' not in config and 'acceleratorType' not in config): raise ValueError( 'Invalid node. For a Compute instance, "machineType" is ' 'required. ' 'For a TPU instance, "acceleratorType" and no "machineType" ' 'is required. ' - f'Got {list(node)}') + f'Got {list(config)}') - if 'machineType' not in node and 'acceleratorType' in node: + if 'machineType' not in config and 'acceleratorType' in config: return GCPNodeType.TPU + + if (config.get(constants.MANAGED_INSTANCE_GROUP_CONFIG, None) is not None + and config.get('guestAccelerators', None) is not None): + # DWS in MIG only works for machine with GPUs. + return GCPNodeType.MIG + return GCPNodeType.COMPUTE diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py new file mode 100644 index 00000000000..9e33f5171e2 --- /dev/null +++ b/sky/provision/gcp/mig_utils.py @@ -0,0 +1,209 @@ +"""Managed Instance Group Utils""" +import re +import subprocess +from typing import Any, Dict + +from sky import sky_logging +from sky.adaptors import gcp +from sky.provision.gcp import constants + +logger = sky_logging.init_logger(__name__) + +MIG_RESOURCE_NOT_FOUND_PATTERN = re.compile( + r'The resource \'projects/.*/zones/.*/instanceGroupManagers/.*\' was not ' + r'found') + +IT_RESOURCE_NOT_FOUND_PATTERN = re.compile( + r'The resource \'projects/.*/regions/.*/instanceTemplates/.*\' was not ' + 'found') + + +def get_instance_template_name(cluster_name: str) -> str: + return f'{constants.INSTANCE_TEMPLATE_NAME_PREFIX}{cluster_name}' + + +def get_managed_instance_group_name(cluster_name: str) -> str: + return f'{constants.MIG_NAME_PREFIX}{cluster_name}' + + +def check_instance_template_exits(project_id: str, region: str, + template_name: str) -> bool: + compute = gcp.build('compute', + 'v1', + credentials=None, + cache_discovery=False) + try: + compute.regionInstanceTemplates().get( + project=project_id, region=region, + instanceTemplate=template_name).execute() + except gcp.http_error_exception() as e: + if IT_RESOURCE_NOT_FOUND_PATTERN.search(str(e)) is not None: + # Instance template does not exist. + return False + raise + return True + + +def create_region_instance_template(cluster_name_on_cloud: str, project_id: str, + region: str, template_name: str, + node_config: Dict[str, Any]) -> dict: + """Create a regional instance template.""" + logger.debug(f'Creating regional instance template {template_name!r}.') + compute = gcp.build('compute', + 'v1', + credentials=None, + cache_discovery=False) + config = node_config.copy() + config.pop(constants.MANAGED_INSTANCE_GROUP_CONFIG, None) + + # We have to ignore user defined scheduling for DWS. + # TODO: Add a warning log for this behvaiour. + scheduling = config.get('scheduling', {}) + assert scheduling.get('provisioningModel') != 'SPOT', ( + 'DWS does not support spot VMs.') + + reservations_affinity = config.pop('reservation_affinity', None) + if reservations_affinity is not None: + logger.warning( + f'Ignoring reservations_affinity {reservations_affinity} ' + 'for DWS.') + + # Create the regional instance template request + operation = compute.regionInstanceTemplates().insert( + project=project_id, + region=region, + body={ + 'name': template_name, + 'properties': dict( + description=( + 'SkyPilot instance template for ' + f'{cluster_name_on_cloud!r} to support DWS requests.'), + reservationAffinity=dict( + consumeReservationType='NO_RESERVATION'), + **config, + ) + }).execute() + return operation + + +def create_managed_instance_group(project_id: str, zone: str, group_name: str, + instance_template_url: str, + size: int) -> dict: + logger.debug(f'Creating managed instance group {group_name!r}.') + compute = gcp.build('compute', + 'v1', + credentials=None, + cache_discovery=False) + operation = compute.instanceGroupManagers().insert( + project=project_id, + zone=zone, + body={ + 'name': group_name, + 'instanceTemplate': instance_template_url, + 'target_size': size, + 'instanceLifecyclePolicy': { + 'defaultActionOnFailure': 'DO_NOTHING', + }, + 'updatePolicy': { + 'type': 'OPPORTUNISTIC', + }, + }).execute() + return operation + + +def resize_managed_instance_group(project_id: str, zone: str, group_name: str, + resize_by: int, run_duration: int) -> dict: + logger.debug(f'Resizing managed instance group {group_name!r} by ' + f'{resize_by} with run duration {run_duration}.') + compute = gcp.build('compute', + 'beta', + credentials=None, + cache_discovery=False) + operation = compute.instanceGroupManagerResizeRequests().insert( + project=project_id, + zone=zone, + instanceGroupManager=group_name, + body={ + 'name': group_name, + 'resizeBy': resize_by, + 'requestedRunDuration': { + 'seconds': run_duration, + } + }).execute() + return operation + + +def cancel_all_resize_request_for_mig(project_id: str, zone: str, + group_name: str) -> None: + logger.debug(f'Cancelling all resize requests for MIG {group_name!r}.') + try: + compute = gcp.build('compute', + 'beta', + credentials=None, + cache_discovery=False) + operation = compute.instanceGroupManagerResizeRequests().list( + project=project_id, + zone=zone, + instanceGroupManager=group_name, + filter='state eq ACCEPTED').execute() + for request in operation.get('items', []): + try: + compute.instanceGroupManagerResizeRequests().cancel( + project=project_id, + zone=zone, + instanceGroupManager=group_name, + resizeRequest=request['name']).execute() + except gcp.http_error_exception() as e: + logger.warning('Failed to cancel resize request ' + f'{request["id"]!r}: {e}') + except gcp.http_error_exception() as e: + if re.search(MIG_RESOURCE_NOT_FOUND_PATTERN, str(e)) is None: + raise + logger.warning(f'MIG {group_name!r} does not exist. Skip ' + 'resize request cancellation.') + logger.debug(f'Error: {e}') + + +def check_managed_instance_group_exists(project_id: str, zone: str, + group_name: str) -> bool: + compute = gcp.build('compute', + 'v1', + credentials=None, + cache_discovery=False) + try: + compute.instanceGroupManagers().get( + project=project_id, zone=zone, + instanceGroupManager=group_name).execute() + except gcp.http_error_exception() as e: + if MIG_RESOURCE_NOT_FOUND_PATTERN.search(str(e)) is not None: + return False + raise + return True + + +def wait_for_managed_group_to_be_stable(project_id: str, zone: str, + group_name: str, timeout: int) -> None: + """Wait until the managed instance group is stable.""" + logger.debug(f'Waiting for MIG {group_name} to be stable with timeout ' + f'{timeout}.') + try: + cmd = ('gcloud compute instance-groups managed wait-until ' + f'{group_name} ' + '--stable ' + f'--zone={zone} ' + f'--project={project_id} ' + f'--timeout={timeout}') + logger.info( + f'Waiting for MIG {group_name} to be stable with command:\n{cmd}') + proc = subprocess.run( + f'yes | {cmd}', + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + check=True, + ) + stdout = proc.stdout.decode('ascii') + logger.info(stdout) + except subprocess.CalledProcessError as e: + stderr = e.stderr.decode('ascii') + logger.info(stderr) diff --git a/sky/templates/gcp-ray.yml.j2 b/sky/templates/gcp-ray.yml.j2 index 9c2092bdfaf..51a7b332a72 100644 --- a/sky/templates/gcp-ray.yml.j2 +++ b/sky/templates/gcp-ray.yml.j2 @@ -62,6 +62,7 @@ provider: # The upper-level SkyPilot code has make sure there will not be resource # leakage. disable_launch_config_check: true + use_managed_instance_group: {{ gcp_use_managed_instance_group }} auth: ssh_user: gcpuser @@ -79,6 +80,12 @@ available_node_types: {%- for label_key, label_value in labels.items() %} {{ label_key }}: {{ label_value|tojson }} {%- endfor %} + managed-instance-group: {{ gcp_use_managed_instance_group }} + {%- if gcp_use_managed_instance_group %} + managed-instance-group: + run_duration: {{ run_duration }} + provision_timeout: {{ provision_timeout }} + {%- endif %} {%- if specific_reservations %} reservationAffinity: consumeReservationType: SPECIFIC_RESERVATION diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 5bc011abaaa..1c6994d5f7b 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -648,6 +648,19 @@ def get_config_schema(): 'type': 'string', }, }, + 'managed_instance_group': { + 'type': 'object', + 'required': ['run_duration'], + 'additionalProperties': False, + 'properties': { + 'run_duration': { + 'type': 'integer', + }, + 'provision_timeout': { + 'type': 'integer', + } + } + }, **_LABELS_SCHEMA, **_NETWORK_CONFIG_SCHEMA, }, diff --git a/tests/test_smoke.py b/tests/test_smoke.py index d70a9fce4cd..d19863b52fe 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -112,6 +112,8 @@ class Test(NamedTuple): teardown: Optional[str] = None # Timeout for each command in seconds. timeout: int = DEFAULT_CMD_TIMEOUT + # Environment variables to set for each command. + env: Dict[str, str] = None def echo(self, message: str): # pytest's xdist plugin captures stdout; print to stderr so that the @@ -158,6 +160,9 @@ def run_one_test(test: Test) -> Tuple[int, str, str]: suffix='.log', delete=False) test.echo(f'Test started. Log: less {log_file.name}') + env_dict = os.environ.copy() + if test.env: + env_dict.update(test.env) for command in test.commands: log_file.write(f'+ {command}\n') log_file.flush() @@ -167,6 +172,7 @@ def run_one_test(test: Test) -> Tuple[int, str, str]: stderr=subprocess.STDOUT, shell=True, executable='/bin/bash', + env=env_dict, ) try: proc.wait(timeout=test.timeout) @@ -761,6 +767,36 @@ def test_clone_disk_gcp(): run_one_test(test) +@pytest.mark.gcp +def test_gcp_mig(): + name = _get_cluster_name() + region = 'us-central1' + test = Test( + 'gcp_mig', + [ + f'sky launch -y -c {name} --gpus t4 --num-nodes 2 --image-id skypilot:gpu-debian-10 --cloud gcp --region {region} tests/test_yamls/minimal.yaml', + f'sky logs {name} 1 --status', # Ensure the job succeeded. + f'sky launch -y -c {name} tests/test_yamls/minimal.yaml', + f'sky logs {name} 2 --status', + f'sky logs {name} --status | grep "Job 2: SUCCEEDED"', # Equivalent. + # Check MIG exists. + f'gcloud compute instance-groups managed list --format="value(name)" | grep "^sky-mig-{name}"', + f'sky autostop -i 0 --down -y {name}', + 'sleep 120', + f'sky status -r {name}; sky status {name} | grep "{name} not found"', + f'gcloud compute instance-templates list | grep "sky-it-{name}"', + # Launch again with the same region. The original instance template + # should be removed. + f'sky launch -y -c {name} --gpus L4 --num-nodes 2 --region {region} nvidia-smi', + f'sky logs {name} 1 | grep "L4"', + f'sky down -y {name}', + f'gcloud compute instance-templates list | grep "sky-it-{name}" && exit 1 || true', + ], + f'sky down -y {name}', + env={'SKYPILOT_CONFIG': 'tests/test_yamls/use_mig_config.yaml'}) + run_one_test(test) + + @pytest.mark.aws def test_image_no_conda(): name = _get_cluster_name() diff --git a/tests/test_yamls/use_mig_config.yaml b/tests/test_yamls/use_mig_config.yaml new file mode 100644 index 00000000000..ef715191a1f --- /dev/null +++ b/tests/test_yamls/use_mig_config.yaml @@ -0,0 +1,4 @@ +gcp: + managed_instance_group: + run_duration: 36000 + provision_timeout: 900