From 9f4d62aba6440e25b4838af90d50bc9e6423e4ef Mon Sep 17 00:00:00 2001 From: Gurcan Gercek Date: Tue, 21 May 2024 13:03:05 -0400 Subject: [PATCH 01/29] [GCP] initial take for dws support with migs --- sky/clouds/gcp.py | 1 + sky/clouds/utils/gcp_utils.py | 3 + sky/provision/gcp/constants.py | 5 + sky/provision/gcp/instance.py | 112 +++++++-- sky/provision/gcp/instance_utils.py | 4 + sky/provision/gcp/mig_utils.py | 339 ++++++++++++++++++++++++++++ sky/setup_files/setup.py | 2 +- sky/templates/gcp-ray.yml.j2 | 2 + sky/utils/schemas.py | 3 + 9 files changed, 453 insertions(+), 18 deletions(-) create mode 100644 sky/provision/gcp/mig_utils.py diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 0c494884c61..c6f8fa7c08f 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -422,6 +422,7 @@ def make_deploy_resources_variables( 'custom_resources': None, 'use_spot': r.use_spot, 'gcp_project_id': self.get_project_id(dryrun), + 'gcp_use_managed_instance_group': gcp_utils.is_use_managed_instance_group(), } accelerators = r.accelerators if accelerators is not None: diff --git a/sky/clouds/utils/gcp_utils.py b/sky/clouds/utils/gcp_utils.py index 68e6192d351..eedf19a5aee 100644 --- a/sky/clouds/utils/gcp_utils.py +++ b/sky/clouds/utils/gcp_utils.py @@ -184,3 +184,6 @@ def get_minimal_permissions() -> List[str]: permissions += constants.RESERVATION_PERMISSIONS return permissions + +def is_use_managed_instance_group() -> bool: + return skypilot_config.get_nested(('gcp', 'use_managed_instance_group'), False) diff --git a/sky/provision/gcp/constants.py b/sky/provision/gcp/constants.py index 7ed8d3da6e0..e605b5a710c 100644 --- a/sky/provision/gcp/constants.py +++ b/sky/provision/gcp/constants.py @@ -214,3 +214,8 @@ MAX_POLLS = 60 // POLL_INTERVAL # Stopping instances can take several minutes, so we increase the timeout MAX_POLLS_STOP = MAX_POLLS * 8 + +TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node' +# Tag uniquely identifying all nodes of a cluster +TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' +TAG_RAY_NODE_KIND = 'ray-node-type' diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index e7f69f8c6eb..ddb9eac7f97 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -9,18 +9,15 @@ from sky import sky_logging from sky import status_lib from sky.adaptors import gcp +from sky.clouds.gcp import gcp_utils from sky.provision import common from sky.provision.gcp import constants from sky.provision.gcp import instance_utils +from sky.provision.gcp import mig_utils from sky.utils import common_utils logger = sky_logging.init_logger(__name__) -TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node' -# Tag uniquely identifying all nodes of a cluster -TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' -TAG_RAY_NODE_KIND = 'ray-node-type' - _INSTANCE_RESOURCE_NOT_FOUND_PATTERN = re.compile( r'The resource \'projects/.*/zones/.*/instances/.*\' was not found') @@ -66,7 +63,7 @@ def query_instances( assert provider_config is not None, (cluster_name_on_cloud, provider_config) zone = provider_config['availability_zone'] project_id = provider_config['project_id'] - label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} handler: Type[ instance_utils.GCPInstance] = instance_utils.GCPComputeInstance @@ -131,12 +128,90 @@ def _get_head_instance_id(instances: List) -> Optional[str]: head_instance_id = None for inst in instances: labels = inst.get('labels', {}) - if (labels.get(TAG_RAY_NODE_KIND) == 'head' or - labels.get(TAG_SKYPILOT_HEAD_NODE) == '1'): + if (labels.get(constants.TAG_RAY_NODE_KIND) == 'head' or + labels.get(constants.TAG_SKYPILOT_HEAD_NODE) == '1'): head_instance_id = inst['name'] break return head_instance_id +def _run_instances_in_managed_instance_group(region: str, cluster_name_on_cloud: str, + config: common.ProvisionConfig) -> common.ProvisionRecord: + print("Managed instance group is enabled.") + + resumed_instance_ids: List[str] = [] + created_instance_ids: List[str] = [] + + node_type = instance_utils.get_node_type(config.node_config) + project_id = config.provider_config['project_id'] + availability_zone = config.provider_config['availability_zone'] + head_instance_id = "" + + # check instance templates + + # TODO add instance template name to the task definition if the user wants to use it. + # Calculate template instance name based on the config node values. Hash them to get a unique name. + node_config_hash = mig_utils.create_node_config_hash(cluster_name_on_cloud, config.node_config) + instance_template_name = f"{cluster_name_on_cloud}-it-{node_config_hash}" + if not mig_utils.check_instance_template_exits(project_id, region, instance_template_name): + mig_utils.create_regional_instance_template(project_id, + region, + instance_template_name, + config.node_config, + cluster_name_on_cloud) + else: + print(f"Instance template {instance_template_name} already exists...") + + # create managed instance group + instance_template_url = f"projects/{project_id}/regions/{region}/instanceTemplates/{instance_template_name}" + managed_instance_group_name = f"{cluster_name_on_cloud}-mig-{node_config_hash}" + if not mig_utils.check_managed_instance_group_exists(project_id, availability_zone, managed_instance_group_name): + mig_utils.create_managed_instance_group(project_id, + availability_zone, + managed_instance_group_name, + instance_template_url, + size=config.count) + else: + # TODO: if we already have one, we should resize it. + print(f"Managed instance group {managed_instance_group_name} already exists...") + # mig_utils.resize_managed_instance_group(project_id, zone, group_name, size, run_duration) + + mig_utils.wait_for_managed_group_to_be_stable(project_id, availability_zone, managed_instance_group_name) + + + + resource: Type[instance_utils.GCPInstance] + if node_type == instance_utils.GCPNodeType.COMPUTE: + resource = instance_utils.GCPComputeInstance + elif node_type == instance_utils.GCPNodeType.TPU: + resource = instance_utils.GCPTPUVMInstance + else: + raise ValueError(f'Unknown node type {node_type}') + + # filter_labels = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + # running_instances = resource.filter( + # project_id=project_id, + # zone=availability_zone, + # label_filters=filter_labels, + # status_filters=resource.RUNNING_STATE, + # ).values() + + # # TODO: Can not tag individual nodes as they are part of the mig + # head_instance_id = resource.create_node_tag( + # project_id, + # availability_zone, + # running_instances[0]['name'], + # is_head=True, + # ) + + # created_instance_ids = [n['name'] for n in running_instances] + + return common.ProvisionRecord(provider_name='gcp', + region=region, + zone=availability_zone, + cluster_name=cluster_name_on_cloud, + head_instance_id=head_instance_id, + resumed_instance_ids=resumed_instance_ids, + created_instance_ids=created_instance_ids) def _run_instances(region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionRecord: @@ -163,7 +238,7 @@ def _run_instances(region: str, cluster_name_on_cloud: str, else: raise ValueError(f'Unknown node type {node_type}') - filter_labels = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + filter_labels = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} # wait until all stopping instances are stopped/terminated while True: @@ -346,7 +421,10 @@ def run_instances(region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionRecord: """See sky/provision/__init__.py""" try: - return _run_instances(region, cluster_name_on_cloud, config) + if gcp_utils.is_use_managed_instance_group(): + return _run_instances_in_managed_instance_group(region, cluster_name_on_cloud, config) + else: + return _run_instances(region, cluster_name_on_cloud, config) except gcp.http_error_exception() as e: error_details = getattr(e, 'error_details') errors = [] @@ -387,7 +465,7 @@ def get_cluster_info( assert provider_config is not None, cluster_name_on_cloud zone = provider_config['availability_zone'] project_id = provider_config['project_id'] - label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} handlers: List[Type[instance_utils.GCPInstance]] = [ instance_utils.GCPComputeInstance @@ -415,7 +493,7 @@ def get_cluster_info( project_id, zone, { - **label_filters, TAG_RAY_NODE_KIND: 'head' + **label_filters, constants.TAG_RAY_NODE_KIND: 'head' }, lambda h: [h.RUNNING_STATE], ) @@ -439,14 +517,14 @@ def stop_instances( assert provider_config is not None, cluster_name_on_cloud zone = provider_config['availability_zone'] project_id = provider_config['project_id'] - label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} tpu_node = provider_config.get('tpu_node') if tpu_node is not None: instance_utils.delete_tpu_node(project_id, zone, tpu_node) if worker_only: - label_filters[TAG_RAY_NODE_KIND] = 'worker' + label_filters[constants.TAG_RAY_NODE_KIND] = 'worker' handlers: List[Type[instance_utils.GCPInstance]] = [ instance_utils.GCPComputeInstance @@ -508,9 +586,9 @@ def terminate_instances( if tpu_node is not None: instance_utils.delete_tpu_node(project_id, zone, tpu_node) - label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} if worker_only: - label_filters[TAG_RAY_NODE_KIND] = 'worker' + label_filters[constants.TAG_RAY_NODE_KIND] = 'worker' handlers: List[Type[instance_utils.GCPInstance]] = [ instance_utils.GCPComputeInstance @@ -553,7 +631,7 @@ def open_ports( project_id = provider_config['project_id'] firewall_rule_name = provider_config['firewall_rule'] - label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} handlers: List[Type[instance_utils.GCPInstance]] = [ instance_utils.GCPComputeInstance, instance_utils.GCPTPUVMInstance, diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index dde0918274d..c3103d18248 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -134,6 +134,10 @@ def instance_to_handler(instance: str): elif instance_type == 'tpu': return GCPTPUVMInstance else: + # Managed Instance Groups breaks this assumption. The suffix is a random value. + if "-mig-" in instance: + # TODO: Implement MIG Instance + return GCPComputeInstance raise ValueError(f'Unknown instance type: {instance_type}') diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py new file mode 100644 index 00000000000..0a3f13f4296 --- /dev/null +++ b/sky/provision/gcp/mig_utils.py @@ -0,0 +1,339 @@ +from __future__ import annotations + +import subprocess +import sys +import zlib +import time +from typing import Any + +from google.api_core.extended_operation import ExtendedOperation +from google.cloud import compute_v1 + +from sky import sky_logging +from sky.provision import common + +from sky.provision.gcp.constants import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_KIND + +logger = sky_logging.init_logger(__name__) + +"""Managed Instance Group Utils""" + +def create_node_config_hash(cluster_name_on_cloud, node_config) -> int: + """Create a hash value for the node config to be used as a unique identifier for the instance template and mig names.""" + properties = create_regional_instance_template_properties(cluster_name_on_cloud, node_config) + return zlib.adler32(repr(properties).encode()) + +def create_regional_instance_template_properties(cluster_name_on_cloud, node_config) -> compute_v1.InstanceProperties: + labels = node_config.get('labels', {}) | { + TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, + # Assume all nodes are workers, we can update the head node once the instances are created + TAG_RAY_NODE_KIND: "worker" + } + # All label values must be string + labels = {key: str(val).lower() for key, val in labels.items()} + + return compute_v1.InstanceProperties( + description=f"A temp instance template for {cluster_name_on_cloud} to support DWS requests.", + machine_type=node_config['machineType'], + # We have to ignore reservations for DWS. + # TODO: Add a warning log for this behvaiour. + reservation_affinity=compute_v1.ReservationAffinity( + consume_reservation_type="NO_RESERVATION" + ), + # We have to ignore user defined scheduling for DWS. + # TODO: Add a warning log for this behvaiour. + scheduling=compute_v1.Scheduling(on_host_maintenance="TERMINATE"), + guest_accelerators=[ + compute_v1.AcceleratorConfig( + accelerator_count=accelerator['acceleratorCount'], + accelerator_type=accelerator['acceleratorType'].split('/')[-1], + ) + for accelerator in node_config.get('guestAccelerators', []) + ], + disks=[ + compute_v1.AttachedDisk( + boot=disk_config['boot'], + auto_delete=disk_config['autoDelete'], + type_=disk_config['type'], + initialize_params=compute_v1.AttachedDiskInitializeParams( + source_image=disk_config['initializeParams']['sourceImage'], + disk_size_gb=disk_config['initializeParams']['diskSizeGb'], + disk_type=disk_config['initializeParams']['diskType'].split('/')[-1] + ), + ) + for disk_config in node_config.get('disks', []) + ], + network_interfaces=[ + compute_v1.NetworkInterface( + subnetwork=network_interface['subnetwork'], + access_configs=[ + compute_v1.AccessConfig( + name=access_config['name'], + type=access_config['type'], + ) + for access_config in network_interface.get('accessConfigs', []) + ], + ) + for network_interface in node_config.get('networkInterfaces', []) + ], + service_accounts=[ + compute_v1.ServiceAccount( + email=service_account['email'], + scopes = service_account['scopes'] + ) + for service_account in node_config.get('serviceAccounts', []) + ], + metadata=compute_v1.Metadata( + items = [ + compute_v1.Items( + key=item['key'], + value=item['value'] + ) + for item in (node_config.get('metadata',{}).get('items', []) + [{'key': 'cluster-name', 'value': cluster_name_on_cloud}]) + ] + ), + # Create labels from node config + labels= labels + ) + +def check_instance_template_exits(project_id, region, template_name) -> bool: + with compute_v1.RegionInstanceTemplatesClient() as compute_client: + request = compute_v1.ListRegionInstanceTemplatesRequest( + filter=f"name eq {template_name}", + project=project_id, + region=region, + ) + page_result = compute_client.list(request) + return len(page_result.items) > 0 and (next(page_result.pages) is not None) + +def create_regional_instance_template( + project_id, + region, + template_name, + node_config, + cluster_name_on_cloud +) -> None: + with compute_v1.RegionInstanceTemplatesClient() as compute_client: + # Create the regional instance template request + + request = compute_v1.InsertRegionInstanceTemplateRequest( + project=project_id, + region=region, + instance_template_resource=compute_v1.InstanceTemplate( + name=template_name, + properties=create_regional_instance_template_properties(cluster_name_on_cloud, node_config), + ), + ) + + # Send the request to create the regional instance template + response = compute_client.insert(request=request) + # Wait for the operation to complete + print(response) + wait_for_extended_operation(response, "create regional instance template", 600) + # TODO: Error handling + # operation = compute_client.wait(response.operation) + # if operation.error: + # raise Exception(f"Failed to create regional instance template: {operation.error}") + + listRequest = compute_v1.ListRegionInstanceTemplatesRequest( + filter=f"name eq {template_name}", + project=project_id, + region=region, + ) + list_response=compute_client.list(listRequest) + print(list_response) + print(f"Regional instance template '{template_name}' created successfully.") + +def delete_regional_instance_template(project_id, region, template_name) -> None: + with compute_v1.RegionInstanceTemplatesClient() as compute_client: + # Create the regional instance template request + request = compute_v1.DeleteRegionInstanceTemplateRequest( + project=project_id, + region=region, + instance_template=template_name, + ) + + # Send the request to delete the regional instance template + response = compute_client.delete(request=request) + # Wait for the operation to complete + print(response) + wait_for_extended_operation(response, "delete regional instance template", 600) + + +def create_managed_instance_group( + project_id, zone, group_name, instance_template_url, size +) -> None: + # credentials, project = google.auth.default() + # compute_client = compute_v1.InstanceGroupManagersClient(credentials=credentials) + + with compute_v1.InstanceGroupManagersClient() as compute_client: + # Create the managed instance group request + request = compute_v1.InsertInstanceGroupManagerRequest( + project=project_id, + zone=zone, + instance_group_manager_resource=compute_v1.InstanceGroupManager( + name=group_name, + instance_template=instance_template_url, + target_size=size, + instance_lifecycle_policy=compute_v1.InstanceGroupManagerInstanceLifecyclePolicy( + default_action_on_failure="DO_NOTHING", + ), + update_policy=compute_v1.InstanceGroupManagerUpdatePolicy( + type="OPPORTUNISTIC", + ), + ), + ) + + # Send the request to create the managed instance group + response = compute_client.insert(request=request) + + # Wait for the operation to complete + print(f"Request submitted, waiting for operation to complete. {response}") + wait_for_extended_operation(response, "create managed instance group", 600) + # TODO: Error handling + print(f"Managed instance group '{group_name}' created successfully.") + +def check_managed_instance_group_exists(project_id, zone, group_name) -> bool: + with compute_v1.InstanceGroupManagersClient() as compute_client: + request = compute_v1.ListInstanceGroupManagersRequest( + project=project_id, + zone=zone, + filter=f"name eq {group_name}", + ) + page_result = compute_client.list(request) + return len(page_result.items) > 0 and (next(page_result.pages) is not None) + +def resize_managed_instance_group(project_id, zone, group_name, size, run_duration) -> None: + try: + resize_request_name = f"resize-request-{str(int(time.time()))}" + + cmd = ( + f"gcloud beta compute instance-groups managed resize-requests create {group_name} " + f"--resize-request={resize_request_name} " + f"--resize-by={size} " + f"--requested-run-duration={run_duration} " + f"--zone={zone} " + f"--project={project_id} " + ) + logger.info(f"Resizing MIG {group_name} with command:\n{cmd}") + proc = subprocess.run( + f"yes | {cmd}", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + check=True, + ) + stdout = proc.stdout.decode("ascii") + logger.info(stdout) + wait_for_managed_group_to_be_stable(project_id, zone, group_name) + + except subprocess.CalledProcessError as e: + stderr = e.stderr.decode("ascii") + logger.info(stderr) + provisioner_err = common.ProvisionerError("Failed to resize MIG") + provisioner_err.errors = [{ + 'code': 'UNKNOWN', + 'domain': 'mig', + 'message': stderr + }] + # _log_errors(provisioner_err.errors, e, zone) + raise provisioner_err from e + +def view_resize_requests(project_id, zone, group_name) -> None: + try: + cmd = ( + "gcloud beta compute instance-groups managed resize-requests " + f"list {group_name} " + f"--zone={zone} " + f"--project={project_id}" + ) + logger.info( + f"Listing resize requests for MIG {group_name} with command:\n{cmd}" + ) + proc = subprocess.run( + f"yes | {cmd}", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + check=True, + ) + stdout = proc.stdout.decode("ascii") + logger.info(stdout) + except subprocess.CalledProcessError as e: + stderr = e.stderr.decode("ascii") + logger.info(stderr) + +def wait_for_managed_group_to_be_stable(project_id, zone, group_name) -> None: + try: + cmd = ( + "gcloud compute instance-groups managed wait-until " + f"{group_name} " + "--stable " + f"--zone={zone} " + f"--project={project_id}" + ) + logger.info( + f"Waiting for MIG {group_name} to be stable with command:\n{cmd}" + ) + proc = subprocess.run( + f"yes | {cmd}", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + check=True, + ) + stdout = proc.stdout.decode("ascii") + logger.info(stdout) + except subprocess.CalledProcessError as e: + stderr = e.stderr.decode("ascii") + logger.info(stderr) + +def wait_for_extended_operation( + operation: ExtendedOperation, verbose_name: str = "operation", timeout: int = 300 +) -> Any: + # Taken from Google's samples + # https://cloud.google.com/compute/docs/samples/compute-operation-extended-wait?hl=en + """ + Waits for the extended (long-running) operation to complete. + + If the operation is successful, it will return its result. + If the operation ends with an error, an exception will be raised. + If there were any warnings during the execution of the operation + they will be printed to sys.stderr. + + Args: + operation: a long-running operation you want to wait on. + verbose_name: (optional) a more verbose name of the operation, + used only during error and warning reporting. + timeout: how long (in seconds) to wait for operation to finish. + If None, wait indefinitely. + + Returns: + Whatever the operation.result() returns. + + Raises: + This method will raise the exception received from `operation.exception()` + or RuntimeError if there is no exception set, but there is an `error_code` + set for the `operation`. + + In case of an operation taking longer than `timeout` seconds to complete, + a `concurrent.futures.TimeoutError` will be raised. + """ + result = operation.result(timeout=timeout) + + if operation.error_code: + print( + f"Error during {verbose_name}: [Code: {operation.error_code}]: {operation.error_message}", + file=sys.stderr, + flush=True, + ) + print(f"Operation ID: {operation.name}", file=sys.stderr, flush=True) + # TODO gurc: wrap this in a custom skypilot exception + raise operation.exception() or RuntimeError(operation.error_message) + + if operation.warnings: + print(f"Warnings during {verbose_name}:\n", file=sys.stderr, flush=True) + for warning in operation.warnings: + print(f" - {warning.code}: {warning.message}", file=sys.stderr, flush=True) + + return result diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py index c05ffcc4f35..9d21be15988 100644 --- a/sky/setup_files/setup.py +++ b/sky/setup_files/setup.py @@ -222,7 +222,7 @@ def parse_readme(readme: str) -> str: # We need google-api-python-client>=2.69.0 to enable 'discardLocalSsd' # parameter for stopping instances. # Reference: https://github.com/googleapis/google-api-python-client/commit/f6e9d3869ed605b06f7cbf2e8cf2db25108506e6 - 'gcp': ['google-api-python-client>=2.69.0', 'google-cloud-storage'], + 'gcp': ['google-api-python-client>=2.69.0', 'google-cloud-storage', 'google-cloud-compute'], 'ibm': [ 'ibm-cloud-sdk-core', 'ibm-vpc', 'ibm-platform-services', 'ibm-cos-sdk' ] + local_ray, diff --git a/sky/templates/gcp-ray.yml.j2 b/sky/templates/gcp-ray.yml.j2 index c3d75015bc0..e553d98c405 100644 --- a/sky/templates/gcp-ray.yml.j2 +++ b/sky/templates/gcp-ray.yml.j2 @@ -62,6 +62,7 @@ provider: # The upper-level SkyPilot code has make sure there will not be resource # leakage. disable_launch_config_check: true + use_managed_instance_group: {{gcp_use_managed_instance_group}} auth: ssh_user: gcpuser @@ -79,6 +80,7 @@ available_node_types: {%- for tag_key, tag_value in instance_tags.items() %} {{ tag_key }}: {{ tag_value }} {%- endfor %} + managed-instance-group: {{gcp_use_managed_instance_group}} {%- if specific_reservations %} reservationAffinity: consumeReservationType: SPECIFIC_RESERVATION diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 689905d7c71..55c5c3ee2bf 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -588,6 +588,9 @@ def get_config_schema(): 'type': 'string', }, }, + 'use_managed_instance_group':{ + 'type': 'boolean', + }, **_INSTANCE_TAGS_SCHEMA, **_NETWORK_CONFIG_SCHEMA, } From 689c4a691c6db3fb6806051e63fef287f91827c4 Mon Sep 17 00:00:00 2001 From: Gurcan Gercek Date: Tue, 21 May 2024 13:11:30 -0400 Subject: [PATCH 02/29] fix lint errors --- sky/clouds/gcp.py | 3 +- sky/clouds/utils/gcp_utils.py | 4 +- sky/provision/gcp/instance.py | 59 ++++++----- sky/provision/gcp/mig_utils.py | 180 +++++++++++++++++---------------- sky/setup_files/setup.py | 5 +- sky/utils/schemas.py | 2 +- 6 files changed, 135 insertions(+), 118 deletions(-) diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index c6f8fa7c08f..4ec807a07b7 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -422,7 +422,8 @@ def make_deploy_resources_variables( 'custom_resources': None, 'use_spot': r.use_spot, 'gcp_project_id': self.get_project_id(dryrun), - 'gcp_use_managed_instance_group': gcp_utils.is_use_managed_instance_group(), + 'gcp_use_managed_instance_group': + gcp_utils.is_use_managed_instance_group(), } accelerators = r.accelerators if accelerators is not None: diff --git a/sky/clouds/utils/gcp_utils.py b/sky/clouds/utils/gcp_utils.py index eedf19a5aee..bb34aed79b0 100644 --- a/sky/clouds/utils/gcp_utils.py +++ b/sky/clouds/utils/gcp_utils.py @@ -185,5 +185,7 @@ def get_minimal_permissions() -> List[str]: return permissions + def is_use_managed_instance_group() -> bool: - return skypilot_config.get_nested(('gcp', 'use_managed_instance_group'), False) + return skypilot_config.get_nested(('gcp', 'use_managed_instance_group'), + False) diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index ddb9eac7f97..82a48fbb4d1 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -134,10 +134,12 @@ def _get_head_instance_id(instances: List) -> Optional[str]: break return head_instance_id -def _run_instances_in_managed_instance_group(region: str, cluster_name_on_cloud: str, - config: common.ProvisionConfig) -> common.ProvisionRecord: + +def _run_instances_in_managed_instance_group( + region: str, cluster_name_on_cloud: str, + config: common.ProvisionConfig) -> common.ProvisionRecord: print("Managed instance group is enabled.") - + resumed_instance_ids: List[str] = [] created_instance_ids: List[str] = [] @@ -145,39 +147,42 @@ def _run_instances_in_managed_instance_group(region: str, cluster_name_on_cloud: project_id = config.provider_config['project_id'] availability_zone = config.provider_config['availability_zone'] head_instance_id = "" - + # check instance templates # TODO add instance template name to the task definition if the user wants to use it. # Calculate template instance name based on the config node values. Hash them to get a unique name. - node_config_hash = mig_utils.create_node_config_hash(cluster_name_on_cloud, config.node_config) + node_config_hash = mig_utils.create_node_config_hash( + cluster_name_on_cloud, config.node_config) instance_template_name = f"{cluster_name_on_cloud}-it-{node_config_hash}" - if not mig_utils.check_instance_template_exits(project_id, region, instance_template_name): - mig_utils.create_regional_instance_template(project_id, - region, - instance_template_name, - config.node_config, - cluster_name_on_cloud) + if not mig_utils.check_instance_template_exits(project_id, region, + instance_template_name): + mig_utils.create_regional_instance_template(project_id, region, + instance_template_name, + config.node_config, + cluster_name_on_cloud) else: print(f"Instance template {instance_template_name} already exists...") - + # create managed instance group - instance_template_url = f"projects/{project_id}/regions/{region}/instanceTemplates/{instance_template_name}" + instance_template_url = f"projects/{project_id}/regions/{region}/instanceTemplates/{instance_template_name}" managed_instance_group_name = f"{cluster_name_on_cloud}-mig-{node_config_hash}" - if not mig_utils.check_managed_instance_group_exists(project_id, availability_zone, managed_instance_group_name): - mig_utils.create_managed_instance_group(project_id, - availability_zone, - managed_instance_group_name, - instance_template_url, + if not mig_utils.check_managed_instance_group_exists( + project_id, availability_zone, managed_instance_group_name): + mig_utils.create_managed_instance_group(project_id, + availability_zone, + managed_instance_group_name, + instance_template_url, size=config.count) - else: + else: # TODO: if we already have one, we should resize it. - print(f"Managed instance group {managed_instance_group_name} already exists...") + print( + f"Managed instance group {managed_instance_group_name} already exists..." + ) # mig_utils.resize_managed_instance_group(project_id, zone, group_name, size, run_duration) - mig_utils.wait_for_managed_group_to_be_stable(project_id, availability_zone, managed_instance_group_name) - - + mig_utils.wait_for_managed_group_to_be_stable(project_id, availability_zone, + managed_instance_group_name) resource: Type[instance_utils.GCPInstance] if node_type == instance_utils.GCPNodeType.COMPUTE: @@ -186,7 +191,7 @@ def _run_instances_in_managed_instance_group(region: str, cluster_name_on_cloud: resource = instance_utils.GCPTPUVMInstance else: raise ValueError(f'Unknown node type {node_type}') - + # filter_labels = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} # running_instances = resource.filter( # project_id=project_id, @@ -202,7 +207,7 @@ def _run_instances_in_managed_instance_group(region: str, cluster_name_on_cloud: # running_instances[0]['name'], # is_head=True, # ) - + # created_instance_ids = [n['name'] for n in running_instances] return common.ProvisionRecord(provider_name='gcp', @@ -213,6 +218,7 @@ def _run_instances_in_managed_instance_group(region: str, cluster_name_on_cloud: resumed_instance_ids=resumed_instance_ids, created_instance_ids=created_instance_ids) + def _run_instances(region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionRecord: """See sky/provision/__init__.py""" @@ -422,7 +428,8 @@ def run_instances(region: str, cluster_name_on_cloud: str, """See sky/provision/__init__.py""" try: if gcp_utils.is_use_managed_instance_group(): - return _run_instances_in_managed_instance_group(region, cluster_name_on_cloud, config) + return _run_instances_in_managed_instance_group( + region, cluster_name_on_cloud, config) else: return _run_instances(region, cluster_name_on_cloud, config) except gcp.http_error_exception() as e: diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index 0a3f13f4296..6d9ab9454e5 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -2,44 +2,47 @@ import subprocess import sys -import zlib import time from typing import Any +import zlib from google.api_core.extended_operation import ExtendedOperation from google.cloud import compute_v1 from sky import sky_logging from sky.provision import common - -from sky.provision.gcp.constants import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_KIND +from sky.provision.gcp.constants import TAG_RAY_CLUSTER_NAME +from sky.provision.gcp.constants import TAG_RAY_NODE_KIND logger = sky_logging.init_logger(__name__) - """Managed Instance Group Utils""" + def create_node_config_hash(cluster_name_on_cloud, node_config) -> int: """Create a hash value for the node config to be used as a unique identifier for the instance template and mig names.""" - properties = create_regional_instance_template_properties(cluster_name_on_cloud, node_config) + properties = create_regional_instance_template_properties( + cluster_name_on_cloud, node_config) return zlib.adler32(repr(properties).encode()) - -def create_regional_instance_template_properties(cluster_name_on_cloud, node_config) -> compute_v1.InstanceProperties: + + +def create_regional_instance_template_properties( + cluster_name_on_cloud, node_config) -> compute_v1.InstanceProperties: labels = node_config.get('labels', {}) | { - TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, - # Assume all nodes are workers, we can update the head node once the instances are created - TAG_RAY_NODE_KIND: "worker" - } + TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, + # Assume all nodes are workers, we can update the head node once the instances are created + TAG_RAY_NODE_KIND: "worker" + } # All label values must be string labels = {key: str(val).lower() for key, val in labels.items()} return compute_v1.InstanceProperties( - description=f"A temp instance template for {cluster_name_on_cloud} to support DWS requests.", + description= + f"A temp instance template for {cluster_name_on_cloud} to support DWS requests.", machine_type=node_config['machineType'], # We have to ignore reservations for DWS. # TODO: Add a warning log for this behvaiour. reservation_affinity=compute_v1.ReservationAffinity( - consume_reservation_type="NO_RESERVATION" - ), + consume_reservation_type="NO_RESERVATION"), # We have to ignore user defined scheduling for DWS. # TODO: Add a warning log for this behvaiour. scheduling=compute_v1.Scheduling(on_host_maintenance="TERMINATE"), @@ -47,8 +50,7 @@ def create_regional_instance_template_properties(cluster_name_on_cloud, node_con compute_v1.AcceleratorConfig( accelerator_count=accelerator['acceleratorCount'], accelerator_type=accelerator['acceleratorType'].split('/')[-1], - ) - for accelerator in node_config.get('guestAccelerators', []) + ) for accelerator in node_config.get('guestAccelerators', []) ], disks=[ compute_v1.AttachedDisk( @@ -56,12 +58,11 @@ def create_regional_instance_template_properties(cluster_name_on_cloud, node_con auto_delete=disk_config['autoDelete'], type_=disk_config['type'], initialize_params=compute_v1.AttachedDiskInitializeParams( - source_image=disk_config['initializeParams']['sourceImage'], + source_image=disk_config['initializeParams']['sourceImage'], disk_size_gb=disk_config['initializeParams']['diskSizeGb'], - disk_type=disk_config['initializeParams']['diskType'].split('/')[-1] - ), - ) - for disk_config in node_config.get('disks', []) + disk_type=disk_config['initializeParams']['diskType'].split( + '/')[-1]), + ) for disk_config in node_config.get('disks', []) ], network_interfaces=[ compute_v1.NetworkInterface( @@ -70,31 +71,26 @@ def create_regional_instance_template_properties(cluster_name_on_cloud, node_con compute_v1.AccessConfig( name=access_config['name'], type=access_config['type'], - ) - for access_config in network_interface.get('accessConfigs', []) + ) for access_config in network_interface.get( + 'accessConfigs', []) ], - ) - for network_interface in node_config.get('networkInterfaces', []) + ) for network_interface in node_config.get('networkInterfaces', []) ], service_accounts=[ - compute_v1.ServiceAccount( - email=service_account['email'], - scopes = service_account['scopes'] - ) + compute_v1.ServiceAccount(email=service_account['email'], + scopes=service_account['scopes']) for service_account in node_config.get('serviceAccounts', []) ], - metadata=compute_v1.Metadata( - items = [ - compute_v1.Items( - key=item['key'], - value=item['value'] - ) - for item in (node_config.get('metadata',{}).get('items', []) + [{'key': 'cluster-name', 'value': cluster_name_on_cloud}]) - ] - ), + metadata=compute_v1.Metadata(items=[ + compute_v1.Items(key=item['key'], value=item['value']) + for item in (node_config.get('metadata', {}).get('items', []) + [{ + 'key': 'cluster-name', + 'value': cluster_name_on_cloud + }]) + ]), # Create labels from node config - labels= labels - ) + labels=labels) + def check_instance_template_exits(project_id, region, template_name) -> bool: with compute_v1.RegionInstanceTemplatesClient() as compute_client: @@ -104,15 +100,13 @@ def check_instance_template_exits(project_id, region, template_name) -> bool: region=region, ) page_result = compute_client.list(request) - return len(page_result.items) > 0 and (next(page_result.pages) is not None) - -def create_regional_instance_template( - project_id, - region, - template_name, - node_config, - cluster_name_on_cloud -) -> None: + return len(page_result.items) > 0 and (next(page_result.pages) + is not None) + + +def create_regional_instance_template(project_id, region, template_name, + node_config, + cluster_name_on_cloud) -> None: with compute_v1.RegionInstanceTemplatesClient() as compute_client: # Create the regional instance template request @@ -121,7 +115,8 @@ def create_regional_instance_template( region=region, instance_template_resource=compute_v1.InstanceTemplate( name=template_name, - properties=create_regional_instance_template_properties(cluster_name_on_cloud, node_config), + properties=create_regional_instance_template_properties( + cluster_name_on_cloud, node_config), ), ) @@ -129,7 +124,8 @@ def create_regional_instance_template( response = compute_client.insert(request=request) # Wait for the operation to complete print(response) - wait_for_extended_operation(response, "create regional instance template", 600) + wait_for_extended_operation(response, + "create regional instance template", 600) # TODO: Error handling # operation = compute_client.wait(response.operation) # if operation.error: @@ -140,11 +136,15 @@ def create_regional_instance_template( project=project_id, region=region, ) - list_response=compute_client.list(listRequest) + list_response = compute_client.list(listRequest) print(list_response) - print(f"Regional instance template '{template_name}' created successfully.") + print( + f"Regional instance template '{template_name}' created successfully." + ) + -def delete_regional_instance_template(project_id, region, template_name) -> None: +def delete_regional_instance_template(project_id, region, + template_name) -> None: with compute_v1.RegionInstanceTemplatesClient() as compute_client: # Create the regional instance template request request = compute_v1.DeleteRegionInstanceTemplateRequest( @@ -157,12 +157,12 @@ def delete_regional_instance_template(project_id, region, template_name) -> None response = compute_client.delete(request=request) # Wait for the operation to complete print(response) - wait_for_extended_operation(response, "delete regional instance template", 600) + wait_for_extended_operation(response, + "delete regional instance template", 600) -def create_managed_instance_group( - project_id, zone, group_name, instance_template_url, size -) -> None: +def create_managed_instance_group(project_id, zone, group_name, + instance_template_url, size) -> None: # credentials, project = google.auth.default() # compute_client = compute_v1.InstanceGroupManagersClient(credentials=credentials) @@ -175,12 +175,11 @@ def create_managed_instance_group( name=group_name, instance_template=instance_template_url, target_size=size, - instance_lifecycle_policy=compute_v1.InstanceGroupManagerInstanceLifecyclePolicy( - default_action_on_failure="DO_NOTHING", - ), + instance_lifecycle_policy=compute_v1. + InstanceGroupManagerInstanceLifecyclePolicy( + default_action_on_failure="DO_NOTHING",), update_policy=compute_v1.InstanceGroupManagerUpdatePolicy( - type="OPPORTUNISTIC", - ), + type="OPPORTUNISTIC",), ), ) @@ -188,11 +187,14 @@ def create_managed_instance_group( response = compute_client.insert(request=request) # Wait for the operation to complete - print(f"Request submitted, waiting for operation to complete. {response}") - wait_for_extended_operation(response, "create managed instance group", 600) + print( + f"Request submitted, waiting for operation to complete. {response}") + wait_for_extended_operation(response, "create managed instance group", + 600) # TODO: Error handling print(f"Managed instance group '{group_name}' created successfully.") + def check_managed_instance_group_exists(project_id, zone, group_name) -> bool: with compute_v1.InstanceGroupManagersClient() as compute_client: request = compute_v1.ListInstanceGroupManagersRequest( @@ -201,9 +203,12 @@ def check_managed_instance_group_exists(project_id, zone, group_name) -> bool: filter=f"name eq {group_name}", ) page_result = compute_client.list(request) - return len(page_result.items) > 0 and (next(page_result.pages) is not None) + return len(page_result.items) > 0 and (next(page_result.pages) + is not None) + -def resize_managed_instance_group(project_id, zone, group_name, size, run_duration) -> None: +def resize_managed_instance_group(project_id, zone, group_name, size, + run_duration) -> None: try: resize_request_name = f"resize-request-{str(int(time.time()))}" @@ -213,8 +218,7 @@ def resize_managed_instance_group(project_id, zone, group_name, size, run_durati f"--resize-by={size} " f"--requested-run-duration={run_duration} " f"--zone={zone} " - f"--project={project_id} " - ) + f"--project={project_id} ") logger.info(f"Resizing MIG {group_name} with command:\n{cmd}") proc = subprocess.run( f"yes | {cmd}", @@ -239,14 +243,13 @@ def resize_managed_instance_group(project_id, zone, group_name, size, run_durati # _log_errors(provisioner_err.errors, e, zone) raise provisioner_err from e + def view_resize_requests(project_id, zone, group_name) -> None: try: - cmd = ( - "gcloud beta compute instance-groups managed resize-requests " - f"list {group_name} " - f"--zone={zone} " - f"--project={project_id}" - ) + cmd = ("gcloud beta compute instance-groups managed resize-requests " + f"list {group_name} " + f"--zone={zone} " + f"--project={project_id}") logger.info( f"Listing resize requests for MIG {group_name} with command:\n{cmd}" ) @@ -263,18 +266,16 @@ def view_resize_requests(project_id, zone, group_name) -> None: stderr = e.stderr.decode("ascii") logger.info(stderr) + def wait_for_managed_group_to_be_stable(project_id, zone, group_name) -> None: try: - cmd = ( - "gcloud compute instance-groups managed wait-until " - f"{group_name} " - "--stable " - f"--zone={zone} " - f"--project={project_id}" - ) + cmd = ("gcloud compute instance-groups managed wait-until " + f"{group_name} " + "--stable " + f"--zone={zone} " + f"--project={project_id}") logger.info( - f"Waiting for MIG {group_name} to be stable with command:\n{cmd}" - ) + f"Waiting for MIG {group_name} to be stable with command:\n{cmd}") proc = subprocess.run( f"yes | {cmd}", stdout=subprocess.PIPE, @@ -288,9 +289,10 @@ def wait_for_managed_group_to_be_stable(project_id, zone, group_name) -> None: stderr = e.stderr.decode("ascii") logger.info(stderr) -def wait_for_extended_operation( - operation: ExtendedOperation, verbose_name: str = "operation", timeout: int = 300 -) -> Any: + +def wait_for_extended_operation(operation: ExtendedOperation, + verbose_name: str = "operation", + timeout: int = 300) -> Any: # Taken from Google's samples # https://cloud.google.com/compute/docs/samples/compute-operation-extended-wait?hl=en """ @@ -334,6 +336,8 @@ def wait_for_extended_operation( if operation.warnings: print(f"Warnings during {verbose_name}:\n", file=sys.stderr, flush=True) for warning in operation.warnings: - print(f" - {warning.code}: {warning.message}", file=sys.stderr, flush=True) + print(f" - {warning.code}: {warning.message}", + file=sys.stderr, + flush=True) return result diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py index 9d21be15988..6bf86111fca 100644 --- a/sky/setup_files/setup.py +++ b/sky/setup_files/setup.py @@ -222,7 +222,10 @@ def parse_readme(readme: str) -> str: # We need google-api-python-client>=2.69.0 to enable 'discardLocalSsd' # parameter for stopping instances. # Reference: https://github.com/googleapis/google-api-python-client/commit/f6e9d3869ed605b06f7cbf2e8cf2db25108506e6 - 'gcp': ['google-api-python-client>=2.69.0', 'google-cloud-storage', 'google-cloud-compute'], + 'gcp': [ + 'google-api-python-client>=2.69.0', 'google-cloud-storage', + 'google-cloud-compute' + ], 'ibm': [ 'ibm-cloud-sdk-core', 'ibm-vpc', 'ibm-platform-services', 'ibm-cos-sdk' ] + local_ray, diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 55c5c3ee2bf..a5f2f84dd09 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -588,7 +588,7 @@ def get_config_schema(): 'type': 'string', }, }, - 'use_managed_instance_group':{ + 'use_managed_instance_group': { 'type': 'boolean', }, **_INSTANCE_TAGS_SCHEMA, From 3c8a236a830a53f7e57cad508c0916ee67e9518a Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Tue, 4 Jun 2024 21:29:01 +0000 Subject: [PATCH 03/29] dependency and format fix --- sky/adaptors/gcp.py | 2 + sky/clouds/gcp.py | 5 +- sky/provision/gcp/mig_utils.py | 141 ++++++++++++++++++--------------- 3 files changed, 82 insertions(+), 66 deletions(-) diff --git a/sky/adaptors/gcp.py b/sky/adaptors/gcp.py index 6465709d42c..54f8fc84444 100644 --- a/sky/adaptors/gcp.py +++ b/sky/adaptors/gcp.py @@ -10,6 +10,8 @@ googleapiclient = common.LazyImport('googleapiclient', import_error_message=_IMPORT_ERROR_MESSAGE) google = common.LazyImport('google', import_error_message=_IMPORT_ERROR_MESSAGE) +compute_v1 = common.LazyImport('google.cloud.compute_v1', + import_error_message=_IMPORT_ERROR_MESSAGE) _LAZY_MODULES = (google, googleapiclient) diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 4ec807a07b7..443c1f4bf14 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -148,7 +148,7 @@ class GCP(clouds.Cloud): _DEPENDENCY_HINT = ( 'GCP tools are not installed. Run the following commands:\n' # Install the Google Cloud SDK: - f'{_INDENT_PREFIX} $ pip install google-api-python-client\n' + f'{_INDENT_PREFIX} $ pip install google-api-python-client google-cloud-compute\n' f'{_INDENT_PREFIX} $ conda install -c conda-forge ' 'google-cloud-sdk -y') @@ -635,6 +635,9 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]: from google import auth # type: ignore import googleapiclient + if gcp_utils.is_use_managed_instance_group(): + from google.cloud import compute_v1 + # Check the installation of google-cloud-sdk. _run_output('gcloud --version') except (ImportError, subprocess.CalledProcessError) as e: diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index 6d9ab9454e5..bf11129034d 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -1,51 +1,58 @@ -from __future__ import annotations - +"""Managed Instance Group Utils""" import subprocess import sys import time +import typing from typing import Any import zlib from google.api_core.extended_operation import ExtendedOperation -from google.cloud import compute_v1 from sky import sky_logging from sky.provision import common from sky.provision.gcp.constants import TAG_RAY_CLUSTER_NAME from sky.provision.gcp.constants import TAG_RAY_NODE_KIND +if typing.TYPE_CHECKING: + from google.cloud import compute_v1 +else: + from sky.adaptors.gcp import compute_v1 + logger = sky_logging.init_logger(__name__) -"""Managed Instance Group Utils""" def create_node_config_hash(cluster_name_on_cloud, node_config) -> int: - """Create a hash value for the node config to be used as a unique identifier for the instance template and mig names.""" + """Create a hash value for the node config. + + This is to be used as a unique identifier for the instance template and mig + names. + """ properties = create_regional_instance_template_properties( cluster_name_on_cloud, node_config) return zlib.adler32(repr(properties).encode()) def create_regional_instance_template_properties( - cluster_name_on_cloud, node_config) -> compute_v1.InstanceProperties: + cluster_name_on_cloud, node_config) -> 'compute_v1.InstanceProperties': labels = node_config.get('labels', {}) | { TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, # Assume all nodes are workers, we can update the head node once the instances are created - TAG_RAY_NODE_KIND: "worker" + TAG_RAY_NODE_KIND: 'worker' } # All label values must be string labels = {key: str(val).lower() for key, val in labels.items()} return compute_v1.InstanceProperties( description= - f"A temp instance template for {cluster_name_on_cloud} to support DWS requests.", + f'A temp instance template for {cluster_name_on_cloud} to support DWS requests.', machine_type=node_config['machineType'], # We have to ignore reservations for DWS. # TODO: Add a warning log for this behvaiour. reservation_affinity=compute_v1.ReservationAffinity( - consume_reservation_type="NO_RESERVATION"), + consume_reservation_type='NO_RESERVATION'), # We have to ignore user defined scheduling for DWS. # TODO: Add a warning log for this behvaiour. - scheduling=compute_v1.Scheduling(on_host_maintenance="TERMINATE"), + scheduling=compute_v1.Scheduling(on_host_maintenance='TERMINATE'), guest_accelerators=[ compute_v1.AcceleratorConfig( accelerator_count=accelerator['acceleratorCount'], @@ -95,7 +102,7 @@ def create_regional_instance_template_properties( def check_instance_template_exits(project_id, region, template_name) -> bool: with compute_v1.RegionInstanceTemplatesClient() as compute_client: request = compute_v1.ListRegionInstanceTemplatesRequest( - filter=f"name eq {template_name}", + filter=f'name eq {template_name}', project=project_id, region=region, ) @@ -123,24 +130,23 @@ def create_regional_instance_template(project_id, region, template_name, # Send the request to create the regional instance template response = compute_client.insert(request=request) # Wait for the operation to complete - print(response) + # logger.debug(response) wait_for_extended_operation(response, - "create regional instance template", 600) + 'create regional instance template', 600) # TODO: Error handling # operation = compute_client.wait(response.operation) # if operation.error: - # raise Exception(f"Failed to create regional instance template: {operation.error}") + # raise Exception(f'Failed to create regional instance template: {operation.error}') listRequest = compute_v1.ListRegionInstanceTemplatesRequest( - filter=f"name eq {template_name}", + filter=f'name eq {template_name}', project=project_id, region=region, ) list_response = compute_client.list(listRequest) - print(list_response) - print( - f"Regional instance template '{template_name}' created successfully." - ) + # logger.debug(list_response) + logger.debug(f'Regional instance template {template_name!r} ' + 'created successfully.') def delete_regional_instance_template(project_id, region, @@ -156,9 +162,9 @@ def delete_regional_instance_template(project_id, region, # Send the request to delete the regional instance template response = compute_client.delete(request=request) # Wait for the operation to complete - print(response) + logger.debug(response) wait_for_extended_operation(response, - "delete regional instance template", 600) + 'delete regional instance template', 600) def create_managed_instance_group(project_id, zone, group_name, @@ -177,9 +183,9 @@ def create_managed_instance_group(project_id, zone, group_name, target_size=size, instance_lifecycle_policy=compute_v1. InstanceGroupManagerInstanceLifecyclePolicy( - default_action_on_failure="DO_NOTHING",), + default_action_on_failure='DO_NOTHING',), update_policy=compute_v1.InstanceGroupManagerUpdatePolicy( - type="OPPORTUNISTIC",), + type='OPPORTUNISTIC',), ), ) @@ -187,12 +193,13 @@ def create_managed_instance_group(project_id, zone, group_name, response = compute_client.insert(request=request) # Wait for the operation to complete - print( - f"Request submitted, waiting for operation to complete. {response}") - wait_for_extended_operation(response, "create managed instance group", + logger.debug( + f'Request submitted, waiting for operation to complete. {response}') + wait_for_extended_operation(response, 'create managed instance group', 600) # TODO: Error handling - print(f"Managed instance group '{group_name}' created successfully.") + logger.debug( + f'Managed instance group {group_name!r} created successfully.') def check_managed_instance_group_exists(project_id, zone, group_name) -> bool: @@ -200,7 +207,7 @@ def check_managed_instance_group_exists(project_id, zone, group_name) -> bool: request = compute_v1.ListInstanceGroupManagersRequest( project=project_id, zone=zone, - filter=f"name eq {group_name}", + filter=f'name eq {group_name}', ) page_result = compute_client.list(request) return len(page_result.items) > 0 and (next(page_result.pages) @@ -210,31 +217,31 @@ def check_managed_instance_group_exists(project_id, zone, group_name) -> bool: def resize_managed_instance_group(project_id, zone, group_name, size, run_duration) -> None: try: - resize_request_name = f"resize-request-{str(int(time.time()))}" + resize_request_name = f'resize-request-{str(int(time.time()))}' cmd = ( - f"gcloud beta compute instance-groups managed resize-requests create {group_name} " - f"--resize-request={resize_request_name} " - f"--resize-by={size} " - f"--requested-run-duration={run_duration} " - f"--zone={zone} " - f"--project={project_id} ") - logger.info(f"Resizing MIG {group_name} with command:\n{cmd}") + f'gcloud beta compute instance-groups managed resize-requests create {group_name} ' + f'--resize-request={resize_request_name} ' + f'--resize-by={size} ' + f'--requested-run-duration={run_duration} ' + f'--zone={zone} ' + f'--project={project_id} ') + logger.info(f'Resizing MIG {group_name} with command:\n{cmd}') proc = subprocess.run( - f"yes | {cmd}", + f'yes | {cmd}', stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, check=True, ) - stdout = proc.stdout.decode("ascii") + stdout = proc.stdout.decode('ascii') logger.info(stdout) wait_for_managed_group_to_be_stable(project_id, zone, group_name) except subprocess.CalledProcessError as e: - stderr = e.stderr.decode("ascii") + stderr = e.stderr.decode('ascii') logger.info(stderr) - provisioner_err = common.ProvisionerError("Failed to resize MIG") + provisioner_err = common.ProvisionerError('Failed to resize MIG') provisioner_err.errors = [{ 'code': 'UNKNOWN', 'domain': 'mig', @@ -246,52 +253,52 @@ def resize_managed_instance_group(project_id, zone, group_name, size, def view_resize_requests(project_id, zone, group_name) -> None: try: - cmd = ("gcloud beta compute instance-groups managed resize-requests " - f"list {group_name} " - f"--zone={zone} " - f"--project={project_id}") + cmd = ('gcloud beta compute instance-groups managed resize-requests ' + f'list {group_name} ' + f'--zone={zone} ' + f'--project={project_id}') logger.info( - f"Listing resize requests for MIG {group_name} with command:\n{cmd}" + f'Listing resize requests for MIG {group_name} with command:\n{cmd}' ) proc = subprocess.run( - f"yes | {cmd}", + f'yes | {cmd}', stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, check=True, ) - stdout = proc.stdout.decode("ascii") + stdout = proc.stdout.decode('ascii') logger.info(stdout) except subprocess.CalledProcessError as e: - stderr = e.stderr.decode("ascii") + stderr = e.stderr.decode('ascii') logger.info(stderr) def wait_for_managed_group_to_be_stable(project_id, zone, group_name) -> None: try: - cmd = ("gcloud compute instance-groups managed wait-until " - f"{group_name} " - "--stable " - f"--zone={zone} " - f"--project={project_id}") + cmd = ('gcloud compute instance-groups managed wait-until ' + f'{group_name} ' + '--stable ' + f'--zone={zone} ' + f'--project={project_id}') logger.info( - f"Waiting for MIG {group_name} to be stable with command:\n{cmd}") + f'Waiting for MIG {group_name} to be stable with command:\n{cmd}') proc = subprocess.run( - f"yes | {cmd}", + f'yes | {cmd}', stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, check=True, ) - stdout = proc.stdout.decode("ascii") + stdout = proc.stdout.decode('ascii') logger.info(stdout) except subprocess.CalledProcessError as e: - stderr = e.stderr.decode("ascii") + stderr = e.stderr.decode('ascii') logger.info(stderr) def wait_for_extended_operation(operation: ExtendedOperation, - verbose_name: str = "operation", + verbose_name: str = 'operation', timeout: int = 300) -> Any: # Taken from Google's samples # https://cloud.google.com/compute/docs/samples/compute-operation-extended-wait?hl=en @@ -324,20 +331,24 @@ def wait_for_extended_operation(operation: ExtendedOperation, result = operation.result(timeout=timeout) if operation.error_code: - print( - f"Error during {verbose_name}: [Code: {operation.error_code}]: {operation.error_message}", + logger.debug( + f'Error during {verbose_name}: [Code: {operation.error_code}]: {operation.error_message}', file=sys.stderr, flush=True, ) - print(f"Operation ID: {operation.name}", file=sys.stderr, flush=True) + logger.debug(f'Operation ID: {operation.name}', + file=sys.stderr, + flush=True) # TODO gurc: wrap this in a custom skypilot exception raise operation.exception() or RuntimeError(operation.error_message) if operation.warnings: - print(f"Warnings during {verbose_name}:\n", file=sys.stderr, flush=True) + logger.debug(f'Warnings during {verbose_name}:\n', + file=sys.stderr, + flush=True) for warning in operation.warnings: - print(f" - {warning.code}: {warning.message}", - file=sys.stderr, - flush=True) + logger.debug(f' - {warning.code}: {warning.message}', + file=sys.stderr, + flush=True) return result From 034a2afce3ea9735efbd2acd654e04e62c34974a Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Wed, 5 Jun 2024 01:46:24 +0000 Subject: [PATCH 04/29] refactor mig instance creation --- sky/provision/gcp/config.py | 6 +- sky/provision/gcp/constants.py | 6 ++ sky/provision/gcp/instance.py | 139 +++++++++++++++++++++------- sky/provision/gcp/instance_utils.py | 137 ++++++++++++++++++++++++--- sky/provision/gcp/mig_utils.py | 69 +++++++------- 5 files changed, 271 insertions(+), 86 deletions(-) diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index b0ed2be9cec..6a205545ee2 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -151,8 +151,7 @@ def bootstrap_instances( config: common.ProvisionConfig) -> common.ProvisionConfig: # Check if we have any TPUs defined, and if so, # insert that information into the provider config - if instance_utils.get_node_type( - config.node_config) == instance_utils.GCPNodeType.TPU: + if instance_utils.get_node_type(config) == instance_utils.GCPNodeType.TPU: config.provider_config[constants.HAS_TPU_PROVIDER_FIELD] = True crm, iam, compute, _ = construct_clients_from_provider_config( @@ -375,8 +374,7 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam) -> dict: 'scopes': ['https://www.googleapis.com/auth/cloud-platform'], } iam_role: Dict[str, Any] - if (instance_utils.get_node_type( - config.node_config) == instance_utils.GCPNodeType.TPU): + if (instance_utils.get_node_type(config) == instance_utils.GCPNodeType.TPU): # SKY: The API for TPU VM is slightly different from normal compute # instances. # See https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#Node # pylint: disable=line-too-long diff --git a/sky/provision/gcp/constants.py b/sky/provision/gcp/constants.py index e605b5a710c..dc0838ab661 100644 --- a/sky/provision/gcp/constants.py +++ b/sky/provision/gcp/constants.py @@ -219,3 +219,9 @@ # Tag uniquely identifying all nodes of a cluster TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' TAG_RAY_NODE_KIND = 'ray-node-type' +TAG_SKYPILOT_CLUSTER_NAME = 'skypilot-cluster-name' + +# MIG constants +USE_MANAGED_INSTANCE_GROUP_CONFIG = 'use_managed_instance_group' +TAG_MANAGED_INSTANCE_GROUP_NAME = 'managed-instance-group-name' +DEFAULT_MAANGED_INSTANCE_GROUP_CREATION_TIMEOUT = 1200 # 20 minutes diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 82a48fbb4d1..5c137c81bca 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -3,6 +3,7 @@ import copy from multiprocessing import pool import re +import sys import time from typing import Any, Callable, Dict, Iterable, List, Optional, Type @@ -138,15 +139,93 @@ def _get_head_instance_id(instances: List) -> Optional[str]: def _run_instances_in_managed_instance_group( region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionRecord: - print("Managed instance group is enabled.") + logger.debug("Managed instance group is enabled.") resumed_instance_ids: List[str] = [] created_instance_ids: List[str] = [] - node_type = instance_utils.get_node_type(config.node_config) + node_type = instance_utils.get_node_type(config) project_id = config.provider_config['project_id'] availability_zone = config.provider_config['availability_zone'] - head_instance_id = "" + head_instance_id = '' + + resource: Type[instance_utils.GCPInstance] + if node_type == instance_utils.GCPNodeType.COMPUTE: + resource = instance_utils.GCPComputeInstance + elif node_type == instance_utils.GCPNodeType.TPU: + resource = instance_utils.GCPTPUVMInstance + else: + raise ValueError(f'Unknown node type {node_type}') + filter_labels = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + exist_instances = resource.filter( + project_id=project_id, + zone=availability_zone, + label_filters=filter_labels, + status_filters=None, + ) + exist_instances = list(exist_instances.values()) + head_instance_id = _get_head_instance_id(exist_instances) + + # NOTE: We are not handling REPAIRING, SUSPENDING, SUSPENDED status. + pending_instances = [] + running_instances = [] + stopping_instances = [] + stopped_instances = [] + + # SkyPilot: We try to use the instances with the same matching launch_config + # first. If there is not enough instances with matching launch_config, we + # then use all the instances with the same matching launch_config plus some + # instances with wrong launch_config. + def get_order_key(node): + import datetime # pylint: disable=import-outside-toplevel + + timestamp = node.get('lastStartTimestamp') + if timestamp is not None: + return datetime.datetime.strptime(timestamp, + '%Y-%m-%dT%H:%M:%S.%f%z') + return node['id'] + + logger.info(str(exist_instances)) + for inst in exist_instances: + state = inst[resource.STATUS_FIELD] + if state in resource.PENDING_STATES: + pending_instances.append(inst) + elif state == resource.RUNNING_STATE: + running_instances.append(inst) + elif state in resource.STOPPING_STATES: + stopping_instances.append(inst) + elif state in resource.STOPPED_STATES: + stopped_instances.append(inst) + else: + raise RuntimeError(f'Unsupported state "{state}".') + + pending_instances.sort(key=get_order_key, reverse=True) + running_instances.sort(key=get_order_key, reverse=True) + stopping_instances.sort(key=get_order_key, reverse=True) + stopped_instances.sort(key=get_order_key, reverse=True) + + if stopping_instances: + raise RuntimeError( + 'Some instances are being stopped during provisioning. ' + 'Please wait a while and retry.') + + if head_instance_id is None: + if running_instances: + head_instance_id = resource.create_node_tag( + project_id, + availability_zone, + running_instances[0]['name'], + is_head=True, + ) + elif pending_instances: + head_instance_id = resource.create_node_tag( + project_id, + availability_zone, + pending_instances[0]['name'], + is_head=True, + ) + + # check instance templates @@ -162,7 +241,7 @@ def _run_instances_in_managed_instance_group( config.node_config, cluster_name_on_cloud) else: - print(f"Instance template {instance_template_name} already exists...") + logger.debug(f"Instance template {instance_template_name} already exists...") # create managed instance group instance_template_url = f"projects/{project_id}/regions/{region}/instanceTemplates/{instance_template_name}" @@ -176,7 +255,7 @@ def _run_instances_in_managed_instance_group( size=config.count) else: # TODO: if we already have one, we should resize it. - print( + logger.debug( f"Managed instance group {managed_instance_group_name} already exists..." ) # mig_utils.resize_managed_instance_group(project_id, zone, group_name, size, run_duration) @@ -184,29 +263,21 @@ def _run_instances_in_managed_instance_group( mig_utils.wait_for_managed_group_to_be_stable(project_id, availability_zone, managed_instance_group_name) - resource: Type[instance_utils.GCPInstance] - if node_type == instance_utils.GCPNodeType.COMPUTE: - resource = instance_utils.GCPComputeInstance - elif node_type == instance_utils.GCPNodeType.TPU: - resource = instance_utils.GCPTPUVMInstance - else: - raise ValueError(f'Unknown node type {node_type}') - # filter_labels = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} - # running_instances = resource.filter( - # project_id=project_id, - # zone=availability_zone, - # label_filters=filter_labels, - # status_filters=resource.RUNNING_STATE, - # ).values() - - # # TODO: Can not tag individual nodes as they are part of the mig - # head_instance_id = resource.create_node_tag( - # project_id, - # availability_zone, - # running_instances[0]['name'], - # is_head=True, - # ) + running_instances = list(resource.filter( + project_id=project_id, + zone=availability_zone, + label_filters=filter_labels, + status_filters=resource.RUNNING_STATE, + ).values()) + + # TODO: Can not tag individual nodes as they are part of the mig + head_instance_id = resource.create_node_tag( + project_id, + availability_zone, + running_instances[0]['name'], + is_head=True, + ) # created_instance_ids = [n['name'] for n in running_instances] @@ -229,7 +300,7 @@ def _run_instances(region: str, cluster_name_on_cloud: str, resumed_instance_ids: List[str] = [] created_instance_ids: List[str] = [] - node_type = instance_utils.get_node_type(config.node_config) + node_type = instance_utils.get_node_type(config) project_id = config.provider_config['project_id'] availability_zone = config.provider_config['availability_zone'] @@ -239,6 +310,8 @@ def _run_instances(region: str, cluster_name_on_cloud: str, resource: Type[instance_utils.GCPInstance] if node_type == instance_utils.GCPNodeType.COMPUTE: resource = instance_utils.GCPComputeInstance + elif node_type == instance_utils.GCPNodeType.MIG: + resource = instance_utils.GCPMIGComputeInstance elif node_type == instance_utils.GCPNodeType.TPU: resource = instance_utils.GCPTPUVMInstance else: @@ -363,8 +436,8 @@ def get_order_key(node): if to_start_count > 0: errors, created_instance_ids = resource.create_instances( cluster_name_on_cloud, project_id, availability_zone, - config.node_config, labels, to_start_count, - head_instance_id is None) + config.node_config, labels, to_start_count, total_count=config.count, + include_head_node=head_instance_id is None) if errors: error = common.ProvisionerError('Failed to launch instances.') error.errors = errors @@ -427,11 +500,7 @@ def run_instances(region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionRecord: """See sky/provision/__init__.py""" try: - if gcp_utils.is_use_managed_instance_group(): - return _run_instances_in_managed_instance_group( - region, cluster_name_on_cloud, config) - else: - return _run_instances(region, cluster_name_on_cloud, config) + return _run_instances(region, cluster_name_on_cloud, config) except gcp.http_error_exception() as e: error_details = getattr(e, 'error_details') errors = [] diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index c3103d18248..5ac919b8c1b 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -14,12 +14,10 @@ from sky.clouds import gcp as gcp_cloud from sky.provision import common from sky.provision.gcp import constants +from sky.provision.gcp import mig_utils from sky.utils import common_utils from sky.utils import ux_utils -# Tag uniquely identifying all nodes of a cluster -TAG_SKYPILOT_CLUSTER_NAME = 'skypilot-cluster-name' -TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' # Tag for the name of the node INSTANCE_NAME_MAX_LEN = 64 INSTANCE_NAME_UUID_LEN = 8 @@ -219,7 +217,7 @@ def create_or_update_firewall_rule( firewall_rule_name: str, project_id: str, vpc_name: str, - cluster_name_on_cloud: str, + cluster_name: str, ports: List[str], ) -> dict: raise NotImplementedError @@ -243,6 +241,7 @@ def create_instances( node_config: dict, labels: dict, count: int, + total_count: int, include_head_node: bool, ) -> Tuple[Optional[List], List[str]]: """Creates multiple instances and returns result. @@ -621,6 +620,7 @@ def create_instances( node_config: dict, labels: dict, count: int, + total_count: int, include_head_node: bool, ) -> Tuple[Optional[List], List[str]]: # NOTE: The syntax for bulkInsert() is different from insert(). @@ -647,8 +647,8 @@ def create_instances( config.update({ 'labels': dict( labels, **{ - TAG_RAY_CLUSTER_NAME: cluster_name, - TAG_SKYPILOT_CLUSTER_NAME: cluster_name + constants.TAG_RAY_CLUSTER_NAME: cluster_name, + constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name }), }) @@ -941,6 +941,110 @@ def resize_disk(cls, project_id: str, availability_zone: str, cls.wait_for_operation(operation, project_id, availability_zone) +class GCPMIGComputeInstance(GCPComputeInstance): + @classmethod + def create_instances( + cls, + cluster_name: str, + project_id: str, + zone: str, + node_config: dict, + labels: dict, + count: int, + total_count: int, + include_head_node: bool, + ) -> Tuple[Optional[List], List[str]]: + logger.debug(f'Creating cluster with MIG: mig-{cluster_name!r}') + config = copy.deepcopy(node_config) + labels = dict(config.get('labels', {}), **labels) + + config.update({ + 'labels': dict( + labels, **{ + constants.TAG_RAY_CLUSTER_NAME: cluster_name, + # Assume all nodes are workers, we can update the head node once the instances are created + constants.TAG_RAY_NODE_KIND: 'worker', + constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name, + }), + }) + # Convert label values to string and lowercase per MIG API requirement. + config['labels'] = {k: str(v).lower() for k, v in config['labels'].items()} + + label_filters = { + constants.TAG_RAY_CLUSTER_NAME: cluster_name, + } + potential_head_instances = cls.filter(project_id, zone, label_filters={ + constants.TAG_RAY_NODE_KIND: 'head', + **label_filters, + }, status_filters=cls.NEED_TO_TERMINATE_STATES) + + # TODO: add instance template name to the task definition if the user + # wants to use it. + region = zone.rpartition('-')[0] + instance_template_name = 'it-' + cluster_name + if not mig_utils.check_instance_template_exits(project_id, region, + instance_template_name): + mig_utils.create_regional_instance_template(project_id, region, + instance_template_name, + config, + cluster_name) + else: + logger.debug(f'Instance template {instance_template_name} already ' + 'exists...') + + # create managed instance group + instance_template_url = (f'projects/{project_id}/regions/{region}/' + f'instanceTemplates/{instance_template_name}') + managed_instance_group_name = 'mig-' + cluster_name + mig_exist = mig_utils.check_managed_instance_group_exists( + project_id, zone, managed_instance_group_name) + if mig_exist: + # TODO: if we already have one, we should resize it. + logger.debug( + f'Managed instance group {managed_instance_group_name!r} ' + 'already exists. Resizing it...' + ) + mig_utils.resize_managed_instance_group(project_id, zone, managed_instance_group_name, total_count) + else: + assert count == total_count, ( + f'The count {count} should be the same as the total_count ' + f'{total_count} when creating a new managed instance group.') + mig_utils.create_managed_instance_group(project_id, + zone, + managed_instance_group_name, + instance_template_url, + size=total_count) + + # TODO: This will block the provisioning until the nodes are ready, + # which makes the failover not effective. + mig_utils.wait_for_managed_group_to_be_stable(project_id, zone, + managed_instance_group_name) + + head_instance_name = None + running_instances = cls.filter(project_id, zone, label_filters, status_filters=[cls.RUNNING_STATE]) + for running_instance_name in running_instances.keys(): + if running_instance_name in potential_head_instances: + head_instance_name = running_instance_name + break + else: + head_instance_name = list(running_instances.keys())[0] + + if mig_exist: + # We need to update the node's label if mig already exists, as the + # config is not updated during the resize operation. + for instance_name in running_instances.keys(): + cls.set_labels(project_id=project_id, + availability_zone=zone, + node_id=instance_name, + labels=config['labels']) + cls.create_node_tag( + project_id, + zone, + head_instance_name, + is_head=True, + ) + + class GCPTPUVMInstance(GCPInstance): """Instance handler for GCP TPU VM.""" @@ -1180,6 +1284,7 @@ def create_instances( node_config: dict, labels: dict, count: int, + total_count: int, include_head_node: bool, ) -> Tuple[Optional[List], List[str]]: config = copy.deepcopy(node_config) @@ -1202,8 +1307,8 @@ def create_instances( config.update({ 'labels': dict( labels, **{ - TAG_RAY_CLUSTER_NAME: cluster_name, - TAG_SKYPILOT_CLUSTER_NAME: cluster_name + constants.TAG_RAY_CLUSTER_NAME: cluster_name, + constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name }), }) @@ -1410,10 +1515,11 @@ class GCPNodeType(enum.Enum): """Enum for GCP node types (compute & tpu)""" COMPUTE = 'compute' + MIG = 'mig' TPU = 'tpu' -def get_node_type(node: dict) -> GCPNodeType: +def get_node_type(config: common.ProvisionConfig) -> GCPNodeType: """Returns node type based on the keys in ``node``. This is a very simple check. If we have a ``machineType`` key, @@ -1423,17 +1529,22 @@ def get_node_type(node: dict) -> GCPNodeType: This works for both node configs and API returned nodes. """ - - if 'machineType' not in node and 'acceleratorType' not in node: + provider_config = config.provider_config + node_config = config.node_config + if 'machineType' not in node_config and 'acceleratorType' not in node_config: raise ValueError( 'Invalid node. For a Compute instance, "machineType" is ' 'required. ' 'For a TPU instance, "acceleratorType" and no "machineType" ' 'is required. ' - f'Got {list(node)}') + f'Got {list(node_config)}') - if 'machineType' not in node and 'acceleratorType' in node: + if 'machineType' not in node_config and 'acceleratorType' in node_config: return GCPNodeType.TPU + + if provider_config.get(constants.USE_MANAGED_INSTANCE_GROUP_CONFIG, False): + return GCPNodeType.MIG + return GCPNodeType.COMPUTE diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index bf11129034d..31563b0af21 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -3,15 +3,14 @@ import sys import time import typing -from typing import Any +from typing import Any, Dict import zlib from google.api_core.extended_operation import ExtendedOperation from sky import sky_logging from sky.provision import common -from sky.provision.gcp.constants import TAG_RAY_CLUSTER_NAME -from sky.provision.gcp.constants import TAG_RAY_NODE_KIND +from sky.provision.gcp import constants if typing.TYPE_CHECKING: from google.cloud import compute_v1 @@ -21,7 +20,7 @@ logger = sky_logging.init_logger(__name__) -def create_node_config_hash(cluster_name_on_cloud, node_config) -> int: +def create_node_config_hash(cluster_name_on_cloud: str, node_config: Dict[str, Any]) -> int: """Create a hash value for the node config. This is to be used as a unique identifier for the instance template and mig @@ -34,13 +33,6 @@ def create_node_config_hash(cluster_name_on_cloud, node_config) -> int: def create_regional_instance_template_properties( cluster_name_on_cloud, node_config) -> 'compute_v1.InstanceProperties': - labels = node_config.get('labels', {}) | { - TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, - # Assume all nodes are workers, we can update the head node once the instances are created - TAG_RAY_NODE_KIND: 'worker' - } - # All label values must be string - labels = {key: str(val).lower() for key, val in labels.items()} return compute_v1.InstanceProperties( description= @@ -96,7 +88,7 @@ def create_regional_instance_template_properties( }]) ]), # Create labels from node config - labels=labels) + labels=node_config.get('labels', {})) def check_instance_template_exits(project_id, region, template_name) -> bool: @@ -214,28 +206,33 @@ def check_managed_instance_group_exists(project_id, zone, group_name) -> bool: is not None) -def resize_managed_instance_group(project_id, zone, group_name, size, - run_duration) -> None: +def resize_managed_instance_group(project_id: str, zone: str, group_name: str, size: int) -> None: try: - resize_request_name = f'resize-request-{str(int(time.time()))}' - - cmd = ( - f'gcloud beta compute instance-groups managed resize-requests create {group_name} ' - f'--resize-request={resize_request_name} ' - f'--resize-by={size} ' - f'--requested-run-duration={run_duration} ' - f'--zone={zone} ' - f'--project={project_id} ') - logger.info(f'Resizing MIG {group_name} with command:\n{cmd}') - proc = subprocess.run( - f'yes | {cmd}', - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=True, - check=True, - ) - stdout = proc.stdout.decode('ascii') - logger.info(stdout) + with compute_v1.InstanceGroupManagersClient() as compute_client: + response = compute_client.resize(project=project_id, + zone=zone, + instance_group_manager=group_name, + size=size) + wait_for_extended_operation(response, 'resize managed instance group', timeout=constants.DEFAULT_MAANGED_INSTANCE_GROUP_CREATION_TIMEOUT) + # resize_request_name = f'resize-request-{str(int(time.time()))}' + + # cmd = ( + # f'gcloud beta compute instance-groups managed resize-requests create {group_name} ' + # f'--resize-request={resize_request_name} ' + # f'--resize-by={size} ' + # f'--requested-run-duration={run_duration} ' + # f'--zone={zone} ' + # f'--project={project_id} ') + # logger.info(f'Resizing MIG {group_name} with command:\n{cmd}') + # proc = subprocess.run( + # f'yes | {cmd}', + # stdout=subprocess.PIPE, + # stderr=subprocess.PIPE, + # shell=True, + # check=True, + # ) + # stdout = proc.stdout.decode('ascii') + # logger.info(stdout) wait_for_managed_group_to_be_stable(project_id, zone, group_name) except subprocess.CalledProcessError as e: @@ -275,12 +272,16 @@ def view_resize_requests(project_id, zone, group_name) -> None: def wait_for_managed_group_to_be_stable(project_id, zone, group_name) -> None: + """Wait until the managed instance group is stable.""" try: cmd = ('gcloud compute instance-groups managed wait-until ' f'{group_name} ' '--stable ' f'--zone={zone} ' - f'--project={project_id}') + f'--project={project_id} ' + # TODO(zhwu): Allow users to specify timeout. + # 20 minutes timeout + '--timeout=1200') logger.info( f'Waiting for MIG {group_name} to be stable with command:\n{cmd}') proc = subprocess.run( From 6bafabf6361ec2d46dc9119f1131707ec52375d5 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Wed, 5 Jun 2024 02:07:14 +0000 Subject: [PATCH 05/29] fix --- sky/provision/gcp/instance_utils.py | 1 + sky/provision/gcp/mig_utils.py | 78 ++++++++++++++--------------- 2 files changed, 40 insertions(+), 39 deletions(-) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 5ac919b8c1b..412a063e9c7 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -1043,6 +1043,7 @@ def create_instances( head_instance_name, is_head=True, ) + return None, list(running_instances.keys()) diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index 31563b0af21..0f537f66305 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -207,45 +207,45 @@ def check_managed_instance_group_exists(project_id, zone, group_name) -> bool: def resize_managed_instance_group(project_id: str, zone: str, group_name: str, size: int) -> None: - try: - with compute_v1.InstanceGroupManagersClient() as compute_client: - response = compute_client.resize(project=project_id, - zone=zone, - instance_group_manager=group_name, - size=size) - wait_for_extended_operation(response, 'resize managed instance group', timeout=constants.DEFAULT_MAANGED_INSTANCE_GROUP_CREATION_TIMEOUT) - # resize_request_name = f'resize-request-{str(int(time.time()))}' - - # cmd = ( - # f'gcloud beta compute instance-groups managed resize-requests create {group_name} ' - # f'--resize-request={resize_request_name} ' - # f'--resize-by={size} ' - # f'--requested-run-duration={run_duration} ' - # f'--zone={zone} ' - # f'--project={project_id} ') - # logger.info(f'Resizing MIG {group_name} with command:\n{cmd}') - # proc = subprocess.run( - # f'yes | {cmd}', - # stdout=subprocess.PIPE, - # stderr=subprocess.PIPE, - # shell=True, - # check=True, - # ) - # stdout = proc.stdout.decode('ascii') - # logger.info(stdout) - wait_for_managed_group_to_be_stable(project_id, zone, group_name) - - except subprocess.CalledProcessError as e: - stderr = e.stderr.decode('ascii') - logger.info(stderr) - provisioner_err = common.ProvisionerError('Failed to resize MIG') - provisioner_err.errors = [{ - 'code': 'UNKNOWN', - 'domain': 'mig', - 'message': stderr - }] - # _log_errors(provisioner_err.errors, e, zone) - raise provisioner_err from e + # try: + with compute_v1.InstanceGroupManagersClient() as compute_client: + response = compute_client.resize(project=project_id, + zone=zone, + instance_group_manager=group_name, + size=size) + wait_for_extended_operation(response, 'resize managed instance group', timeout=constants.DEFAULT_MAANGED_INSTANCE_GROUP_CREATION_TIMEOUT) + # resize_request_name = f'resize-request-{str(int(time.time()))}' + + # cmd = ( + # f'gcloud beta compute instance-groups managed resize-requests create {group_name} ' + # f'--resize-request={resize_request_name} ' + # f'--resize-by={size} ' + # f'--requested-run-duration={run_duration} ' + # f'--zone={zone} ' + # f'--project={project_id} ') + # logger.info(f'Resizing MIG {group_name} with command:\n{cmd}') + # proc = subprocess.run( + # f'yes | {cmd}', + # stdout=subprocess.PIPE, + # stderr=subprocess.PIPE, + # shell=True, + # check=True, + # ) + # stdout = proc.stdout.decode('ascii') + # logger.info(stdout) + wait_for_managed_group_to_be_stable(project_id, zone, group_name) + + # except subprocess.CalledProcessError as e: + # stderr = e.stderr.decode('ascii') + # logger.info(stderr) + # provisioner_err = common.ProvisionerError('Failed to resize MIG') + # provisioner_err.errors = [{ + # 'code': 'UNKNOWN', + # 'domain': 'mig', + # 'message': stderr + # }] + # # _log_errors(provisioner_err.errors, e, zone) + # raise provisioner_err from e def view_resize_requests(project_id, zone, group_name) -> None: From af0cde50b7be18faba2ada8f4c0ee2fd8d1a0ef9 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Wed, 5 Jun 2024 02:40:42 +0000 Subject: [PATCH 06/29] remove unecessary instance creation code for mig --- sky/provision/gcp/instance.py | 161 ++-------------------------- sky/provision/gcp/instance_utils.py | 12 ++- sky/provision/gcp/mig_utils.py | 31 +++--- 3 files changed, 33 insertions(+), 171 deletions(-) diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 5c137c81bca..8c116584d6e 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -136,160 +136,6 @@ def _get_head_instance_id(instances: List) -> Optional[str]: return head_instance_id -def _run_instances_in_managed_instance_group( - region: str, cluster_name_on_cloud: str, - config: common.ProvisionConfig) -> common.ProvisionRecord: - logger.debug("Managed instance group is enabled.") - - resumed_instance_ids: List[str] = [] - created_instance_ids: List[str] = [] - - node_type = instance_utils.get_node_type(config) - project_id = config.provider_config['project_id'] - availability_zone = config.provider_config['availability_zone'] - head_instance_id = '' - - resource: Type[instance_utils.GCPInstance] - if node_type == instance_utils.GCPNodeType.COMPUTE: - resource = instance_utils.GCPComputeInstance - elif node_type == instance_utils.GCPNodeType.TPU: - resource = instance_utils.GCPTPUVMInstance - else: - raise ValueError(f'Unknown node type {node_type}') - filter_labels = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} - exist_instances = resource.filter( - project_id=project_id, - zone=availability_zone, - label_filters=filter_labels, - status_filters=None, - ) - exist_instances = list(exist_instances.values()) - head_instance_id = _get_head_instance_id(exist_instances) - - # NOTE: We are not handling REPAIRING, SUSPENDING, SUSPENDED status. - pending_instances = [] - running_instances = [] - stopping_instances = [] - stopped_instances = [] - - # SkyPilot: We try to use the instances with the same matching launch_config - # first. If there is not enough instances with matching launch_config, we - # then use all the instances with the same matching launch_config plus some - # instances with wrong launch_config. - def get_order_key(node): - import datetime # pylint: disable=import-outside-toplevel - - timestamp = node.get('lastStartTimestamp') - if timestamp is not None: - return datetime.datetime.strptime(timestamp, - '%Y-%m-%dT%H:%M:%S.%f%z') - return node['id'] - - logger.info(str(exist_instances)) - for inst in exist_instances: - state = inst[resource.STATUS_FIELD] - if state in resource.PENDING_STATES: - pending_instances.append(inst) - elif state == resource.RUNNING_STATE: - running_instances.append(inst) - elif state in resource.STOPPING_STATES: - stopping_instances.append(inst) - elif state in resource.STOPPED_STATES: - stopped_instances.append(inst) - else: - raise RuntimeError(f'Unsupported state "{state}".') - - pending_instances.sort(key=get_order_key, reverse=True) - running_instances.sort(key=get_order_key, reverse=True) - stopping_instances.sort(key=get_order_key, reverse=True) - stopped_instances.sort(key=get_order_key, reverse=True) - - if stopping_instances: - raise RuntimeError( - 'Some instances are being stopped during provisioning. ' - 'Please wait a while and retry.') - - if head_instance_id is None: - if running_instances: - head_instance_id = resource.create_node_tag( - project_id, - availability_zone, - running_instances[0]['name'], - is_head=True, - ) - elif pending_instances: - head_instance_id = resource.create_node_tag( - project_id, - availability_zone, - pending_instances[0]['name'], - is_head=True, - ) - - - - # check instance templates - - # TODO add instance template name to the task definition if the user wants to use it. - # Calculate template instance name based on the config node values. Hash them to get a unique name. - node_config_hash = mig_utils.create_node_config_hash( - cluster_name_on_cloud, config.node_config) - instance_template_name = f"{cluster_name_on_cloud}-it-{node_config_hash}" - if not mig_utils.check_instance_template_exits(project_id, region, - instance_template_name): - mig_utils.create_regional_instance_template(project_id, region, - instance_template_name, - config.node_config, - cluster_name_on_cloud) - else: - logger.debug(f"Instance template {instance_template_name} already exists...") - - # create managed instance group - instance_template_url = f"projects/{project_id}/regions/{region}/instanceTemplates/{instance_template_name}" - managed_instance_group_name = f"{cluster_name_on_cloud}-mig-{node_config_hash}" - if not mig_utils.check_managed_instance_group_exists( - project_id, availability_zone, managed_instance_group_name): - mig_utils.create_managed_instance_group(project_id, - availability_zone, - managed_instance_group_name, - instance_template_url, - size=config.count) - else: - # TODO: if we already have one, we should resize it. - logger.debug( - f"Managed instance group {managed_instance_group_name} already exists..." - ) - # mig_utils.resize_managed_instance_group(project_id, zone, group_name, size, run_duration) - - mig_utils.wait_for_managed_group_to_be_stable(project_id, availability_zone, - managed_instance_group_name) - - - running_instances = list(resource.filter( - project_id=project_id, - zone=availability_zone, - label_filters=filter_labels, - status_filters=resource.RUNNING_STATE, - ).values()) - - # TODO: Can not tag individual nodes as they are part of the mig - head_instance_id = resource.create_node_tag( - project_id, - availability_zone, - running_instances[0]['name'], - is_head=True, - ) - - # created_instance_ids = [n['name'] for n in running_instances] - - return common.ProvisionRecord(provider_name='gcp', - region=region, - zone=availability_zone, - cluster_name=cluster_name_on_cloud, - head_instance_id=head_instance_id, - resumed_instance_ids=resumed_instance_ids, - created_instance_ids=created_instance_ids) - - def _run_instances(region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionRecord: """See sky/provision/__init__.py""" @@ -661,6 +507,13 @@ def terminate_instances( tpu_node = provider_config.get('tpu_node') if tpu_node is not None: instance_utils.delete_tpu_node(project_id, zone, tpu_node) + use_mig = provider_config.get(constants.USE_MANAGED_INSTANCE_GROUP_CONFIG, False) + if use_mig: + # Deleting the MIG will also delete the instances. + instance_utils.GCPMIGComputeInstance.delete_mig( + project_id, zone, cluster_name_on_cloud) + return + label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} if worker_only: diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 412a063e9c7..d99ec84de0e 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -131,11 +131,9 @@ def instance_to_handler(instance: str): return GCPComputeInstance elif instance_type == 'tpu': return GCPTPUVMInstance + elif instance.startswith('mig-'): + return GCPMIGComputeInstance else: - # Managed Instance Groups breaks this assumption. The suffix is a random value. - if "-mig-" in instance: - # TODO: Implement MIG Instance - return GCPComputeInstance raise ValueError(f'Unknown instance type: {instance_type}') @@ -1045,7 +1043,11 @@ def create_instances( ) return None, list(running_instances.keys()) - + @classmethod + def delete_mig(cls, project_id: str, zone: str, cluster_name: str) -> dict: + mig_utils.delete_managed_instance_group(project_id, zone, ) + mig_utils.delete_regional_instance_template(project_id, zone, ) + return {} class GCPTPUVMInstance(GCPInstance): """Instance handler for GCP TPU VM.""" diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index 0f537f66305..683680f4218 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -9,7 +9,6 @@ from google.api_core.extended_operation import ExtendedOperation from sky import sky_logging -from sky.provision import common from sky.provision.gcp import constants if typing.TYPE_CHECKING: @@ -20,15 +19,11 @@ logger = sky_logging.init_logger(__name__) -def create_node_config_hash(cluster_name_on_cloud: str, node_config: Dict[str, Any]) -> int: - """Create a hash value for the node config. +def get_instance_template_name(cluster_name: str) -> str: + return f'it-{cluster_name}' - This is to be used as a unique identifier for the instance template and mig - names. - """ - properties = create_regional_instance_template_properties( - cluster_name_on_cloud, node_config) - return zlib.adler32(repr(properties).encode()) +def get_managed_instance_group_name(cluster_name: str) -> str: + return f'mig-{cluster_name}' def create_regional_instance_template_properties( @@ -161,9 +156,6 @@ def delete_regional_instance_template(project_id, region, def create_managed_instance_group(project_id, zone, group_name, instance_template_url, size) -> None: - # credentials, project = google.auth.default() - # compute_client = compute_v1.InstanceGroupManagersClient(credentials=credentials) - with compute_v1.InstanceGroupManagersClient() as compute_client: # Create the managed instance group request request = compute_v1.InsertInstanceGroupManagerRequest( @@ -193,6 +185,21 @@ def create_managed_instance_group(project_id, zone, group_name, logger.debug( f'Managed instance group {group_name!r} created successfully.') +def delete_managed_instance_group(project_id, zone, group_name) -> None: + with compute_v1.InstanceGroupManagersClient() as compute_client: + # Create the managed instance group request + request = compute_v1.DeleteInstanceGroupManagerRequest( + project=project_id, + zone=zone, + instance_group_manager=group_name, + ) + + # Send the request to delete the managed instance group + response = compute_client.delete(request=request) + # Wait for the operation to complete + logger.debug(response) + wait_for_extended_operation(response, + 'delete managed instance group', 600) def check_managed_instance_group_exists(project_id, zone, group_name) -> bool: with compute_v1.InstanceGroupManagersClient() as compute_client: From 5c7850be7f60c23f84dfb5a4b670221926f9f31c Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Wed, 5 Jun 2024 03:32:22 +0000 Subject: [PATCH 07/29] Fix deletion --- sky/provision/gcp/constants.py | 2 +- sky/provision/gcp/instance.py | 17 +++-- sky/provision/gcp/instance_utils.py | 75 +++++++++++++-------- sky/provision/gcp/mig_utils.py | 100 ++++++++++++++++++---------- 4 files changed, 122 insertions(+), 72 deletions(-) diff --git a/sky/provision/gcp/constants.py b/sky/provision/gcp/constants.py index dc0838ab661..3f7f6e2d23b 100644 --- a/sky/provision/gcp/constants.py +++ b/sky/provision/gcp/constants.py @@ -224,4 +224,4 @@ # MIG constants USE_MANAGED_INSTANCE_GROUP_CONFIG = 'use_managed_instance_group' TAG_MANAGED_INSTANCE_GROUP_NAME = 'managed-instance-group-name' -DEFAULT_MAANGED_INSTANCE_GROUP_CREATION_TIMEOUT = 1200 # 20 minutes +DEFAULT_MAANGED_INSTANCE_GROUP_CREATION_TIMEOUT = 1200 # 20 minutes diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 8c116584d6e..a32135e5dc1 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -157,7 +157,7 @@ def _run_instances(region: str, cluster_name_on_cloud: str, if node_type == instance_utils.GCPNodeType.COMPUTE: resource = instance_utils.GCPComputeInstance elif node_type == instance_utils.GCPNodeType.MIG: - resource = instance_utils.GCPMIGComputeInstance + resource = instance_utils.GCPManagedInstanceGroup elif node_type == instance_utils.GCPNodeType.TPU: resource = instance_utils.GCPTPUVMInstance else: @@ -281,8 +281,13 @@ def get_order_key(node): if to_start_count > 0: errors, created_instance_ids = resource.create_instances( - cluster_name_on_cloud, project_id, availability_zone, - config.node_config, labels, to_start_count, total_count=config.count, + cluster_name_on_cloud, + project_id, + availability_zone, + config.node_config, + labels, + to_start_count, + total_count=config.count, include_head_node=head_instance_id is None) if errors: error = common.ProvisionerError('Failed to launch instances.') @@ -507,14 +512,14 @@ def terminate_instances( tpu_node = provider_config.get('tpu_node') if tpu_node is not None: instance_utils.delete_tpu_node(project_id, zone, tpu_node) - use_mig = provider_config.get(constants.USE_MANAGED_INSTANCE_GROUP_CONFIG, False) + use_mig = provider_config.get(constants.USE_MANAGED_INSTANCE_GROUP_CONFIG, + False) if use_mig: # Deleting the MIG will also delete the instances. - instance_utils.GCPMIGComputeInstance.delete_mig( + instance_utils.GCPManagedInstanceGroup.delete_mig( project_id, zone, cluster_name_on_cloud) return - label_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} if worker_only: label_filters[constants.TAG_RAY_NODE_KIND] = 'worker' diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index d99ec84de0e..d986ef57c84 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -132,7 +132,7 @@ def instance_to_handler(instance: str): elif instance_type == 'tpu': return GCPTPUVMInstance elif instance.startswith('mig-'): - return GCPMIGComputeInstance + return GCPManagedInstanceGroup else: raise ValueError(f'Unknown instance type: {instance_type}') @@ -939,7 +939,9 @@ def resize_disk(cls, project_id: str, availability_zone: str, cls.wait_for_operation(operation, project_id, availability_zone) -class GCPMIGComputeInstance(GCPComputeInstance): + +class GCPManagedInstanceGroup(GCPComputeInstance): + @classmethod def create_instances( cls, @@ -949,7 +951,7 @@ def create_instances( node_config: dict, labels: dict, count: int, - total_count: int, + total_count: int, include_head_node: bool, ) -> Tuple[Optional[List], List[str]]: logger.debug(f'Creating cluster with MIG: mig-{cluster_name!r}') @@ -958,7 +960,8 @@ def create_instances( config.update({ 'labels': dict( - labels, **{ + labels, + **{ constants.TAG_RAY_CLUSTER_NAME: cluster_name, # Assume all nodes are workers, we can update the head node once the instances are created constants.TAG_RAY_NODE_KIND: 'worker', @@ -966,26 +969,31 @@ def create_instances( }), }) # Convert label values to string and lowercase per MIG API requirement. - config['labels'] = {k: str(v).lower() for k, v in config['labels'].items()} + config['labels'] = { + k: str(v).lower() for k, v in config['labels'].items() + } label_filters = { constants.TAG_RAY_CLUSTER_NAME: cluster_name, } - potential_head_instances = cls.filter(project_id, zone, label_filters={ - constants.TAG_RAY_NODE_KIND: 'head', - **label_filters, - }, status_filters=cls.NEED_TO_TERMINATE_STATES) + potential_head_instances = cls.filter( + project_id, + zone, + label_filters={ + constants.TAG_RAY_NODE_KIND: 'head', + **label_filters, + }, + status_filters=cls.NEED_TO_TERMINATE_STATES) # TODO: add instance template name to the task definition if the user # wants to use it. region = zone.rpartition('-')[0] instance_template_name = 'it-' + cluster_name if not mig_utils.check_instance_template_exits(project_id, region, - instance_template_name): + instance_template_name): mig_utils.create_regional_instance_template(project_id, region, instance_template_name, - config, - cluster_name) + config, cluster_name) else: logger.debug(f'Instance template {instance_template_name} already ' 'exists...') @@ -995,14 +1003,14 @@ def create_instances( f'instanceTemplates/{instance_template_name}') managed_instance_group_name = 'mig-' + cluster_name mig_exist = mig_utils.check_managed_instance_group_exists( - project_id, zone, managed_instance_group_name) + project_id, zone, managed_instance_group_name) if mig_exist: # TODO: if we already have one, we should resize it. logger.debug( f'Managed instance group {managed_instance_group_name!r} ' - 'already exists. Resizing it...' - ) - mig_utils.resize_managed_instance_group(project_id, zone, managed_instance_group_name, total_count) + 'already exists. Resizing it...') + mig_utils.resize_managed_instance_group( + project_id, zone, managed_instance_group_name, total_count) else: assert count == total_count, ( f'The count {count} should be the same as the total_count ' @@ -1015,11 +1023,14 @@ def create_instances( # TODO: This will block the provisioning until the nodes are ready, # which makes the failover not effective. - mig_utils.wait_for_managed_group_to_be_stable(project_id, zone, - managed_instance_group_name) - + mig_utils.wait_for_managed_group_to_be_stable( + project_id, zone, managed_instance_group_name) + head_instance_name = None - running_instances = cls.filter(project_id, zone, label_filters, status_filters=[cls.RUNNING_STATE]) + running_instances = cls.filter(project_id, + zone, + label_filters, + status_filters=[cls.RUNNING_STATE]) for running_instance_name in running_instances.keys(): if running_instance_name in potential_head_instances: head_instance_name = running_instance_name @@ -1032,9 +1043,9 @@ def create_instances( # config is not updated during the resize operation. for instance_name in running_instances.keys(): cls.set_labels(project_id=project_id, - availability_zone=zone, - node_id=instance_name, - labels=config['labels']) + availability_zone=zone, + node_id=instance_name, + labels=config['labels']) cls.create_node_tag( project_id, zone, @@ -1042,12 +1053,18 @@ def create_instances( is_head=True, ) return None, list(running_instances.keys()) - + @classmethod - def delete_mig(cls, project_id: str, zone: str, cluster_name: str) -> dict: - mig_utils.delete_managed_instance_group(project_id, zone, ) - mig_utils.delete_regional_instance_template(project_id, zone, ) - return {} + def delete_mig(cls, project_id: str, zone: str, cluster_name: str) -> None: + mig_utils.delete_managed_instance_group( + project_id, zone, + mig_utils.get_managed_instance_group_name(cluster_name)) + response = mig_utils.delete_regional_instance_template( + project_id, zone, + mig_utils.get_instance_template_name(cluster_name)) + mig_utils.wait_for_extended_operation( + response, verbose_name='Deletion of Instance Template') + class GCPTPUVMInstance(GCPInstance): """Instance handler for GCP TPU VM.""" @@ -1547,7 +1564,7 @@ def get_node_type(config: common.ProvisionConfig) -> GCPNodeType: if provider_config.get(constants.USE_MANAGED_INSTANCE_GROUP_CONFIG, False): return GCPNodeType.MIG - + return GCPNodeType.COMPUTE diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index 683680f4218..b88e6983c56 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -1,4 +1,5 @@ """Managed Instance Group Utils""" +import re import subprocess import sys import time @@ -9,6 +10,7 @@ from google.api_core.extended_operation import ExtendedOperation from sky import sky_logging +from sky.adaptors import gcp from sky.provision.gcp import constants if typing.TYPE_CHECKING: @@ -18,10 +20,18 @@ logger = sky_logging.init_logger(__name__) +_MIG_RESOURCE_NOT_FOUND_PATTERN = re.compile( + r'The resource \'projects/.*/zones/.*/instanceGroupManagers/.*\' was not found' +) + +_IT_RESOURCE_NOT_FOUND_PATTERN = re.compile( + r'The resource \'projects/.*/zones/.*/instanceTemplates/.*\' was not found') + def get_instance_template_name(cluster_name: str) -> str: return f'it-{cluster_name}' + def get_managed_instance_group_name(cluster_name: str) -> str: return f'mig-{cluster_name}' @@ -136,8 +146,8 @@ def create_regional_instance_template(project_id, region, template_name, 'created successfully.') -def delete_regional_instance_template(project_id, region, - template_name) -> None: +def delete_regional_instance_template(project_id, zone, template_name) -> dict: + region = zone.rsplit('-', 1)[0] with compute_v1.RegionInstanceTemplatesClient() as compute_client: # Create the regional instance template request request = compute_v1.DeleteRegionInstanceTemplateRequest( @@ -145,13 +155,41 @@ def delete_regional_instance_template(project_id, region, region=region, instance_template=template_name, ) + try: + # Send the request to delete the regional instance template + response = compute_client.delete(request=request) + return response + except gcp.google.api_core.exceptions.NotFound as e: + if re.search(_IT_RESOURCE_NOT_FOUND_PATTERN, str(e)) is None: + raise + logger.warning(f'Instance template {template_name!r} does not ' + 'exist. Skip deletion.') + return {} + + +def delete_managed_instance_group(project_id, zone, group_name) -> dict: + with compute_v1.InstanceGroupManagersClient() as compute_client: + # Create the managed instance group request + request = compute_v1.DeleteInstanceGroupManagerRequest( + project=project_id, + zone=zone, + instance_group_manager=group_name, + ) - # Send the request to delete the regional instance template - response = compute_client.delete(request=request) - # Wait for the operation to complete - logger.debug(response) - wait_for_extended_operation(response, - 'delete regional instance template', 600) + try: + # Send the request to delete the managed instance group + response = compute_client.delete(request=request) + # Do not wait for the deletion of MIG, so we can send the deletion + # request for the instance template, immediately after this, which is + # important when we are autodown a cluster from the head node. + return response + except gcp.google.api_core.exceptions.NotFound as e: + print(e) + if re.search(_MIG_RESOURCE_NOT_FOUND_PATTERN, str(e)) is None: + raise + logger.warning(f'MIG {group_name!r} does not exist. Skip ' + 'deletion.') + return {} def create_managed_instance_group(project_id, zone, group_name, @@ -185,21 +223,6 @@ def create_managed_instance_group(project_id, zone, group_name, logger.debug( f'Managed instance group {group_name!r} created successfully.') -def delete_managed_instance_group(project_id, zone, group_name) -> None: - with compute_v1.InstanceGroupManagersClient() as compute_client: - # Create the managed instance group request - request = compute_v1.DeleteInstanceGroupManagerRequest( - project=project_id, - zone=zone, - instance_group_manager=group_name, - ) - - # Send the request to delete the managed instance group - response = compute_client.delete(request=request) - # Wait for the operation to complete - logger.debug(response) - wait_for_extended_operation(response, - 'delete managed instance group', 600) def check_managed_instance_group_exists(project_id, zone, group_name) -> bool: with compute_v1.InstanceGroupManagersClient() as compute_client: @@ -213,14 +236,18 @@ def check_managed_instance_group_exists(project_id, zone, group_name) -> bool: is not None) -def resize_managed_instance_group(project_id: str, zone: str, group_name: str, size: int) -> None: +def resize_managed_instance_group(project_id: str, zone: str, group_name: str, + size: int) -> None: # try: with compute_v1.InstanceGroupManagersClient() as compute_client: response = compute_client.resize(project=project_id, - zone=zone, - instance_group_manager=group_name, - size=size) - wait_for_extended_operation(response, 'resize managed instance group', timeout=constants.DEFAULT_MAANGED_INSTANCE_GROUP_CREATION_TIMEOUT) + zone=zone, + instance_group_manager=group_name, + size=size) + wait_for_extended_operation( + response, + 'resize managed instance group', + timeout=constants.DEFAULT_MAANGED_INSTANCE_GROUP_CREATION_TIMEOUT) # resize_request_name = f'resize-request-{str(int(time.time()))}' # cmd = ( @@ -281,14 +308,15 @@ def view_resize_requests(project_id, zone, group_name) -> None: def wait_for_managed_group_to_be_stable(project_id, zone, group_name) -> None: """Wait until the managed instance group is stable.""" try: - cmd = ('gcloud compute instance-groups managed wait-until ' - f'{group_name} ' - '--stable ' - f'--zone={zone} ' - f'--project={project_id} ' - # TODO(zhwu): Allow users to specify timeout. - # 20 minutes timeout - '--timeout=1200') + cmd = ( + 'gcloud compute instance-groups managed wait-until ' + f'{group_name} ' + '--stable ' + f'--zone={zone} ' + f'--project={project_id} ' + # TODO(zhwu): Allow users to specify timeout. + # 20 minutes timeout + '--timeout=1200') logger.info( f'Waiting for MIG {group_name} to be stable with command:\n{cmd}') proc = subprocess.run( From 91bba396402329a9ea2e5963b342e38d9e0fdddd Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Wed, 5 Jun 2024 08:32:46 +0000 Subject: [PATCH 08/29] Fix instance template logic --- sky/clouds/gcp.py | 13 ++-- sky/provision/gcp/config.py | 2 +- sky/provision/gcp/constants.py | 2 + sky/provision/gcp/instance.py | 3 - sky/provision/gcp/instance_utils.py | 102 ++++++++++++++++++---------- sky/provision/gcp/mig_utils.py | 71 ++++++++++--------- 6 files changed, 113 insertions(+), 80 deletions(-) diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 443c1f4bf14..993effc9654 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -179,20 +179,19 @@ class GCP(clouds.Cloud): def _unsupported_features_for_resources( cls, resources: 'resources.Resources' ) -> Dict[clouds.CloudImplementationFeatures, str]: + unsupported = {} if gcp_utils.is_tpu_vm_pod(resources): - return { + unsupported = { clouds.CloudImplementationFeatures.STOP: ( 'TPU VM pods cannot be stopped. Please refer to: https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#stopping_your_resources' ) } if gcp_utils.is_tpu(resources) and not gcp_utils.is_tpu_vm(resources): # TPU node does not support multi-node. - return { - clouds.CloudImplementationFeatures.MULTI_NODE: - ('TPU node does not support multi-node. Please set ' - 'num_nodes to 1.') - } - return {} + unsupported[clouds.CloudImplementationFeatures.MULTI_NODE] = ( + 'TPU node does not support multi-node. Please set ' + 'num_nodes to 1.') + return unsupported @classmethod def max_cluster_name_length(cls) -> Optional[int]: diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index 6a205545ee2..6d576e13bc9 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -374,7 +374,7 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam) -> dict: 'scopes': ['https://www.googleapis.com/auth/cloud-platform'], } iam_role: Dict[str, Any] - if (instance_utils.get_node_type(config) == instance_utils.GCPNodeType.TPU): + if instance_utils.get_node_type(config) == instance_utils.GCPNodeType.TPU: # SKY: The API for TPU VM is slightly different from normal compute # instances. # See https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#Node # pylint: disable=line-too-long diff --git a/sky/provision/gcp/constants.py b/sky/provision/gcp/constants.py index 3f7f6e2d23b..31f7ef787ae 100644 --- a/sky/provision/gcp/constants.py +++ b/sky/provision/gcp/constants.py @@ -225,3 +225,5 @@ USE_MANAGED_INSTANCE_GROUP_CONFIG = 'use_managed_instance_group' TAG_MANAGED_INSTANCE_GROUP_NAME = 'managed-instance-group-name' DEFAULT_MAANGED_INSTANCE_GROUP_CREATION_TIMEOUT = 1200 # 20 minutes +MIG_NAME_PREFIX = 'sky-mig-' +INSTANCE_TEMPLATE_NAME_PREFIX = 'sky-it-' diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index a32135e5dc1..93cbea5cb3c 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -3,18 +3,15 @@ import copy from multiprocessing import pool import re -import sys import time from typing import Any, Callable, Dict, Iterable, List, Optional, Type from sky import sky_logging from sky import status_lib from sky.adaptors import gcp -from sky.clouds.gcp import gcp_utils from sky.provision import common from sky.provision.gcp import constants from sky.provision.gcp import instance_utils -from sky.provision.gcp import mig_utils from sky.utils import common_utils logger = sky_logging.init_logger(__name__) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index d986ef57c84..0bf8ab925a0 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -131,7 +131,7 @@ def instance_to_handler(instance: str): return GCPComputeInstance elif instance_type == 'tpu': return GCPTPUVMInstance - elif instance.startswith('mig-'): + elif instance.startswith(constants.MIG_NAME_PREFIX): return GCPManagedInstanceGroup else: raise ValueError(f'Unknown instance type: {instance_type}') @@ -215,7 +215,7 @@ def create_or_update_firewall_rule( firewall_rule_name: str, project_id: str, vpc_name: str, - cluster_name: str, + cluster_name_on_cloud: str, ports: List[str], ) -> dict: raise NotImplementedError @@ -941,6 +941,7 @@ def resize_disk(cls, project_id: str, availability_zone: str, class GCPManagedInstanceGroup(GCPComputeInstance): + """Handler for GCP Managed Instance Group.""" @classmethod def create_instances( @@ -954,7 +955,7 @@ def create_instances( total_count: int, include_head_node: bool, ) -> Tuple[Optional[List], List[str]]: - logger.debug(f'Creating cluster with MIG: mig-{cluster_name!r}') + logger.debug(f'Creating cluster with MIG: {cluster_name!r}') config = copy.deepcopy(node_config) labels = dict(config.get('labels', {}), **labels) @@ -963,48 +964,70 @@ def create_instances( labels, **{ constants.TAG_RAY_CLUSTER_NAME: cluster_name, - # Assume all nodes are workers, we can update the head node once the instances are created + # Assume all nodes are workers, we can update the head node + # once the instances are created. constants.TAG_RAY_NODE_KIND: 'worker', constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name, }), }) # Convert label values to string and lowercase per MIG API requirement. - config['labels'] = { - k: str(v).lower() for k, v in config['labels'].items() - } + region = zone.rpartition('-')[0] + instance_template_name = mig_utils.get_instance_template_name( + cluster_name) + managed_instance_group_name = mig_utils.get_managed_instance_group_name( + cluster_name) + + instance_template_exists = mig_utils.check_instance_template_exits( + project_id, region, instance_template_name) + mig_exists = mig_utils.check_managed_instance_group_exists( + project_id, zone, managed_instance_group_name) label_filters = { constants.TAG_RAY_CLUSTER_NAME: cluster_name, } - potential_head_instances = cls.filter( - project_id, - zone, - label_filters={ - constants.TAG_RAY_NODE_KIND: 'head', - **label_filters, - }, - status_filters=cls.NEED_TO_TERMINATE_STATES) - - # TODO: add instance template name to the task definition if the user - # wants to use it. - region = zone.rpartition('-')[0] - instance_template_name = 'it-' + cluster_name - if not mig_utils.check_instance_template_exits(project_id, region, - instance_template_name): + potential_head_instances = {} + if mig_exists: + potential_head_instances = cls.filter( + project_id, + zone, + label_filters={ + constants.TAG_RAY_NODE_KIND: 'head', + **label_filters, + }, + status_filters=cls.NEED_TO_TERMINATE_STATES) + + config['labels'] = { + k: str(v).lower() for k, v in config['labels'].items() + } + if instance_template_exists: + if mig_exists: + logger.debug( + f'Instance template {instance_template_name} already ' + 'exists. Skip creating it.') + else: + logger.debug( + f'Instance template {instance_template_name!r} ' + 'exists and no instance group is using it. This is a leftover ' + 'of a previous autodown. Delete it and recreate it.') + # TODO(zhwu): this is a bit hacky as we cannot delete instance + # template during an autodown, we can only defer the deletion + # to the next launch of a cluster with the same name. We should + # find a better way to handle this. + operation = mig_utils.delete_regional_instance_template( + project_id, region, instance_template_name) + mig_utils.wait_for_extended_operation( + operation, verbose_name='Deletion of Instance Template') + instance_template_exists = False + + if not instance_template_exists: mig_utils.create_regional_instance_template(project_id, region, instance_template_name, config, cluster_name) - else: - logger.debug(f'Instance template {instance_template_name} already ' - 'exists...') # create managed instance group instance_template_url = (f'projects/{project_id}/regions/{region}/' f'instanceTemplates/{instance_template_name}') - managed_instance_group_name = 'mig-' + cluster_name - mig_exist = mig_utils.check_managed_instance_group_exists( - project_id, zone, managed_instance_group_name) - if mig_exist: + if mig_exists: # TODO: if we already have one, we should resize it. logger.debug( f'Managed instance group {managed_instance_group_name!r} ' @@ -1026,7 +1049,7 @@ def create_instances( mig_utils.wait_for_managed_group_to_be_stable( project_id, zone, managed_instance_group_name) - head_instance_name = None + head_instance_name: Optional[str] = None running_instances = cls.filter(project_id, zone, label_filters, @@ -1038,7 +1061,7 @@ def create_instances( else: head_instance_name = list(running_instances.keys())[0] - if mig_exist: + if mig_exists: # We need to update the node's label if mig already exists, as the # config is not updated during the resize operation. for instance_name in running_instances.keys(): @@ -1056,14 +1079,20 @@ def create_instances( @classmethod def delete_mig(cls, project_id: str, zone: str, cluster_name: str) -> None: - mig_utils.delete_managed_instance_group( + response = mig_utils.delete_managed_instance_group( project_id, zone, mig_utils.get_managed_instance_group_name(cluster_name)) - response = mig_utils.delete_regional_instance_template( - project_id, zone, - mig_utils.get_instance_template_name(cluster_name)) mig_utils.wait_for_extended_operation( response, verbose_name='Deletion of Instance Template') + # In the autostop case, the following deletion of instance template + # will not be executed as the instance that runs the deletion will be + # terminated with the managed instance group. It is ok to leave the + # instance template there as when a user creates a new cluster with the + # same name, the instance template will be updated in our + # create_instances method. + mig_utils.delete_regional_instance_template( + project_id, zone, + mig_utils.get_instance_template_name(cluster_name)) class GCPTPUVMInstance(GCPInstance): @@ -1551,7 +1580,8 @@ def get_node_type(config: common.ProvisionConfig) -> GCPNodeType: """ provider_config = config.provider_config node_config = config.node_config - if 'machineType' not in node_config and 'acceleratorType' not in node_config: + if ('machineType' not in node_config and + 'acceleratorType' not in node_config): raise ValueError( 'Invalid node. For a Compute instance, "machineType" is ' 'required. ' diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index b88e6983c56..98e6053b2e8 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -2,18 +2,15 @@ import re import subprocess import sys -import time import typing -from typing import Any, Dict -import zlib - -from google.api_core.extended_operation import ExtendedOperation +from typing import Any, Optional from sky import sky_logging from sky.adaptors import gcp from sky.provision.gcp import constants if typing.TYPE_CHECKING: + from google.api_core import extended_operation from google.cloud import compute_v1 else: from sky.adaptors.gcp import compute_v1 @@ -21,27 +18,27 @@ logger = sky_logging.init_logger(__name__) _MIG_RESOURCE_NOT_FOUND_PATTERN = re.compile( - r'The resource \'projects/.*/zones/.*/instanceGroupManagers/.*\' was not found' -) + r'The resource \'projects/.*/zones/.*/instanceGroupManagers/.*\' was not ' + r'found') _IT_RESOURCE_NOT_FOUND_PATTERN = re.compile( r'The resource \'projects/.*/zones/.*/instanceTemplates/.*\' was not found') def get_instance_template_name(cluster_name: str) -> str: - return f'it-{cluster_name}' + return f'{constants.INSTANCE_TEMPLATE_NAME_PREFIX}{cluster_name}' def get_managed_instance_group_name(cluster_name: str) -> str: - return f'mig-{cluster_name}' + return f'{constants.MIG_NAME_PREFIX}{cluster_name}' def create_regional_instance_template_properties( cluster_name_on_cloud, node_config) -> 'compute_v1.InstanceProperties': return compute_v1.InstanceProperties( - description= - f'A temp instance template for {cluster_name_on_cloud} to support DWS requests.', + description=(f'A temp instance template for {cluster_name_on_cloud} to ' + 'support DWS requests.'), machine_type=node_config['machineType'], # We have to ignore reservations for DWS. # TODO: Add a warning log for this behvaiour. @@ -135,18 +132,20 @@ def create_regional_instance_template(project_id, region, template_name, # if operation.error: # raise Exception(f'Failed to create regional instance template: {operation.error}') - listRequest = compute_v1.ListRegionInstanceTemplatesRequest( + list_request = compute_v1.ListRegionInstanceTemplatesRequest( filter=f'name eq {template_name}', project=project_id, region=region, ) - list_response = compute_client.list(listRequest) + compute_client.list(list_request) # logger.debug(list_response) logger.debug(f'Regional instance template {template_name!r} ' 'created successfully.') -def delete_regional_instance_template(project_id, zone, template_name) -> dict: +def delete_regional_instance_template( + project_id, zone, + template_name) -> Optional['extended_operation.ExtendedOperation']: region = zone.rsplit('-', 1)[0] with compute_v1.RegionInstanceTemplatesClient() as compute_client: # Create the regional instance template request @@ -164,10 +163,12 @@ def delete_regional_instance_template(project_id, zone, template_name) -> dict: raise logger.warning(f'Instance template {template_name!r} does not ' 'exist. Skip deletion.') - return {} + return None -def delete_managed_instance_group(project_id, zone, group_name) -> dict: +def delete_managed_instance_group( + project_id, zone, + group_name) -> Optional['extended_operation.ExtendedOperation']: with compute_v1.InstanceGroupManagersClient() as compute_client: # Create the managed instance group request request = compute_v1.DeleteInstanceGroupManagerRequest( @@ -180,16 +181,15 @@ def delete_managed_instance_group(project_id, zone, group_name) -> dict: # Send the request to delete the managed instance group response = compute_client.delete(request=request) # Do not wait for the deletion of MIG, so we can send the deletion - # request for the instance template, immediately after this, which is - # important when we are autodown a cluster from the head node. + # request for the instance template, immediately after this, which + # is important when we are autodown a cluster from the head node. return response except gcp.google.api_core.exceptions.NotFound as e: - print(e) if re.search(_MIG_RESOURCE_NOT_FOUND_PATTERN, str(e)) is None: raise logger.warning(f'MIG {group_name!r} does not exist. Skip ' 'deletion.') - return {} + return None def create_managed_instance_group(project_id, zone, group_name, @@ -333,13 +333,14 @@ def wait_for_managed_group_to_be_stable(project_id, zone, group_name) -> None: logger.info(stderr) -def wait_for_extended_operation(operation: ExtendedOperation, - verbose_name: str = 'operation', - timeout: int = 300) -> Any: - # Taken from Google's samples - # https://cloud.google.com/compute/docs/samples/compute-operation-extended-wait?hl=en - """ - Waits for the extended (long-running) operation to complete. +def wait_for_extended_operation( + operation: 'extended_operation.ExtendedOperation', + verbose_name: str = 'operation', + timeout: int = 300) -> Any: + """Waits for the extended (long-running) operation to complete. + + Taken from Google's samples + https://cloud.google.com/compute/docs/samples/compute-operation-extended-wait?hl=en If the operation is successful, it will return its result. If the operation ends with an error, an exception will be raised. @@ -357,18 +358,22 @@ def wait_for_extended_operation(operation: ExtendedOperation, Whatever the operation.result() returns. Raises: - This method will raise the exception received from `operation.exception()` - or RuntimeError if there is no exception set, but there is an `error_code` - set for the `operation`. + This method will raise the exception received from + `operation.exception()` or RuntimeError if there is no exception set, + but there is an `error_code` set for the `operation`. - In case of an operation taking longer than `timeout` seconds to complete, - a `concurrent.futures.TimeoutError` will be raised. + In case of an operation taking longer than `timeout` seconds to + complete, a `concurrent.futures.TimeoutError` will be raised. """ + if operation is None: + return None + result = operation.result(timeout=timeout) if operation.error_code: logger.debug( - f'Error during {verbose_name}: [Code: {operation.error_code}]: {operation.error_message}', + f'Error during {verbose_name}: [Code: {operation.error_code}]: ' + f'{operation.error_message}', file=sys.stderr, flush=True, ) From 75241869d4c93a9b271ad52623e935964bafc645 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Wed, 5 Jun 2024 09:36:01 +0000 Subject: [PATCH 09/29] Restart --- sky/provision/gcp/instance.py | 16 ++-- sky/provision/gcp/instance_utils.py | 110 ++++++++++++++++++++-------- sky/provision/gcp/mig_utils.py | 25 ++++++- 3 files changed, 112 insertions(+), 39 deletions(-) diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 93cbea5cb3c..8c0e3804ef6 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -261,12 +261,16 @@ def get_order_key(node): if config.resume_stopped_nodes and to_start_count > 0 and stopped_instances: resumed_instance_ids = [n['name'] for n in stopped_instances] if resumed_instance_ids: - for instance_id in resumed_instance_ids: - resource.start_instance(instance_id, project_id, - availability_zone) - resource.set_labels(project_id, availability_zone, instance_id, - labels) - to_start_count -= len(resumed_instance_ids) + resumed_instance_ids = resource.start_instances( + cluster_name_on_cloud, project_id, availability_zone, + resumed_instance_ids, labels) + # In MIG case, the resumed_instance_ids will include the previously + # PENDING and RUNNING instances. To avoid double counting, we need to + # remove them from the resumed_instance_ids. + ready_instances = set(resumed_instance_ids) + ready_instances |= set([n['name'] for n in running_instances]) + ready_instances |= set([n['name'] for n in pending_instances]) + to_start_count = config.count - len(ready_instances) if head_instance_id is None: head_instance_id = resource.create_node_tag( diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 0bf8ab925a0..663716a7fd8 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -248,6 +248,20 @@ def create_instances( """ raise NotImplementedError + @classmethod + def start_instances(cls, cluster_name: str, project_id: str, zone: str, + instances: List[str], labels: Dict[str, + str]) -> List[str]: + """Start multiple instances. + + Returns: + List of instance names that are started. + """ + for instance_id in instances: + cls.start_instance(instance_id, project_id, zone) + cls.set_labels(project_id, zone, instance_id, labels) + return instances + @classmethod def start_instance(cls, node_id: str, project_id: str, zone: str) -> None: """Start a stopped instance.""" @@ -985,16 +999,16 @@ def create_instances( label_filters = { constants.TAG_RAY_CLUSTER_NAME: cluster_name, } - potential_head_instances = {} + potential_head_instances = [] if mig_exists: - potential_head_instances = cls.filter( - project_id, - zone, - label_filters={ - constants.TAG_RAY_NODE_KIND: 'head', - **label_filters, - }, - status_filters=cls.NEED_TO_TERMINATE_STATES) + instances = cls.filter(project_id, + zone, + label_filters={ + constants.TAG_RAY_NODE_KIND: 'head', + **label_filters, + }, + status_filters=cls.NEED_TO_TERMINATE_STATES) + potential_head_instances = list(instances.keys()) config['labels'] = { k: str(v).lower() for k, v in config['labels'].items() @@ -1049,33 +1063,17 @@ def create_instances( mig_utils.wait_for_managed_group_to_be_stable( project_id, zone, managed_instance_group_name) - head_instance_name: Optional[str] = None - running_instances = cls.filter(project_id, - zone, - label_filters, - status_filters=[cls.RUNNING_STATE]) - for running_instance_name in running_instances.keys(): - if running_instance_name in potential_head_instances: - head_instance_name = running_instance_name - break - else: - head_instance_name = list(running_instances.keys())[0] - - if mig_exists: - # We need to update the node's label if mig already exists, as the - # config is not updated during the resize operation. - for instance_name in running_instances.keys(): - cls.set_labels(project_id=project_id, - availability_zone=zone, - node_id=instance_name, - labels=config['labels']) + running_instance_names = cls._add_labels_and_find_head( + cluster_name, project_id, zone, labels, potential_head_instances) + assert len(running_instance_names) == total_count, ( + running_instance_names, total_count) cls.create_node_tag( project_id, zone, - head_instance_name, + running_instance_names[0], is_head=True, ) - return None, list(running_instances.keys()) + return None, running_instance_names @classmethod def delete_mig(cls, project_id: str, zone: str, cluster_name: str) -> None: @@ -1094,6 +1092,56 @@ def delete_mig(cls, project_id: str, zone: str, cluster_name: str) -> None: project_id, zone, mig_utils.get_instance_template_name(cluster_name)) + # TODO(zhwu): We want to restart the instances with MIG instead of the + # normal instance start API to take advantage of DWS. + # @classmethod + # def start_instances(cls, cluster_name: str, project_id: str, zone: str, + # instances: List[str], labels: Dict[str, + # str]) -> List[str]: + # del instances # unused + # potential_head_instances = cls.filter( + # project_id, + # zone, + # label_filters={ + # constants.TAG_RAY_NODE_KIND: 'head', + # constants.TAG_RAY_CLUSTER_NAME: cluster_name, + # }, + # status_filters=cls.NEED_TO_TERMINATE_STATES) + # mig_name = mig_utils.get_managed_instance_group_name(cluster_name) + # mig_utils.start_managed_instance_group(project_id, zone, mig_name) + # mig_utils.wait_for_managed_group_to_be_stable(project_id, zone, + # mig_name) + # return cls._add_labels_and_find_head(cluster_name, project_id, zone, + # labels, potential_head_instances) + + @classmethod + def _add_labels_and_find_head( + cls, cluster_name: str, project_id: str, zone: str, + labels: Dict[str, str], + potential_head_instances: List[str]) -> List[str]: + running_instances = cls.filter( + project_id, + zone, {constants.TAG_RAY_CLUSTER_NAME: cluster_name}, + status_filters=[cls.RUNNING_STATE]) + for running_instance_name in running_instances.keys(): + if running_instance_name in potential_head_instances: + head_instance_name = running_instance_name + break + else: + head_instance_name = list(running_instances.keys())[0] + # We need to update the node's label if mig already exists, as the + # config is not updated during the resize operation. + for instance_name in running_instances.keys(): + cls.set_labels(project_id=project_id, + availability_zone=zone, + node_id=instance_name, + labels=labels) + + running_instance_names = list(running_instances.keys()) + running_instance_names.remove(head_instance_name) + # Label for head node type will be set by caller + return [head_instance_name] + running_instance_names + class GCPTPUVMInstance(GCPInstance): """Instance handler for GCP TPU VM.""" diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index 98e6053b2e8..73b8fff8f35 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -37,8 +37,8 @@ def create_regional_instance_template_properties( cluster_name_on_cloud, node_config) -> 'compute_v1.InstanceProperties': return compute_v1.InstanceProperties( - description=(f'A temp instance template for {cluster_name_on_cloud} to ' - 'support DWS requests.'), + description=('SkyPilot instance template for ' + f'{cluster_name_on_cloud!r} to support DWS requests.'), machine_type=node_config['machineType'], # We have to ignore reservations for DWS. # TODO: Add a warning log for this behvaiour. @@ -224,6 +224,27 @@ def create_managed_instance_group(project_id, zone, group_name, f'Managed instance group {group_name!r} created successfully.') +def start_managed_instance_group(project_id, zone, group_name) -> None: + with compute_v1.InstanceGroupManagersClient() as compute_client: + # Create the managed instance group request + request = compute_v1.ApplyUpdatesToInstancesInstanceGroupManagerRequest( + instance_group_manager=group_name, + project=project_id, + zone=zone, + instance_group_managers_apply_updates_request_resource=compute_v1. + InstanceGroupManagersApplyUpdatesRequest( + all_instances=True, + minimal_action='NONE', + most_disruptive_allowed_action='RESTART', + ), + ) + + response = compute_client.apply_updates_to_instances(request) + wait_for_extended_operation(response, 'restart managed instance group', + 600) + logger.debug('Managed instance group restarted successfully.') + + def check_managed_instance_group_exists(project_id, zone, group_name) -> bool: with compute_v1.InstanceGroupManagersClient() as compute_client: request = compute_v1.ListInstanceGroupManagersRequest( From 4d29c5ba6b93a2a38a91323663fc6780a1d156cd Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Wed, 5 Jun 2024 09:57:05 +0000 Subject: [PATCH 10/29] format --- sky/provision/gcp/instance_utils.py | 10 ++++++---- sky/provision/gcp/mig_utils.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 663716a7fd8..1b8ed6c7c1f 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -253,10 +253,11 @@ def start_instances(cls, cluster_name: str, project_id: str, zone: str, instances: List[str], labels: Dict[str, str]) -> List[str]: """Start multiple instances. - + Returns: List of instance names that are started. """ + del cluster_name # Unused for instance_id in instances: cls.start_instance(instance_id, project_id, zone) cls.set_labels(project_id, zone, instance_id, labels) @@ -1021,8 +1022,9 @@ def create_instances( else: logger.debug( f'Instance template {instance_template_name!r} ' - 'exists and no instance group is using it. This is a leftover ' - 'of a previous autodown. Delete it and recreate it.') + 'exists and no instance group is using it. This is a ' + 'leftover of a previous autodown. Delete it and recreate ' + 'it.') # TODO(zhwu): this is a bit hacky as we cannot delete instance # template during an autodown, we can only defer the deletion # to the next launch of a cluster with the same name. We should @@ -1042,7 +1044,7 @@ def create_instances( instance_template_url = (f'projects/{project_id}/regions/{region}/' f'instanceTemplates/{instance_template_name}') if mig_exists: - # TODO: if we already have one, we should resize it. + # If we already have one, we should resize it. logger.debug( f'Managed instance group {managed_instance_group_name!r} ' 'already exists. Resizing it...') diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index 73b8fff8f35..2df12629ff1 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -38,7 +38,7 @@ def create_regional_instance_template_properties( return compute_v1.InstanceProperties( description=('SkyPilot instance template for ' - f'{cluster_name_on_cloud!r} to support DWS requests.'), + f'{cluster_name_on_cloud!r} to support DWS requests.'), machine_type=node_config['machineType'], # We have to ignore reservations for DWS. # TODO: Add a warning log for this behvaiour. From d839357be95b4943dc0536c23027613728b9c446 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Wed, 5 Jun 2024 09:59:38 +0000 Subject: [PATCH 11/29] format --- sky/provision/gcp/mig_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index 2df12629ff1..0d04dc33589 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -36,7 +36,7 @@ def get_managed_instance_group_name(cluster_name: str) -> str: def create_regional_instance_template_properties( cluster_name_on_cloud, node_config) -> 'compute_v1.InstanceProperties': - return compute_v1.InstanceProperties( + return compute_v1.InstanceProperties( # pylint: disable=used-before-assignment description=('SkyPilot instance template for ' f'{cluster_name_on_cloud!r} to support DWS requests.'), machine_type=node_config['machineType'], @@ -130,7 +130,8 @@ def create_regional_instance_template(project_id, region, template_name, # TODO: Error handling # operation = compute_client.wait(response.operation) # if operation.error: - # raise Exception(f'Failed to create regional instance template: {operation.error}') + # raise Exception(f'Failed to create regional instance template: ' + # f'{operation.error}') list_request = compute_v1.ListRegionInstanceTemplatesRequest( filter=f'name eq {template_name}', @@ -272,7 +273,8 @@ def resize_managed_instance_group(project_id: str, zone: str, group_name: str, # resize_request_name = f'resize-request-{str(int(time.time()))}' # cmd = ( - # f'gcloud beta compute instance-groups managed resize-requests create {group_name} ' + # f'gcloud beta compute instance-groups managed resize-requests ' + # f'create {group_name} ' # f'--resize-request={resize_request_name} ' # f'--resize-by={size} ' # f'--requested-run-duration={run_duration} ' From b4b8266a455e962d422f36469e12dc1e85653fc5 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Wed, 5 Jun 2024 22:58:38 +0000 Subject: [PATCH 12/29] move to REST APIs instead of python APIs --- docs/source/reference/config.rst | 17 + sky/clouds/gcp.py | 22 +- sky/clouds/utils/gcp_utils.py | 5 - sky/provision/gcp/config.py | 6 +- sky/provision/gcp/constants.py | 3 +- sky/provision/gcp/instance.py | 8 +- sky/provision/gcp/instance_utils.py | 179 +++++++---- sky/provision/gcp/mig_utils.py | 482 +++++++--------------------- sky/setup_files/setup.py | 5 +- sky/templates/gcp-ray.yml.j2 | 9 +- sky/utils/schemas.py | 16 +- 11 files changed, 290 insertions(+), 462 deletions(-) diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index a550ed819fc..f13b35a271a 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -219,6 +219,23 @@ Available fields and semantics: - projects/my-project/reservations/my-reservation2 + # Managed instance group / DWS (optional). + # + # SkyPilot supports launching instances in a managed instance group (MIG) + # which schedules the GPU instance creation through DWS, offering a better + # availability. This feature is only applied when a resource request + # contains GPU instances. + managed_instance_group: + # Seconds for the created instances to be kept alive. This is required + # for the DWS to work properly. After the specified duration, the + # instances will be terminated. + run_duration_seconds: 3600 + # Seconds to wait for MIG/DWS to create the requested resources. If the + # resources are not be able to create within the specified duration, + # SkyPilot will start failover to other clouds/regions/zones. + creation_timeout_seconds: 300 + + # Identity to use for all GCP instances (optional). # # LOCAL_CREDENTIALS: The user's local credential files will be uploaded to diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 993effc9654..1b47869284b 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -14,6 +14,7 @@ from sky import clouds from sky import exceptions from sky import sky_logging +from sky import skypilot_config from sky.adaptors import gcp from sky.clouds import service_catalog from sky.clouds.utils import gcp_utils @@ -191,6 +192,16 @@ def _unsupported_features_for_resources( unsupported[clouds.CloudImplementationFeatures.MULTI_NODE] = ( 'TPU node does not support multi-node. Please set ' 'num_nodes to 1.') + # TODO(zhwu): We probably need to store the MIG requirement in resources + # because `skypilot_config` may change for an existing cluster. + # Clusters created with MIG (only GPU clusters) cannot be stopped. + if (skypilot_config.get_nested(('gcp', 'managed_instance_group'), None) + is not None and resources.accelerators): + unsupported[clouds.CloudImplementationFeatures.MULTI_NODE] = ( + 'Managed Instance Group (MIG) does not support multi-node yet. ' + 'Please set num_nodes to 1.') + unsupported[clouds.CloudImplementationFeatures.STOP] = ( + 'Managed Instance Group (MIG) does not support stopping yet.') return unsupported @classmethod @@ -421,8 +432,6 @@ def make_deploy_resources_variables( 'custom_resources': None, 'use_spot': r.use_spot, 'gcp_project_id': self.get_project_id(dryrun), - 'gcp_use_managed_instance_group': - gcp_utils.is_use_managed_instance_group(), } accelerators = r.accelerators if accelerators is not None: @@ -497,6 +506,12 @@ def make_deploy_resources_variables( resources_vars['tpu_node_name'] = tpu_node_name + managed_instance_group_config = skypilot_config.get_nested( + ('gcp', 'managed_instance_group'), None) + use_mig = managed_instance_group_config is not None + resources_vars['gcp_use_managed_instance_group'] = use_mig + if use_mig: + resources_vars.update(managed_instance_group_config) return resources_vars def _get_feasible_launchable_resources( @@ -634,9 +649,6 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]: from google import auth # type: ignore import googleapiclient - if gcp_utils.is_use_managed_instance_group(): - from google.cloud import compute_v1 - # Check the installation of google-cloud-sdk. _run_output('gcloud --version') except (ImportError, subprocess.CalledProcessError) as e: diff --git a/sky/clouds/utils/gcp_utils.py b/sky/clouds/utils/gcp_utils.py index bb34aed79b0..68e6192d351 100644 --- a/sky/clouds/utils/gcp_utils.py +++ b/sky/clouds/utils/gcp_utils.py @@ -184,8 +184,3 @@ def get_minimal_permissions() -> List[str]: permissions += constants.RESERVATION_PERMISSIONS return permissions - - -def is_use_managed_instance_group() -> bool: - return skypilot_config.get_nested(('gcp', 'use_managed_instance_group'), - False) diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index 6d576e13bc9..b3ef214be70 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -151,7 +151,8 @@ def bootstrap_instances( config: common.ProvisionConfig) -> common.ProvisionConfig: # Check if we have any TPUs defined, and if so, # insert that information into the provider config - if instance_utils.get_node_type(config) == instance_utils.GCPNodeType.TPU: + if (instance_utils.get_node_type( + config.node_config) == instance_utils.GCPNodeType.TPU): config.provider_config[constants.HAS_TPU_PROVIDER_FIELD] = True crm, iam, compute, _ = construct_clients_from_provider_config( @@ -374,7 +375,8 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam) -> dict: 'scopes': ['https://www.googleapis.com/auth/cloud-platform'], } iam_role: Dict[str, Any] - if instance_utils.get_node_type(config) == instance_utils.GCPNodeType.TPU: + if (instance_utils.get_node_type( + config.node_config) == instance_utils.GCPNodeType.TPU): # SKY: The API for TPU VM is slightly different from normal compute # instances. # See https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#Node # pylint: disable=line-too-long diff --git a/sky/provision/gcp/constants.py b/sky/provision/gcp/constants.py index 31f7ef787ae..c830fde9e78 100644 --- a/sky/provision/gcp/constants.py +++ b/sky/provision/gcp/constants.py @@ -222,8 +222,7 @@ TAG_SKYPILOT_CLUSTER_NAME = 'skypilot-cluster-name' # MIG constants -USE_MANAGED_INSTANCE_GROUP_CONFIG = 'use_managed_instance_group' -TAG_MANAGED_INSTANCE_GROUP_NAME = 'managed-instance-group-name' +MANAGED_INSTANCE_GROUP_CONFIG = 'managed-instance-group' DEFAULT_MAANGED_INSTANCE_GROUP_CREATION_TIMEOUT = 1200 # 20 minutes MIG_NAME_PREFIX = 'sky-mig-' INSTANCE_TEMPLATE_NAME_PREFIX = 'sky-it-' diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 8c0e3804ef6..a7776f5bf2b 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -119,7 +119,7 @@ def _wait_for_operations( logger.debug( f'wait_for_compute_{op_type}_operation: ' f'Waiting for operation {operation["name"]} to finish...') - handler.wait_for_operation(operation, project_id, zone) + handler.wait_for_operation(operation, project_id, zone=zone) def _get_head_instance_id(instances: List) -> Optional[str]: @@ -143,7 +143,7 @@ def _run_instances(region: str, cluster_name_on_cloud: str, resumed_instance_ids: List[str] = [] created_instance_ids: List[str] = [] - node_type = instance_utils.get_node_type(config) + node_type = instance_utils.get_node_type(config.node_config) project_id = config.provider_config['project_id'] availability_zone = config.provider_config['availability_zone'] @@ -513,8 +513,8 @@ def terminate_instances( tpu_node = provider_config.get('tpu_node') if tpu_node is not None: instance_utils.delete_tpu_node(project_id, zone, tpu_node) - use_mig = provider_config.get(constants.USE_MANAGED_INSTANCE_GROUP_CONFIG, - False) + + use_mig = provider_config.get('use_managed_instance_group', False) if use_mig: # Deleting the MIG will also delete the instances. instance_utils.GCPManagedInstanceGroup.delete_mig( diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 1b8ed6c7c1f..3083ee9e4a9 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -176,8 +176,11 @@ def terminate( raise NotImplementedError @classmethod - def wait_for_operation(cls, operation: dict, project_id: str, - zone: Optional[str]) -> None: + def wait_for_operation(cls, + operation: dict, + project_id: str, + region: Optional[str] = None, + zone: Optional[str] = None) -> None: raise NotImplementedError @classmethod @@ -416,11 +419,17 @@ def filter( return instances @classmethod - def wait_for_operation(cls, operation: dict, project_id: str, - zone: Optional[str]) -> None: + def wait_for_operation(cls, + operation: dict, + project_id: str, + region: Optional[str] = None, + zone: Optional[str] = None) -> None: if zone is not None: kwargs = {'zone': zone} operation_caller = cls.load_resource().zoneOperations() + elif region is not None: + kwargs = {'region': region} + operation_caller = cls.load_resource().regionOperations() else: kwargs = {} operation_caller = cls.load_resource().globalOperations() @@ -622,7 +631,7 @@ def set_labels(cls, project_id: str, availability_zone: str, node_id: str, body=body, ).execute(num_retries=GCP_CREATE_MAX_RETRIES)) - cls.wait_for_operation(operation, project_id, availability_zone) + cls.wait_for_operation(operation, project_id, zone=availability_zone) @classmethod def create_instances( @@ -756,6 +765,19 @@ def _insert(cls, names: List[str], project_id: str, zone: str, logger.debug('"insert" operation requested ...') return operations + @classmethod + def _convert_selflinks_in_config(cls, config: dict) -> None: + """Convert selflinks to names in the config.""" + for disk in config.get('disks', []): + disk_type = disk.get('initializeParams', {}).get('diskType') + if disk_type is not None: + disk['initializeParams']['diskType'] = selflink_to_name( + disk_type) + config['machineType'] = selflink_to_name(config['machineType']) + for accelerator in config.get('guestAccelerators', []): + accelerator['acceleratorType'] = selflink_to_name( + accelerator['acceleratorType']) + @classmethod def _bulk_insert(cls, names: List[str], project_id: str, zone: str, config: dict) -> List[dict]: @@ -769,15 +791,7 @@ def _bulk_insert(cls, names: List[str], project_id: str, zone: str, k: v for d in config['scheduling'] for k, v in d.items() } - for disk in config.get('disks', []): - disk_type = disk.get('initializeParams', {}).get('diskType') - if disk_type is not None: - disk['initializeParams']['diskType'] = selflink_to_name( - disk_type) - config['machineType'] = selflink_to_name(config['machineType']) - for accelerator in config.get('guestAccelerators', []): - accelerator['acceleratorType'] = selflink_to_name( - accelerator['acceleratorType']) + cls._convert_selflinks_in_config(config) body = { 'count': len(names), @@ -872,7 +886,7 @@ def _handle_http_error(e): logger.debug('Waiting GCP instances to be ready ...') try: for operation in operations: - cls.wait_for_operation(operation, project_id, zone) + cls.wait_for_operation(operation, project_id, zone=zone) except common.ProvisionerError as e: return e.errors except gcp.http_error_exception() as e: @@ -893,7 +907,7 @@ def start_instance(cls, node_id: str, project_id: str, zone: str) -> None: instance=node_id, ).execute()) - cls.wait_for_operation(operation, project_id, zone) + cls.wait_for_operation(operation, project_id, zone=zone) @classmethod def get_instance_info(cls, project_id: str, availability_zone: str, @@ -952,7 +966,7 @@ def resize_disk(cls, project_id: str, availability_zone: str, logger.warning(f'googleapiclient.errors.HttpError: {e.reason}') return - cls.wait_for_operation(operation, project_id, availability_zone) + cls.wait_for_operation(operation, project_id, zone=availability_zone) class GCPManagedInstanceGroup(GCPComputeInstance): @@ -985,6 +999,8 @@ def create_instances( constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name, }), }) + cls._convert_selflinks_in_config(config) + # Convert label values to string and lowercase per MIG API requirement. region = zone.rpartition('-')[0] instance_template_name = mig_utils.get_instance_template_name( @@ -1029,41 +1045,52 @@ def create_instances( # template during an autodown, we can only defer the deletion # to the next launch of a cluster with the same name. We should # find a better way to handle this. - operation = mig_utils.delete_regional_instance_template( - project_id, region, instance_template_name) - mig_utils.wait_for_extended_operation( - operation, verbose_name='Deletion of Instance Template') + cls._delete_instance_template(project_id, zone, + instance_template_name) instance_template_exists = False if not instance_template_exists: - mig_utils.create_regional_instance_template(project_id, region, - instance_template_name, - config, cluster_name) - + operation = mig_utils.create_region_instance_template( + cluster_name, project_id, region, instance_template_name, + config) + cls.wait_for_operation(operation, project_id, region=region) # create managed instance group instance_template_url = (f'projects/{project_id}/regions/{region}/' f'instanceTemplates/{instance_template_name}') - if mig_exists: - # If we already have one, we should resize it. - logger.debug( - f'Managed instance group {managed_instance_group_name!r} ' - 'already exists. Resizing it...') - mig_utils.resize_managed_instance_group( - project_id, zone, managed_instance_group_name, total_count) - else: - assert count == total_count, ( - f'The count {count} should be the same as the total_count ' - f'{total_count} when creating a new managed instance group.') - mig_utils.create_managed_instance_group(project_id, - zone, - managed_instance_group_name, - instance_template_url, - size=total_count) - - # TODO: This will block the provisioning until the nodes are ready, - # which makes the failover not effective. + if not mig_exists: + # Create a new MIG with size 0 and resize it later for triggering + # DWS, according to the doc: https://cloud.google.com/compute/docs/instance-groups/create-mig-with-gpu-vms # pylint: disable=line-too-long + operation = mig_utils.create_managed_instance_group( + project_id, + zone, + managed_instance_group_name, + instance_template_url, + size=0) + cls.wait_for_operation(operation, project_id, zone=zone) + + managed_instance_group_config = config[ + constants.MANAGED_INSTANCE_GROUP_CONFIG] + if count > 0: + # Use resize to trigger DWS for creating VMs. + logger.debug(f'Resizing Managed instance group ' + f'{managed_instance_group_name!r} by {count}...') + operation = mig_utils.resize_managed_instance_group( + project_id, + zone, + managed_instance_group_name, + count, + run_duration_seconds=managed_instance_group_config[ + 'run_duration_seconds']) + cls.wait_for_operation(operation, project_id, zone=zone) + + # This will block the provisioning until the nodes are ready, which + # makes the failover not effective. We rely on the request timeout set + # by user to trigger failover. mig_utils.wait_for_managed_group_to_be_stable( - project_id, zone, managed_instance_group_name) + project_id, + zone, + managed_instance_group_name, + timeout=managed_instance_group_config['creation_timeout_seconds']) running_instance_names = cls._add_labels_and_find_head( cluster_name, project_id, zone, labels, potential_head_instances) @@ -1077,20 +1104,46 @@ def create_instances( ) return None, running_instance_names + @classmethod + def _delete_instance_template(cls, project_id: str, zone: str, + instance_template_name: str) -> None: + region = zone.rpartition('-')[0] + try: + operation = cls.load_resource().regionInstanceTemplates().delete( + project=project_id, + region=region, + instanceTemplate=instance_template_name).execute() + cls.wait_for_operation(operation, project_id, region=region) + except gcp.http_error_exception() as e: + if re.search(mig_utils.IT_RESOURCE_NOT_FOUND_PATTERN, + str(e)) is None: + raise + logger.warning( + f'Instance template {instance_template_name!r} does not exist. ' + 'Skip deletion.') + @classmethod def delete_mig(cls, project_id: str, zone: str, cluster_name: str) -> None: - response = mig_utils.delete_managed_instance_group( - project_id, zone, - mig_utils.get_managed_instance_group_name(cluster_name)) - mig_utils.wait_for_extended_operation( - response, verbose_name='Deletion of Instance Template') + mig_name = mig_utils.get_managed_instance_group_name(cluster_name) + try: + operation = cls.load_resource().instanceGroupManagers().delete( + project=project_id, zone=zone, + instanceGroupManager=mig_name).execute() + cls.wait_for_operation(operation, project_id, zone=zone) + except gcp.http_error_exception() as e: + if re.search(mig_utils.MIG_RESOURCE_NOT_FOUND_PATTERN, + str(e)) is None: + raise + logger.warning(f'MIG {mig_name!r} does not exist. Skip ' + 'deletion.') + # In the autostop case, the following deletion of instance template # will not be executed as the instance that runs the deletion will be # terminated with the managed instance group. It is ok to leave the # instance template there as when a user creates a new cluster with the # same name, the instance template will be updated in our # create_instances method. - mig_utils.delete_regional_instance_template( + cls._delete_instance_template( project_id, zone, mig_utils.get_instance_template_name(cluster_name)) @@ -1166,10 +1219,13 @@ def load_resource(cls): discoveryServiceUrl='https://tpu.googleapis.com/$discovery/rest') @classmethod - def wait_for_operation(cls, operation: dict, project_id: str, - zone: Optional[str]) -> None: + def wait_for_operation(cls, + operation: dict, + project_id: str, + region: Optional[str] = None, + zone: Optional[str] = None) -> None: """Poll for TPU operation until finished.""" - del project_id, zone # unused + del project_id, region, zone # unused @_retry_on_http_exception( f'Failed to wait for operation {operation["name"]}') @@ -1618,7 +1674,7 @@ class GCPNodeType(enum.Enum): TPU = 'tpu' -def get_node_type(config: common.ProvisionConfig) -> GCPNodeType: +def get_node_type(config: Dict[str, Any]) -> GCPNodeType: """Returns node type based on the keys in ``node``. This is a very simple check. If we have a ``machineType`` key, @@ -1628,21 +1684,20 @@ def get_node_type(config: common.ProvisionConfig) -> GCPNodeType: This works for both node configs and API returned nodes. """ - provider_config = config.provider_config - node_config = config.node_config - if ('machineType' not in node_config and - 'acceleratorType' not in node_config): + if ('machineType' not in config and 'acceleratorType' not in config): raise ValueError( 'Invalid node. For a Compute instance, "machineType" is ' 'required. ' 'For a TPU instance, "acceleratorType" and no "machineType" ' 'is required. ' - f'Got {list(node_config)}') + f'Got {list(config)}') - if 'machineType' not in node_config and 'acceleratorType' in node_config: + if 'machineType' not in config and 'acceleratorType' in config: return GCPNodeType.TPU - if provider_config.get(constants.USE_MANAGED_INSTANCE_GROUP_CONFIG, False): + if (config.get(constants.MANAGED_INSTANCE_GROUP_CONFIG, None) is not None + and config.get('guestAccelerators', None) is not None): + # DWS in MIG only works for machine with GPUs. return GCPNodeType.MIG return GCPNodeType.COMPUTE diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index 0d04dc33589..d0bfe1cb650 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -1,28 +1,21 @@ """Managed Instance Group Utils""" import re import subprocess -import sys -import typing -from typing import Any, Optional +from typing import Any, Dict from sky import sky_logging from sky.adaptors import gcp from sky.provision.gcp import constants -if typing.TYPE_CHECKING: - from google.api_core import extended_operation - from google.cloud import compute_v1 -else: - from sky.adaptors.gcp import compute_v1 - logger = sky_logging.init_logger(__name__) -_MIG_RESOURCE_NOT_FOUND_PATTERN = re.compile( +MIG_RESOURCE_NOT_FOUND_PATTERN = re.compile( r'The resource \'projects/.*/zones/.*/instanceGroupManagers/.*\' was not ' r'found') -_IT_RESOURCE_NOT_FOUND_PATTERN = re.compile( - r'The resource \'projects/.*/zones/.*/instanceTemplates/.*\' was not found') +IT_RESOURCE_NOT_FOUND_PATTERN = re.compile( + r'The resource \'projects/.*/regions/.*/instanceTemplates/.*\' was not ' + 'found') def get_instance_template_name(cluster_name: str) -> str: @@ -33,313 +26,116 @@ def get_managed_instance_group_name(cluster_name: str) -> str: return f'{constants.MIG_NAME_PREFIX}{cluster_name}' -def create_regional_instance_template_properties( - cluster_name_on_cloud, node_config) -> 'compute_v1.InstanceProperties': - - return compute_v1.InstanceProperties( # pylint: disable=used-before-assignment - description=('SkyPilot instance template for ' - f'{cluster_name_on_cloud!r} to support DWS requests.'), - machine_type=node_config['machineType'], - # We have to ignore reservations for DWS. - # TODO: Add a warning log for this behvaiour. - reservation_affinity=compute_v1.ReservationAffinity( - consume_reservation_type='NO_RESERVATION'), - # We have to ignore user defined scheduling for DWS. - # TODO: Add a warning log for this behvaiour. - scheduling=compute_v1.Scheduling(on_host_maintenance='TERMINATE'), - guest_accelerators=[ - compute_v1.AcceleratorConfig( - accelerator_count=accelerator['acceleratorCount'], - accelerator_type=accelerator['acceleratorType'].split('/')[-1], - ) for accelerator in node_config.get('guestAccelerators', []) - ], - disks=[ - compute_v1.AttachedDisk( - boot=disk_config['boot'], - auto_delete=disk_config['autoDelete'], - type_=disk_config['type'], - initialize_params=compute_v1.AttachedDiskInitializeParams( - source_image=disk_config['initializeParams']['sourceImage'], - disk_size_gb=disk_config['initializeParams']['diskSizeGb'], - disk_type=disk_config['initializeParams']['diskType'].split( - '/')[-1]), - ) for disk_config in node_config.get('disks', []) - ], - network_interfaces=[ - compute_v1.NetworkInterface( - subnetwork=network_interface['subnetwork'], - access_configs=[ - compute_v1.AccessConfig( - name=access_config['name'], - type=access_config['type'], - ) for access_config in network_interface.get( - 'accessConfigs', []) - ], - ) for network_interface in node_config.get('networkInterfaces', []) - ], - service_accounts=[ - compute_v1.ServiceAccount(email=service_account['email'], - scopes=service_account['scopes']) - for service_account in node_config.get('serviceAccounts', []) - ], - metadata=compute_v1.Metadata(items=[ - compute_v1.Items(key=item['key'], value=item['value']) - for item in (node_config.get('metadata', {}).get('items', []) + [{ - 'key': 'cluster-name', - 'value': cluster_name_on_cloud - }]) - ]), - # Create labels from node config - labels=node_config.get('labels', {})) - - def check_instance_template_exits(project_id, region, template_name) -> bool: - with compute_v1.RegionInstanceTemplatesClient() as compute_client: - request = compute_v1.ListRegionInstanceTemplatesRequest( - filter=f'name eq {template_name}', - project=project_id, - region=region, - ) - page_result = compute_client.list(request) - return len(page_result.items) > 0 and (next(page_result.pages) - is not None) - - -def create_regional_instance_template(project_id, region, template_name, - node_config, - cluster_name_on_cloud) -> None: - with compute_v1.RegionInstanceTemplatesClient() as compute_client: - # Create the regional instance template request - - request = compute_v1.InsertRegionInstanceTemplateRequest( - project=project_id, - region=region, - instance_template_resource=compute_v1.InstanceTemplate( - name=template_name, - properties=create_regional_instance_template_properties( - cluster_name_on_cloud, node_config), - ), - ) - - # Send the request to create the regional instance template - response = compute_client.insert(request=request) - # Wait for the operation to complete - # logger.debug(response) - wait_for_extended_operation(response, - 'create regional instance template', 600) - # TODO: Error handling - # operation = compute_client.wait(response.operation) - # if operation.error: - # raise Exception(f'Failed to create regional instance template: ' - # f'{operation.error}') - - list_request = compute_v1.ListRegionInstanceTemplatesRequest( - filter=f'name eq {template_name}', - project=project_id, - region=region, - ) - compute_client.list(list_request) - # logger.debug(list_response) - logger.debug(f'Regional instance template {template_name!r} ' - 'created successfully.') - - -def delete_regional_instance_template( - project_id, zone, - template_name) -> Optional['extended_operation.ExtendedOperation']: - region = zone.rsplit('-', 1)[0] - with compute_v1.RegionInstanceTemplatesClient() as compute_client: - # Create the regional instance template request - request = compute_v1.DeleteRegionInstanceTemplateRequest( - project=project_id, - region=region, - instance_template=template_name, - ) - try: - # Send the request to delete the regional instance template - response = compute_client.delete(request=request) - return response - except gcp.google.api_core.exceptions.NotFound as e: - if re.search(_IT_RESOURCE_NOT_FOUND_PATTERN, str(e)) is None: - raise - logger.warning(f'Instance template {template_name!r} does not ' - 'exist. Skip deletion.') - return None - - -def delete_managed_instance_group( - project_id, zone, - group_name) -> Optional['extended_operation.ExtendedOperation']: - with compute_v1.InstanceGroupManagersClient() as compute_client: - # Create the managed instance group request - request = compute_v1.DeleteInstanceGroupManagerRequest( - project=project_id, - zone=zone, - instance_group_manager=group_name, - ) - - try: - # Send the request to delete the managed instance group - response = compute_client.delete(request=request) - # Do not wait for the deletion of MIG, so we can send the deletion - # request for the instance template, immediately after this, which - # is important when we are autodown a cluster from the head node. - return response - except gcp.google.api_core.exceptions.NotFound as e: - if re.search(_MIG_RESOURCE_NOT_FOUND_PATTERN, str(e)) is None: - raise - logger.warning(f'MIG {group_name!r} does not exist. Skip ' - 'deletion.') - return None - - -def create_managed_instance_group(project_id, zone, group_name, - instance_template_url, size) -> None: - with compute_v1.InstanceGroupManagersClient() as compute_client: - # Create the managed instance group request - request = compute_v1.InsertInstanceGroupManagerRequest( - project=project_id, - zone=zone, - instance_group_manager_resource=compute_v1.InstanceGroupManager( - name=group_name, - instance_template=instance_template_url, - target_size=size, - instance_lifecycle_policy=compute_v1. - InstanceGroupManagerInstanceLifecyclePolicy( - default_action_on_failure='DO_NOTHING',), - update_policy=compute_v1.InstanceGroupManagerUpdatePolicy( - type='OPPORTUNISTIC',), - ), - ) - - # Send the request to create the managed instance group - response = compute_client.insert(request=request) - - # Wait for the operation to complete - logger.debug( - f'Request submitted, waiting for operation to complete. {response}') - wait_for_extended_operation(response, 'create managed instance group', - 600) - # TODO: Error handling - logger.debug( - f'Managed instance group {group_name!r} created successfully.') - - -def start_managed_instance_group(project_id, zone, group_name) -> None: - with compute_v1.InstanceGroupManagersClient() as compute_client: - # Create the managed instance group request - request = compute_v1.ApplyUpdatesToInstancesInstanceGroupManagerRequest( - instance_group_manager=group_name, - project=project_id, - zone=zone, - instance_group_managers_apply_updates_request_resource=compute_v1. - InstanceGroupManagersApplyUpdatesRequest( - all_instances=True, - minimal_action='NONE', - most_disruptive_allowed_action='RESTART', - ), - ) - - response = compute_client.apply_updates_to_instances(request) - wait_for_extended_operation(response, 'restart managed instance group', - 600) - logger.debug('Managed instance group restarted successfully.') - - -def check_managed_instance_group_exists(project_id, zone, group_name) -> bool: - with compute_v1.InstanceGroupManagersClient() as compute_client: - request = compute_v1.ListInstanceGroupManagersRequest( - project=project_id, - zone=zone, - filter=f'name eq {group_name}', - ) - page_result = compute_client.list(request) - return len(page_result.items) > 0 and (next(page_result.pages) - is not None) + compute = gcp.build('compute', 'v1') + try: + compute.regionInstanceTemplates().get( + project=project_id, region=region, + instanceTemplate=template_name).execute() + except gcp.http_error_exception() as e: + if IT_RESOURCE_NOT_FOUND_PATTERN.search(str(e)) is not None: + return False + raise + return True + + +def create_region_instance_template(cluster_name_on_cloud: str, project_id: str, + region: str, template_name: str, + node_config: Dict[str, Any]): + """Create a regional instance template.""" + compute = gcp.build('compute', 'v1') + config = node_config.copy() + config.pop(constants.MANAGED_INSTANCE_GROUP_CONFIG, None) + + # We have to ignore user defined scheduling for DWS. + # TODO: Add a warning log for this behvaiour. + config.pop('scheduling', None) + # We have to ignore reservations for DWS. + # TODO: Add a warning log for this behvaiour. + config.pop('reservation_affinity', None) + + # Create the regional instance template request + operation = compute.regionInstanceTemplates().insert( + project=project_id, + region=region, + body={ + 'name': template_name, + 'properties': dict( + description=( + 'SkyPilot instance template for ' + f'{cluster_name_on_cloud!r} to support DWS requests.'), + reservationAffinity=dict( + consumeReservationType='NO_RESERVATION'), + scheduling=dict(onHostMaintenance='TERMINATE'), + **config, + ) + }).execute() + return operation + + +def create_managed_instance_group(project_id: str, zone: str, group_name: str, + instance_template_url: str, size: int): + compute = gcp.build('compute', 'v1') + operation = compute.instanceGroupManagers().insert( + project=project_id, + zone=zone, + body={ + 'name': group_name, + 'instanceTemplate': instance_template_url, + 'target_size': size, + 'instanceLifecyclePolicy': { + 'defaultActionOnFailure': 'DO_NOTHING', + }, + 'updatePolicy': { + 'type': 'OPPORTUNISTIC', + }, + }).execute() + return operation def resize_managed_instance_group(project_id: str, zone: str, group_name: str, - size: int) -> None: - # try: - with compute_v1.InstanceGroupManagersClient() as compute_client: - response = compute_client.resize(project=project_id, - zone=zone, - instance_group_manager=group_name, - size=size) - wait_for_extended_operation( - response, - 'resize managed instance group', - timeout=constants.DEFAULT_MAANGED_INSTANCE_GROUP_CREATION_TIMEOUT) - # resize_request_name = f'resize-request-{str(int(time.time()))}' - - # cmd = ( - # f'gcloud beta compute instance-groups managed resize-requests ' - # f'create {group_name} ' - # f'--resize-request={resize_request_name} ' - # f'--resize-by={size} ' - # f'--requested-run-duration={run_duration} ' - # f'--zone={zone} ' - # f'--project={project_id} ') - # logger.info(f'Resizing MIG {group_name} with command:\n{cmd}') - # proc = subprocess.run( - # f'yes | {cmd}', - # stdout=subprocess.PIPE, - # stderr=subprocess.PIPE, - # shell=True, - # check=True, - # ) - # stdout = proc.stdout.decode('ascii') - # logger.info(stdout) - wait_for_managed_group_to_be_stable(project_id, zone, group_name) - - # except subprocess.CalledProcessError as e: - # stderr = e.stderr.decode('ascii') - # logger.info(stderr) - # provisioner_err = common.ProvisionerError('Failed to resize MIG') - # provisioner_err.errors = [{ - # 'code': 'UNKNOWN', - # 'domain': 'mig', - # 'message': stderr - # }] - # # _log_errors(provisioner_err.errors, e, zone) - # raise provisioner_err from e - - -def view_resize_requests(project_id, zone, group_name) -> None: + resize_by: int, run_duration_seconds: int): + compute = gcp.build('compute', 'beta') + operation = compute.instanceGroupManagerResizeRequests().insert( + project=project_id, + zone=zone, + instanceGroupManager=group_name, + body={ + 'name': group_name, + 'resizeBy': resize_by, + 'requestedRunDuration': { + 'seconds': run_duration_seconds, + } + }).execute() + return operation + + +def check_managed_instance_group_exists(project_id: str, zone: str, + group_name: str) -> bool: + compute = gcp.build('compute', 'v1') try: - cmd = ('gcloud beta compute instance-groups managed resize-requests ' - f'list {group_name} ' - f'--zone={zone} ' - f'--project={project_id}') - logger.info( - f'Listing resize requests for MIG {group_name} with command:\n{cmd}' - ) - proc = subprocess.run( - f'yes | {cmd}', - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=True, - check=True, - ) - stdout = proc.stdout.decode('ascii') - logger.info(stdout) - except subprocess.CalledProcessError as e: - stderr = e.stderr.decode('ascii') - logger.info(stderr) - - -def wait_for_managed_group_to_be_stable(project_id, zone, group_name) -> None: + compute.instanceGroupManagers().get( + project=project_id, zone=zone, + instanceGroupManager=group_name).execute() + except gcp.http_error_exception() as e: + if MIG_RESOURCE_NOT_FOUND_PATTERN.search(str(e)) is not None: + return False + raise + return True + + +def wait_for_managed_group_to_be_stable(project_id: str, + zone: str, + group_name: str, + timeout: int = 1200) -> None: """Wait until the managed instance group is stable.""" try: - cmd = ( - 'gcloud compute instance-groups managed wait-until ' - f'{group_name} ' - '--stable ' - f'--zone={zone} ' - f'--project={project_id} ' - # TODO(zhwu): Allow users to specify timeout. - # 20 minutes timeout - '--timeout=1200') + cmd = ('gcloud compute instance-groups managed wait-until ' + f'{group_name} ' + '--stable ' + f'--zone={zone} ' + f'--project={project_id} ' + f'--timeout={timeout}') logger.info( f'Waiting for MIG {group_name} to be stable with command:\n{cmd}') proc = subprocess.run( @@ -354,65 +150,3 @@ def wait_for_managed_group_to_be_stable(project_id, zone, group_name) -> None: except subprocess.CalledProcessError as e: stderr = e.stderr.decode('ascii') logger.info(stderr) - - -def wait_for_extended_operation( - operation: 'extended_operation.ExtendedOperation', - verbose_name: str = 'operation', - timeout: int = 300) -> Any: - """Waits for the extended (long-running) operation to complete. - - Taken from Google's samples - https://cloud.google.com/compute/docs/samples/compute-operation-extended-wait?hl=en - - If the operation is successful, it will return its result. - If the operation ends with an error, an exception will be raised. - If there were any warnings during the execution of the operation - they will be printed to sys.stderr. - - Args: - operation: a long-running operation you want to wait on. - verbose_name: (optional) a more verbose name of the operation, - used only during error and warning reporting. - timeout: how long (in seconds) to wait for operation to finish. - If None, wait indefinitely. - - Returns: - Whatever the operation.result() returns. - - Raises: - This method will raise the exception received from - `operation.exception()` or RuntimeError if there is no exception set, - but there is an `error_code` set for the `operation`. - - In case of an operation taking longer than `timeout` seconds to - complete, a `concurrent.futures.TimeoutError` will be raised. - """ - if operation is None: - return None - - result = operation.result(timeout=timeout) - - if operation.error_code: - logger.debug( - f'Error during {verbose_name}: [Code: {operation.error_code}]: ' - f'{operation.error_message}', - file=sys.stderr, - flush=True, - ) - logger.debug(f'Operation ID: {operation.name}', - file=sys.stderr, - flush=True) - # TODO gurc: wrap this in a custom skypilot exception - raise operation.exception() or RuntimeError(operation.error_message) - - if operation.warnings: - logger.debug(f'Warnings during {verbose_name}:\n', - file=sys.stderr, - flush=True) - for warning in operation.warnings: - logger.debug(f' - {warning.code}: {warning.message}', - file=sys.stderr, - flush=True) - - return result diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py index 6bf86111fca..c05ffcc4f35 100644 --- a/sky/setup_files/setup.py +++ b/sky/setup_files/setup.py @@ -222,10 +222,7 @@ def parse_readme(readme: str) -> str: # We need google-api-python-client>=2.69.0 to enable 'discardLocalSsd' # parameter for stopping instances. # Reference: https://github.com/googleapis/google-api-python-client/commit/f6e9d3869ed605b06f7cbf2e8cf2db25108506e6 - 'gcp': [ - 'google-api-python-client>=2.69.0', 'google-cloud-storage', - 'google-cloud-compute' - ], + 'gcp': ['google-api-python-client>=2.69.0', 'google-cloud-storage'], 'ibm': [ 'ibm-cloud-sdk-core', 'ibm-vpc', 'ibm-platform-services', 'ibm-cos-sdk' ] + local_ray, diff --git a/sky/templates/gcp-ray.yml.j2 b/sky/templates/gcp-ray.yml.j2 index e553d98c405..51f901bfc3f 100644 --- a/sky/templates/gcp-ray.yml.j2 +++ b/sky/templates/gcp-ray.yml.j2 @@ -62,7 +62,7 @@ provider: # The upper-level SkyPilot code has make sure there will not be resource # leakage. disable_launch_config_check: true - use_managed_instance_group: {{gcp_use_managed_instance_group}} + use_managed_instance_group: {{ gcp_use_managed_instance_group }} auth: ssh_user: gcpuser @@ -80,7 +80,12 @@ available_node_types: {%- for tag_key, tag_value in instance_tags.items() %} {{ tag_key }}: {{ tag_value }} {%- endfor %} - managed-instance-group: {{gcp_use_managed_instance_group}} + managed-instance-group: {{ gcp_use_managed_instance_group }} + {%- if gcp_use_managed_instance_group %} + managed-instance-group: + run_duration_seconds: {{ run_duration_seconds }} + creation_timeout_seconds: {{ creation_timeout_seconds }} + {%- endif %} {%- if specific_reservations %} reservationAffinity: consumeReservationType: SPECIFIC_RESERVATION diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index a5f2f84dd09..6e26beb9b56 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -588,8 +588,20 @@ def get_config_schema(): 'type': 'string', }, }, - 'use_managed_instance_group': { - 'type': 'boolean', + 'managed_instance_group': { + 'type': 'object', + 'required': [ + 'run_duration_seconds', 'creation_timeout_seconds' + ], + 'additionalProperties': False, + 'properties': { + 'run_duration_seconds': { + 'type': 'integer', + }, + 'creation_timeout_seconds': { + 'type': 'integer', + } + } }, **_INSTANCE_TAGS_SCHEMA, **_NETWORK_CONFIG_SCHEMA, From ea6aefb9303e51537f690f7d6253fed1795dd9e9 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Wed, 5 Jun 2024 23:01:23 +0000 Subject: [PATCH 13/29] add multi-node back --- sky/clouds/gcp.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 1b47869284b..22b6225c967 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -197,9 +197,6 @@ def _unsupported_features_for_resources( # Clusters created with MIG (only GPU clusters) cannot be stopped. if (skypilot_config.get_nested(('gcp', 'managed_instance_group'), None) is not None and resources.accelerators): - unsupported[clouds.CloudImplementationFeatures.MULTI_NODE] = ( - 'Managed Instance Group (MIG) does not support multi-node yet. ' - 'Please set num_nodes to 1.') unsupported[clouds.CloudImplementationFeatures.STOP] = ( 'Managed Instance Group (MIG) does not support stopping yet.') return unsupported From 504f0c6e0cdf3cc3e6bdaca99d56c56ee0a35e5c Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 6 Jun 2024 03:04:34 +0000 Subject: [PATCH 14/29] Fix multi-node --- docs/source/reference/config.rst | 3 +- sky/provision/gcp/instance_utils.py | 43 ++++++++++++++++------------- sky/provision/gcp/mig_utils.py | 33 ++++++++++++++++++---- 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index f13b35a271a..67e754863ed 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -233,7 +233,8 @@ Available fields and semantics: # Seconds to wait for MIG/DWS to create the requested resources. If the # resources are not be able to create within the specified duration, # SkyPilot will start failover to other clouds/regions/zones. - creation_timeout_seconds: 300 + # TODO: aligh with k8s provision_timeout + creation_timeout_seconds: 900 # Identity to use for all GCP instances (optional). diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 3083ee9e4a9..4f75fa65759 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -423,7 +423,8 @@ def wait_for_operation(cls, operation: dict, project_id: str, region: Optional[str] = None, - zone: Optional[str] = None) -> None: + zone: Optional[str] = None, + timeout: int = GCP_TIMEOUT) -> None: if zone is not None: kwargs = {'zone': zone} operation_caller = cls.load_resource().zoneOperations() @@ -448,13 +449,13 @@ def call_operation(fn, timeout: int): return request.execute(num_retries=GCP_MAX_RETRIES) wait_start = time.time() - while time.time() - wait_start < GCP_TIMEOUT: + while time.time() - wait_start < timeout: # Retry the wait() call until it succeeds or times out. # This is because the wait() call is only best effort, and does not # guarantee that the operation is done when it returns. # Reference: https://cloud.google.com/workflows/docs/reference/googleapis/compute/v1/zoneOperations/wait # pylint: disable=line-too-long - timeout = max(GCP_TIMEOUT - (time.time() - wait_start), 1) - result = call_operation(operation_caller.wait, timeout) + remaining_timeout = max(timeout - (time.time() - wait_start), 1) + result = call_operation(operation_caller.wait, remaining_timeout) if result['status'] == 'DONE': # NOTE: Error example: # { @@ -476,9 +477,9 @@ def call_operation(fn, timeout: int): else: logger.warning('wait_for_operation: Timeout waiting for creation ' 'operation, cancelling the operation ...') - timeout = max(GCP_TIMEOUT - (time.time() - wait_start), 1) + remaining_timeout = max(timeout - (time.time() - wait_start), 1) try: - result = call_operation(operation_caller.delete, timeout) + result = call_operation(operation_caller.delete, remaining_timeout) except gcp.http_error_exception() as e: logger.debug('wait_for_operation: failed to cancel operation ' f'due to error: {e}') @@ -1092,17 +1093,18 @@ def create_instances( managed_instance_group_name, timeout=managed_instance_group_config['creation_timeout_seconds']) - running_instance_names = cls._add_labels_and_find_head( + + pending_running_instance_names = cls._add_labels_and_find_head( cluster_name, project_id, zone, labels, potential_head_instances) - assert len(running_instance_names) == total_count, ( - running_instance_names, total_count) + assert len(pending_running_instance_names) == total_count, ( + pending_running_instance_names, total_count) cls.create_node_tag( project_id, zone, - running_instance_names[0], + pending_running_instance_names[0], is_head=True, ) - return None, running_instance_names + return None, pending_running_instance_names @classmethod def _delete_instance_template(cls, project_id: str, zone: str, @@ -1125,6 +1127,8 @@ def _delete_instance_template(cls, project_id: str, zone: str, @classmethod def delete_mig(cls, project_id: str, zone: str, cluster_name: str) -> None: mig_name = mig_utils.get_managed_instance_group_name(cluster_name) + # Get all resize request of the MIG and cancel them. + mig_utils.cancel_all_resize_requests(project_id, zone, mig_name) try: operation = cls.load_resource().instanceGroupManagers().delete( project=project_id, zone=zone, @@ -1174,28 +1178,29 @@ def _add_labels_and_find_head( cls, cluster_name: str, project_id: str, zone: str, labels: Dict[str, str], potential_head_instances: List[str]) -> List[str]: - running_instances = cls.filter( + pending_running_instances = cls.filter( project_id, zone, {constants.TAG_RAY_CLUSTER_NAME: cluster_name}, - status_filters=[cls.RUNNING_STATE]) - for running_instance_name in running_instances.keys(): + # Find all provisioning and running instances. + status_filters=cls.NEED_TO_STOP_STATES) + for running_instance_name in pending_running_instances.keys(): if running_instance_name in potential_head_instances: head_instance_name = running_instance_name break else: - head_instance_name = list(running_instances.keys())[0] + head_instance_name = list(pending_running_instances.keys())[0] # We need to update the node's label if mig already exists, as the # config is not updated during the resize operation. - for instance_name in running_instances.keys(): + for instance_name in pending_running_instances.keys(): cls.set_labels(project_id=project_id, availability_zone=zone, node_id=instance_name, labels=labels) - running_instance_names = list(running_instances.keys()) - running_instance_names.remove(head_instance_name) + pending_running_instance_names = list(pending_running_instances.keys()) + pending_running_instance_names.remove(head_instance_name) # Label for head node type will be set by caller - return [head_instance_name] + running_instance_names + return [head_instance_name] + pending_running_instance_names class GCPTPUVMInstance(GCPInstance): diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index d0bfe1cb650..c4bd5363760 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -27,7 +27,7 @@ def get_managed_instance_group_name(cluster_name: str) -> str: def check_instance_template_exits(project_id, region, template_name) -> bool: - compute = gcp.build('compute', 'v1') + compute = gcp.build('compute', 'v1', credentials=None, cache_discovery=False) try: compute.regionInstanceTemplates().get( project=project_id, region=region, @@ -43,7 +43,7 @@ def create_region_instance_template(cluster_name_on_cloud: str, project_id: str, region: str, template_name: str, node_config: Dict[str, Any]): """Create a regional instance template.""" - compute = gcp.build('compute', 'v1') + compute = gcp.build('compute', 'v1', credentials=None, cache_discovery=False) config = node_config.copy() config.pop(constants.MANAGED_INSTANCE_GROUP_CONFIG, None) @@ -75,7 +75,7 @@ def create_region_instance_template(cluster_name_on_cloud: str, project_id: str, def create_managed_instance_group(project_id: str, zone: str, group_name: str, instance_template_url: str, size: int): - compute = gcp.build('compute', 'v1') + compute = gcp.build('compute', 'v1', credentials=None, cache_discovery=False) operation = compute.instanceGroupManagers().insert( project=project_id, zone=zone, @@ -95,7 +95,7 @@ def create_managed_instance_group(project_id: str, zone: str, group_name: str, def resize_managed_instance_group(project_id: str, zone: str, group_name: str, resize_by: int, run_duration_seconds: int): - compute = gcp.build('compute', 'beta') + compute = gcp.build('compute', 'beta', credentials=None, cache_discovery=False) operation = compute.instanceGroupManagerResizeRequests().insert( project=project_id, zone=zone, @@ -109,10 +109,33 @@ def resize_managed_instance_group(project_id: str, zone: str, group_name: str, }).execute() return operation +def cancel_all_resize_request_for_mig(project_id: str, zone: str, group_name: str): + try: + compute = gcp.build('compute', 'beta', credentials=None, cache_discovery=False) + operation = compute.instanceGroupManagerResizeRequests().list( + project=project_id, zone=zone, + instanceGroupManager=group_name, + filter='state eq ACCEPTED').execute() + for request in operation.get('items', []): + try: + compute.instanceGroupManagerResizeRequests().cancel( + project=project_id, zone=zone, + instanceGroupManager=group_name, + requestId=request['id']).execute() + except gcp.http_error_exception() as e: + logger.warning('Failed to cancel resize request ' + f'{request["id"]!r}: {e}') + except gcp.http_error_exception() as e: + if re.search(MIG_RESOURCE_NOT_FOUND_PATTERN, + str(e)) is None: + raise + logger.warning(f'MIG {group_name!r} does not exist. Skip ' + 'resize request cancellation.') + logger.debug(f'Error: {e}') def check_managed_instance_group_exists(project_id: str, zone: str, group_name: str) -> bool: - compute = gcp.build('compute', 'v1') + compute = gcp.build('compute', 'v1', credentials=None, cache_discovery=False) try: compute.instanceGroupManagers().get( project=project_id, zone=zone, From b5484c6baf8ecdf7ad8f8e2dc7071489c605cfb7 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 6 Jun 2024 04:16:28 +0000 Subject: [PATCH 15/29] Avoid spot --- sky/clouds/gcp.py | 11 +++-- sky/provision/gcp/instance_utils.py | 9 +++-- sky/provision/gcp/mig_utils.py | 63 ++++++++++++++++++++--------- 3 files changed, 58 insertions(+), 25 deletions(-) diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 22b6225c967..1d8abae220b 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -184,7 +184,8 @@ def _unsupported_features_for_resources( if gcp_utils.is_tpu_vm_pod(resources): unsupported = { clouds.CloudImplementationFeatures.STOP: ( - 'TPU VM pods cannot be stopped. Please refer to: https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#stopping_your_resources' + 'TPU VM pods cannot be stopped. Please refer to: ' + 'https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#stopping_your_resources' ) } if gcp_utils.is_tpu(resources) and not gcp_utils.is_tpu_vm(resources): @@ -195,10 +196,14 @@ def _unsupported_features_for_resources( # TODO(zhwu): We probably need to store the MIG requirement in resources # because `skypilot_config` may change for an existing cluster. # Clusters created with MIG (only GPU clusters) cannot be stopped. - if (skypilot_config.get_nested(('gcp', 'managed_instance_group'), None) - is not None and resources.accelerators): + if (skypilot_config.get_nested( + ('gcp', 'managed_instance_group'), None) is not None and + resources.accelerators): unsupported[clouds.CloudImplementationFeatures.STOP] = ( 'Managed Instance Group (MIG) does not support stopping yet.') + unsupported[clouds.CloudImplementationFeatures.SPOT_INSTANCE] = ( + 'Managed Instance Group with DWS does not support ' + 'spot instances.') return unsupported @classmethod diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 4f75fa65759..a37fe7196d4 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -479,7 +479,8 @@ def call_operation(fn, timeout: int): 'operation, cancelling the operation ...') remaining_timeout = max(timeout - (time.time() - wait_start), 1) try: - result = call_operation(operation_caller.delete, remaining_timeout) + result = call_operation(operation_caller.delete, + remaining_timeout) except gcp.http_error_exception() as e: logger.debug('wait_for_operation: failed to cancel operation ' f'due to error: {e}') @@ -1093,7 +1094,6 @@ def create_instances( managed_instance_group_name, timeout=managed_instance_group_config['creation_timeout_seconds']) - pending_running_instance_names = cls._add_labels_and_find_head( cluster_name, project_id, zone, labels, potential_head_instances) assert len(pending_running_instance_names) == total_count, ( @@ -1128,7 +1128,7 @@ def _delete_instance_template(cls, project_id: str, zone: str, def delete_mig(cls, project_id: str, zone: str, cluster_name: str) -> None: mig_name = mig_utils.get_managed_instance_group_name(cluster_name) # Get all resize request of the MIG and cancel them. - mig_utils.cancel_all_resize_requests(project_id, zone, mig_name) + mig_utils.cancel_all_resize_request_for_mig(project_id, zone, mig_name) try: operation = cls.load_resource().instanceGroupManagers().delete( project=project_id, zone=zone, @@ -1180,7 +1180,8 @@ def _add_labels_and_find_head( potential_head_instances: List[str]) -> List[str]: pending_running_instances = cls.filter( project_id, - zone, {constants.TAG_RAY_CLUSTER_NAME: cluster_name}, + zone, + {constants.TAG_RAY_CLUSTER_NAME: cluster_name}, # Find all provisioning and running instances. status_filters=cls.NEED_TO_STOP_STATES) for running_instance_name in pending_running_instances.keys(): diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index c4bd5363760..b492a378bfd 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -27,7 +27,10 @@ def get_managed_instance_group_name(cluster_name: str) -> str: def check_instance_template_exits(project_id, region, template_name) -> bool: - compute = gcp.build('compute', 'v1', credentials=None, cache_discovery=False) + compute = gcp.build('compute', + 'v1', + credentials=None, + cache_discovery=False) try: compute.regionInstanceTemplates().get( project=project_id, region=region, @@ -43,16 +46,24 @@ def create_region_instance_template(cluster_name_on_cloud: str, project_id: str, region: str, template_name: str, node_config: Dict[str, Any]): """Create a regional instance template.""" - compute = gcp.build('compute', 'v1', credentials=None, cache_discovery=False) + compute = gcp.build('compute', + 'v1', + credentials=None, + cache_discovery=False) config = node_config.copy() config.pop(constants.MANAGED_INSTANCE_GROUP_CONFIG, None) # We have to ignore user defined scheduling for DWS. # TODO: Add a warning log for this behvaiour. - config.pop('scheduling', None) - # We have to ignore reservations for DWS. - # TODO: Add a warning log for this behvaiour. - config.pop('reservation_affinity', None) + scheduling = config.pop('scheduling', {}) + assert scheduling.get('provisioningModel') != 'SPOT', ( + 'DWS does not support spot VMs.') + + reservations_affinity = config.pop('reservation_affinity', None) + if reservations_affinity is not None: + logger.warning( + f'Ignoring reservations_affinity {reservations_affinity} ' + 'for DWS.') # Create the regional instance template request operation = compute.regionInstanceTemplates().insert( @@ -66,7 +77,6 @@ def create_region_instance_template(cluster_name_on_cloud: str, project_id: str, f'{cluster_name_on_cloud!r} to support DWS requests.'), reservationAffinity=dict( consumeReservationType='NO_RESERVATION'), - scheduling=dict(onHostMaintenance='TERMINATE'), **config, ) }).execute() @@ -75,7 +85,10 @@ def create_region_instance_template(cluster_name_on_cloud: str, project_id: str, def create_managed_instance_group(project_id: str, zone: str, group_name: str, instance_template_url: str, size: int): - compute = gcp.build('compute', 'v1', credentials=None, cache_discovery=False) + compute = gcp.build('compute', + 'v1', + credentials=None, + cache_discovery=False) operation = compute.instanceGroupManagers().insert( project=project_id, zone=zone, @@ -95,7 +108,10 @@ def create_managed_instance_group(project_id: str, zone: str, group_name: str, def resize_managed_instance_group(project_id: str, zone: str, group_name: str, resize_by: int, run_duration_seconds: int): - compute = gcp.build('compute', 'beta', credentials=None, cache_discovery=False) + compute = gcp.build('compute', + 'beta', + credentials=None, + cache_discovery=False) operation = compute.instanceGroupManagerResizeRequests().insert( project=project_id, zone=zone, @@ -109,33 +125,44 @@ def resize_managed_instance_group(project_id: str, zone: str, group_name: str, }).execute() return operation -def cancel_all_resize_request_for_mig(project_id: str, zone: str, group_name: str): + +def cancel_all_resize_request_for_mig(project_id: str, zone: str, + group_name: str): + logger.debug(f'Cancelling all resize requests for MIG {group_name!r}.') try: - compute = gcp.build('compute', 'beta', credentials=None, cache_discovery=False) + compute = gcp.build('compute', + 'beta', + credentials=None, + cache_discovery=False) operation = compute.instanceGroupManagerResizeRequests().list( - project=project_id, zone=zone, + project=project_id, + zone=zone, instanceGroupManager=group_name, filter='state eq ACCEPTED').execute() for request in operation.get('items', []): try: compute.instanceGroupManagerResizeRequests().cancel( - project=project_id, zone=zone, + project=project_id, + zone=zone, instanceGroupManager=group_name, requestId=request['id']).execute() except gcp.http_error_exception() as e: logger.warning('Failed to cancel resize request ' - f'{request["id"]!r}: {e}') + f'{request["id"]!r}: {e}') except gcp.http_error_exception() as e: - if re.search(MIG_RESOURCE_NOT_FOUND_PATTERN, - str(e)) is None: + if re.search(MIG_RESOURCE_NOT_FOUND_PATTERN, str(e)) is None: raise logger.warning(f'MIG {group_name!r} does not exist. Skip ' - 'resize request cancellation.') + 'resize request cancellation.') logger.debug(f'Error: {e}') + def check_managed_instance_group_exists(project_id: str, zone: str, group_name: str) -> bool: - compute = gcp.build('compute', 'v1', credentials=None, cache_discovery=False) + compute = gcp.build('compute', + 'v1', + credentials=None, + cache_discovery=False) try: compute.instanceGroupManagers().get( project=project_id, zone=zone, From 4ec8869a6dcce18288471204f9a97503a227b17d Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 6 Jun 2024 04:43:02 +0000 Subject: [PATCH 16/29] format --- sky/skylet/log_lib.pyi | 1 + 1 file changed, 1 insertion(+) diff --git a/sky/skylet/log_lib.pyi b/sky/skylet/log_lib.pyi index 6815905a461..a590e735b4e 100644 --- a/sky/skylet/log_lib.pyi +++ b/sky/skylet/log_lib.pyi @@ -13,6 +13,7 @@ from sky.skylet import constants as constants from sky.skylet import job_lib as job_lib from sky.utils import log_utils as log_utils + class _ProcessingArgs: log_path: str stream_logs: bool From e300898c2f51cd49912efa9b6bd8297e35aa7736 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 6 Jun 2024 04:44:01 +0000 Subject: [PATCH 17/29] format --- sky/skylet/log_lib.pyi | 1 - 1 file changed, 1 deletion(-) diff --git a/sky/skylet/log_lib.pyi b/sky/skylet/log_lib.pyi index a590e735b4e..6815905a461 100644 --- a/sky/skylet/log_lib.pyi +++ b/sky/skylet/log_lib.pyi @@ -13,7 +13,6 @@ from sky.skylet import constants as constants from sky.skylet import job_lib as job_lib from sky.utils import log_utils as log_utils - class _ProcessingArgs: log_path: str stream_logs: bool From 30792a2c2de54b8792ca45d97d7e5195c153a898 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 6 Jun 2024 04:48:34 +0000 Subject: [PATCH 18/29] fix scheduling --- sky/provision/gcp/mig_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index b492a378bfd..e3852f2349b 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -55,7 +55,7 @@ def create_region_instance_template(cluster_name_on_cloud: str, project_id: str, # We have to ignore user defined scheduling for DWS. # TODO: Add a warning log for this behvaiour. - scheduling = config.pop('scheduling', {}) + scheduling = config.get('scheduling', {}) assert scheduling.get('provisioningModel') != 'SPOT', ( 'DWS does not support spot VMs.') From 58768b2a2405c269bd73c2542ed04b2f958d27df Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 6 Jun 2024 05:19:27 +0000 Subject: [PATCH 19/29] fix cancel --- sky/provision/gcp/mig_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index e3852f2349b..235cd706e78 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -145,6 +145,7 @@ def cancel_all_resize_request_for_mig(project_id: str, zone: str, project=project_id, zone=zone, instanceGroupManager=group_name, + resizeRequest=request['name'], requestId=request['id']).execute() except gcp.http_error_exception() as e: logger.warning('Failed to cancel resize request ' From 78b3d2fd71deccaa2fdde362ac882d3d651b32dc Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 6 Jun 2024 06:22:30 +0000 Subject: [PATCH 20/29] Add smoke test --- tests/test_smoke.py | 34 ++++++++++++++++++++++++++++ tests/test_yamls/use_mig_config.yaml | 4 ++++ 2 files changed, 38 insertions(+) create mode 100644 tests/test_yamls/use_mig_config.yaml diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 944818f05dc..7d703ac2432 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -109,6 +109,8 @@ class Test(NamedTuple): teardown: Optional[str] = None # Timeout for each command in seconds. timeout: int = DEFAULT_CMD_TIMEOUT + # Environment variables to set for each command. + env: Dict[str, str] = None def echo(self, message: str): # pytest's xdist plugin captures stdout; print to stderr so that the @@ -155,6 +157,9 @@ def run_one_test(test: Test) -> Tuple[int, str, str]: suffix='.log', delete=False) test.echo(f'Test started. Log: less {log_file.name}') + env_dict = os.environ.copy() + if test.env: + env_dict.update(test.env) for command in test.commands: log_file.write(f'+ {command}\n') log_file.flush() @@ -164,6 +169,7 @@ def run_one_test(test: Test) -> Tuple[int, str, str]: stderr=subprocess.STDOUT, shell=True, executable='/bin/bash', + env=env_dict, ) try: proc.wait(timeout=test.timeout) @@ -701,6 +707,34 @@ def test_clone_disk_gcp(): run_one_test(test) +@pytest.mark.gcp +def test_gcp_mig(): + name = _get_cluster_name() + test = Test( + 'gcp_mig', + [ + f'sky launch -y -c {name} --gpus t4 --num-nodes 2 --image-id skypilot:gpu-debian-10 --cloud gcp tests/test_yamls/minimal.yaml', + f'sky logs {name} 1 --status', # Ensure the job succeeded. + f'sky launch -y -c {name} tests/test_yamls/minimal.yaml', + f'sky logs {name} 2 --status', + f'sky logs {name} --status | grep "Job 2: SUCCEEDED"', # Equivalent. + # Check MIG exists. + f'gcloud compute instance-groups managed list --format="value(name)" | grep "^sky-mig-{name}"', + f'sky autostop -i 0 --down -y {name}', + 'sleep 120', + f'sky status -r {name}; sky status {name} | grep "{name} not found"', + f'gcloud compute instance-templates list --regions | grep "sky-it-{name}"', + # Launch again. + f'sky launch -y -c {name} --gpus L4 --num-nodes 2 nvidia-smi', + f'sky logs {name} 1 | grep "L4"', + f'sky down -y {name}', + f'gcloud compute instance-templates list --regions | grep "sky-it-{name}" && exit 1 || true', + ], + f'sky down -y {name}', + env={'SKYPILOT_CONFIG': 'tests/test_yamls/use_mig_config.yaml'}) + run_one_test(test) + + @pytest.mark.aws def test_image_no_conda(): name = _get_cluster_name() diff --git a/tests/test_yamls/use_mig_config.yaml b/tests/test_yamls/use_mig_config.yaml new file mode 100644 index 00000000000..b2a00c85977 --- /dev/null +++ b/tests/test_yamls/use_mig_config.yaml @@ -0,0 +1,4 @@ +gcp: + managed_instance_group: + run_duration_seconds: 36000 + creation_timeout_seconds: 900 From bade73002a8bff37874d0ea497cea0787327716f Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 6 Jun 2024 06:28:40 +0000 Subject: [PATCH 21/29] revert some changes --- sky/adaptors/gcp.py | 2 -- sky/clouds/gcp.py | 2 +- sky/provision/gcp/config.py | 4 ++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sky/adaptors/gcp.py b/sky/adaptors/gcp.py index 54f8fc84444..6465709d42c 100644 --- a/sky/adaptors/gcp.py +++ b/sky/adaptors/gcp.py @@ -10,8 +10,6 @@ googleapiclient = common.LazyImport('googleapiclient', import_error_message=_IMPORT_ERROR_MESSAGE) google = common.LazyImport('google', import_error_message=_IMPORT_ERROR_MESSAGE) -compute_v1 = common.LazyImport('google.cloud.compute_v1', - import_error_message=_IMPORT_ERROR_MESSAGE) _LAZY_MODULES = (google, googleapiclient) diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 1d8abae220b..4f5fefd9b10 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -149,7 +149,7 @@ class GCP(clouds.Cloud): _DEPENDENCY_HINT = ( 'GCP tools are not installed. Run the following commands:\n' # Install the Google Cloud SDK: - f'{_INDENT_PREFIX} $ pip install google-api-python-client google-cloud-compute\n' + f'{_INDENT_PREFIX} $ pip install google-api-python-client\n' f'{_INDENT_PREFIX} $ conda install -c conda-forge ' 'google-cloud-sdk -y') diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index b3ef214be70..b0ed2be9cec 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -151,8 +151,8 @@ def bootstrap_instances( config: common.ProvisionConfig) -> common.ProvisionConfig: # Check if we have any TPUs defined, and if so, # insert that information into the provider config - if (instance_utils.get_node_type( - config.node_config) == instance_utils.GCPNodeType.TPU): + if instance_utils.get_node_type( + config.node_config) == instance_utils.GCPNodeType.TPU: config.provider_config[constants.HAS_TPU_PROVIDER_FIELD] = True crm, iam, compute, _ = construct_clients_from_provider_config( From 00afacf3514f6dbc0aaf0c787a4fbd0756ad6fe7 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 6 Jun 2024 06:49:29 +0000 Subject: [PATCH 22/29] fix smoke --- tests/test_smoke.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 7d703ac2432..66c98b746e9 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -723,12 +723,12 @@ def test_gcp_mig(): f'sky autostop -i 0 --down -y {name}', 'sleep 120', f'sky status -r {name}; sky status {name} | grep "{name} not found"', - f'gcloud compute instance-templates list --regions | grep "sky-it-{name}"', + f'gcloud compute instance-templates list | grep "sky-it-{name}"', # Launch again. f'sky launch -y -c {name} --gpus L4 --num-nodes 2 nvidia-smi', f'sky logs {name} 1 | grep "L4"', f'sky down -y {name}', - f'gcloud compute instance-templates list --regions | grep "sky-it-{name}" && exit 1 || true', + f'gcloud compute instance-templates list | grep "sky-it-{name}" && exit 1 || true', ], f'sky down -y {name}', env={'SKYPILOT_CONFIG': 'tests/test_yamls/use_mig_config.yaml'}) From 439fa2ac6f85492d5329ce6952b91dff0f4f917a Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 6 Jun 2024 09:19:37 +0000 Subject: [PATCH 23/29] Fix --- sky/provision/gcp/instance_utils.py | 1 + sky/provision/gcp/mig_utils.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index a37fe7196d4..4f6995fe7a9 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -462,6 +462,7 @@ def call_operation(fn, timeout: int): # 'code': 'VM_MIN_COUNT_NOT_REACHED', # 'message': 'Requested minimum count of 4 VMs could not be created.' # } + logger.debug(str(result)) errors = result.get('error', {}).get('errors') if errors is not None: logger.debug( diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index 235cd706e78..8ade84ed90b 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -26,7 +26,8 @@ def get_managed_instance_group_name(cluster_name: str) -> str: return f'{constants.MIG_NAME_PREFIX}{cluster_name}' -def check_instance_template_exits(project_id, region, template_name) -> bool: +def check_instance_template_exits(project_id: str, region: str, + template_name: str) -> bool: compute = gcp.build('compute', 'v1', credentials=None, @@ -36,7 +37,8 @@ def check_instance_template_exits(project_id, region, template_name) -> bool: project=project_id, region=region, instanceTemplate=template_name).execute() except gcp.http_error_exception() as e: - if IT_RESOURCE_NOT_FOUND_PATTERN.search(str(e)) is not None: + if IT_RESOURCE_NOT_FOUND_PATTERN.search(str(e)) is None: + # Instance template does not exist. return False raise return True @@ -44,7 +46,7 @@ def check_instance_template_exits(project_id, region, template_name) -> bool: def create_region_instance_template(cluster_name_on_cloud: str, project_id: str, region: str, template_name: str, - node_config: Dict[str, Any]): + node_config: Dict[str, Any]) -> dict: """Create a regional instance template.""" compute = gcp.build('compute', 'v1', @@ -84,7 +86,8 @@ def create_region_instance_template(cluster_name_on_cloud: str, project_id: str, def create_managed_instance_group(project_id: str, zone: str, group_name: str, - instance_template_url: str, size: int): + instance_template_url: str, + size: int) -> dict: compute = gcp.build('compute', 'v1', credentials=None, @@ -107,7 +110,8 @@ def create_managed_instance_group(project_id: str, zone: str, group_name: str, def resize_managed_instance_group(project_id: str, zone: str, group_name: str, - resize_by: int, run_duration_seconds: int): + resize_by: int, + run_duration_seconds: int) -> dict: compute = gcp.build('compute', 'beta', credentials=None, @@ -127,7 +131,7 @@ def resize_managed_instance_group(project_id: str, zone: str, group_name: str, def cancel_all_resize_request_for_mig(project_id: str, zone: str, - group_name: str): + group_name: str) -> None: logger.debug(f'Cancelling all resize requests for MIG {group_name!r}.') try: compute = gcp.build('compute', From 2ff8d2760bc288505667f7cba2952c1da8709b07 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 6 Jun 2024 09:24:04 +0000 Subject: [PATCH 24/29] fix --- sky/provision/gcp/mig_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index 8ade84ed90b..1cae00c4adf 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -37,7 +37,7 @@ def check_instance_template_exits(project_id: str, region: str, project=project_id, region=region, instanceTemplate=template_name).execute() except gcp.http_error_exception() as e: - if IT_RESOURCE_NOT_FOUND_PATTERN.search(str(e)) is None: + if IT_RESOURCE_NOT_FOUND_PATTERN.search(str(e)) is not None: # Instance template does not exist. return False raise From 4c21abced6321cbacc257ebe9342f6cddd456ea4 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 6 Jun 2024 09:44:33 +0000 Subject: [PATCH 25/29] Fix smoke --- sky/provision/gcp/instance_utils.py | 1 - tests/test_smoke.py | 8 +++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 4f6995fe7a9..a37fe7196d4 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -462,7 +462,6 @@ def call_operation(fn, timeout: int): # 'code': 'VM_MIN_COUNT_NOT_REACHED', # 'message': 'Requested minimum count of 4 VMs could not be created.' # } - logger.debug(str(result)) errors = result.get('error', {}).get('errors') if errors is not None: logger.debug( diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 66c98b746e9..3d5de8b97df 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -710,10 +710,11 @@ def test_clone_disk_gcp(): @pytest.mark.gcp def test_gcp_mig(): name = _get_cluster_name() + region = 'us-central1' test = Test( 'gcp_mig', [ - f'sky launch -y -c {name} --gpus t4 --num-nodes 2 --image-id skypilot:gpu-debian-10 --cloud gcp tests/test_yamls/minimal.yaml', + f'sky launch -y -c {name} --gpus t4 --num-nodes 2 --image-id skypilot:gpu-debian-10 --cloud gcp --region {region} tests/test_yamls/minimal.yaml', f'sky logs {name} 1 --status', # Ensure the job succeeded. f'sky launch -y -c {name} tests/test_yamls/minimal.yaml', f'sky logs {name} 2 --status', @@ -724,8 +725,9 @@ def test_gcp_mig(): 'sleep 120', f'sky status -r {name}; sky status {name} | grep "{name} not found"', f'gcloud compute instance-templates list | grep "sky-it-{name}"', - # Launch again. - f'sky launch -y -c {name} --gpus L4 --num-nodes 2 nvidia-smi', + # Launch again with the same region. The original instance template + # should be removed. + f'sky launch -y -c {name} --gpus L4 --num-nodes 2 --region {region} nvidia-smi', f'sky logs {name} 1 | grep "L4"', f'sky down -y {name}', f'gcloud compute instance-templates list | grep "sky-it-{name}" && exit 1 || true', From 1d48fa6d070560e8ea93f179fc6af9f75a9c1281 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 7 Jun 2024 12:14:28 -0700 Subject: [PATCH 26/29] [GCP] Changing the config name for DWS support and fix for resize request cancellation (#5) * Fix config fields * fix cancel * Add loggings * remove useless codes --- docs/source/reference/config.rst | 24 ++++++++++++-------- sky/provision/gcp/constants.py | 2 +- sky/provision/gcp/instance_utils.py | 33 +++++----------------------- sky/provision/gcp/mig_utils.py | 20 +++++++++-------- sky/templates/gcp-ray.yml.j2 | 4 ++-- sky/utils/schemas.py | 8 +++---- tests/test_yamls/use_mig_config.yaml | 4 ++-- 7 files changed, 40 insertions(+), 55 deletions(-) diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 67e754863ed..3e1214d958d 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -226,15 +226,21 @@ Available fields and semantics: # availability. This feature is only applied when a resource request # contains GPU instances. managed_instance_group: - # Seconds for the created instances to be kept alive. This is required - # for the DWS to work properly. After the specified duration, the - # instances will be terminated. - run_duration_seconds: 3600 - # Seconds to wait for MIG/DWS to create the requested resources. If the - # resources are not be able to create within the specified duration, - # SkyPilot will start failover to other clouds/regions/zones. - # TODO: aligh with k8s provision_timeout - creation_timeout_seconds: 900 + # Duration for a created instance to be kept alive (in seconds, required). + # + # This is required for the DWS to work properly. After the + # specified duration, the instance will be terminated. + run_duration: 3600 + # Timeout for provisioning an instance by DWS (in seconds, optional). + # + # This timeout determines how long SkyPilot will wait for a managed + # instance group to create the requested resources before giving up, + # deleting the MIG and failing over to other locations. Larger timeouts + # may increase the chance for getting a resource, but will blcok failover + # to go to other zones/regions/clouds. + # + # Default: 900 + provision_timeout: 900 # Identity to use for all GCP instances (optional). diff --git a/sky/provision/gcp/constants.py b/sky/provision/gcp/constants.py index c830fde9e78..8f9341bd342 100644 --- a/sky/provision/gcp/constants.py +++ b/sky/provision/gcp/constants.py @@ -223,6 +223,6 @@ # MIG constants MANAGED_INSTANCE_GROUP_CONFIG = 'managed-instance-group' -DEFAULT_MAANGED_INSTANCE_GROUP_CREATION_TIMEOUT = 1200 # 20 minutes +DEFAULT_MANAGED_INSTANCE_GROUP_PROVISION_TIMEOUT = 900 # 15 minutes MIG_NAME_PREFIX = 'sky-mig-' INSTANCE_TEMPLATE_NAME_PREFIX = 'sky-it-' diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index a37fe7196d4..3e23cceb4ac 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -1074,15 +1074,12 @@ def create_instances( constants.MANAGED_INSTANCE_GROUP_CONFIG] if count > 0: # Use resize to trigger DWS for creating VMs. - logger.debug(f'Resizing Managed instance group ' - f'{managed_instance_group_name!r} by {count}...') operation = mig_utils.resize_managed_instance_group( project_id, zone, managed_instance_group_name, count, - run_duration_seconds=managed_instance_group_config[ - 'run_duration_seconds']) + run_duration=managed_instance_group_config['run_duration']) cls.wait_for_operation(operation, project_id, zone=zone) # This will block the provisioning until the nodes are ready, which @@ -1092,7 +1089,9 @@ def create_instances( project_id, zone, managed_instance_group_name, - timeout=managed_instance_group_config['creation_timeout_seconds']) + timeout=managed_instance_group_config.get( + 'provision_timeout', + constants.DEFAULT_MANAGED_INSTANCE_GROUP_PROVISION_TIMEOUT)) pending_running_instance_names = cls._add_labels_and_find_head( cluster_name, project_id, zone, labels, potential_head_instances) @@ -1109,6 +1108,7 @@ def create_instances( @classmethod def _delete_instance_template(cls, project_id: str, zone: str, instance_template_name: str) -> None: + logger.debug(f'Deleting instance template {instance_template_name}...') region = zone.rpartition('-')[0] try: operation = cls.load_resource().regionInstanceTemplates().delete( @@ -1129,6 +1129,7 @@ def delete_mig(cls, project_id: str, zone: str, cluster_name: str) -> None: mig_name = mig_utils.get_managed_instance_group_name(cluster_name) # Get all resize request of the MIG and cancel them. mig_utils.cancel_all_resize_request_for_mig(project_id, zone, mig_name) + logger.debug(f'Deleting MIG {mig_name!r} ...') try: operation = cls.load_resource().instanceGroupManagers().delete( project=project_id, zone=zone, @@ -1151,28 +1152,6 @@ def delete_mig(cls, project_id: str, zone: str, cluster_name: str) -> None: project_id, zone, mig_utils.get_instance_template_name(cluster_name)) - # TODO(zhwu): We want to restart the instances with MIG instead of the - # normal instance start API to take advantage of DWS. - # @classmethod - # def start_instances(cls, cluster_name: str, project_id: str, zone: str, - # instances: List[str], labels: Dict[str, - # str]) -> List[str]: - # del instances # unused - # potential_head_instances = cls.filter( - # project_id, - # zone, - # label_filters={ - # constants.TAG_RAY_NODE_KIND: 'head', - # constants.TAG_RAY_CLUSTER_NAME: cluster_name, - # }, - # status_filters=cls.NEED_TO_TERMINATE_STATES) - # mig_name = mig_utils.get_managed_instance_group_name(cluster_name) - # mig_utils.start_managed_instance_group(project_id, zone, mig_name) - # mig_utils.wait_for_managed_group_to_be_stable(project_id, zone, - # mig_name) - # return cls._add_labels_and_find_head(cluster_name, project_id, zone, - # labels, potential_head_instances) - @classmethod def _add_labels_and_find_head( cls, cluster_name: str, project_id: str, zone: str, diff --git a/sky/provision/gcp/mig_utils.py b/sky/provision/gcp/mig_utils.py index 1cae00c4adf..9e33f5171e2 100644 --- a/sky/provision/gcp/mig_utils.py +++ b/sky/provision/gcp/mig_utils.py @@ -48,6 +48,7 @@ def create_region_instance_template(cluster_name_on_cloud: str, project_id: str, region: str, template_name: str, node_config: Dict[str, Any]) -> dict: """Create a regional instance template.""" + logger.debug(f'Creating regional instance template {template_name!r}.') compute = gcp.build('compute', 'v1', credentials=None, @@ -88,6 +89,7 @@ def create_region_instance_template(cluster_name_on_cloud: str, project_id: str, def create_managed_instance_group(project_id: str, zone: str, group_name: str, instance_template_url: str, size: int) -> dict: + logger.debug(f'Creating managed instance group {group_name!r}.') compute = gcp.build('compute', 'v1', credentials=None, @@ -110,8 +112,9 @@ def create_managed_instance_group(project_id: str, zone: str, group_name: str, def resize_managed_instance_group(project_id: str, zone: str, group_name: str, - resize_by: int, - run_duration_seconds: int) -> dict: + resize_by: int, run_duration: int) -> dict: + logger.debug(f'Resizing managed instance group {group_name!r} by ' + f'{resize_by} with run duration {run_duration}.') compute = gcp.build('compute', 'beta', credentials=None, @@ -124,7 +127,7 @@ def resize_managed_instance_group(project_id: str, zone: str, group_name: str, 'name': group_name, 'resizeBy': resize_by, 'requestedRunDuration': { - 'seconds': run_duration_seconds, + 'seconds': run_duration, } }).execute() return operation @@ -149,8 +152,7 @@ def cancel_all_resize_request_for_mig(project_id: str, zone: str, project=project_id, zone=zone, instanceGroupManager=group_name, - resizeRequest=request['name'], - requestId=request['id']).execute() + resizeRequest=request['name']).execute() except gcp.http_error_exception() as e: logger.warning('Failed to cancel resize request ' f'{request["id"]!r}: {e}') @@ -179,11 +181,11 @@ def check_managed_instance_group_exists(project_id: str, zone: str, return True -def wait_for_managed_group_to_be_stable(project_id: str, - zone: str, - group_name: str, - timeout: int = 1200) -> None: +def wait_for_managed_group_to_be_stable(project_id: str, zone: str, + group_name: str, timeout: int) -> None: """Wait until the managed instance group is stable.""" + logger.debug(f'Waiting for MIG {group_name} to be stable with timeout ' + f'{timeout}.') try: cmd = ('gcloud compute instance-groups managed wait-until ' f'{group_name} ' diff --git a/sky/templates/gcp-ray.yml.j2 b/sky/templates/gcp-ray.yml.j2 index 51f901bfc3f..1a75ba3c380 100644 --- a/sky/templates/gcp-ray.yml.j2 +++ b/sky/templates/gcp-ray.yml.j2 @@ -83,8 +83,8 @@ available_node_types: managed-instance-group: {{ gcp_use_managed_instance_group }} {%- if gcp_use_managed_instance_group %} managed-instance-group: - run_duration_seconds: {{ run_duration_seconds }} - creation_timeout_seconds: {{ creation_timeout_seconds }} + run_duration: {{ run_duration }} + provision_timeout: {{ provision_timeout }} {%- endif %} {%- if specific_reservations %} reservationAffinity: diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 6e26beb9b56..fae793af4c9 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -590,15 +590,13 @@ def get_config_schema(): }, 'managed_instance_group': { 'type': 'object', - 'required': [ - 'run_duration_seconds', 'creation_timeout_seconds' - ], + 'required': ['run_duration'], 'additionalProperties': False, 'properties': { - 'run_duration_seconds': { + 'run_duration': { 'type': 'integer', }, - 'creation_timeout_seconds': { + 'provision_timeout': { 'type': 'integer', } } diff --git a/tests/test_yamls/use_mig_config.yaml b/tests/test_yamls/use_mig_config.yaml index b2a00c85977..ef715191a1f 100644 --- a/tests/test_yamls/use_mig_config.yaml +++ b/tests/test_yamls/use_mig_config.yaml @@ -1,4 +1,4 @@ gcp: managed_instance_group: - run_duration_seconds: 36000 - creation_timeout_seconds: 900 + run_duration: 36000 + provision_timeout: 900 From 3b8b040a68407ab1a857199d3a087e16e6426a31 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Tue, 11 Jun 2024 08:09:47 +0000 Subject: [PATCH 27/29] Fix labels for GCP TPU --- sky/clouds/gcp.py | 3 +++ sky/templates/gcp-ray.yml.j2 | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 7e7dacc539f..f3cbd330cea 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -509,6 +509,9 @@ def make_deploy_resources_variables( ('gcp', 'managed_instance_group'), None) use_mig = managed_instance_group_config is not None resources_vars['gcp_use_managed_instance_group'] = use_mig + # Convert boolean to 0 or 1 in string, as GCP does not support boolean + # value in labels for TPU VM APIs. + resources_vars['gcp_use_managed_instance_group_value'] = str(int(use_mig)) if use_mig: resources_vars.update(managed_instance_group_config) return resources_vars diff --git a/sky/templates/gcp-ray.yml.j2 b/sky/templates/gcp-ray.yml.j2 index 51a7b332a72..aabe2fb7842 100644 --- a/sky/templates/gcp-ray.yml.j2 +++ b/sky/templates/gcp-ray.yml.j2 @@ -80,7 +80,7 @@ available_node_types: {%- for label_key, label_value in labels.items() %} {{ label_key }}: {{ label_value|tojson }} {%- endfor %} - managed-instance-group: {{ gcp_use_managed_instance_group }} + managed-instance-group: {{ gcp_use_managed_instance_group_value|tojson }} {%- if gcp_use_managed_instance_group %} managed-instance-group: run_duration: {{ run_duration }} From e6d13964cdec9c177e137fa21f938f433932c2cc Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Tue, 11 Jun 2024 08:29:43 +0000 Subject: [PATCH 28/29] format --- sky/clouds/gcp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index f3cbd330cea..94add7fce7d 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -511,7 +511,8 @@ def make_deploy_resources_variables( resources_vars['gcp_use_managed_instance_group'] = use_mig # Convert boolean to 0 or 1 in string, as GCP does not support boolean # value in labels for TPU VM APIs. - resources_vars['gcp_use_managed_instance_group_value'] = str(int(use_mig)) + resources_vars['gcp_use_managed_instance_group_value'] = str( + int(use_mig)) if use_mig: resources_vars.update(managed_instance_group_config) return resources_vars From ed9073ea6f0bc50affc0e6dadcd303ae216ea579 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Tue, 11 Jun 2024 08:57:51 +0000 Subject: [PATCH 29/29] fix key --- sky/templates/gcp-ray.yml.j2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/templates/gcp-ray.yml.j2 b/sky/templates/gcp-ray.yml.j2 index aabe2fb7842..f4ec10a697d 100644 --- a/sky/templates/gcp-ray.yml.j2 +++ b/sky/templates/gcp-ray.yml.j2 @@ -80,7 +80,7 @@ available_node_types: {%- for label_key, label_value in labels.items() %} {{ label_key }}: {{ label_value|tojson }} {%- endfor %} - managed-instance-group: {{ gcp_use_managed_instance_group_value|tojson }} + use-managed-instance-group: {{ gcp_use_managed_instance_group_value|tojson }} {%- if gcp_use_managed_instance_group %} managed-instance-group: run_duration: {{ run_duration }}