From fda9c830fd7532b8db046b1e317b4d5891bae6f1 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Thu, 5 Oct 2023 15:45:58 -0700 Subject: [PATCH 01/84] init --- sky/provision/gcp/config.py | 1055 ++++++++++++++++++++++++++++++++++ sky/provision/provisioner.py | 3 +- 2 files changed, 1057 insertions(+), 1 deletion(-) create mode 100644 sky/provision/gcp/config.py diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py new file mode 100644 index 00000000000..6983a6d1a6c --- /dev/null +++ b/sky/provision/gcp/config.py @@ -0,0 +1,1055 @@ +"""GCP config module.""" +import copy +from functools import partial +import json +import logging +import os +import time +from typing import Dict, List, Set, Tuple + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from google.oauth2 import service_account +from google.oauth2.credentials import Credentials as OAuthCredentials +from googleapiclient import discovery +from googleapiclient import errors +from ray.autoscaler._private.util import check_legacy_fields + +from sky.skylet.providers.gcp.constants import FIREWALL_RULES_REQUIRED +from sky.skylet.providers.gcp.constants import FIREWALL_RULES_TEMPLATE +from sky.skylet.providers.gcp.constants import SKYPILOT_VPC_NAME +from sky.skylet.providers.gcp.constants import TPU_MINIMAL_PERMISSIONS +from sky.skylet.providers.gcp.constants import VM_MINIMAL_PERMISSIONS +from sky.skylet.providers.gcp.constants import VPC_TEMPLATE +from sky.skylet.providers.gcp.node import GCPCompute +from sky.skylet.providers.gcp.node import GCPNodeType +from sky.skylet.providers.gcp.node import MAX_POLLS +from sky.skylet.providers.gcp.node import POLL_INTERVAL + +logger = logging.getLogger(__name__) + +VERSION = 'v1' +TPU_VERSION = 'v2alpha' # change once v2 is stable + +RAY = 'ray-autoscaler' +DEFAULT_SERVICE_ACCOUNT_ID = RAY + '-sa-' + VERSION +SERVICE_ACCOUNT_EMAIL_TEMPLATE = '{account_id}@{project_id}.iam.gserviceaccount.com' +DEFAULT_SERVICE_ACCOUNT_CONFIG = { + 'displayName': 'Ray Autoscaler Service Account ({})'.format(VERSION), +} + +SKYPILOT = 'skypilot' +SKYPILOT_SERVICE_ACCOUNT_ID = SKYPILOT + '-' + VERSION +SKYPILOT_SERVICE_ACCOUNT_EMAIL_TEMPLATE = ( + '{account_id}@{project_id}.iam.gserviceaccount.com' +) +SKYPILOT_SERVICE_ACCOUNT_CONFIG = { + 'displayName': 'SkyPilot Service Account ({})'.format(VERSION), +} + +# Those roles will be always added. +# NOTE: `serviceAccountUser` allows the head node to create workers with +# a serviceAccount. `roleViewer` allows the head node to run bootstrap_gcp. +DEFAULT_SERVICE_ACCOUNT_ROLES = [ + 'roles/storage.objectAdmin', + 'roles/compute.admin', + 'roles/iam.serviceAccountUser', + 'roles/iam.roleViewer', +] +# Those roles will only be added if there are TPU nodes defined in config. +TPU_SERVICE_ACCOUNT_ROLES = ['roles/tpu.admin'] + +# If there are TPU nodes in config, this field will be set +# to True in config['provider']. +HAS_TPU_PROVIDER_FIELD = '_has_tpus' + +# NOTE: iam.serviceAccountUser allows the Head Node to create worker nodes +# with ServiceAccounts. + + +def get_node_type(node: dict) -> GCPNodeType: + """Returns node type based on the keys in ``node``. + + This is a very simple check. If we have a ``machineType`` key, + this is a Compute instance. If we don't have a ``machineType`` key, + but we have ``acceleratorType``, this is a TPU. Otherwise, it's + invalid and an exception is raised. + + This works for both node configs and API returned nodes. + """ + + if 'machineType' not in node and 'acceleratorType' not in node: + raise ValueError( + 'Invalid node. For a Compute instance, "machineType" is ' + 'required. ' + 'For a TPU instance, "acceleratorType" and no "machineType" ' + 'is required. ' + f'Got {list(node)}' + ) + + if 'machineType' not in node and 'acceleratorType' in node: + return GCPNodeType.TPU + return GCPNodeType.COMPUTE + + +def wait_for_crm_operation(operation, crm): + """Poll for cloud resource manager operation until finished.""" + logger.info( + 'wait_for_crm_operation: ' + 'Waiting for operation {} to finish...'.format(operation) + ) + + for _ in range(MAX_POLLS): + result = crm.operations().get(name=operation['name']).execute() + if 'error' in result: + raise Exception(result['error']) + + if 'done' in result and result['done']: + logger.info('wait_for_crm_operation: Operation done.') + break + + time.sleep(POLL_INTERVAL) + + return result + + +def wait_for_compute_global_operation(project_name, operation, compute): + """Poll for global compute operation until finished.""" + logger.info( + 'wait_for_compute_global_operation: ' + 'Waiting for operation {} to finish...'.format(operation['name']) + ) + + for _ in range(MAX_POLLS): + result = ( + compute.globalOperations() + .get( + project=project_name, + operation=operation['name'], + ) + .execute() + ) + if 'error' in result: + raise Exception(result['error']) + + if result['status'] == 'DONE': + logger.info('wait_for_compute_global_operation: Operation done.') + break + + time.sleep(POLL_INTERVAL) + + return result + + +def key_pair_name(i, region, project_id, ssh_user): + """Returns the ith default gcp_key_pair_name.""" + return f'{SKYPILOT}_gcp_{region}_{project_id}_{ssh_user}_{i}' + + +def key_pair_paths(key_name): + """Returns public and private key paths for a given key_name.""" + public_key_path = os.path.expanduser(f'~/.ssh/{key_name}.pub') + private_key_path = os.path.expanduser(f'~/.ssh/{key_name}.pem') + return public_key_path, private_key_path + + +def generate_rsa_key_pair(): + """Create public and private ssh-keys.""" + + key = rsa.generate_private_key( + backend=default_backend(), public_exponent=65537, key_size=2048 + ) + + public_key = ( + key.public_key() + .public_bytes( + serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH + ) + .decode('utf-8') + ) + + pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode('utf-8') + + return public_key, pem + + +def _has_tpus_in_node_configs(config: dict) -> bool: + """Check if any nodes in config are TPUs.""" + node_configs = [ + node_type['node_config'] + for node_type in config['available_node_types'].values() + ] + return any(get_node_type(node) == GCPNodeType.TPU for node in node_configs) + + +def _is_head_node_a_tpu(config: dict) -> bool: + """Check if the head node is a TPU.""" + node_configs = { + node_id: node_type['node_config'] + for node_id, node_type in config['available_node_types'].items() + } + return get_node_type(node_configs[config['head_node_type']]) == GCPNodeType.TPU + + +def _create_crm(gcp_credentials=None): + return discovery.build( + 'cloudresourcemanager', 'v1', credentials=gcp_credentials, cache_discovery=False + ) + + +def _create_iam(gcp_credentials=None): + return discovery.build( + 'iam', 'v1', credentials=gcp_credentials, cache_discovery=False + ) + + +def _create_compute(gcp_credentials=None): + return discovery.build( + 'compute', 'v1', credentials=gcp_credentials, cache_discovery=False + ) + + +def _create_tpu(gcp_credentials=None): + return discovery.build( + 'tpu', + TPU_VERSION, + credentials=gcp_credentials, + cache_discovery=False, + discoveryServiceUrl='https://tpu.googleapis.com/$discovery/rest', + ) + + +def construct_clients_from_provider_config(provider_config): + """ + Attempt to fetch and parse the JSON GCP credentials from the provider + config yaml file. + + tpu resource (the last element of the tuple) will be None if + `_has_tpus` in provider config is not set or False. + """ + gcp_credentials = provider_config.get('gcp_credentials') + if gcp_credentials is None: + logger.debug( + 'gcp_credentials not found in cluster yaml file. ' + 'Falling back to GOOGLE_APPLICATION_CREDENTIALS ' + 'environment variable.' + ) + tpu_resource = ( + _create_tpu() + if provider_config.get(HAS_TPU_PROVIDER_FIELD, False) + else None + ) + # If gcp_credentials is None, then discovery.build will search for + # credentials in the local environment. + return _create_crm(), _create_iam(), _create_compute(), tpu_resource + + assert ( + 'type' in gcp_credentials + ), 'gcp_credentials cluster yaml field missing "type" field.' + assert ( + 'credentials' in gcp_credentials + ), 'gcp_credentials cluster yaml field missing "credentials" field.' + + cred_type = gcp_credentials['type'] + credentials_field = gcp_credentials['credentials'] + + if cred_type == 'service_account': + # If parsing the gcp_credentials failed, then the user likely made a + # mistake in copying the credentials into the config yaml. + try: + service_account_info = json.loads(credentials_field) + except json.decoder.JSONDecodeError: + raise RuntimeError( + 'gcp_credentials found in cluster yaml file but ' + 'formatted improperly.' + ) + credentials = service_account.Credentials.from_service_account_info( + service_account_info + ) + elif cred_type == 'credentials_token': + # Otherwise the credentials type must be credentials_token. + credentials = OAuthCredentials(credentials_field) + + tpu_resource = ( + _create_tpu(credentials) + if provider_config.get(HAS_TPU_PROVIDER_FIELD, False) + else None + ) + + return ( + _create_crm(credentials), + _create_iam(credentials), + _create_compute(credentials), + tpu_resource, + ) + + +def bootstrap_gcp(config): + config = copy.deepcopy(config) + check_legacy_fields(config) + # Used internally to store head IAM role. + config['head_node'] = {} + + # Check if we have any TPUs defined, and if so, + # insert that information into the provider config + if _has_tpus_in_node_configs(config): + config['provider'][HAS_TPU_PROVIDER_FIELD] = True + + crm, iam, compute, tpu = construct_clients_from_provider_config(config['provider']) + + config = _configure_project(config, crm) + config = _configure_iam_role(config, crm, iam) + config = _configure_key_pair(config, compute) + config = _configure_subnet(config, compute) + + return config + + +def _configure_project(config, crm): + """Setup a Google Cloud Platform Project. + + Google Compute Platform organizes all the resources, such as storage + buckets, users, and instances under projects. This is different from + aws ec2 where everything is global. + """ + config = copy.deepcopy(config) + + project_id = config['provider'].get('project_id') + assert config['provider']['project_id'] is not None, ( + '"project_id" must be set in the "provider" section of the autoscaler' + ' config. Notice that the project id must be globally unique.' + ) + project = _get_project(project_id, crm) + + if project is None: + # Project not found, try creating it + _create_project(project_id, crm) + project = _get_project(project_id, crm) + + assert project is not None, 'Failed to create project' + assert ( + project['lifecycleState'] == 'ACTIVE' + ), 'Project status needs to be ACTIVE, got {}'.format(project['lifecycleState']) + + config['provider']['project_id'] = project['projectId'] + + return config + + +def _is_permission_satisfied( + service_account, crm, iam, required_permissions, required_roles +): + """Check if either of the roles or permissions are satisfied.""" + if service_account is None: + return False, None + + project_id = service_account['projectId'] + email = service_account['email'] + + member_id = 'serviceAccount:' + email + + required_permissions = set(required_permissions) + policy = crm.projects().getIamPolicy(resource=project_id, body={}).execute() + original_policy = copy.deepcopy(policy) + already_configured = True + + logger.info(f'_configure_iam_role: Checking permissions for {email}...') + + # Check the roles first, as checking the permission requires more API calls and + # permissions. + for role in required_roles: + role_exists = False + for binding in policy['bindings']: + if binding['role'] == role: + if member_id not in binding['members']: + binding['members'].append(member_id) + already_configured = False + role_exists = True + + if not role_exists: + already_configured = False + policy['bindings'].append( + { + 'members': [member_id], + 'role': role, + } + ) + + if already_configured: + # In some managed environments, an admin needs to grant the + # roles, so only call setIamPolicy if needed. + return True, policy + + for binding in original_policy['bindings']: + if member_id in binding['members']: + role = binding['role'] + try: + role_definition = iam.projects().roles().get(name=role).execute() + except TypeError as e: + if 'does not match the pattern' in str(e): + logger.info( + f'_configure_iam_role: fail to check permission for built-in role {role}. skipped.' + ) + permissions = [] + else: + raise + else: + permissions = role_definition['includedPermissions'] + required_permissions -= set(permissions) + if not required_permissions: + break + if not required_permissions: + # All required permissions are already granted. + return True, policy + logger.info(f'_configure_iam_role: missing permisisons {required_permissions}') + + return False, policy + + +def _configure_iam_role(config, crm, iam): + """Setup a gcp service account with IAM roles. + + Creates a gcp service acconut and binds IAM roles which allow it to control + control storage/compute services. Specifically, the head node needs to have + an IAM role that allows it to create further gce instances and store items + in google cloud storage. + + TODO: Allow the name/id of the service account to be configured + """ + config = copy.deepcopy(config) + + email = SKYPILOT_SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( + account_id=SKYPILOT_SERVICE_ACCOUNT_ID, + project_id=config['provider']['project_id'], + ) + service_account = _get_service_account(email, config, iam) + + permissions = VM_MINIMAL_PERMISSIONS + roles = DEFAULT_SERVICE_ACCOUNT_ROLES + if config['provider'].get(HAS_TPU_PROVIDER_FIELD, False): + roles = DEFAULT_SERVICE_ACCOUNT_ROLES + TPU_SERVICE_ACCOUNT_ROLES + permissions = VM_MINIMAL_PERMISSIONS + TPU_MINIMAL_PERMISSIONS + + satisfied, policy = _is_permission_satisfied( + service_account, crm, iam, permissions, roles + ) + + if not satisfied: + # SkyPilot: Fallback to the old ray service account name for + # backwards compatibility. Users using GCP before #2112 have + # the old service account setup setup in their GCP project, + # and the user may not have the permissions to create the + # new service account. This is to ensure that the old service + # account is still usable. + email = SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( + account_id=DEFAULT_SERVICE_ACCOUNT_ID, + project_id=config['provider']['project_id'], + ) + logger.info(f'_configure_iam_role: Fallback to service account {email}') + + ray_service_account = _get_service_account(email, config, iam) + ray_satisfied, _ = _is_permission_satisfied( + ray_service_account, crm, iam, permissions, roles + ) + logger.info( + '_configure_iam_role: ' + f'Fallback to service account {email} succeeded? {ray_satisfied}' + ) + + if ray_satisfied: + service_account = ray_service_account + satisfied = ray_satisfied + elif service_account is None: + logger.info( + '_configure_iam_role: ' + 'Creating new service account {}'.format(SKYPILOT_SERVICE_ACCOUNT_ID) + ) + # SkyPilot: a GCP user without the permission to create a service + # account will fail here. + service_account = _create_service_account( + SKYPILOT_SERVICE_ACCOUNT_ID, + SKYPILOT_SERVICE_ACCOUNT_CONFIG, + config, + iam, + ) + satisfied, policy = _is_permission_satisfied( + service_account, crm, iam, permissions, roles + ) + + assert service_account is not None, 'Failed to create service account' + + if not satisfied: + logger.info( + '_configure_iam_role: ' f'Adding roles to service account {email}...' + ) + _add_iam_policy_binding(service_account, policy, crm, iam) + + account_dict = { + 'email': service_account['email'], + # NOTE: The amount of access is determined by the scope + IAM + # role of the service account. Even if the cloud-platform scope + # gives (scope) access to the whole cloud-platform, the service + # account is limited by the IAM rights specified below. + 'scopes': ['https://www.googleapis.com/auth/cloud-platform'], + } + if _is_head_node_a_tpu(config): + # SKY: The API for TPU VM is slightly different from normal compute instances. + # See https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#Node + account_dict['scope'] = account_dict['scopes'] + account_dict.pop('scopes') + config['head_node']['serviceAccount'] = account_dict + else: + config['head_node']['serviceAccounts'] = [account_dict] + + return config + + +def _configure_key_pair(config, compute): + """Configure SSH access, using an existing key pair if possible. + + Creates a project-wide ssh key that can be used to access all the instances + unless explicitly prohibited by instance config. + + The ssh-keys created by ray are of format: + + [USERNAME]:ssh-rsa [KEY_VALUE] [USERNAME] + + where: + + [USERNAME] is the user for the SSH key, specified in the config. + [KEY_VALUE] is the public SSH key value. + """ + config = copy.deepcopy(config) + + if 'ssh_private_key' in config['auth']: + return config + + ssh_user = config['auth']['ssh_user'] + + project = compute.projects().get(project=config['provider']['project_id']).execute() + + # Key pairs associated with project meta data. The key pairs are general, + # and not just ssh keys. + ssh_keys_str = next( + ( + item + for item in project['commonInstanceMetadata'].get('items', []) + if item['key'] == 'ssh-keys' + ), + {}, + ).get('value', '') + + ssh_keys = ssh_keys_str.split('\n') if ssh_keys_str else [] + + # Try a few times to get or create a good key pair. + key_found = False + for i in range(10): + key_name = key_pair_name( + i, config['provider']['region'], config['provider']['project_id'], ssh_user + ) + public_key_path, private_key_path = key_pair_paths(key_name) + + for ssh_key in ssh_keys: + key_parts = ssh_key.split(' ') + if len(key_parts) != 3: + continue + + if key_parts[2] == ssh_user and os.path.exists(private_key_path): + # Found a key + key_found = True + break + + # Writing the new ssh key to the filesystem fails if the ~/.ssh + # directory doesn't already exist. + os.makedirs(os.path.expanduser('~/.ssh'), exist_ok=True) + + # Create a key since it doesn't exist locally or in GCP + if not key_found and not os.path.exists(private_key_path): + logger.info( + '_configure_key_pair: Creating new key pair {}'.format(key_name) + ) + public_key, private_key = generate_rsa_key_pair() + + _create_project_ssh_key_pair(project, public_key, ssh_user, compute) + + # Create the directory if it doesn't exists + private_key_dir = os.path.dirname(private_key_path) + os.makedirs(private_key_dir, exist_ok=True) + + # We need to make sure to _create_ the file with the right + # permissions. In order to do that we need to change the default + # os.open behavior to include the mode we want. + with open( + private_key_path, + 'w', + opener=partial(os.open, mode=0o600), + ) as f: + f.write(private_key) + + with open(public_key_path, 'w') as f: + f.write(public_key) + + key_found = True + + break + + if key_found: + break + + assert key_found, 'SSH keypair for user {} not found for {}'.format( + ssh_user, private_key_path + ) + assert os.path.exists( + private_key_path + ), 'Private key file {} not found for user {}'.format(private_key_path, ssh_user) + + logger.info( + '_configure_key_pair: ' + 'Private key not specified in config, using' + '{}'.format(private_key_path) + ) + + config['auth']['ssh_private_key'] = private_key_path + + return config + + +def _check_firewall_rules(vpc_name, config, compute): + """Check if the firewall rules in the VPC are sufficient.""" + required_rules = FIREWALL_RULES_REQUIRED.copy() + + operation = compute.networks().getEffectiveFirewalls( + project=config['provider']['project_id'], network=vpc_name + ) + response = operation.execute() + if len(response) == 0: + return False + effective_rules = response['firewalls'] + + def _merge_and_refine_rule(rules): + """Returns the reformatted rules from the firewall rules + + The function translates firewall rules fetched from the cloud provider + to a format for simple comparison. + + Example of firewall rules from the cloud: + [ + { + ... + "direction": "INGRESS", + "allowed": [ + {"IPProtocol": "tcp", "ports": ['80', '443']}, + {"IPProtocol": "udp", "ports": ['53']}, + ], + "sourceRanges": ["10.128.0.0/9"], + }, + { + ... + "direction": "INGRESS", + "allowed": [{ + "IPProtocol": "tcp", + "ports": ["22"], + }], + "sourceRanges": ["0.0.0.0/0"], + }, + ] + + Returns: + source2rules: Dict[(direction, sourceRanges) -> Dict(protocol -> Set[ports])] + Example { + ("INGRESS", "10.128.0.0/9"): {"tcp": {80, 443}, "udp": {53}}, + ("INGRESS", "0.0.0.0/0"): {"tcp": {22}}, + } + """ + source2rules: Dict[Tuple[str, str], Dict[str, Set[int]]] = {} + source2allowed_list: Dict[Tuple[str, str], List[Dict[str, str]]] = {} + for rule in rules: + # Rules applied to specific VM (targetTags) may not work for the + # current VM, so should be skipped. + # Filter by targetTags == ['cluster_name'] + # See https://developers.google.com/resources/api-libraries/documentation/compute/alpha/python/latest/compute_alpha.networks.html#getEffectiveFirewalls # pylint: disable=line-too-long + tags = rule.get('targetTags', None) + if tags is not None: + if len(tags) != 1: + continue + if tags[0] != config['cluster_name']: + continue + direction = rule.get('direction', '') + sources = rule.get('sourceRanges', []) + allowed = rule.get('allowed', []) + for source in sources: + key = (direction, source) + source2allowed_list[key] = source2allowed_list.get(key, []) + allowed + for direction_source, allowed_list in source2allowed_list.items(): + source2rules[direction_source] = {} + for allowed in allowed_list: + # Example of port_list: ['20', '50-60'] + # If list is empty, it means all ports + port_list = allowed.get('ports', []) + port_set = set() + if port_list == []: + port_set.update(set(range(1, 65536))) + else: + for port_range in port_list: + parse_ports = port_range.split('-') + if len(parse_ports) == 1: + port_set.add(int(parse_ports[0])) + else: + assert ( + len(parse_ports) == 2 + ), f'Failed to parse the port range: {port_range}' + port_set.update( + set(range(int(parse_ports[0]), int(parse_ports[1]) + 1)) + ) + if allowed['IPProtocol'] not in source2rules[direction_source]: + source2rules[direction_source][allowed['IPProtocol']] = set() + source2rules[direction_source][allowed['IPProtocol']].update(port_set) + return source2rules + + effective_rules = _merge_and_refine_rule(effective_rules) + required_rules = _merge_and_refine_rule(required_rules) + + for direction_source, allowed_req in required_rules.items(): + if direction_source not in effective_rules: + return False + allowed_eff = effective_rules[direction_source] + # Special case: 'all' means allowing all traffic + if 'all' in allowed_eff: + continue + # Check if the required ports are a subset of the effective ports + for protocol, ports_req in allowed_req.items(): + ports_eff = allowed_eff.get(protocol, set()) + if not ports_req.issubset(ports_eff): + return False + return True + + +def _create_rules(config, compute, rules, VPC_NAME, PROJ_ID): + opertaions = [] + for rule in rules: + # Query firewall rule by its name (unique in a project). + # If the rule already exists, delete it first. + rule_name = rule['name'].format(VPC_NAME=VPC_NAME) + rule_list = _list_firewall_rules(config, compute, filter=f'(name={rule_name})') + if len(rule_list) > 0: + _delete_firewall_rule(config, compute, rule_name) + + body = rule.copy() + body['name'] = body['name'].format(VPC_NAME=VPC_NAME) + body['network'] = body['network'].format(PROJ_ID=PROJ_ID, VPC_NAME=VPC_NAME) + body['selfLink'] = body['selfLink'].format(PROJ_ID=PROJ_ID, VPC_NAME=VPC_NAME) + op = _create_firewall_rule_submit(config, compute, body) + opertaions.append(op) + for op in opertaions: + wait_for_compute_global_operation(config['provider']['project_id'], op, compute) + + +def get_usable_vpc(config): + """Return a usable VPC. + + If not found, create a new one with sufficient firewall rules. + """ + _, _, compute, _ = construct_clients_from_provider_config(config['provider']) + + # For backward compatibility, reuse the VPC if the VM is launched. + resource = GCPCompute( + compute, + config['provider']['project_id'], + config['provider']['availability_zone'], + config['cluster_name'], + ) + node = resource._list_instances(label_filters=None, status_filter=None) + if len(node) > 0: + netInterfaces = node[0].get('networkInterfaces', []) + if len(netInterfaces) > 0: + vpc_name = netInterfaces[0]['network'].split('/')[-1] + return vpc_name + + vpcnets_all = _list_vpcnets(config, compute) + + usable_vpc_name = None + for vpc in vpcnets_all: + if _check_firewall_rules(vpc['name'], config, compute): + usable_vpc_name = vpc['name'] + break + + proj_id = config['provider']['project_id'] + if usable_vpc_name is None: + logger.info(f'Creating a default VPC network, {SKYPILOT_VPC_NAME}...') + + # Create a SkyPilot VPC network if it doesn't exist + vpc_list = _list_vpcnets(config, compute, filter=f'name={SKYPILOT_VPC_NAME}') + if len(vpc_list) == 0: + body = VPC_TEMPLATE.copy() + body['name'] = body['name'].format(VPC_NAME=SKYPILOT_VPC_NAME) + body['selfLink'] = body['selfLink'].format( + PROJ_ID=proj_id, VPC_NAME=SKYPILOT_VPC_NAME + ) + _create_vpcnet(config, compute, body) + + _create_rules( + config, compute, FIREWALL_RULES_TEMPLATE, SKYPILOT_VPC_NAME, proj_id + ) + + usable_vpc_name = SKYPILOT_VPC_NAME + logger.info(f'A VPC network {SKYPILOT_VPC_NAME} created.') + + return usable_vpc_name + + +def _configure_subnet(config, compute): + """Pick a reasonable subnet if not specified by the config.""" + config = copy.deepcopy(config) + + node_configs = [ + node_type['node_config'] + for node_type in config['available_node_types'].values() + ] + # Rationale: avoid subnet lookup if the network is already + # completely manually configured + + # networkInterfaces is compute, networkConfig is TPU + if all( + 'networkInterfaces' in node_config or 'networkConfig' in node_config + for node_config in node_configs + ): + return config + + # SkyPilot: make sure there's a usable VPC + usable_vpc_name = get_usable_vpc(config) + subnets = _list_subnets(config, compute, filter=f'(name="{usable_vpc_name}")') + default_subnet = subnets[0] + + default_interfaces = [ + { + 'subnetwork': default_subnet['selfLink'], + 'accessConfigs': [ + { + 'name': 'External NAT', + 'type': 'ONE_TO_ONE_NAT', + } + ], + } + ] + + for node_config in node_configs: + # The not applicable key will be removed during node creation + + # compute + if 'networkInterfaces' not in node_config: + node_config['networkInterfaces'] = copy.deepcopy(default_interfaces) + # TPU + if 'networkConfig' not in node_config: + node_config['networkConfig'] = copy.deepcopy(default_interfaces)[0] + node_config['networkConfig'].pop('accessConfigs') + + return config + + +def _create_firewall_rule_submit(config, compute, body): + operation = ( + compute.firewalls() + .insert(project=config['provider']['project_id'], body=body) + .execute() + ) + return operation + + +def _delete_firewall_rule(config, compute, name): + operation = ( + compute.firewalls() + .delete(project=config['provider']['project_id'], firewall=name) + .execute() + ) + response = wait_for_compute_global_operation( + config['provider']['project_id'], operation, compute + ) + return response + + +def _list_firewall_rules(config, compute, filter=None): + response = ( + compute.firewalls() + .list( + project=config['provider']['project_id'], + filter=filter, + ) + .execute() + ) + return response['items'] if 'items' in response else [] + + +def _create_vpcnet(config, compute, body): + operation = ( + compute.networks() + .insert(project=config['provider']['project_id'], body=body) + .execute() + ) + response = wait_for_compute_global_operation( + config['provider']['project_id'], operation, compute + ) + return response + + +def _list_vpcnets(config, compute, filter=None): + response = ( + compute.networks() + .list( + project=config['provider']['project_id'], + filter=filter, + ) + .execute() + ) + + return response['items'] if 'items' in response else [] + + +def _list_subnets(config, compute, filter=None): + response = ( + compute.subnetworks() + .list( + project=config['provider']['project_id'], + region=config['provider']['region'], + filter=filter, + ) + .execute() + ) + + return response['items'] if 'items' in response else [] + + +def _get_subnet(config, subnet_id, compute): + subnet = ( + compute.subnetworks() + .get( + project=config['provider']['project_id'], + region=config['provider']['region'], + subnetwork=subnet_id, + ) + .execute() + ) + + return subnet + + +def _get_project(project_id, crm): + try: + project = crm.projects().get(projectId=project_id).execute() + except errors.HttpError as e: + if e.resp.status != 403: + raise + project = None + + return project + + +def _create_project(project_id, crm): + operation = ( + crm.projects() + .create(body={'projectId': project_id, 'name': project_id}) + .execute() + ) + + result = wait_for_crm_operation(operation, crm) + + return result + + +def _get_service_account(account, config, iam): + project_id = config['provider']['project_id'] + full_name = 'projects/{project_id}/serviceAccounts/{account}'.format( + project_id=project_id, account=account + ) + try: + service_account = iam.projects().serviceAccounts().get(name=full_name).execute() + except errors.HttpError as e: + if e.resp.status not in [403, 404]: + # SkyPilot: added 403, which means the service account doesn't exist, + # or not accessible by the current account, which is fine, as we do the + # fallback in the caller. + raise + service_account = None + + return service_account + + +def _create_service_account(account_id, account_config, config, iam): + project_id = config['provider']['project_id'] + + service_account = ( + iam.projects() + .serviceAccounts() + .create( + name='projects/{project_id}'.format(project_id=project_id), + body={ + 'accountId': account_id, + 'serviceAccount': account_config, + }, + ) + .execute() + ) + + return service_account + + +def _add_iam_policy_binding(service_account, policy, crm, iam): + """Add new IAM roles for the service account.""" + project_id = service_account['projectId'] + + result = ( + crm.projects() + .setIamPolicy( + resource=project_id, + body={ + 'policy': policy, + }, + ) + .execute() + ) + + return result + + +def _create_project_ssh_key_pair(project, public_key, ssh_user, compute): + """Inserts an ssh-key into project commonInstanceMetadata""" + + key_parts = public_key.split(' ') + + # Sanity checks to make sure that the generated key matches expectation + assert len(key_parts) == 2, key_parts + assert key_parts[0] == 'ssh-rsa', key_parts + + new_ssh_meta = '{ssh_user}:ssh-rsa {key_value} {ssh_user}'.format( + ssh_user=ssh_user, key_value=key_parts[1] + ) + + common_instance_info = project['commonInstanceMetadata'] + items = common_instance_info.get('items', []) + + ssh_keys_i = next( + (i for i, item in enumerate(items) if item['key'] == 'ssh-keys'), None + ) + + if ssh_keys_i is None: + items.append({'key': 'ssh-keys', 'value': new_ssh_meta}) + else: + ssh_keys = items[ssh_keys_i] + ssh_keys['value'] += '\n' + new_ssh_meta + items[ssh_keys_i] = ssh_keys + + common_instance_info['items'] = items + + operation = ( + compute.projects() + .setCommonInstanceMetadata(project=project['name'], body=common_instance_info) + .execute() + ) + + response = wait_for_compute_global_operation(project['name'], operation, compute) + + return response diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index cc80ce0dbd5..e119d756250 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -138,13 +138,14 @@ def bulk_provision( ) -> Optional[provision_common.ProvisionRecord]: """Provisions a cluster and wait until fully provisioned.""" original_config = common_utils.read_yaml(cluster_yaml) + head_node_type = original_config['head_node_type'] bootstrap_config = provision_common.ProvisionConfig( provider_config=original_config['provider'], authentication_config=original_config['auth'], docker_config=original_config.get('docker', {}), # NOTE: (might be a legacy issue) we call it # 'ray_head_default' in 'gcp-ray.yaml' - node_config=original_config['available_node_types']['ray.head.default'] + node_config=original_config['available_node_types'][head_node_type] ['node_config'], count=num_nodes, tags={}, From a972b83682442c988b3a2347a75cc061f1c2279d Mon Sep 17 00:00:00 2001 From: Siyuan Date: Thu, 5 Oct 2023 15:47:20 -0700 Subject: [PATCH 02/84] remove ray --- sky/provision/gcp/config.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index 6983a6d1a6c..e5cb52760d6 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -14,7 +14,6 @@ from google.oauth2.credentials import Credentials as OAuthCredentials from googleapiclient import discovery from googleapiclient import errors -from ray.autoscaler._private.util import check_legacy_fields from sky.skylet.providers.gcp.constants import FIREWALL_RULES_REQUIRED from sky.skylet.providers.gcp.constants import FIREWALL_RULES_TEMPLATE @@ -36,7 +35,7 @@ DEFAULT_SERVICE_ACCOUNT_ID = RAY + '-sa-' + VERSION SERVICE_ACCOUNT_EMAIL_TEMPLATE = '{account_id}@{project_id}.iam.gserviceaccount.com' DEFAULT_SERVICE_ACCOUNT_CONFIG = { - 'displayName': 'Ray Autoscaler Service Account ({})'.format(VERSION), + 'displayName': f'Ray Autoscaler Service Account ({VERSION})', } SKYPILOT = 'skypilot' @@ -45,7 +44,7 @@ '{account_id}@{project_id}.iam.gserviceaccount.com' ) SKYPILOT_SERVICE_ACCOUNT_CONFIG = { - 'displayName': 'SkyPilot Service Account ({})'.format(VERSION), + 'displayName': f'SkyPilot Service Account ({VERSION})', } # Those roles will be always added. @@ -291,7 +290,6 @@ def construct_clients_from_provider_config(provider_config): def bootstrap_gcp(config): config = copy.deepcopy(config) - check_legacy_fields(config) # Used internally to store head IAM role. config['head_node'] = {} From 95f470e1a5e362231462fa31852235950280c281 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Thu, 5 Oct 2023 21:51:10 -0700 Subject: [PATCH 03/84] update config --- sky/provision/gcp/config.py | 588 ++++++++++++------------------------ 1 file changed, 196 insertions(+), 392 deletions(-) diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index e5cb52760d6..caead01a28b 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -15,6 +15,7 @@ from googleapiclient import discovery from googleapiclient import errors +from sky.provision import common from sky.skylet.providers.gcp.constants import FIREWALL_RULES_REQUIRED from sky.skylet.providers.gcp.constants import FIREWALL_RULES_TEMPLATE from sky.skylet.providers.gcp.constants import SKYPILOT_VPC_NAME @@ -41,8 +42,7 @@ SKYPILOT = 'skypilot' SKYPILOT_SERVICE_ACCOUNT_ID = SKYPILOT + '-' + VERSION SKYPILOT_SERVICE_ACCOUNT_EMAIL_TEMPLATE = ( - '{account_id}@{project_id}.iam.gserviceaccount.com' -) + '{account_id}@{project_id}.iam.gserviceaccount.com') SKYPILOT_SERVICE_ACCOUNT_CONFIG = { 'displayName': f'SkyPilot Service Account ({VERSION})', } @@ -84,8 +84,7 @@ def get_node_type(node: dict) -> GCPNodeType: 'required. ' 'For a TPU instance, "acceleratorType" and no "machineType" ' 'is required. ' - f'Got {list(node)}' - ) + f'Got {list(node)}') if 'machineType' not in node and 'acceleratorType' in node: return GCPNodeType.TPU @@ -94,10 +93,8 @@ def get_node_type(node: dict) -> GCPNodeType: def wait_for_crm_operation(operation, crm): """Poll for cloud resource manager operation until finished.""" - logger.info( - 'wait_for_crm_operation: ' - 'Waiting for operation {} to finish...'.format(operation) - ) + logger.info('wait_for_crm_operation: ' + 'Waiting for operation {} to finish...'.format(operation)) for _ in range(MAX_POLLS): result = crm.operations().get(name=operation['name']).execute() @@ -115,20 +112,15 @@ def wait_for_crm_operation(operation, crm): def wait_for_compute_global_operation(project_name, operation, compute): """Poll for global compute operation until finished.""" - logger.info( - 'wait_for_compute_global_operation: ' - 'Waiting for operation {} to finish...'.format(operation['name']) - ) + logger.info('wait_for_compute_global_operation: ' + 'Waiting for operation {} to finish...'.format( + operation['name'])) for _ in range(MAX_POLLS): - result = ( - compute.globalOperations() - .get( - project=project_name, - operation=operation['name'], - ) - .execute() - ) + result = (compute.globalOperations().get( + project=project_name, + operation=operation['name'], + ).execute()) if 'error' in result: raise Exception(result['error']) @@ -156,17 +148,13 @@ def key_pair_paths(key_name): def generate_rsa_key_pair(): """Create public and private ssh-keys.""" - key = rsa.generate_private_key( - backend=default_backend(), public_exponent=65537, key_size=2048 - ) + key = rsa.generate_private_key(backend=default_backend(), + public_exponent=65537, + key_size=2048) - public_key = ( - key.public_key() - .public_bytes( - serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH - ) - .decode('utf-8') - ) + public_key = (key.public_key().public_bytes( + serialization.Encoding.OpenSSH, + serialization.PublicFormat.OpenSSH).decode('utf-8')) pem = key.private_bytes( encoding=serialization.Encoding.PEM, @@ -177,40 +165,35 @@ def generate_rsa_key_pair(): return public_key, pem -def _has_tpus_in_node_configs(config: dict) -> bool: - """Check if any nodes in config are TPUs.""" - node_configs = [ - node_type['node_config'] - for node_type in config['available_node_types'].values() - ] - return any(get_node_type(node) == GCPNodeType.TPU for node in node_configs) - - def _is_head_node_a_tpu(config: dict) -> bool: """Check if the head node is a TPU.""" node_configs = { node_id: node_type['node_config'] for node_id, node_type in config['available_node_types'].items() } - return get_node_type(node_configs[config['head_node_type']]) == GCPNodeType.TPU + return get_node_type( + node_configs[config['head_node_type']]) == GCPNodeType.TPU def _create_crm(gcp_credentials=None): - return discovery.build( - 'cloudresourcemanager', 'v1', credentials=gcp_credentials, cache_discovery=False - ) + return discovery.build('cloudresourcemanager', + 'v1', + credentials=gcp_credentials, + cache_discovery=False) def _create_iam(gcp_credentials=None): - return discovery.build( - 'iam', 'v1', credentials=gcp_credentials, cache_discovery=False - ) + return discovery.build('iam', + 'v1', + credentials=gcp_credentials, + cache_discovery=False) def _create_compute(gcp_credentials=None): - return discovery.build( - 'compute', 'v1', credentials=gcp_credentials, cache_discovery=False - ) + return discovery.build('compute', + 'v1', + credentials=gcp_credentials, + cache_discovery=False) def _create_tpu(gcp_credentials=None): @@ -233,26 +216,19 @@ def construct_clients_from_provider_config(provider_config): """ gcp_credentials = provider_config.get('gcp_credentials') if gcp_credentials is None: - logger.debug( - 'gcp_credentials not found in cluster yaml file. ' - 'Falling back to GOOGLE_APPLICATION_CREDENTIALS ' - 'environment variable.' - ) - tpu_resource = ( - _create_tpu() - if provider_config.get(HAS_TPU_PROVIDER_FIELD, False) - else None - ) + logger.debug('gcp_credentials not found in cluster yaml file. ' + 'Falling back to GOOGLE_APPLICATION_CREDENTIALS ' + 'environment variable.') + tpu_resource = (_create_tpu() if provider_config.get( + HAS_TPU_PROVIDER_FIELD, False) else None) # If gcp_credentials is None, then discovery.build will search for # credentials in the local environment. return _create_crm(), _create_iam(), _create_compute(), tpu_resource - assert ( - 'type' in gcp_credentials - ), 'gcp_credentials cluster yaml field missing "type" field.' - assert ( - 'credentials' in gcp_credentials - ), 'gcp_credentials cluster yaml field missing "credentials" field.' + assert ('type' in gcp_credentials + ), 'gcp_credentials cluster yaml field missing "type" field.' + assert ('credentials' in gcp_credentials + ), 'gcp_credentials cluster yaml field missing "credentials" field.' cred_type = gcp_credentials['type'] credentials_field = gcp_credentials['credentials'] @@ -263,22 +239,16 @@ def construct_clients_from_provider_config(provider_config): try: service_account_info = json.loads(credentials_field) except json.decoder.JSONDecodeError: - raise RuntimeError( - 'gcp_credentials found in cluster yaml file but ' - 'formatted improperly.' - ) + raise RuntimeError('gcp_credentials found in cluster yaml file but ' + 'formatted improperly.') credentials = service_account.Credentials.from_service_account_info( - service_account_info - ) + service_account_info) elif cred_type == 'credentials_token': # Otherwise the credentials type must be credentials_token. credentials = OAuthCredentials(credentials_field) - tpu_resource = ( - _create_tpu(credentials) - if provider_config.get(HAS_TPU_PROVIDER_FIELD, False) - else None - ) + tpu_resource = (_create_tpu(credentials) if provider_config.get( + HAS_TPU_PROVIDER_FIELD, False) else None) return ( _create_crm(credentials), @@ -288,40 +258,41 @@ def construct_clients_from_provider_config(provider_config): ) -def bootstrap_gcp(config): - config = copy.deepcopy(config) - # Used internally to store head IAM role. - config['head_node'] = {} - +def bootstrap_gcp(region: str, cluster_name: str, + config: common.ProvisionConfig): # Check if we have any TPUs defined, and if so, # insert that information into the provider config - if _has_tpus_in_node_configs(config): - config['provider'][HAS_TPU_PROVIDER_FIELD] = True + if get_node_type(config.node_config) == GCPNodeType.TPU: + config.provider_config[HAS_TPU_PROVIDER_FIELD] = True - crm, iam, compute, tpu = construct_clients_from_provider_config(config['provider']) + crm, iam, compute, tpu = construct_clients_from_provider_config( + config.provider_config) - config = _configure_project(config, crm) - config = _configure_iam_role(config, crm, iam) - config = _configure_key_pair(config, compute) + # Setup a Google Cloud Platform Project. + + # Google Compute Platform organizes all the resources, such as storage + # buckets, users, and instances under projects. This is different from + # aws ec2 where everything is global. + + _configure_project(config.provider_config, crm) + config, iam_role = _configure_iam_role(config, crm, iam) + config.provider_config['iam_role'] = iam_role config = _configure_subnet(config, compute) return config -def _configure_project(config, crm): +def _configure_project(provider_config, crm): """Setup a Google Cloud Platform Project. Google Compute Platform organizes all the resources, such as storage buckets, users, and instances under projects. This is different from aws ec2 where everything is global. """ - config = copy.deepcopy(config) - - project_id = config['provider'].get('project_id') - assert config['provider']['project_id'] is not None, ( + project_id = provider_config.get('project_id') + assert project_id is not None, ( '"project_id" must be set in the "provider" section of the autoscaler' - ' config. Notice that the project id must be globally unique.' - ) + ' config. Notice that the project id must be globally unique.') project = _get_project(project_id, crm) if project is None: @@ -330,18 +301,15 @@ def _configure_project(config, crm): project = _get_project(project_id, crm) assert project is not None, 'Failed to create project' - assert ( - project['lifecycleState'] == 'ACTIVE' - ), 'Project status needs to be ACTIVE, got {}'.format(project['lifecycleState']) - - config['provider']['project_id'] = project['projectId'] + assert (project['lifecycleState'] == 'ACTIVE' + ), 'Project status needs to be ACTIVE, got {}'.format( + project['lifecycleState']) - return config + provider_config['project_id'] = project['projectId'] -def _is_permission_satisfied( - service_account, crm, iam, required_permissions, required_roles -): +def _is_permission_satisfied(service_account, crm, iam, required_permissions, + required_roles): """Check if either of the roles or permissions are satisfied.""" if service_account is None: return False, None @@ -358,8 +326,8 @@ def _is_permission_satisfied( logger.info(f'_configure_iam_role: Checking permissions for {email}...') - # Check the roles first, as checking the permission requires more API calls and - # permissions. + # Check the roles first, as checking the permission + # requires more API calls and permissions. for role in required_roles: role_exists = False for binding in policy['bindings']: @@ -371,12 +339,10 @@ def _is_permission_satisfied( if not role_exists: already_configured = False - policy['bindings'].append( - { - 'members': [member_id], - 'role': role, - } - ) + policy['bindings'].append({ + 'members': [member_id], + 'role': role, + }) if already_configured: # In some managed environments, an admin needs to grant the @@ -387,12 +353,12 @@ def _is_permission_satisfied( if member_id in binding['members']: role = binding['role'] try: - role_definition = iam.projects().roles().get(name=role).execute() + role_definition = iam.projects().roles().get( + name=role).execute() except TypeError as e: if 'does not match the pattern' in str(e): - logger.info( - f'_configure_iam_role: fail to check permission for built-in role {role}. skipped.' - ) + logger.info('_configure_iam_role: fail to check permission ' + f'for built-in role {role}. skipped.') permissions = [] else: raise @@ -404,12 +370,13 @@ def _is_permission_satisfied( if not required_permissions: # All required permissions are already granted. return True, policy - logger.info(f'_configure_iam_role: missing permisisons {required_permissions}') + logger.info( + f'_configure_iam_role: missing permisisons {required_permissions}') return False, policy -def _configure_iam_role(config, crm, iam): +def _configure_iam_role(config: common.ProvisionConfig, crm, iam): """Setup a gcp service account with IAM roles. Creates a gcp service acconut and binds IAM roles which allow it to control @@ -419,23 +386,20 @@ def _configure_iam_role(config, crm, iam): TODO: Allow the name/id of the service account to be configured """ - config = copy.deepcopy(config) - email = SKYPILOT_SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( account_id=SKYPILOT_SERVICE_ACCOUNT_ID, - project_id=config['provider']['project_id'], + project_id=config.provider_config['project_id'], ) - service_account = _get_service_account(email, config, iam) + service_account = _get_service_account(email, config.provider_config, iam) permissions = VM_MINIMAL_PERMISSIONS roles = DEFAULT_SERVICE_ACCOUNT_ROLES - if config['provider'].get(HAS_TPU_PROVIDER_FIELD, False): + if config.provider_config.get(HAS_TPU_PROVIDER_FIELD, False): roles = DEFAULT_SERVICE_ACCOUNT_ROLES + TPU_SERVICE_ACCOUNT_ROLES permissions = VM_MINIMAL_PERMISSIONS + TPU_MINIMAL_PERMISSIONS - satisfied, policy = _is_permission_satisfied( - service_account, crm, iam, permissions, roles - ) + satisfied, policy = _is_permission_satisfied(service_account, crm, iam, + permissions, roles) if not satisfied: # SkyPilot: Fallback to the old ray service account name for @@ -446,45 +410,41 @@ def _configure_iam_role(config, crm, iam): # account is still usable. email = SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( account_id=DEFAULT_SERVICE_ACCOUNT_ID, - project_id=config['provider']['project_id'], + project_id=config.provider_config['project_id'], ) logger.info(f'_configure_iam_role: Fallback to service account {email}') - ray_service_account = _get_service_account(email, config, iam) - ray_satisfied, _ = _is_permission_satisfied( - ray_service_account, crm, iam, permissions, roles - ) + ray_service_account = _get_service_account(email, + config.provider_config, iam) + ray_satisfied, _ = _is_permission_satisfied(ray_service_account, crm, + iam, permissions, roles) logger.info( '_configure_iam_role: ' - f'Fallback to service account {email} succeeded? {ray_satisfied}' - ) + f'Fallback to service account {email} succeeded? {ray_satisfied}') if ray_satisfied: service_account = ray_service_account satisfied = ray_satisfied elif service_account is None: - logger.info( - '_configure_iam_role: ' - 'Creating new service account {}'.format(SKYPILOT_SERVICE_ACCOUNT_ID) - ) + logger.info('_configure_iam_role: ' + 'Creating new service account {}'.format( + SKYPILOT_SERVICE_ACCOUNT_ID)) # SkyPilot: a GCP user without the permission to create a service # account will fail here. service_account = _create_service_account( SKYPILOT_SERVICE_ACCOUNT_ID, SKYPILOT_SERVICE_ACCOUNT_CONFIG, - config, + config.provider_config, iam, ) satisfied, policy = _is_permission_satisfied( - service_account, crm, iam, permissions, roles - ) + service_account, crm, iam, permissions, roles) assert service_account is not None, 'Failed to create service account' if not satisfied: - logger.info( - '_configure_iam_role: ' f'Adding roles to service account {email}...' - ) + logger.info('_configure_iam_role: ' + f'Adding roles to service account {email}...') _add_iam_policy_binding(service_account, policy, crm, iam) account_dict = { @@ -495,126 +455,16 @@ def _configure_iam_role(config, crm, iam): # account is limited by the IAM rights specified below. 'scopes': ['https://www.googleapis.com/auth/cloud-platform'], } - if _is_head_node_a_tpu(config): + if get_node_type(config.node_config) == GCPNodeType.TPU: # SKY: The API for TPU VM is slightly different from normal compute instances. # See https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#Node account_dict['scope'] = account_dict['scopes'] account_dict.pop('scopes') - config['head_node']['serviceAccount'] = account_dict + iam_role = {'serviceAccount': account_dict} else: - config['head_node']['serviceAccounts'] = [account_dict] - - return config - - -def _configure_key_pair(config, compute): - """Configure SSH access, using an existing key pair if possible. - - Creates a project-wide ssh key that can be used to access all the instances - unless explicitly prohibited by instance config. - - The ssh-keys created by ray are of format: - - [USERNAME]:ssh-rsa [KEY_VALUE] [USERNAME] - - where: - - [USERNAME] is the user for the SSH key, specified in the config. - [KEY_VALUE] is the public SSH key value. - """ - config = copy.deepcopy(config) - - if 'ssh_private_key' in config['auth']: - return config - - ssh_user = config['auth']['ssh_user'] + iam_role = {'serviceAccounts': [account_dict]} - project = compute.projects().get(project=config['provider']['project_id']).execute() - - # Key pairs associated with project meta data. The key pairs are general, - # and not just ssh keys. - ssh_keys_str = next( - ( - item - for item in project['commonInstanceMetadata'].get('items', []) - if item['key'] == 'ssh-keys' - ), - {}, - ).get('value', '') - - ssh_keys = ssh_keys_str.split('\n') if ssh_keys_str else [] - - # Try a few times to get or create a good key pair. - key_found = False - for i in range(10): - key_name = key_pair_name( - i, config['provider']['region'], config['provider']['project_id'], ssh_user - ) - public_key_path, private_key_path = key_pair_paths(key_name) - - for ssh_key in ssh_keys: - key_parts = ssh_key.split(' ') - if len(key_parts) != 3: - continue - - if key_parts[2] == ssh_user and os.path.exists(private_key_path): - # Found a key - key_found = True - break - - # Writing the new ssh key to the filesystem fails if the ~/.ssh - # directory doesn't already exist. - os.makedirs(os.path.expanduser('~/.ssh'), exist_ok=True) - - # Create a key since it doesn't exist locally or in GCP - if not key_found and not os.path.exists(private_key_path): - logger.info( - '_configure_key_pair: Creating new key pair {}'.format(key_name) - ) - public_key, private_key = generate_rsa_key_pair() - - _create_project_ssh_key_pair(project, public_key, ssh_user, compute) - - # Create the directory if it doesn't exists - private_key_dir = os.path.dirname(private_key_path) - os.makedirs(private_key_dir, exist_ok=True) - - # We need to make sure to _create_ the file with the right - # permissions. In order to do that we need to change the default - # os.open behavior to include the mode we want. - with open( - private_key_path, - 'w', - opener=partial(os.open, mode=0o600), - ) as f: - f.write(private_key) - - with open(public_key_path, 'w') as f: - f.write(public_key) - - key_found = True - - break - - if key_found: - break - - assert key_found, 'SSH keypair for user {} not found for {}'.format( - ssh_user, private_key_path - ) - assert os.path.exists( - private_key_path - ), 'Private key file {} not found for user {}'.format(private_key_path, ssh_user) - - logger.info( - '_configure_key_pair: ' - 'Private key not specified in config, using' - '{}'.format(private_key_path) - ) - - config['auth']['ssh_private_key'] = private_key_path - - return config + return config, iam_role def _check_firewall_rules(vpc_name, config, compute): @@ -622,8 +472,7 @@ def _check_firewall_rules(vpc_name, config, compute): required_rules = FIREWALL_RULES_REQUIRED.copy() operation = compute.networks().getEffectiveFirewalls( - project=config['provider']['project_id'], network=vpc_name - ) + project=config['provider']['project_id'], network=vpc_name) response = operation.execute() if len(response) == 0: return False @@ -682,7 +531,8 @@ def _merge_and_refine_rule(rules): allowed = rule.get('allowed', []) for source in sources: key = (direction, source) - source2allowed_list[key] = source2allowed_list.get(key, []) + allowed + source2allowed_list[key] = source2allowed_list.get(key, + []) + allowed for direction_source, allowed_list in source2allowed_list.items(): source2rules[direction_source] = {} for allowed in allowed_list: @@ -702,11 +552,14 @@ def _merge_and_refine_rule(rules): len(parse_ports) == 2 ), f'Failed to parse the port range: {port_range}' port_set.update( - set(range(int(parse_ports[0]), int(parse_ports[1]) + 1)) - ) + set( + range(int(parse_ports[0]), + int(parse_ports[1]) + 1))) if allowed['IPProtocol'] not in source2rules[direction_source]: - source2rules[direction_source][allowed['IPProtocol']] = set() - source2rules[direction_source][allowed['IPProtocol']].update(port_set) + source2rules[direction_source][ + allowed['IPProtocol']] = set() + source2rules[direction_source][allowed['IPProtocol']].update( + port_set) return source2rules effective_rules = _merge_and_refine_rule(effective_rules) @@ -733,18 +586,23 @@ def _create_rules(config, compute, rules, VPC_NAME, PROJ_ID): # Query firewall rule by its name (unique in a project). # If the rule already exists, delete it first. rule_name = rule['name'].format(VPC_NAME=VPC_NAME) - rule_list = _list_firewall_rules(config, compute, filter=f'(name={rule_name})') + rule_list = _list_firewall_rules(config, + compute, + filter=f'(name={rule_name})') if len(rule_list) > 0: _delete_firewall_rule(config, compute, rule_name) body = rule.copy() body['name'] = body['name'].format(VPC_NAME=VPC_NAME) - body['network'] = body['network'].format(PROJ_ID=PROJ_ID, VPC_NAME=VPC_NAME) - body['selfLink'] = body['selfLink'].format(PROJ_ID=PROJ_ID, VPC_NAME=VPC_NAME) + body['network'] = body['network'].format(PROJ_ID=PROJ_ID, + VPC_NAME=VPC_NAME) + body['selfLink'] = body['selfLink'].format(PROJ_ID=PROJ_ID, + VPC_NAME=VPC_NAME) op = _create_firewall_rule_submit(config, compute, body) opertaions.append(op) for op in opertaions: - wait_for_compute_global_operation(config['provider']['project_id'], op, compute) + wait_for_compute_global_operation(config['provider']['project_id'], op, + compute) def get_usable_vpc(config): @@ -752,7 +610,8 @@ def get_usable_vpc(config): If not found, create a new one with sufficient firewall rules. """ - _, _, compute, _ = construct_clients_from_provider_config(config['provider']) + _, _, compute, _ = construct_clients_from_provider_config( + config['provider']) # For backward compatibility, reuse the VPC if the VM is launched. resource = GCPCompute( @@ -781,18 +640,18 @@ def get_usable_vpc(config): logger.info(f'Creating a default VPC network, {SKYPILOT_VPC_NAME}...') # Create a SkyPilot VPC network if it doesn't exist - vpc_list = _list_vpcnets(config, compute, filter=f'name={SKYPILOT_VPC_NAME}') + vpc_list = _list_vpcnets(config, + compute, + filter=f'name={SKYPILOT_VPC_NAME}') if len(vpc_list) == 0: body = VPC_TEMPLATE.copy() body['name'] = body['name'].format(VPC_NAME=SKYPILOT_VPC_NAME) body['selfLink'] = body['selfLink'].format( - PROJ_ID=proj_id, VPC_NAME=SKYPILOT_VPC_NAME - ) + PROJ_ID=proj_id, VPC_NAME=SKYPILOT_VPC_NAME) _create_vpcnet(config, compute, body) - _create_rules( - config, compute, FIREWALL_RULES_TEMPLATE, SKYPILOT_VPC_NAME, proj_id - ) + _create_rules(config, compute, FIREWALL_RULES_TEMPLATE, + SKYPILOT_VPC_NAME, proj_id) usable_vpc_name = SKYPILOT_VPC_NAME logger.info(f'A VPC network {SKYPILOT_VPC_NAME} created.') @@ -812,28 +671,24 @@ def _configure_subnet(config, compute): # completely manually configured # networkInterfaces is compute, networkConfig is TPU - if all( - 'networkInterfaces' in node_config or 'networkConfig' in node_config - for node_config in node_configs - ): + if all('networkInterfaces' in node_config or 'networkConfig' in node_config + for node_config in node_configs): return config # SkyPilot: make sure there's a usable VPC usable_vpc_name = get_usable_vpc(config) - subnets = _list_subnets(config, compute, filter=f'(name="{usable_vpc_name}")') + subnets = _list_subnets(config, + compute, + filter=f'(name="{usable_vpc_name}")') default_subnet = subnets[0] - default_interfaces = [ - { - 'subnetwork': default_subnet['selfLink'], - 'accessConfigs': [ - { - 'name': 'External NAT', - 'type': 'ONE_TO_ONE_NAT', - } - ], - } - ] + default_interfaces = [{ + 'subnetwork': default_subnet['selfLink'], + 'accessConfigs': [{ + 'name': 'External NAT', + 'type': 'ONE_TO_ONE_NAT', + }], + }] for node_config in node_configs: # The not applicable key will be removed during node creation @@ -850,91 +705,54 @@ def _configure_subnet(config, compute): def _create_firewall_rule_submit(config, compute, body): - operation = ( - compute.firewalls() - .insert(project=config['provider']['project_id'], body=body) - .execute() - ) + operation = (compute.firewalls().insert( + project=config['provider']['project_id'], body=body).execute()) return operation def _delete_firewall_rule(config, compute, name): - operation = ( - compute.firewalls() - .delete(project=config['provider']['project_id'], firewall=name) - .execute() - ) + operation = (compute.firewalls().delete( + project=config['provider']['project_id'], firewall=name).execute()) response = wait_for_compute_global_operation( - config['provider']['project_id'], operation, compute - ) + config['provider']['project_id'], operation, compute) return response def _list_firewall_rules(config, compute, filter=None): - response = ( - compute.firewalls() - .list( - project=config['provider']['project_id'], - filter=filter, - ) - .execute() - ) + response = (compute.firewalls().list( + project=config['provider']['project_id'], + filter=filter, + ).execute()) return response['items'] if 'items' in response else [] def _create_vpcnet(config, compute, body): - operation = ( - compute.networks() - .insert(project=config['provider']['project_id'], body=body) - .execute() - ) + operation = (compute.networks().insert( + project=config['provider']['project_id'], body=body).execute()) response = wait_for_compute_global_operation( - config['provider']['project_id'], operation, compute - ) + config['provider']['project_id'], operation, compute) return response def _list_vpcnets(config, compute, filter=None): - response = ( - compute.networks() - .list( - project=config['provider']['project_id'], - filter=filter, - ) - .execute() - ) + response = (compute.networks().list( + project=config['provider']['project_id'], + filter=filter, + ).execute()) return response['items'] if 'items' in response else [] def _list_subnets(config, compute, filter=None): - response = ( - compute.subnetworks() - .list( - project=config['provider']['project_id'], - region=config['provider']['region'], - filter=filter, - ) - .execute() - ) + response = (compute.subnetworks().list( + project=config['provider']['project_id'], + region=config['provider']['region'], + filter=filter, + ).execute()) return response['items'] if 'items' in response else [] -def _get_subnet(config, subnet_id, compute): - subnet = ( - compute.subnetworks() - .get( - project=config['provider']['project_id'], - region=config['provider']['region'], - subnetwork=subnet_id, - ) - .execute() - ) - - return subnet - - def _get_project(project_id, crm): try: project = crm.projects().get(projectId=project_id).execute() @@ -947,24 +765,23 @@ def _get_project(project_id, crm): def _create_project(project_id, crm): - operation = ( - crm.projects() - .create(body={'projectId': project_id, 'name': project_id}) - .execute() - ) + operation = (crm.projects().create(body={ + 'projectId': project_id, + 'name': project_id + }).execute()) result = wait_for_crm_operation(operation, crm) return result -def _get_service_account(account, config, iam): - project_id = config['provider']['project_id'] +def _get_service_account(account, provider_config, iam): + project_id = provider_config['project_id'] full_name = 'projects/{project_id}/serviceAccounts/{account}'.format( - project_id=project_id, account=account - ) + project_id=project_id, account=account) try: - service_account = iam.projects().serviceAccounts().get(name=full_name).execute() + service_account = iam.projects().serviceAccounts().get( + name=full_name).execute() except errors.HttpError as e: if e.resp.status not in [403, 404]: # SkyPilot: added 403, which means the service account doesn't exist, @@ -976,21 +793,16 @@ def _get_service_account(account, config, iam): return service_account -def _create_service_account(account_id, account_config, config, iam): - project_id = config['provider']['project_id'] +def _create_service_account(account_id, account_config, provider_config, iam): + project_id = provider_config['project_id'] - service_account = ( - iam.projects() - .serviceAccounts() - .create( - name='projects/{project_id}'.format(project_id=project_id), - body={ - 'accountId': account_id, - 'serviceAccount': account_config, - }, - ) - .execute() - ) + service_account = (iam.projects().serviceAccounts().create( + name='projects/{project_id}'.format(project_id=project_id), + body={ + 'accountId': account_id, + 'serviceAccount': account_config, + }, + ).execute()) return service_account @@ -999,16 +811,12 @@ def _add_iam_policy_binding(service_account, policy, crm, iam): """Add new IAM roles for the service account.""" project_id = service_account['projectId'] - result = ( - crm.projects() - .setIamPolicy( - resource=project_id, - body={ - 'policy': policy, - }, - ) - .execute() - ) + result = (crm.projects().setIamPolicy( + resource=project_id, + body={ + 'policy': policy, + }, + ).execute()) return result @@ -1023,15 +831,13 @@ def _create_project_ssh_key_pair(project, public_key, ssh_user, compute): assert key_parts[0] == 'ssh-rsa', key_parts new_ssh_meta = '{ssh_user}:ssh-rsa {key_value} {ssh_user}'.format( - ssh_user=ssh_user, key_value=key_parts[1] - ) + ssh_user=ssh_user, key_value=key_parts[1]) common_instance_info = project['commonInstanceMetadata'] items = common_instance_info.get('items', []) ssh_keys_i = next( - (i for i, item in enumerate(items) if item['key'] == 'ssh-keys'), None - ) + (i for i, item in enumerate(items) if item['key'] == 'ssh-keys'), None) if ssh_keys_i is None: items.append({'key': 'ssh-keys', 'value': new_ssh_meta}) @@ -1042,12 +848,10 @@ def _create_project_ssh_key_pair(project, public_key, ssh_user, compute): common_instance_info['items'] = items - operation = ( - compute.projects() - .setCommonInstanceMetadata(project=project['name'], body=common_instance_info) - .execute() - ) + operation = (compute.projects().setCommonInstanceMetadata( + project=project['name'], body=common_instance_info).execute()) - response = wait_for_compute_global_operation(project['name'], operation, compute) + response = wait_for_compute_global_operation(project['name'], operation, + compute) return response From 8dd39e76643a3aa5f3f27407825e76a783ea09bb Mon Sep 17 00:00:00 2001 From: Siyuan Date: Fri, 6 Oct 2023 10:33:00 -0700 Subject: [PATCH 04/84] update --- sky/provision/gcp/config.py | 71 +++++++------------ sky/provision/gcp/constants.py | 124 +++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 44 deletions(-) create mode 100644 sky/provision/gcp/constants.py diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index caead01a28b..45179eb441f 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -16,12 +16,12 @@ from googleapiclient import errors from sky.provision import common -from sky.skylet.providers.gcp.constants import FIREWALL_RULES_REQUIRED -from sky.skylet.providers.gcp.constants import FIREWALL_RULES_TEMPLATE -from sky.skylet.providers.gcp.constants import SKYPILOT_VPC_NAME -from sky.skylet.providers.gcp.constants import TPU_MINIMAL_PERMISSIONS -from sky.skylet.providers.gcp.constants import VM_MINIMAL_PERMISSIONS -from sky.skylet.providers.gcp.constants import VPC_TEMPLATE +from sky.provision.gcp.constants import FIREWALL_RULES_REQUIRED +from sky.provision.gcp.constants import FIREWALL_RULES_TEMPLATE +from sky.provision.gcp.constants import SKYPILOT_VPC_NAME +from sky.provision.gcp.constants import TPU_MINIMAL_PERMISSIONS +from sky.provision.gcp.constants import VM_MINIMAL_PERMISSIONS +from sky.provision.gcp.constants import VPC_TEMPLATE from sky.skylet.providers.gcp.node import GCPCompute from sky.skylet.providers.gcp.node import GCPNodeType from sky.skylet.providers.gcp.node import MAX_POLLS @@ -165,16 +165,6 @@ def generate_rsa_key_pair(): return public_key, pem -def _is_head_node_a_tpu(config: dict) -> bool: - """Check if the head node is a TPU.""" - node_configs = { - node_id: node_type['node_config'] - for node_id, node_type in config['available_node_types'].items() - } - return get_node_type( - node_configs[config['head_node_type']]) == GCPNodeType.TPU - - def _create_crm(gcp_credentials=None): return discovery.build('cloudresourcemanager', 'v1', @@ -605,13 +595,13 @@ def _create_rules(config, compute, rules, VPC_NAME, PROJ_ID): compute) -def get_usable_vpc(config): +def get_usable_vpc(config: common.ProvisionConfig): """Return a usable VPC. If not found, create a new one with sufficient firewall rules. """ _, _, compute, _ = construct_clients_from_provider_config( - config['provider']) + config.provider_config) # For backward compatibility, reuse the VPC if the VM is launched. resource = GCPCompute( @@ -627,7 +617,7 @@ def get_usable_vpc(config): vpc_name = netInterfaces[0]['network'].split('/')[-1] return vpc_name - vpcnets_all = _list_vpcnets(config, compute) + vpcnets_all = _list_vpcnets(config.provider_config, compute) usable_vpc_name = None for vpc in vpcnets_all: @@ -640,7 +630,7 @@ def get_usable_vpc(config): logger.info(f'Creating a default VPC network, {SKYPILOT_VPC_NAME}...') # Create a SkyPilot VPC network if it doesn't exist - vpc_list = _list_vpcnets(config, + vpc_list = _list_vpcnets(config.provider_config, compute, filter=f'name={SKYPILOT_VPC_NAME}') if len(vpc_list) == 0: @@ -659,25 +649,19 @@ def get_usable_vpc(config): return usable_vpc_name -def _configure_subnet(config, compute): +def _configure_subnet(config: common.ProvisionConfig, compute): """Pick a reasonable subnet if not specified by the config.""" - config = copy.deepcopy(config) - - node_configs = [ - node_type['node_config'] - for node_type in config['available_node_types'].values() - ] + node_config = config.node_config # Rationale: avoid subnet lookup if the network is already # completely manually configured # networkInterfaces is compute, networkConfig is TPU - if all('networkInterfaces' in node_config or 'networkConfig' in node_config - for node_config in node_configs): + if 'networkInterfaces' in node_config or 'networkConfig' in node_config: return config # SkyPilot: make sure there's a usable VPC usable_vpc_name = get_usable_vpc(config) - subnets = _list_subnets(config, + subnets = _list_subnets(config.provider_config, compute, filter=f'(name="{usable_vpc_name}")') default_subnet = subnets[0] @@ -690,16 +674,15 @@ def _configure_subnet(config, compute): }], }] - for node_config in node_configs: - # The not applicable key will be removed during node creation + # The not applicable key will be removed during node creation - # compute - if 'networkInterfaces' not in node_config: - node_config['networkInterfaces'] = copy.deepcopy(default_interfaces) - # TPU - if 'networkConfig' not in node_config: - node_config['networkConfig'] = copy.deepcopy(default_interfaces)[0] - node_config['networkConfig'].pop('accessConfigs') + # compute + if 'networkInterfaces' not in node_config: + node_config['networkInterfaces'] = copy.deepcopy(default_interfaces) + # TPU + if 'networkConfig' not in node_config: + node_config['networkConfig'] = copy.deepcopy(default_interfaces)[0] + node_config['networkConfig'].pop('accessConfigs') return config @@ -734,19 +717,19 @@ def _create_vpcnet(config, compute, body): return response -def _list_vpcnets(config, compute, filter=None): +def _list_vpcnets(provider_config, compute, filter=None): response = (compute.networks().list( - project=config['provider']['project_id'], + project=provider_config['project_id'], filter=filter, ).execute()) return response['items'] if 'items' in response else [] -def _list_subnets(config, compute, filter=None): +def _list_subnets(provider_config, compute, filter=None): response = (compute.subnetworks().list( - project=config['provider']['project_id'], - region=config['provider']['region'], + project=provider_config['project_id'], + region=provider_config['region'], filter=filter, ).execute()) diff --git a/sky/provision/gcp/constants.py b/sky/provision/gcp/constants.py new file mode 100644 index 00000000000..bcf3b02ef70 --- /dev/null +++ b/sky/provision/gcp/constants.py @@ -0,0 +1,124 @@ +SKYPILOT_VPC_NAME = "skypilot-vpc" + +# Below parameters are from the default VPC on GCP. +# https://cloud.google.com/vpc/docs/firewalls#more_rules_default_vpc +VPC_TEMPLATE = { + "name": "{VPC_NAME}", + "selfLink": "projects/{PROJ_ID}/global/networks/{VPC_NAME}", + "autoCreateSubnetworks": True, + "mtu": 1460, + "routingConfig": {"routingMode": "GLOBAL"}, +} +# Required firewall rules for SkyPilot to work. +FIREWALL_RULES_REQUIRED = [ + # Allow internal connections between GCP VMs for Ray multi-node cluster. + { + "direction": "INGRESS", + "allowed": [ + {"IPProtocol": "tcp", "ports": ["0-65535"]}, + {"IPProtocol": "udp", "ports": ["0-65535"]}, + ], + "sourceRanges": ["10.128.0.0/9"], + }, + # Allow ssh connection from anywhere. + { + "direction": "INGRESS", + "allowed": [ + { + "IPProtocol": "tcp", + "ports": ["22"], + } + ], + "sourceRanges": ["0.0.0.0/0"], + }, +] +# Template when creating firewall rules for a new VPC. +FIREWALL_RULES_TEMPLATE = [ + { + "name": "{VPC_NAME}-allow-custom", + "description": "Allows connection from any source to any instance on the network using custom protocols.", + "network": "projects/{PROJ_ID}/global/networks/{VPC_NAME}", + "selfLink": "projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-custom", + "direction": "INGRESS", + "priority": 65534, + "allowed": [ + {"IPProtocol": "tcp", "ports": ["0-65535"]}, + {"IPProtocol": "udp", "ports": ["0-65535"]}, + {"IPProtocol": "icmp"}, + ], + "sourceRanges": ["10.128.0.0/9"], + }, + { + "name": "{VPC_NAME}-allow-ssh", + "description": "Allows TCP connections from any source to any instance on the network using port 22.", + "network": "projects/{PROJ_ID}/global/networks/{VPC_NAME}", + "selfLink": "projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-ssh", + "direction": "INGRESS", + "priority": 65534, + "allowed": [ + { + "IPProtocol": "tcp", + "ports": ["22"], + } + ], + "sourceRanges": ["0.0.0.0/0"], + }, + { + "name": "{VPC_NAME}-allow-icmp", + "description": "Allows ICMP connections from any source to any instance on the network.", + "network": "projects/{PROJ_ID}/global/networks/{VPC_NAME}", + "selfLink": "projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-icmp", + "direction": "INGRESS", + "priority": 65534, + "allowed": [ + { + "IPProtocol": "icmp", + } + ], + "sourceRanges": ["0.0.0.0/0"], + }, +] + +# A list of permissions required to run SkyPilot on GCP. +# Keep this in sync with https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions.html#gcp # pylint: disable=line-too-long +VM_MINIMAL_PERMISSIONS = [ + "compute.disks.create", + "compute.disks.list", + "compute.firewalls.create", + "compute.firewalls.delete", + "compute.firewalls.get", + "compute.instances.create", + "compute.instances.delete", + "compute.instances.get", + "compute.instances.list", + "compute.instances.setLabels", + "compute.instances.setServiceAccount", + "compute.instances.start", + "compute.instances.stop", + "compute.networks.get", + "compute.networks.list", + "compute.networks.getEffectiveFirewalls", + "compute.globalOperations.get", + "compute.subnetworks.use", + "compute.subnetworks.list", + "compute.subnetworks.useExternalIp", + "compute.projects.get", + "compute.zoneOperations.get", + "iam.roles.get", + "iam.serviceAccounts.actAs", + "iam.serviceAccounts.get", + "serviceusage.services.enable", + "serviceusage.services.list", + "serviceusage.services.use", + "resourcemanager.projects.get", + "resourcemanager.projects.getIamPolicy", +] + +TPU_MINIMAL_PERMISSIONS = [ + "tpu.nodes.create", + "tpu.nodes.delete", + "tpu.nodes.list", + "tpu.nodes.get", + "tpu.nodes.update", + "tpu.operations.get", +] From 45ef17aca27df2a6c855a9d97a16f36cc4b4ec46 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Fri, 6 Oct 2023 10:59:52 -0700 Subject: [PATCH 05/84] update --- sky/provision/gcp/config.py | 142 +++++++++--------------- sky/provision/gcp/constants.py | 196 ++++++++++++++++++--------------- 2 files changed, 155 insertions(+), 183 deletions(-) diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index 45179eb441f..c1aaf311744 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -1,15 +1,10 @@ """GCP config module.""" import copy -from functools import partial import json import logging -import os import time from typing import Dict, List, Set, Tuple -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import rsa from google.oauth2 import service_account from google.oauth2.credentials import Credentials as OAuthCredentials from googleapiclient import discovery @@ -18,14 +13,14 @@ from sky.provision import common from sky.provision.gcp.constants import FIREWALL_RULES_REQUIRED from sky.provision.gcp.constants import FIREWALL_RULES_TEMPLATE +from sky.provision.gcp.constants import MAX_POLLS +from sky.provision.gcp.constants import POLL_INTERVAL from sky.provision.gcp.constants import SKYPILOT_VPC_NAME from sky.provision.gcp.constants import TPU_MINIMAL_PERMISSIONS from sky.provision.gcp.constants import VM_MINIMAL_PERMISSIONS from sky.provision.gcp.constants import VPC_TEMPLATE from sky.skylet.providers.gcp.node import GCPCompute from sky.skylet.providers.gcp.node import GCPNodeType -from sky.skylet.providers.gcp.node import MAX_POLLS -from sky.skylet.providers.gcp.node import POLL_INTERVAL logger = logging.getLogger(__name__) @@ -133,38 +128,6 @@ def wait_for_compute_global_operation(project_name, operation, compute): return result -def key_pair_name(i, region, project_id, ssh_user): - """Returns the ith default gcp_key_pair_name.""" - return f'{SKYPILOT}_gcp_{region}_{project_id}_{ssh_user}_{i}' - - -def key_pair_paths(key_name): - """Returns public and private key paths for a given key_name.""" - public_key_path = os.path.expanduser(f'~/.ssh/{key_name}.pub') - private_key_path = os.path.expanduser(f'~/.ssh/{key_name}.pem') - return public_key_path, private_key_path - - -def generate_rsa_key_pair(): - """Create public and private ssh-keys.""" - - key = rsa.generate_private_key(backend=default_backend(), - public_exponent=65537, - key_size=2048) - - public_key = (key.public_key().public_bytes( - serialization.Encoding.OpenSSH, - serialization.PublicFormat.OpenSSH).decode('utf-8')) - - pem = key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ).decode('utf-8') - - return public_key, pem - - def _create_crm(gcp_credentials=None): return discovery.build('cloudresourcemanager', 'v1', @@ -267,7 +230,7 @@ def bootstrap_gcp(region: str, cluster_name: str, _configure_project(config.provider_config, crm) config, iam_role = _configure_iam_role(config, crm, iam) config.provider_config['iam_role'] = iam_role - config = _configure_subnet(config, compute) + config = _configure_subnet(region, cluster_name, config, compute) return config @@ -457,12 +420,13 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam): return config, iam_role -def _check_firewall_rules(vpc_name, config, compute): +def _check_firewall_rules(cluster_name: str, vpc_name: str, project_id: str, + compute): """Check if the firewall rules in the VPC are sufficient.""" required_rules = FIREWALL_RULES_REQUIRED.copy() - operation = compute.networks().getEffectiveFirewalls( - project=config['provider']['project_id'], network=vpc_name) + operation = compute.networks().getEffectiveFirewalls(project=project_id, + network=vpc_name) response = operation.execute() if len(response) == 0: return False @@ -514,7 +478,7 @@ def _merge_and_refine_rule(rules): if tags is not None: if len(tags) != 1: continue - if tags[0] != config['cluster_name']: + if tags[0] != cluster_name: continue direction = rule.get('direction', '') sources = rule.get('sourceRanges', []) @@ -570,45 +534,45 @@ def _merge_and_refine_rule(rules): return True -def _create_rules(config, compute, rules, VPC_NAME, PROJ_ID): +def _create_rules(project_id: str, compute, rules, VPC_NAME): opertaions = [] for rule in rules: # Query firewall rule by its name (unique in a project). # If the rule already exists, delete it first. rule_name = rule['name'].format(VPC_NAME=VPC_NAME) - rule_list = _list_firewall_rules(config, + rule_list = _list_firewall_rules(project_id, compute, filter=f'(name={rule_name})') if len(rule_list) > 0: - _delete_firewall_rule(config, compute, rule_name) + _delete_firewall_rule(project_id, compute, rule_name) body = rule.copy() body['name'] = body['name'].format(VPC_NAME=VPC_NAME) - body['network'] = body['network'].format(PROJ_ID=PROJ_ID, + body['network'] = body['network'].format(PROJ_ID=project_id, VPC_NAME=VPC_NAME) - body['selfLink'] = body['selfLink'].format(PROJ_ID=PROJ_ID, + body['selfLink'] = body['selfLink'].format(PROJ_ID=project_id, VPC_NAME=VPC_NAME) - op = _create_firewall_rule_submit(config, compute, body) + op = compute.firewalls().insert(project=project_id, body=body).execute() opertaions.append(op) for op in opertaions: - wait_for_compute_global_operation(config['provider']['project_id'], op, - compute) + wait_for_compute_global_operation(project_id, op, compute) -def get_usable_vpc(config: common.ProvisionConfig): +def get_usable_vpc(cluster_name: str, config: common.ProvisionConfig): """Return a usable VPC. If not found, create a new one with sufficient firewall rules. """ + project_id = config.provider_config['project_id'] _, _, compute, _ = construct_clients_from_provider_config( config.provider_config) # For backward compatibility, reuse the VPC if the VM is launched. resource = GCPCompute( compute, - config['provider']['project_id'], - config['provider']['availability_zone'], - config['cluster_name'], + project_id, + config.provider_config['availability_zone'], + cluster_name, ) node = resource._list_instances(label_filters=None, status_filter=None) if len(node) > 0: @@ -617,31 +581,31 @@ def get_usable_vpc(config: common.ProvisionConfig): vpc_name = netInterfaces[0]['network'].split('/')[-1] return vpc_name - vpcnets_all = _list_vpcnets(config.provider_config, compute) + vpcnets_all = _list_vpcnets(project_id, compute) usable_vpc_name = None for vpc in vpcnets_all: - if _check_firewall_rules(vpc['name'], config, compute): + if _check_firewall_rules(cluster_name, vpc['name'], project_id, + compute): usable_vpc_name = vpc['name'] break - proj_id = config['provider']['project_id'] if usable_vpc_name is None: logger.info(f'Creating a default VPC network, {SKYPILOT_VPC_NAME}...') # Create a SkyPilot VPC network if it doesn't exist - vpc_list = _list_vpcnets(config.provider_config, + vpc_list = _list_vpcnets(project_id, compute, filter=f'name={SKYPILOT_VPC_NAME}') if len(vpc_list) == 0: body = VPC_TEMPLATE.copy() body['name'] = body['name'].format(VPC_NAME=SKYPILOT_VPC_NAME) body['selfLink'] = body['selfLink'].format( - PROJ_ID=proj_id, VPC_NAME=SKYPILOT_VPC_NAME) - _create_vpcnet(config, compute, body) + PROJ_ID=project_id, VPC_NAME=SKYPILOT_VPC_NAME) + _create_vpcnet(project_id, compute, body) - _create_rules(config, compute, FIREWALL_RULES_TEMPLATE, - SKYPILOT_VPC_NAME, proj_id) + _create_rules(project_id, compute, FIREWALL_RULES_TEMPLATE, + SKYPILOT_VPC_NAME) usable_vpc_name = SKYPILOT_VPC_NAME logger.info(f'A VPC network {SKYPILOT_VPC_NAME} created.') @@ -649,7 +613,8 @@ def get_usable_vpc(config: common.ProvisionConfig): return usable_vpc_name -def _configure_subnet(config: common.ProvisionConfig, compute): +def _configure_subnet(region: str, cluster_name: str, + config: common.ProvisionConfig, compute): """Pick a reasonable subnet if not specified by the config.""" node_config = config.node_config # Rationale: avoid subnet lookup if the network is already @@ -660,8 +625,9 @@ def _configure_subnet(config: common.ProvisionConfig, compute): return config # SkyPilot: make sure there's a usable VPC - usable_vpc_name = get_usable_vpc(config) - subnets = _list_subnets(config.provider_config, + usable_vpc_name = get_usable_vpc(cluster_name, config) + subnets = _list_subnets(config.provider_config['project_id'], + region, compute, filter=f'(name="{usable_vpc_name}")') default_subnet = subnets[0] @@ -687,56 +653,48 @@ def _configure_subnet(config: common.ProvisionConfig, compute): return config -def _create_firewall_rule_submit(config, compute, body): - operation = (compute.firewalls().insert( - project=config['provider']['project_id'], body=body).execute()) - return operation - - -def _delete_firewall_rule(config, compute, name): - operation = (compute.firewalls().delete( - project=config['provider']['project_id'], firewall=name).execute()) - response = wait_for_compute_global_operation( - config['provider']['project_id'], operation, compute) +def _delete_firewall_rule(project_id: str, compute, name): + operation = (compute.firewalls().delete(project=project_id, + firewall=name).execute()) + response = wait_for_compute_global_operation(project_id, operation, compute) return response -def _list_firewall_rules(config, compute, filter=None): +def _list_firewall_rules(project_id, compute, filter=None): response = (compute.firewalls().list( - project=config['provider']['project_id'], + project=project_id, filter=filter, ).execute()) return response['items'] if 'items' in response else [] -def _create_vpcnet(config, compute, body): - operation = (compute.networks().insert( - project=config['provider']['project_id'], body=body).execute()) - response = wait_for_compute_global_operation( - config['provider']['project_id'], operation, compute) +def _create_vpcnet(project_id: str, compute, body): + operation = (compute.networks().insert(project=project_id, + body=body).execute()) + response = wait_for_compute_global_operation(project_id, operation, compute) return response -def _list_vpcnets(provider_config, compute, filter=None): +def _list_vpcnets(project_id: str, compute, filter=None): response = (compute.networks().list( - project=provider_config['project_id'], + project=project_id, filter=filter, ).execute()) return response['items'] if 'items' in response else [] -def _list_subnets(provider_config, compute, filter=None): +def _list_subnets(project_id: str, region: str, compute, filter=None): response = (compute.subnetworks().list( - project=provider_config['project_id'], - region=provider_config['region'], + project=project_id, + region=region, filter=filter, ).execute()) return response['items'] if 'items' in response else [] -def _get_project(project_id, crm): +def _get_project(project_id: str, crm): try: project = crm.projects().get(projectId=project_id).execute() except errors.HttpError as e: @@ -747,7 +705,7 @@ def _get_project(project_id, crm): return project -def _create_project(project_id, crm): +def _create_project(project_id: str, crm): operation = (crm.projects().create(body={ 'projectId': project_id, 'name': project_id diff --git a/sky/provision/gcp/constants.py b/sky/provision/gcp/constants.py index bcf3b02ef70..4f70839d407 100644 --- a/sky/provision/gcp/constants.py +++ b/sky/provision/gcp/constants.py @@ -1,124 +1,138 @@ -SKYPILOT_VPC_NAME = "skypilot-vpc" +SKYPILOT_VPC_NAME = 'skypilot-vpc' # Below parameters are from the default VPC on GCP. # https://cloud.google.com/vpc/docs/firewalls#more_rules_default_vpc VPC_TEMPLATE = { - "name": "{VPC_NAME}", - "selfLink": "projects/{PROJ_ID}/global/networks/{VPC_NAME}", - "autoCreateSubnetworks": True, - "mtu": 1460, - "routingConfig": {"routingMode": "GLOBAL"}, + 'name': '{VPC_NAME}', + 'selfLink': 'projects/{PROJ_ID}/global/networks/{VPC_NAME}', + 'autoCreateSubnetworks': True, + 'mtu': 1460, + 'routingConfig': { + 'routingMode': 'GLOBAL' + }, } # Required firewall rules for SkyPilot to work. FIREWALL_RULES_REQUIRED = [ # Allow internal connections between GCP VMs for Ray multi-node cluster. { - "direction": "INGRESS", - "allowed": [ - {"IPProtocol": "tcp", "ports": ["0-65535"]}, - {"IPProtocol": "udp", "ports": ["0-65535"]}, + 'direction': 'INGRESS', + 'allowed': [ + { + 'IPProtocol': 'tcp', + 'ports': ['0-65535'] + }, + { + 'IPProtocol': 'udp', + 'ports': ['0-65535'] + }, ], - "sourceRanges": ["10.128.0.0/9"], + 'sourceRanges': ['10.128.0.0/9'], }, # Allow ssh connection from anywhere. { - "direction": "INGRESS", - "allowed": [ - { - "IPProtocol": "tcp", - "ports": ["22"], - } - ], - "sourceRanges": ["0.0.0.0/0"], + 'direction': 'INGRESS', + 'allowed': [{ + 'IPProtocol': 'tcp', + 'ports': ['22'], + }], + 'sourceRanges': ['0.0.0.0/0'], }, ] # Template when creating firewall rules for a new VPC. FIREWALL_RULES_TEMPLATE = [ { - "name": "{VPC_NAME}-allow-custom", - "description": "Allows connection from any source to any instance on the network using custom protocols.", - "network": "projects/{PROJ_ID}/global/networks/{VPC_NAME}", - "selfLink": "projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-custom", - "direction": "INGRESS", - "priority": 65534, - "allowed": [ - {"IPProtocol": "tcp", "ports": ["0-65535"]}, - {"IPProtocol": "udp", "ports": ["0-65535"]}, - {"IPProtocol": "icmp"}, + 'name': '{VPC_NAME}-allow-custom', + 'description': 'Allows connection from any source to any instance on the network using custom protocols.', + 'network': 'projects/{PROJ_ID}/global/networks/{VPC_NAME}', + 'selfLink': 'projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-custom', + 'direction': 'INGRESS', + 'priority': 65534, + 'allowed': [ + { + 'IPProtocol': 'tcp', + 'ports': ['0-65535'] + }, + { + 'IPProtocol': 'udp', + 'ports': ['0-65535'] + }, + { + 'IPProtocol': 'icmp' + }, ], - "sourceRanges": ["10.128.0.0/9"], + 'sourceRanges': ['10.128.0.0/9'], }, { - "name": "{VPC_NAME}-allow-ssh", - "description": "Allows TCP connections from any source to any instance on the network using port 22.", - "network": "projects/{PROJ_ID}/global/networks/{VPC_NAME}", - "selfLink": "projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-ssh", - "direction": "INGRESS", - "priority": 65534, - "allowed": [ - { - "IPProtocol": "tcp", - "ports": ["22"], - } - ], - "sourceRanges": ["0.0.0.0/0"], + 'name': '{VPC_NAME}-allow-ssh', + 'description': 'Allows TCP connections from any source to any instance on the network using port 22.', + 'network': 'projects/{PROJ_ID}/global/networks/{VPC_NAME}', + 'selfLink': 'projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-ssh', + 'direction': 'INGRESS', + 'priority': 65534, + 'allowed': [{ + 'IPProtocol': 'tcp', + 'ports': ['22'], + }], + 'sourceRanges': ['0.0.0.0/0'], }, { - "name": "{VPC_NAME}-allow-icmp", - "description": "Allows ICMP connections from any source to any instance on the network.", - "network": "projects/{PROJ_ID}/global/networks/{VPC_NAME}", - "selfLink": "projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-icmp", - "direction": "INGRESS", - "priority": 65534, - "allowed": [ - { - "IPProtocol": "icmp", - } - ], - "sourceRanges": ["0.0.0.0/0"], + 'name': '{VPC_NAME}-allow-icmp', + 'description': 'Allows ICMP connections from any source to any instance on the network.', + 'network': 'projects/{PROJ_ID}/global/networks/{VPC_NAME}', + 'selfLink': 'projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-icmp', + 'direction': 'INGRESS', + 'priority': 65534, + 'allowed': [{ + 'IPProtocol': 'icmp', + }], + 'sourceRanges': ['0.0.0.0/0'], }, ] # A list of permissions required to run SkyPilot on GCP. # Keep this in sync with https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions.html#gcp # pylint: disable=line-too-long VM_MINIMAL_PERMISSIONS = [ - "compute.disks.create", - "compute.disks.list", - "compute.firewalls.create", - "compute.firewalls.delete", - "compute.firewalls.get", - "compute.instances.create", - "compute.instances.delete", - "compute.instances.get", - "compute.instances.list", - "compute.instances.setLabels", - "compute.instances.setServiceAccount", - "compute.instances.start", - "compute.instances.stop", - "compute.networks.get", - "compute.networks.list", - "compute.networks.getEffectiveFirewalls", - "compute.globalOperations.get", - "compute.subnetworks.use", - "compute.subnetworks.list", - "compute.subnetworks.useExternalIp", - "compute.projects.get", - "compute.zoneOperations.get", - "iam.roles.get", - "iam.serviceAccounts.actAs", - "iam.serviceAccounts.get", - "serviceusage.services.enable", - "serviceusage.services.list", - "serviceusage.services.use", - "resourcemanager.projects.get", - "resourcemanager.projects.getIamPolicy", + 'compute.disks.create', + 'compute.disks.list', + 'compute.firewalls.create', + 'compute.firewalls.delete', + 'compute.firewalls.get', + 'compute.instances.create', + 'compute.instances.delete', + 'compute.instances.get', + 'compute.instances.list', + 'compute.instances.setLabels', + 'compute.instances.setServiceAccount', + 'compute.instances.start', + 'compute.instances.stop', + 'compute.networks.get', + 'compute.networks.list', + 'compute.networks.getEffectiveFirewalls', + 'compute.globalOperations.get', + 'compute.subnetworks.use', + 'compute.subnetworks.list', + 'compute.subnetworks.useExternalIp', + 'compute.projects.get', + 'compute.zoneOperations.get', + 'iam.roles.get', + 'iam.serviceAccounts.actAs', + 'iam.serviceAccounts.get', + 'serviceusage.services.enable', + 'serviceusage.services.list', + 'serviceusage.services.use', + 'resourcemanager.projects.get', + 'resourcemanager.projects.getIamPolicy', ] TPU_MINIMAL_PERMISSIONS = [ - "tpu.nodes.create", - "tpu.nodes.delete", - "tpu.nodes.list", - "tpu.nodes.get", - "tpu.nodes.update", - "tpu.operations.get", + 'tpu.nodes.create', + 'tpu.nodes.delete', + 'tpu.nodes.list', + 'tpu.nodes.get', + 'tpu.nodes.update', + 'tpu.operations.get', ] + +# The maximum number of times to poll for the status of an operation. +MAX_POLLS = 12 +POLL_INTERVAL = 5 From 5d07647a28206ee1ebbff6456794ad3d53343e76 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Fri, 6 Oct 2023 11:07:25 -0700 Subject: [PATCH 06/84] update --- sky/provision/gcp/config.py | 60 +++++++------------------------------ 1 file changed, 11 insertions(+), 49 deletions(-) diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index c1aaf311744..e9cc12c537e 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -228,7 +228,7 @@ def bootstrap_gcp(region: str, cluster_name: str, # aws ec2 where everything is global. _configure_project(config.provider_config, crm) - config, iam_role = _configure_iam_role(config, crm, iam) + iam_role = _configure_iam_role(config, crm, iam) config.provider_config['iam_role'] = iam_role config = _configure_subnet(region, cluster_name, config, compute) @@ -339,11 +339,12 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam): TODO: Allow the name/id of the service account to be configured """ + project_id = config.provider_config['project_id'] email = SKYPILOT_SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( account_id=SKYPILOT_SERVICE_ACCOUNT_ID, - project_id=config.provider_config['project_id'], + project_id=project_id, ) - service_account = _get_service_account(email, config.provider_config, iam) + service_account = _get_service_account(email, project_id, iam) permissions = VM_MINIMAL_PERMISSIONS roles = DEFAULT_SERVICE_ACCOUNT_ROLES @@ -363,12 +364,11 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam): # account is still usable. email = SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( account_id=DEFAULT_SERVICE_ACCOUNT_ID, - project_id=config.provider_config['project_id'], + project_id=project_id, ) logger.info(f'_configure_iam_role: Fallback to service account {email}') - ray_service_account = _get_service_account(email, - config.provider_config, iam) + ray_service_account = _get_service_account(email, project_id, iam) ray_satisfied, _ = _is_permission_satisfied(ray_service_account, crm, iam, permissions, roles) logger.info( @@ -387,7 +387,7 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam): service_account = _create_service_account( SKYPILOT_SERVICE_ACCOUNT_ID, SKYPILOT_SERVICE_ACCOUNT_CONFIG, - config.provider_config, + project_id, iam, ) satisfied, policy = _is_permission_satisfied( @@ -417,7 +417,7 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam): else: iam_role = {'serviceAccounts': [account_dict]} - return config, iam_role + return iam_role def _check_firewall_rules(cluster_name: str, vpc_name: str, project_id: str, @@ -716,8 +716,7 @@ def _create_project(project_id: str, crm): return result -def _get_service_account(account, provider_config, iam): - project_id = provider_config['project_id'] +def _get_service_account(account: str, project_id: str, iam): full_name = 'projects/{project_id}/serviceAccounts/{account}'.format( project_id=project_id, account=account) try: @@ -734,9 +733,8 @@ def _get_service_account(account, provider_config, iam): return service_account -def _create_service_account(account_id, account_config, provider_config, iam): - project_id = provider_config['project_id'] - +def _create_service_account(account_id: str, account_config, project_id: str, + iam): service_account = (iam.projects().serviceAccounts().create( name='projects/{project_id}'.format(project_id=project_id), body={ @@ -760,39 +758,3 @@ def _add_iam_policy_binding(service_account, policy, crm, iam): ).execute()) return result - - -def _create_project_ssh_key_pair(project, public_key, ssh_user, compute): - """Inserts an ssh-key into project commonInstanceMetadata""" - - key_parts = public_key.split(' ') - - # Sanity checks to make sure that the generated key matches expectation - assert len(key_parts) == 2, key_parts - assert key_parts[0] == 'ssh-rsa', key_parts - - new_ssh_meta = '{ssh_user}:ssh-rsa {key_value} {ssh_user}'.format( - ssh_user=ssh_user, key_value=key_parts[1]) - - common_instance_info = project['commonInstanceMetadata'] - items = common_instance_info.get('items', []) - - ssh_keys_i = next( - (i for i, item in enumerate(items) if item['key'] == 'ssh-keys'), None) - - if ssh_keys_i is None: - items.append({'key': 'ssh-keys', 'value': new_ssh_meta}) - else: - ssh_keys = items[ssh_keys_i] - ssh_keys['value'] += '\n' + new_ssh_meta - items[ssh_keys_i] = ssh_keys - - common_instance_info['items'] = items - - operation = (compute.projects().setCommonInstanceMetadata( - project=project['name'], body=common_instance_info).execute()) - - response = wait_for_compute_global_operation(project['name'], operation, - compute) - - return response From b45c09f40c8aabf22e398251eae90bafbc601560 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Fri, 6 Oct 2023 15:45:00 -0700 Subject: [PATCH 07/84] complete bootstrapping --- sky/provision/gcp/config.py | 34 +++++++++---------- sky/provision/gcp/instance.py | 9 ++--- sky/provision/gcp/instance_utils.py | 52 ++++++++++++++++++++++------- 3 files changed, 62 insertions(+), 33 deletions(-) diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index e9cc12c537e..7357048408c 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -1,4 +1,4 @@ -"""GCP config module.""" +"""GCP configuration bootstrapping.""" import copy import json import logging @@ -11,6 +11,7 @@ from googleapiclient import errors from sky.provision import common +from sky.provision.gcp import instance_utils from sky.provision.gcp.constants import FIREWALL_RULES_REQUIRED from sky.provision.gcp.constants import FIREWALL_RULES_TEMPLATE from sky.provision.gcp.constants import MAX_POLLS @@ -19,8 +20,6 @@ from sky.provision.gcp.constants import TPU_MINIMAL_PERMISSIONS from sky.provision.gcp.constants import VM_MINIMAL_PERMISSIONS from sky.provision.gcp.constants import VPC_TEMPLATE -from sky.skylet.providers.gcp.node import GCPCompute -from sky.skylet.providers.gcp.node import GCPNodeType logger = logging.getLogger(__name__) @@ -62,7 +61,7 @@ # with ServiceAccounts. -def get_node_type(node: dict) -> GCPNodeType: +def get_node_type(node: dict) -> instance_utils.GCPNodeType: """Returns node type based on the keys in ``node``. This is a very simple check. If we have a ``machineType`` key, @@ -82,8 +81,8 @@ def get_node_type(node: dict) -> GCPNodeType: f'Got {list(node)}') if 'machineType' not in node and 'acceleratorType' in node: - return GCPNodeType.TPU - return GCPNodeType.COMPUTE + return instance_utils.GCPNodeType.TPU + return instance_utils.GCPNodeType.COMPUTE def wait_for_crm_operation(operation, crm): @@ -211,11 +210,12 @@ def construct_clients_from_provider_config(provider_config): ) -def bootstrap_gcp(region: str, cluster_name: str, - config: common.ProvisionConfig): +def bootstrap_instances( + region: str, cluster_name: str, + config: common.ProvisionConfig) -> common.ProvisionConfig: # Check if we have any TPUs defined, and if so, # insert that information into the provider config - if get_node_type(config.node_config) == GCPNodeType.TPU: + if get_node_type(config.node_config) == instance_utils.GCPNodeType.TPU: config.provider_config[HAS_TPU_PROVIDER_FIELD] = True crm, iam, compute, tpu = construct_clients_from_provider_config( @@ -408,7 +408,7 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam): # account is limited by the IAM rights specified below. 'scopes': ['https://www.googleapis.com/auth/cloud-platform'], } - if get_node_type(config.node_config) == GCPNodeType.TPU: + if get_node_type(config.node_config) == instance_utils.GCPNodeType.TPU: # SKY: The API for TPU VM is slightly different from normal compute instances. # See https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#Node account_dict['scope'] = account_dict['scopes'] @@ -568,15 +568,15 @@ def get_usable_vpc(cluster_name: str, config: common.ProvisionConfig): config.provider_config) # For backward compatibility, reuse the VPC if the VM is launched. - resource = GCPCompute( - compute, + instance_dict = instance_utils.GCPComputeInstance.filter( project_id, config.provider_config['availability_zone'], - cluster_name, - ) - node = resource._list_instances(label_filters=None, status_filter=None) - if len(node) > 0: - netInterfaces = node[0].get('networkInterfaces', []) + label_filters=None, + status_filters=None) + + if instance_dict: + instance_metadata = list(instance_dict.values())[0] + netInterfaces = instance_metadata.get('networkInterfaces', []) if len(netInterfaces) > 0: vpc_name = netInterfaces[0]['network'].split('/')[-1] return vpc_name diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 99b80ab5074..e967c39c862 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -37,10 +37,11 @@ def _filter_instances( instances = set() logger.debug(f'handlers: {handlers}') for instance_handler in handlers: - instances |= set( - instance_handler.filter(project_id, zone, label_filters, - status_filters_fn(instance_handler), - included_instances, excluded_instances)) + instance_dict = instance_handler.filter( + project_id, zone, label_filters, + status_filters_fn(instance_handler), included_instances, + excluded_instances) + instances |= set(instance_dict.keys()) handler_to_instances = collections.defaultdict(list) for instance in instances: handler = instance_utils.instance_to_handler(instance) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index cf46bef52fa..a1644df5dcf 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -1,6 +1,7 @@ """Utilities for GCP instances.""" +import enum import re -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from sky import sky_logging from sky.adaptors import gcp @@ -73,7 +74,7 @@ def filter( status_filters: Optional[List[str]], included_instances: Optional[List[str]] = None, excluded_instances: Optional[List[str]] = None, - ) -> List[str]: + ) -> Dict[str, Any]: raise NotImplementedError @classmethod @@ -174,7 +175,7 @@ def filter( status_filters: Optional[List[str]], included_instances: Optional[List[str]] = None, excluded_instances: Optional[List[str]] = None, - ) -> List[str]: + ) -> Dict[str, Any]: if label_filters: label_filter_expr = ('(' + ' AND '.join([ '(labels.{key} = {value})'.format(key=key, value=value) @@ -206,11 +207,17 @@ def filter( zone=zone, ).execute()) instances = response.get('items', []) - instances = [i['name'] for i in instances] + instances = {i['name']: i for i in instances} if included_instances: - instances = [i for i in instances if i in included_instances] + instances = { + k: v for k, v in instances.items() if k in included_instances + } if excluded_instances: - instances = [i for i in instances if i not in excluded_instances] + instances = { + k: v + for k, v in instances.items() + if k not in excluded_instances + } return instances @classmethod @@ -403,7 +410,7 @@ def filter( status_filters: Optional[List[str]], included_instances: Optional[List[str]] = None, excluded_instances: Optional[List[str]] = None, - ) -> List[str]: + ) -> Dict[str, Any]: path = f'projects/{project_id}/locations/{zone}' try: response = (cls.load_resource().projects().locations().nodes().list( @@ -413,7 +420,7 @@ def filter( # Return empty list instead of raising exception to not break # ray down. logger.warning(f'googleapiclient.errors.HttpError: {e.reason}') - return [] + return {} instances = response.get('nodes', []) @@ -439,13 +446,18 @@ def filter_instance(instance) -> bool: return True instances = list(filter(filter_instance, instances)) - instances = [i['name'] for i in instances] + instances = {i['name']: i for i in instances} if included_instances: - instances = [i for i in instances if i in included_instances] + instances = { + k: v for k, v in instances.items() if k in included_instances + } if excluded_instances: - instances = [i for i in instances if i not in excluded_instances] - + instances = { + k: v + for k, v in instances.items() + if k not in excluded_instances + } return instances @classmethod @@ -512,3 +524,19 @@ def get_vpc_name( with ux_utils.print_exception_no_traceback(): raise ValueError( f'Failed to get VPC name for instance {instance}') from e + + +class GCPNodeType(enum.Enum): + """Enum for GCP node types (compute & tpu)""" + + COMPUTE = "compute" + TPU = "tpu" + + @staticmethod + def name_to_type(name: str): + """Provided a node name, determine the type. + + This expects the name to be in format '[NAME]-[UUID]-[TYPE]', + where [TYPE] is either 'compute' or 'tpu'. + """ + return GCPNodeType(name.split("-")[-1]) From ee7a924615968b459003d707791160488d574bac Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 8 Oct 2023 14:41:38 -0700 Subject: [PATCH 08/84] add start instance --- sky/provision/gcp/__init__.py | 1 + sky/provision/gcp/config.py | 30 +--- sky/provision/gcp/instance.py | 144 ++++++++++++++++++ sky/provision/gcp/instance_utils.py | 224 +++++++++++++++++++++++++++- 4 files changed, 372 insertions(+), 27 deletions(-) diff --git a/sky/provision/gcp/__init__.py b/sky/provision/gcp/__init__.py index 9faaab37088..4c83a198568 100644 --- a/sky/provision/gcp/__init__.py +++ b/sky/provision/gcp/__init__.py @@ -1,5 +1,6 @@ """GCP provisioner for SkyPilot.""" +from sky.provision.gcp.config import bootstrap_instances from sky.provision.gcp.instance import cleanup_ports from sky.provision.gcp.instance import open_ports from sky.provision.gcp.instance import stop_instances diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index 7357048408c..caeb93b8b8c 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -61,30 +61,6 @@ # with ServiceAccounts. -def get_node_type(node: dict) -> instance_utils.GCPNodeType: - """Returns node type based on the keys in ``node``. - - This is a very simple check. If we have a ``machineType`` key, - this is a Compute instance. If we don't have a ``machineType`` key, - but we have ``acceleratorType``, this is a TPU. Otherwise, it's - invalid and an exception is raised. - - This works for both node configs and API returned nodes. - """ - - if 'machineType' not in node and 'acceleratorType' not in node: - raise ValueError( - 'Invalid node. For a Compute instance, "machineType" is ' - 'required. ' - 'For a TPU instance, "acceleratorType" and no "machineType" ' - 'is required. ' - f'Got {list(node)}') - - if 'machineType' not in node and 'acceleratorType' in node: - return instance_utils.GCPNodeType.TPU - return instance_utils.GCPNodeType.COMPUTE - - def wait_for_crm_operation(operation, crm): """Poll for cloud resource manager operation until finished.""" logger.info('wait_for_crm_operation: ' @@ -215,7 +191,8 @@ def bootstrap_instances( config: common.ProvisionConfig) -> common.ProvisionConfig: # Check if we have any TPUs defined, and if so, # insert that information into the provider config - if get_node_type(config.node_config) == instance_utils.GCPNodeType.TPU: + if instance_utils.get_node_type( + config.node_config) == instance_utils.GCPNodeType.TPU: config.provider_config[HAS_TPU_PROVIDER_FIELD] = True crm, iam, compute, tpu = construct_clients_from_provider_config( @@ -408,7 +385,8 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam): # account is limited by the IAM rights specified below. 'scopes': ['https://www.googleapis.com/auth/cloud-platform'], } - if get_node_type(config.node_config) == instance_utils.GCPNodeType.TPU: + if instance_utils.get_node_type( + config.node_config) == instance_utils.GCPNodeType.TPU: # SKY: The API for TPU VM is slightly different from normal compute instances. # See https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#Node account_dict['scope'] = account_dict['scopes'] diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index e967c39c862..575d3a8a5ec 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -1,13 +1,29 @@ """GCP instance provisioning.""" import collections +import copy import re import time from typing import Any, Callable, Dict, Iterable, List, Optional, Type from sky import sky_logging +from sky import status_lib from sky.adaptors import gcp +from sky.provision import common from sky.provision.gcp import instance_utils +# Tag for user defined node types (e.g., m4xl_spot). This is used for multi +# node type clusters. +TAG_RAY_USER_NODE_TYPE = "ray-user-node-type" +# Hash of the node launch config, used to identify out-of-date nodes +TAG_RAY_LAUNCH_CONFIG = "ray-launch-config" +# Tag for autofilled node types for legacy cluster yamls without multi +# node type defined in the cluster configs. +NODE_TYPE_LEGACY_HEAD = "ray-legacy-head-node-type" +NODE_TYPE_LEGACY_WORKER = "ray-legacy-worker-node-type" + +# Tag that reports the current state of the node (e.g. Updating, Up-to-date) +TAG_RAY_NODE_STATUS = "ray-node-status" + logger = sky_logging.init_logger(__name__) MAX_POLLS = 12 @@ -73,6 +89,134 @@ def _wait_for_operations( total_polls += 1 +def run_instances(region: str, cluster_name: str, + config: common.ProvisionConfig) -> common.ProvisionRecord: + """See sky/provision/__init__.py""" + result_dict = {} + labels = config.tags # gcp uses "labels" instead of aws "tags" + labels = dict(sorted(copy.deepcopy(labels).items())) + node_type = instance_utils.get_node_type(config.node_config) + count = config.count + project_id = config.provider_config['project_id'] + availability_zone = config.provider_config['availability_zone'] + + # SKY: "TERMINATED" for compute VM, "STOPPED" for TPU VM + # "STOPPING" means the VM is being stopped, which needs + # to be included to avoid creating a new VM. + if node_type == instance_utils.GCPNodeType.COMPUTE: + resource = instance_utils.GCPComputeInstance + STOPPED_STATUS = ["TERMINATED", "STOPPING"] + elif node_type == instance_utils.GCPNodeType.TPU: + resource = instance_utils.GCPTPUVMInstance + STOPPED_STATUS = ["STOPPED", "STOPPING"] + else: + raise ValueError(f'Unknown node type {node_type}') + + # Try to reuse previously stopped nodes with compatible configs + if config.resume_stopped_nodes: + filters = { + TAG_RAY_NODE_KIND: labels[TAG_RAY_NODE_KIND], + # SkyPilot: removed TAG_RAY_LAUNCH_CONFIG to allow reusing nodes + # with different launch configs. + # Reference: https://github.com/skypilot-org/skypilot/pull/1671 + } + # This tag may not always be present. + if TAG_RAY_USER_NODE_TYPE in labels: + filters[TAG_RAY_USER_NODE_TYPE] = labels[TAG_RAY_USER_NODE_TYPE] + filters_with_launch_config = copy.copy(filters) + filters_with_launch_config[TAG_RAY_LAUNCH_CONFIG] = labels[ + TAG_RAY_LAUNCH_CONFIG] + + # SkyPilot: We try to use the instances with the same matching launch_config first. If + # there is not enough instances with matching launch_config, we then use all the + # instances with the same matching launch_config plus some instances with wrong + # launch_config. + def get_order_key(node): + import datetime + + timestamp = node.get("lastStartTimestamp") + if timestamp is not None: + return datetime.datetime.strptime(timestamp, + "%Y-%m-%dT%H:%M:%S.%f%z") + return node.id + + nodes_matching_launch_config = resource.filter( + project_id=project_id, + zone=availability_zone, + label_filters=filters_with_launch_config, + status_filters=STOPPED_STATUS, + ) + nodes_matching_launch_config = list( + nodes_matching_launch_config.values()) + + nodes_matching_launch_config.sort(key=lambda n: get_order_key(n), + reverse=True) + if len(nodes_matching_launch_config) >= count: + reuse_nodes = nodes_matching_launch_config[:count] + else: + nodes_all = resource.filter( + project_id=project_id, + zone=availability_zone, + label_filters=filters, + status_filters=STOPPED_STATUS, + ) + nodes_all = list(nodes_all.values()) + + nodes_matching_launch_config_ids = set( + n.id for n in nodes_matching_launch_config) + nodes_non_matching_launch_config = [ + n for n in nodes_all + if n.id not in nodes_matching_launch_config_ids + ] + # This is for backward compatibility, where the uesr already has leaked + # stopped nodes with the different launch config before update to #1671, + # and the total number of the leaked nodes is greater than the number of + # nodes to be created. With this, we will make sure we will reuse the + # most recently used nodes. + # This can be removed in the future when we are sure all the users + # have updated to #1671. + nodes_non_matching_launch_config.sort( + key=lambda n: get_order_key(n), reverse=True) + reuse_nodes = (nodes_matching_launch_config + + nodes_non_matching_launch_config) + # The total number of reusable nodes can be less than the number of nodes to be created. + # This `[:count]` is fine, as it will get all the reusable nodes, even if there are + # less nodes. + reuse_nodes = reuse_nodes[:count] + + reuse_node_ids = [n.id for n in reuse_nodes] + if reuse_nodes: + # TODO(suquark): Some instances could still be stopping. + # We may wait until these instances stop. + logger.info( + # TODO: handle plural vs singular? + f"Reusing nodes {reuse_node_ids}. " + "To disable reuse, set `cache_stopped_nodes: False` " + "under `provider` in the cluster configuration.") + for node_id in reuse_node_ids: + result = resource.start_instance(node_id, project_id, + availability_zone) + result_dict[node_id] = {node_id: result} + for node in reuse_nodes: + resource.set_labels(project_id, availability_zone, node, labels) + count -= len(reuse_node_ids) + if count: + results = resource.create_instances(project_id, availability_zone, + config.node_config, labels, count) + result_dict.update( + {instance_id: result for result, instance_id in results}) + return result_dict + + +def wait_instances(region: str, cluster_name: str, + state: Optional[status_lib.ClusterStatus]) -> None: + """See sky/provision/__init__.py""" + # TODO: maybe we just wait for the instances to be running, immediately + # after the instances are created, so we do not need to wait for the + # instances to be ready here. + raise NotImplementedError + + def stop_instances( cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None, diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index a1644df5dcf..4b549a2b182 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -1,10 +1,14 @@ """Utilities for GCP instances.""" import enum +import functools import re -from typing import Any, Dict, List, Optional +import time +from typing import Any, Dict, List, Optional, Tuple from sky import sky_logging from sky.adaptors import gcp +from sky.provision.gcp.constants import MAX_POLLS +from sky.provision.gcp.constants import POLL_INTERVAL from sky.utils import ux_utils logger = sky_logging.init_logger(__name__) @@ -17,6 +21,45 @@ r'The resource \'projects/.*/global/firewalls/.*\' was not found') +def _retry_on_http_exception( + regex: Optional[str] = None, + max_retries: int = MAX_POLLS, + retry_interval_s: int = POLL_INTERVAL, +): + """Retry a function call n-times for as long as it throws an exception.""" + from googleapiclient.errors import HttpError + + exception = HttpError + + def dec(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + + def try_catch_exc(): + try: + value = func(*args, **kwargs) + return value + except Exception as e: + if not isinstance(e, exception) or ( + regex and not re.search(regex, str(e))): + raise e + return e + + for _ in range(max_retries): + ret = try_catch_exc() + if not isinstance(ret, Exception): + break + time.sleep(retry_interval_s) + if isinstance(ret, Exception): + raise ret + return ret + + return wrapper + + return dec + + def instance_to_handler(instance: str): instance_type = instance.split('-')[-1] if instance_type == 'compute': @@ -115,6 +158,61 @@ def add_network_tag_if_not_exist( ) -> None: raise NotImplementedError + def create_instance(self, + node_config: dict, + labels: dict, + wait_for_operation: bool = True) -> Tuple[dict, str]: + """Creates a single instance and returns result. + + Returns a tuple of (result, node_name). + """ + raise NotImplementedError + + @classmethod + def create_instances( + cls, + project_id: str, + zone: str, + node_config: dict, + labels: dict, + count: int, + wait_for_operation: bool = True, + ) -> List[Tuple[dict, str]]: + """Creates multiple instances and returns result. + + Returns a list of tuples of (result, node_name). + """ + operations = [ + cls.create_instance(node_config, labels, wait_for_operation=False) + for i in range(count) + ] + + if wait_for_operation: + results = [(cls.wait_for_operation(operation, project_id, + zone), node_name) + for operation, node_name in operations] + else: + results = operations + + return results + + def start_instance(cls, + node_id: str, + project_id: str, + zone: str, + wait_for_operation: bool = True) -> dict: + """Start a stopped instance.""" + raise NotImplementedError + + @classmethod + def set_labels(cls, + project_id: str, + availability_zone: str, + node: dict, + labels: dict, + wait_for_operation: bool = True) -> dict: + return NotImplementedError + class GCPComputeInstance(GCPInstance): """Instance handler for GCP compute instances.""" @@ -364,6 +462,57 @@ def create_or_update_firewall_rule( ).execute() return operation + @classmethod + def set_labels(cls, + project_id: str, + availability_zone: str, + node: dict, + labels: dict, + wait_for_operation: bool = True) -> dict: + body = { + "labels": dict(node["labels"], **labels), + "labelFingerprint": node["labelFingerprint"], + } + node_id = node["name"] + operation = (cls.load_resource().instances().setLabels( + project=project_id, + zone=availability_zone, + instance=node_id, + body=body, + ).execute()) + + if wait_for_operation: + result = cls.wait_for_operation(operation, project_id, + availability_zone) + else: + result = operation + + return result + + def create_instance(self, + node_config: dict, + labels: dict, + wait_for_operation: bool = True) -> Tuple[dict, str]: + raise NotImplementedError + + def start_instance(cls, + node_id: str, + project_id: str, + zone: str, + wait_for_operation: bool = True) -> dict: + operation = (cls.load_resource().instances().start( + project=project_id, + zone=zone, + instance=node_id, + ).execute()) + + if wait_for_operation: + result = cls.wait_for_operation(operation, project_id, zone) + else: + result = operation + + return result + class GCPTPUVMInstance(GCPInstance): """Instance handler for GCP TPU node.""" @@ -525,6 +674,55 @@ def get_vpc_name( raise ValueError( f'Failed to get VPC name for instance {instance}') from e + @classmethod + @_retry_on_http_exception("unable to queue the operation") + def set_labels(cls, + project_id: str, + availability_zone: str, + node: dict, + labels: dict, + wait_for_operation: bool = True) -> dict: + body = { + "labels": dict(node["labels"], **labels), + } + update_mask = "labels" + + operation = (cls.load_resource().projects().locations().nodes().patch( + name=node["name"], + updateMask=update_mask, + body=body, + ).execute()) + + if wait_for_operation: + result = cls.wait_for_operation(project_id, availability_zone, + operation) + else: + result = operation + + return result + + def create_instance(self, + node_config: dict, + labels: dict, + wait_for_operation: bool = True) -> Tuple[dict, str]: + raise NotImplementedError + + def start_instance(cls, + node_id: str, + project_id: str, + zone: str, + wait_for_operation: bool = True) -> dict: + operation = (cls.load_resource().projects().locations().nodes().start( + name=node_id).execute()) + + # FIXME: original implementation has the "max_polls=MAX_POLLS" option. + if wait_for_operation: + result = cls.wait_for_operation(operation, project_id, zone) + else: + result = operation + + return result + class GCPNodeType(enum.Enum): """Enum for GCP node types (compute & tpu)""" @@ -540,3 +738,27 @@ def name_to_type(name: str): where [TYPE] is either 'compute' or 'tpu'. """ return GCPNodeType(name.split("-")[-1]) + + +def get_node_type(node: dict) -> GCPNodeType: + """Returns node type based on the keys in ``node``. + + This is a very simple check. If we have a ``machineType`` key, + this is a Compute instance. If we don't have a ``machineType`` key, + but we have ``acceleratorType``, this is a TPU. Otherwise, it's + invalid and an exception is raised. + + This works for both node configs and API returned nodes. + """ + + if 'machineType' not in node and 'acceleratorType' not in node: + raise ValueError( + 'Invalid node. For a Compute instance, "machineType" is ' + 'required. ' + 'For a TPU instance, "acceleratorType" and no "machineType" ' + 'is required. ' + f'Got {list(node)}') + + if 'machineType' not in node and 'acceleratorType' in node: + return GCPNodeType.TPU + return GCPNodeType.COMPUTE From 70ede43a85be5df9f87d1ee3bf9f4cdac838addb Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 8 Oct 2023 16:02:57 -0700 Subject: [PATCH 09/84] fix --- sky/provision/gcp/instance_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 4b549a2b182..58aa6c36944 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -158,6 +158,7 @@ def add_network_tag_if_not_exist( ) -> None: raise NotImplementedError + @classmethod def create_instance(self, node_config: dict, labels: dict, @@ -196,6 +197,7 @@ def create_instances( return results + @classmethod def start_instance(cls, node_id: str, project_id: str, @@ -489,12 +491,14 @@ def set_labels(cls, return result + @classmethod def create_instance(self, node_config: dict, labels: dict, wait_for_operation: bool = True) -> Tuple[dict, str]: raise NotImplementedError + @classmethod def start_instance(cls, node_id: str, project_id: str, @@ -701,12 +705,14 @@ def set_labels(cls, return result - def create_instance(self, + @classmethod + def create_instance(cls, node_config: dict, labels: dict, wait_for_operation: bool = True) -> Tuple[dict, str]: raise NotImplementedError + @classmethod def start_instance(cls, node_id: str, project_id: str, From f8fd06dbe82ab7a7e0a1eeec190afae708eb4c83 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Wed, 18 Oct 2023 16:20:57 -0700 Subject: [PATCH 10/84] fix --- sky/backends/cloud_vm_ray_backend.py | 2 +- sky/provision/gcp/__init__.py | 1 + sky/provision/gcp/config.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 7fc596d8457..ff9d815724e 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -1611,7 +1611,7 @@ def _retry_zones( global_user_state.set_owner_identity_for_cluster( cluster_name, cloud_user_identity) - if isinstance(to_provision.cloud, clouds.AWS): + if isinstance(to_provision.cloud, (clouds.AWS, clouds.GCP)): # Use the new provisioner for AWS. # TODO (suquark): Gradually move the other clouds to # the new provisioner once they are ready. diff --git a/sky/provision/gcp/__init__.py b/sky/provision/gcp/__init__.py index 4c83a198568..269481469f0 100644 --- a/sky/provision/gcp/__init__.py +++ b/sky/provision/gcp/__init__.py @@ -1,6 +1,7 @@ """GCP provisioner for SkyPilot.""" from sky.provision.gcp.config import bootstrap_instances +from sky.provision.gcp.instance import run_instances from sky.provision.gcp.instance import cleanup_ports from sky.provision.gcp.instance import open_ports from sky.provision.gcp.instance import stop_instances diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index caeb93b8b8c..7b36f627041 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -206,7 +206,7 @@ def bootstrap_instances( _configure_project(config.provider_config, crm) iam_role = _configure_iam_role(config, crm, iam) - config.provider_config['iam_role'] = iam_role + config.provider_config['iam_role'] = iam_role # temporary store config = _configure_subnet(region, cluster_name, config, compute) return config From 83880d20551955d880ff7d5f8c9cc1191bd92d70 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Wed, 18 Oct 2023 19:26:34 -0700 Subject: [PATCH 11/84] fix --- sky/provision/gcp/__init__.py | 2 +- sky/provision/gcp/instance.py | 129 ++++++++++++++----------- sky/provision/gcp/instance_utils.py | 140 ++++++++++++++++++++++++++-- 3 files changed, 204 insertions(+), 67 deletions(-) diff --git a/sky/provision/gcp/__init__.py b/sky/provision/gcp/__init__.py index 269481469f0..654c38146f4 100644 --- a/sky/provision/gcp/__init__.py +++ b/sky/provision/gcp/__init__.py @@ -1,8 +1,8 @@ """GCP provisioner for SkyPilot.""" from sky.provision.gcp.config import bootstrap_instances -from sky.provision.gcp.instance import run_instances from sky.provision.gcp.instance import cleanup_ports from sky.provision.gcp.instance import open_ports +from sky.provision.gcp.instance import run_instances from sky.provision.gcp.instance import stop_instances from sky.provision.gcp.instance import terminate_instances diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 575d3a8a5ec..97c713be5df 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -95,6 +95,7 @@ def run_instances(region: str, cluster_name: str, result_dict = {} labels = config.tags # gcp uses "labels" instead of aws "tags" labels = dict(sorted(copy.deepcopy(labels).items())) + node_type = instance_utils.get_node_type(config.node_config) count = config.count project_id = config.provider_config['project_id'] @@ -105,27 +106,64 @@ def run_instances(region: str, cluster_name: str, # to be included to avoid creating a new VM. if node_type == instance_utils.GCPNodeType.COMPUTE: resource = instance_utils.GCPComputeInstance - STOPPED_STATUS = ["TERMINATED", "STOPPING"] + STOPPED_STATUS = 'TERMINATED' elif node_type == instance_utils.GCPNodeType.TPU: resource = instance_utils.GCPTPUVMInstance - STOPPED_STATUS = ["STOPPED", "STOPPING"] + STOPPED_STATUS = 'STOPPED' else: raise ValueError(f'Unknown node type {node_type}') + PENDING_STATUS = ['PROVISIONING', 'STAGING'] + filter_labels = {TAG_RAY_CLUSTER_NAME: cluster_name} + + exist_instances = resource.filter( + project_id=project_id, + zone=availability_zone, + label_filters=filter_labels, + status_filters=None, + ) + exist_instances = list(exist_instances.values()) + + # NOTE: We are not handling REPAIRING, SUSPENDING, SUSPENDED status. + pending_instances = [] + running_instances = [] + stopping_instances = [] + stopped_instances = [] + + for inst in exist_instances: + state = inst['status'] + if state in PENDING_STATUS: + pending_instances.append(inst) + elif state == 'RUNNING': + running_instances.append(inst) + elif state == 'STOPPING': + stopping_instances.append(inst) + elif state == STOPPED_STATUS: + stopped_instances.append(inst) + else: + raise RuntimeError(f'Unsupported state "{state}".') + + # TODO(suquark): Maybe in the future, users could adjust the number + # of instances dynamically. Then this case would not be an error. + if config.resume_stopped_nodes and len(exist_instances) > config.count: + raise RuntimeError('The number of running/stopped/stopping ' + f'instances combined ({len(exist_instances)}) in ' + f'cluster "{cluster_name}" is greater than the ' + f'number requested by the user ({config.count}). ' + 'This is likely a resource leak. ' + 'Use "sky down" to terminate the cluster.') + + # TODO: if there are running instances, use their zones instead + + to_start_count = (config.count - len(running_instances) - + len(pending_instances)) + # Try to reuse previously stopped nodes with compatible configs - if config.resume_stopped_nodes: - filters = { - TAG_RAY_NODE_KIND: labels[TAG_RAY_NODE_KIND], - # SkyPilot: removed TAG_RAY_LAUNCH_CONFIG to allow reusing nodes - # with different launch configs. - # Reference: https://github.com/skypilot-org/skypilot/pull/1671 - } - # This tag may not always be present. - if TAG_RAY_USER_NODE_TYPE in labels: - filters[TAG_RAY_USER_NODE_TYPE] = labels[TAG_RAY_USER_NODE_TYPE] - filters_with_launch_config = copy.copy(filters) - filters_with_launch_config[TAG_RAY_LAUNCH_CONFIG] = labels[ - TAG_RAY_LAUNCH_CONFIG] + if config.resume_stopped_nodes and to_start_count > 0 and ( + stopping_instances or stopped_instances): + # TODO: we should wait until stopped instances are actually stopped. + # However, in GCP it is hard to know whether one instance is stopping for termination. + # So we need to wait and check. # SkyPilot: We try to use the instances with the same matching launch_config first. If # there is not enough instances with matching launch_config, we then use all the @@ -140,49 +178,20 @@ def get_order_key(node): "%Y-%m-%dT%H:%M:%S.%f%z") return node.id - nodes_matching_launch_config = resource.filter( + stopped_nodes = resource.filter( project_id=project_id, zone=availability_zone, - label_filters=filters_with_launch_config, - status_filters=STOPPED_STATUS, + label_filters=filter_labels, + status_filters=['STOPPING', STOPPED_STATUS], ) - nodes_matching_launch_config = list( - nodes_matching_launch_config.values()) + stopped_nodes = list(stopped_nodes.values()) + stopped_nodes.sort(key=lambda n: get_order_key(n), reverse=True) - nodes_matching_launch_config.sort(key=lambda n: get_order_key(n), - reverse=True) - if len(nodes_matching_launch_config) >= count: - reuse_nodes = nodes_matching_launch_config[:count] - else: - nodes_all = resource.filter( - project_id=project_id, - zone=availability_zone, - label_filters=filters, - status_filters=STOPPED_STATUS, - ) - nodes_all = list(nodes_all.values()) - - nodes_matching_launch_config_ids = set( - n.id for n in nodes_matching_launch_config) - nodes_non_matching_launch_config = [ - n for n in nodes_all - if n.id not in nodes_matching_launch_config_ids - ] - # This is for backward compatibility, where the uesr already has leaked - # stopped nodes with the different launch config before update to #1671, - # and the total number of the leaked nodes is greater than the number of - # nodes to be created. With this, we will make sure we will reuse the - # most recently used nodes. - # This can be removed in the future when we are sure all the users - # have updated to #1671. - nodes_non_matching_launch_config.sort( - key=lambda n: get_order_key(n), reverse=True) - reuse_nodes = (nodes_matching_launch_config + - nodes_non_matching_launch_config) - # The total number of reusable nodes can be less than the number of nodes to be created. - # This `[:count]` is fine, as it will get all the reusable nodes, even if there are - # less nodes. - reuse_nodes = reuse_nodes[:count] + # The total number of reusable nodes can be less than the number of nodes to be created. + # This `[:count]` is fine, as it will get all the reusable nodes, even if there are + # less nodes. + # FIXME: This is not correct. Use the method in AWS. + reuse_nodes = stopped_nodes[:count] reuse_node_ids = [n.id for n in reuse_nodes] if reuse_nodes: @@ -200,9 +209,15 @@ def get_order_key(node): for node in reuse_nodes: resource.set_labels(project_id, availability_zone, node, labels) count -= len(reuse_node_ids) - if count: - results = resource.create_instances(project_id, availability_zone, - config.node_config, labels, count) + + if to_start_count > 0: + results = resource.create_instances(cluster_name, project_id, + availability_zone, + config.node_config, labels, + to_start_count) + for success, instance_id in results: + resource.set_labels(project_id, availability_zone, instance_id, + labels) result_dict.update( {instance_id: result for result, instance_id in results}) return result_dict diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 58aa6c36944..2a4a885b9ac 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -1,9 +1,11 @@ """Utilities for GCP instances.""" +import copy import enum import functools import re import time from typing import Any, Dict, List, Optional, Tuple +import uuid from sky import sky_logging from sky.adaptors import gcp @@ -11,6 +13,14 @@ from sky.provision.gcp.constants import POLL_INTERVAL from sky.utils import ux_utils +# Tag uniquely identifying all nodes of a cluster +TAG_SKYPILOT_CLUSTER_NAME = 'skypilot-cluster-name' +TAG_RAY_CLUSTER_NAME = "ray-cluster-name" +# Tag for the name of the node +TAG_RAY_NODE_NAME = "ray-node-name" +INSTANCE_NAME_MAX_LEN = 64 +INSTANCE_NAME_UUID_LEN = 8 + logger = sky_logging.init_logger(__name__) # Using v2 according to @@ -60,6 +70,20 @@ def try_catch_exc(): return dec +def _generate_node_name(cluster_name: str, node_suffix: str) -> str: + """Generate node name from labels and suffix. + + This is required so that the correct resource can be selected + when the only information autoscaler has is the name of the node. + + The suffix is expected to be one of 'compute' or 'tpu' + (as in ``GCPNodeType``). + """ + suffix = f'-{uuid.uuid4().hex[:INSTANCE_NAME_UUID_LEN]}-{node_suffix}' + prefix = cluster_name[:INSTANCE_NAME_MAX_LEN - len(suffix)] + return prefix + suffix + + def instance_to_handler(instance: str): instance_type = instance.split('-')[-1] if instance_type == 'compute': @@ -159,7 +183,10 @@ def add_network_tag_if_not_exist( raise NotImplementedError @classmethod - def create_instance(self, + def create_instance(cls, + cluster_name: str, + project_id: str, + availability_zone: str, node_config: dict, labels: dict, wait_for_operation: bool = True) -> Tuple[dict, str]: @@ -172,6 +199,7 @@ def create_instance(self, @classmethod def create_instances( cls, + cluster_name: str, project_id: str, zone: str, node_config: dict, @@ -184,8 +212,12 @@ def create_instances( Returns a list of tuples of (result, node_name). """ operations = [ - cls.create_instance(node_config, labels, wait_for_operation=False) - for i in range(count) + cls.create_instance(cluster_name, + project_id, + zone, + node_config, + labels, + wait_for_operation=False) for i in range(count) ] if wait_for_operation: @@ -210,7 +242,7 @@ def start_instance(cls, def set_labels(cls, project_id: str, availability_zone: str, - node: dict, + node_id: str, labels: dict, wait_for_operation: bool = True) -> dict: return NotImplementedError @@ -468,14 +500,19 @@ def create_or_update_firewall_rule( def set_labels(cls, project_id: str, availability_zone: str, - node: dict, + node_id: str, labels: dict, wait_for_operation: bool = True) -> dict: + response = (cls.load_resource().instances().list( + project=project_id, + filter=f'name = {node_id}', + zone=availability_zone, + ).execute()) + node = response.get('items', [])[0] body = { "labels": dict(node["labels"], **labels), "labelFingerprint": node["labelFingerprint"], } - node_id = node["name"] operation = (cls.load_resource().instances().setLabels( project=project_id, zone=availability_zone, @@ -492,11 +529,96 @@ def set_labels(cls, return result @classmethod - def create_instance(self, + def _convert_resources_to_urls( + cls, project_id: str, availability_zone: str, + configuration_dict: Dict[str, Any]) -> Dict[str, Any]: + """Ensures that resources are in their full URL form. + + GCP expects machineType and accleratorType to be a full URL (e.g. + `zones/us-west1/machineTypes/n1-standard-2`) instead of just the + type (`n1-standard-2`) + + Args: + configuration_dict: Dict of options that will be passed to GCP + Returns: + Input dictionary, but with possibly expanding `machineType` and + `acceleratorType`. + """ + configuration_dict = copy.deepcopy(configuration_dict) + existing_machine_type = configuration_dict["machineType"] + if not re.search(".*/machineTypes/.*", existing_machine_type): + configuration_dict[ + "machineType"] = "zones/{zone}/machineTypes/{machine_type}".format( + zone=availability_zone, + machine_type=configuration_dict["machineType"], + ) + + for accelerator in configuration_dict.get("guestAccelerators", []): + gpu_type = accelerator["acceleratorType"] + if not re.search(".*/acceleratorTypes/.*", gpu_type): + accelerator[ + "acceleratorType"] = "projects/{project}/zones/{zone}/acceleratorTypes/{accelerator}".format( # noqa: E501 + project=project_id, + zone=availability_zone, + accelerator=gpu_type, + ) + + return configuration_dict + + @classmethod + def create_instance(cls, + cluster_name: str, + project_id: str, + availability_zone: str, node_config: dict, labels: dict, wait_for_operation: bool = True) -> Tuple[dict, str]: - raise NotImplementedError + config = cls._convert_resources_to_urls(project_id, availability_zone, + node_config) + # removing TPU-specific default key set in config.py + config.pop("networkConfig", None) + name = _generate_node_name(cluster_name, GCPNodeType.COMPUTE.value) + + labels = dict(config.get("labels", {}), **labels) + + config.update({ + "labels": dict( + labels, **{ + TAG_RAY_CLUSTER_NAME: cluster_name, + TAG_SKYPILOT_CLUSTER_NAME: cluster_name + }), + "name": name, + }) + + # Allow Google Compute Engine instance templates. + # + # Config example: + # + # ... + # node_config: + # sourceInstanceTemplate: global/instanceTemplates/worker-16 + # machineType: e2-standard-16 + # ... + # + # node_config parameters override matching template parameters, if any. + # + # https://cloud.google.com/compute/docs/instance-templates + # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert + source_instance_template = config.pop("sourceInstanceTemplate", None) + operation = (cls.load_resource().instances().insert( + project=project_id, + zone=availability_zone, + sourceInstanceTemplate=source_instance_template, + body=config, + ).execute()) + + if wait_for_operation: + result = cls.wait_for_operation(operation, project_id, + availability_zone) + else: + result = operation + + return result, name @classmethod def start_instance(cls, @@ -683,7 +805,7 @@ def get_vpc_name( def set_labels(cls, project_id: str, availability_zone: str, - node: dict, + node_id: str, labels: dict, wait_for_operation: bool = True) -> dict: body = { From 75b23acca359d4b7f6d730823d2060fa44217b45 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Wed, 18 Oct 2023 23:46:47 -0700 Subject: [PATCH 12/84] update --- sky/provision/gcp/instance.py | 119 +++++++++++++++++----------- sky/provision/gcp/instance_utils.py | 44 +++++++++- 2 files changed, 117 insertions(+), 46 deletions(-) diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 97c713be5df..6efdcff807c 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -31,6 +31,7 @@ MAX_POLLS_STOP = MAX_POLLS * 8 POLL_INTERVAL = 5 +TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node' # Tag uniquely identifying all nodes of a cluster TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' TAG_RAY_NODE_KIND = 'ray-node-type' @@ -89,15 +90,27 @@ def _wait_for_operations( total_polls += 1 +def _get_head_instance_id(instances: List) -> Optional[str]: + head_instance_id = None + for inst in instances: + labels = inst.get('labels', {}) + if (labels.get(TAG_RAY_NODE_KIND) == 'head' or + labels.get(TAG_SKYPILOT_HEAD_NODE) == '1'): + head_instance_id = inst['id'] + break + return head_instance_id + + def run_instances(region: str, cluster_name: str, config: common.ProvisionConfig) -> common.ProvisionRecord: """See sky/provision/__init__.py""" result_dict = {} labels = config.tags # gcp uses "labels" instead of aws "tags" labels = dict(sorted(copy.deepcopy(labels).items())) + resumed_instance_ids: List[str] = [] + created_instance_ids: List[str] = [] node_type = instance_utils.get_node_type(config.node_config) - count = config.count project_id = config.provider_config['project_id'] availability_zone = config.provider_config['availability_zone'] @@ -123,6 +136,7 @@ def run_instances(region: str, cluster_name: str, status_filters=None, ) exist_instances = list(exist_instances.values()) + head_instance_id = _get_head_instance_id(exist_instances) # NOTE: We are not handling REPAIRING, SUSPENDING, SUSPENDED status. pending_instances = [] @@ -130,6 +144,19 @@ def run_instances(region: str, cluster_name: str, stopping_instances = [] stopped_instances = [] + # SkyPilot: We try to use the instances with the same matching launch_config first. If + # there is not enough instances with matching launch_config, we then use all the + # instances with the same matching launch_config plus some instances with wrong + # launch_config. + def get_order_key(node): + import datetime + + timestamp = node.get("lastStartTimestamp") + if timestamp is not None: + return datetime.datetime.strptime(timestamp, + "%Y-%m-%dT%H:%M:%S.%f%z") + return node['id'] + for inst in exist_instances: state = inst['status'] if state in PENDING_STATUS: @@ -143,6 +170,19 @@ def run_instances(region: str, cluster_name: str, else: raise RuntimeError(f'Unsupported state "{state}".') + pending_instances.sort(key=lambda n: get_order_key(n), reverse=True) + running_instances.sort(key=lambda n: get_order_key(n), reverse=True) + stopping_instances.sort(key=lambda n: get_order_key(n), reverse=True) + stopped_instances.sort(key=lambda n: get_order_key(n), reverse=True) + + if head_instance_id is None: + if running_instances: + head_instance_id = resource.create_node_tag( + running_instances[0]['id']) + elif pending_instances: + head_instance_id = resource.create_node_tag( + pending_instances[0]['id']) + # TODO(suquark): Maybe in the future, users could adjust the number # of instances dynamically. Then this case would not be an error. if config.resume_stopped_nodes and len(exist_instances) > config.count: @@ -164,72 +204,61 @@ def run_instances(region: str, cluster_name: str, # TODO: we should wait until stopped instances are actually stopped. # However, in GCP it is hard to know whether one instance is stopping for termination. # So we need to wait and check. - - # SkyPilot: We try to use the instances with the same matching launch_config first. If - # there is not enough instances with matching launch_config, we then use all the - # instances with the same matching launch_config plus some instances with wrong - # launch_config. - def get_order_key(node): - import datetime - - timestamp = node.get("lastStartTimestamp") - if timestamp is not None: - return datetime.datetime.strptime(timestamp, - "%Y-%m-%dT%H:%M:%S.%f%z") - return node.id - - stopped_nodes = resource.filter( - project_id=project_id, - zone=availability_zone, - label_filters=filter_labels, - status_filters=['STOPPING', STOPPED_STATUS], - ) - stopped_nodes = list(stopped_nodes.values()) + stopped_nodes = stopping_instances + stopped_instances stopped_nodes.sort(key=lambda n: get_order_key(n), reverse=True) - - # The total number of reusable nodes can be less than the number of nodes to be created. - # This `[:count]` is fine, as it will get all the reusable nodes, even if there are - # less nodes. - # FIXME: This is not correct. Use the method in AWS. - reuse_nodes = stopped_nodes[:count] - - reuse_node_ids = [n.id for n in reuse_nodes] - if reuse_nodes: + resumed_instance_ids = [n['id'] for n in stopped_nodes] + if resumed_instance_ids: # TODO(suquark): Some instances could still be stopping. # We may wait until these instances stop. - logger.info( - # TODO: handle plural vs singular? - f"Reusing nodes {reuse_node_ids}. " - "To disable reuse, set `cache_stopped_nodes: False` " - "under `provider` in the cluster configuration.") - for node_id in reuse_node_ids: + for node_id in resumed_instance_ids: result = resource.start_instance(node_id, project_id, availability_zone) result_dict[node_id] = {node_id: result} - for node in reuse_nodes: - resource.set_labels(project_id, availability_zone, node, labels) - count -= len(reuse_node_ids) + resource.set_labels(project_id, availability_zone, node_id, + labels) + to_start_count -= len(resumed_instance_ids) + + if head_instance_id is None: + head_instance_id = resource.create_node_tag(resumed_instance_ids[0]) if to_start_count > 0: results = resource.create_instances(cluster_name, project_id, availability_zone, config.node_config, labels, to_start_count) + # FIXME: it seems that success is always False. for success, instance_id in results: resource.set_labels(project_id, availability_zone, instance_id, labels) + created_instance_ids.append(instance_id) result_dict.update( {instance_id: result for result, instance_id in results}) - return result_dict + + # NOTE: we only create worker tags for newly started nodes, because + # the worker tag is a legacy feature, so we would not care about + # more corner cases. + if head_instance_id is None: + head_instance_id = resource.create_node_tag(created_instance_ids[0]) + for inst in created_instance_ids[1:]: + resource.create_node_tag(inst, is_head=False) + else: + for inst in created_instance_ids: + resource.create_node_tag(inst, is_head=False) + return common.ProvisionRecord(provider_name='gcp', + region=region, + zone=availability_zone, + cluster_name=cluster_name, + head_instance_id=head_instance_id, + resumed_instance_ids=resumed_instance_ids, + created_instance_ids=created_instance_ids) def wait_instances(region: str, cluster_name: str, state: Optional[status_lib.ClusterStatus]) -> None: """See sky/provision/__init__.py""" - # TODO: maybe we just wait for the instances to be running, immediately - # after the instances are created, so we do not need to wait for the - # instances to be ready here. - raise NotImplementedError + # We already wait for the instances to be running in run_instances. + # So we don't need to wait here. + return def stop_instances( diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 2a4a885b9ac..c06f1019720 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -20,6 +20,10 @@ TAG_RAY_NODE_NAME = "ray-node-name" INSTANCE_NAME_MAX_LEN = 64 INSTANCE_NAME_UUID_LEN = 8 +TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node' +# Tag uniquely identifying all nodes of a cluster +TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' +TAG_RAY_NODE_KIND = 'ray-node-type' logger = sky_logging.init_logger(__name__) @@ -245,7 +249,17 @@ def set_labels(cls, node_id: str, labels: dict, wait_for_operation: bool = True) -> dict: - return NotImplementedError + raise NotImplementedError + + @classmethod + def create_node_tag(cls, + cluster_name: str, + project_id: str, + availability_zone: str, + target_instance_id: str, + is_head: bool = True, + wait_for_operation: bool = True) -> str: + raise NotImplementedError class GCPComputeInstance(GCPInstance): @@ -528,6 +542,34 @@ def set_labels(cls, return result + @classmethod + def create_node_tag(cls, + cluster_name: str, + project_id: str, + availability_zone: str, + target_instance_id: str, + is_head: bool = True, + wait_for_operation: bool = True) -> str: + if is_head: + node_tag = { + TAG_SKYPILOT_HEAD_NODE: '1', + TAG_RAY_NODE_KIND: 'head', + 'Name': f'sky-{cluster_name}-head', + } + else: + node_tag = { + TAG_SKYPILOT_HEAD_NODE: '1', + TAG_RAY_NODE_KIND: 'worker', + 'Name': f'sky-{cluster_name}-worker', + } + cls.set_labels(project_id=project_id, + availability_zone=availability_zone, + node_id=target_instance_id, + labels=node_tag, + wait_for_operation=wait_for_operation) + + return target_instance_id + @classmethod def _convert_resources_to_urls( cls, project_id: str, availability_zone: str, From 26149b7b91b565b11ccc11210535f73b69316e90 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Thu, 19 Oct 2023 13:38:52 -0700 Subject: [PATCH 13/84] wait stopping instances --- sky/provision/gcp/instance.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 6efdcff807c..d5272a4f96a 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -129,6 +129,18 @@ def run_instances(region: str, cluster_name: str, PENDING_STATUS = ['PROVISIONING', 'STAGING'] filter_labels = {TAG_RAY_CLUSTER_NAME: cluster_name} + # wait until all stopping instances are stopped/terminated + while True: + instances = resource.filter( + project_id=project_id, + zone=availability_zone, + label_filters=filter_labels, + status_filters=['STOPPING'], + ) + if not instances: + break + time.sleep(POLL_INTERVAL) + exist_instances = resource.filter( project_id=project_id, zone=availability_zone, @@ -175,6 +187,10 @@ def get_order_key(node): stopping_instances.sort(key=lambda n: get_order_key(n), reverse=True) stopped_instances.sort(key=lambda n: get_order_key(n), reverse=True) + if stopping_instances: + raise RuntimeError( + f'Some instances are being stopped during provisioning.') + if head_instance_id is None: if running_instances: head_instance_id = resource.create_node_tag( @@ -193,23 +209,13 @@ def get_order_key(node): 'This is likely a resource leak. ' 'Use "sky down" to terminate the cluster.') - # TODO: if there are running instances, use their zones instead - to_start_count = (config.count - len(running_instances) - len(pending_instances)) # Try to reuse previously stopped nodes with compatible configs - if config.resume_stopped_nodes and to_start_count > 0 and ( - stopping_instances or stopped_instances): - # TODO: we should wait until stopped instances are actually stopped. - # However, in GCP it is hard to know whether one instance is stopping for termination. - # So we need to wait and check. - stopped_nodes = stopping_instances + stopped_instances - stopped_nodes.sort(key=lambda n: get_order_key(n), reverse=True) - resumed_instance_ids = [n['id'] for n in stopped_nodes] + if config.resume_stopped_nodes and to_start_count > 0 and stopped_instances: + resumed_instance_ids = [n['id'] for n in stopped_instances] if resumed_instance_ids: - # TODO(suquark): Some instances could still be stopping. - # We may wait until these instances stop. for node_id in resumed_instance_ids: result = resource.start_instance(node_id, project_id, availability_zone) @@ -244,6 +250,7 @@ def get_order_key(node): else: for inst in created_instance_ids: resource.create_node_tag(inst, is_head=False) + return common.ProvisionRecord(provider_name='gcp', region=region, zone=availability_zone, From 03ff947bef561c6f4b151ca147c787b904dfe1d4 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Thu, 19 Oct 2023 13:39:25 -0700 Subject: [PATCH 14/84] support normal gcp tpus first --- sky/backends/cloud_vm_ray_backend.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index ff9d815724e..14d0268da8d 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -1611,7 +1611,10 @@ def _retry_zones( global_user_state.set_owner_identity_for_cluster( cluster_name, cloud_user_identity) - if isinstance(to_provision.cloud, (clouds.AWS, clouds.GCP)): + if isinstance( + to_provision.cloud, + (clouds.AWS, + clouds.GCP)) and not tpu_utils.is_tpu_vm(to_provision): # Use the new provisioner for AWS. # TODO (suquark): Gradually move the other clouds to # the new provisioner once they are ready. From aea9d1abc051dd790378df4ec2bad2d363b8f36a Mon Sep 17 00:00:00 2001 From: Siyuan Date: Thu, 19 Oct 2023 15:03:55 -0700 Subject: [PATCH 15/84] fix gcp --- sky/provision/gcp/__init__.py | 2 + sky/provision/gcp/instance.py | 72 ++++++++++-------- sky/provision/gcp/instance_utils.py | 111 +++++++++++++++++++++++++--- 3 files changed, 143 insertions(+), 42 deletions(-) diff --git a/sky/provision/gcp/__init__.py b/sky/provision/gcp/__init__.py index 654c38146f4..6f840c29e9f 100644 --- a/sky/provision/gcp/__init__.py +++ b/sky/provision/gcp/__init__.py @@ -6,3 +6,5 @@ from sky.provision.gcp.instance import run_instances from sky.provision.gcp.instance import stop_instances from sky.provision.gcp.instance import terminate_instances +from sky.provision.gcp.instance import wait_instances +from sky.provision.gcp.instance import get_cluster_info diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index d5272a4f96a..170d8bd884c 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -96,7 +96,7 @@ def _get_head_instance_id(instances: List) -> Optional[str]: labels = inst.get('labels', {}) if (labels.get(TAG_RAY_NODE_KIND) == 'head' or labels.get(TAG_SKYPILOT_HEAD_NODE) == '1'): - head_instance_id = inst['id'] + head_instance_id = inst['name'] break return head_instance_id @@ -104,7 +104,8 @@ def _get_head_instance_id(instances: List) -> Optional[str]: def run_instances(region: str, cluster_name: str, config: common.ProvisionConfig) -> common.ProvisionRecord: """See sky/provision/__init__.py""" - result_dict = {} + # NOTE: although google cloud instances have IDs, but they are + # not used for indexing. Instead, we use the instance name. labels = config.tags # gcp uses "labels" instead of aws "tags" labels = dict(sorted(copy.deepcopy(labels).items())) resumed_instance_ids: List[str] = [] @@ -194,11 +195,20 @@ def get_order_key(node): if head_instance_id is None: if running_instances: head_instance_id = resource.create_node_tag( - running_instances[0]['id']) + cluster_name, + project_id, + availability_zone, + running_instances[0]['name'], + is_head=True, + ) elif pending_instances: head_instance_id = resource.create_node_tag( - pending_instances[0]['id']) - + cluster_name, + project_id, + availability_zone, + pending_instances[0]['name'], + is_head=True, + ) # TODO(suquark): Maybe in the future, users could adjust the number # of instances dynamically. Then this case would not be an error. if config.resume_stopped_nodes and len(exist_instances) > config.count: @@ -214,42 +224,35 @@ def get_order_key(node): # Try to reuse previously stopped nodes with compatible configs if config.resume_stopped_nodes and to_start_count > 0 and stopped_instances: - resumed_instance_ids = [n['id'] for n in stopped_instances] + resumed_instance_ids = [n['name'] for n in stopped_instances] if resumed_instance_ids: - for node_id in resumed_instance_ids: - result = resource.start_instance(node_id, project_id, - availability_zone) - result_dict[node_id] = {node_id: result} - resource.set_labels(project_id, availability_zone, node_id, + for instance_id in resumed_instance_ids: + resource.start_instance(instance_id, project_id, + availability_zone) + resource.set_labels(project_id, availability_zone, instance_id, labels) to_start_count -= len(resumed_instance_ids) if head_instance_id is None: - head_instance_id = resource.create_node_tag(resumed_instance_ids[0]) + head_instance_id = resource.create_node_tag( + cluster_name, + project_id, + availability_zone, + resumed_instance_ids[0], + is_head=True, + ) if to_start_count > 0: results = resource.create_instances(cluster_name, project_id, availability_zone, config.node_config, labels, - to_start_count) - # FIXME: it seems that success is always False. - for success, instance_id in results: - resource.set_labels(project_id, availability_zone, instance_id, - labels) - created_instance_ids.append(instance_id) - result_dict.update( - {instance_id: result for result, instance_id in results}) - - # NOTE: we only create worker tags for newly started nodes, because - # the worker tag is a legacy feature, so we would not care about - # more corner cases. + to_start_count, + head_instance_id is None) if head_instance_id is None: - head_instance_id = resource.create_node_tag(created_instance_ids[0]) - for inst in created_instance_ids[1:]: - resource.create_node_tag(inst, is_head=False) - else: - for inst in created_instance_ids: - resource.create_node_tag(inst, is_head=False) + head_instance_id = results[0][1] + # FIXME: it seems that success could be False sometimes, but the instances + # are started correctly. + created_instance_ids = [instance_id for success, instance_id in results] return common.ProvisionRecord(provider_name='gcp', region=region, @@ -268,6 +271,15 @@ def wait_instances(region: str, cluster_name: str, return +def get_cluster_info(region: str, cluster_name: str) -> common.ClusterInfo: + """See sky/provision/__init__.py""" + raise NotImplementedError + return common.ClusterInfo( + instances=instances, + head_instance_id=head_instance_id, + ) + + def stop_instances( cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None, diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index c06f1019720..c32427b0c7e 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -74,7 +74,8 @@ def try_catch_exc(): return dec -def _generate_node_name(cluster_name: str, node_suffix: str) -> str: +def _generate_node_name(cluster_name: str, node_suffix: str, + is_head: bool) -> str: """Generate node name from labels and suffix. This is required so that the correct resource can be selected @@ -84,6 +85,10 @@ def _generate_node_name(cluster_name: str, node_suffix: str) -> str: (as in ``GCPNodeType``). """ suffix = f'-{uuid.uuid4().hex[:INSTANCE_NAME_UUID_LEN]}-{node_suffix}' + if is_head: + suffix = f'-head{suffix}' + else: + suffix = f'-worker{suffix}' prefix = cluster_name[:INSTANCE_NAME_MAX_LEN - len(suffix)] return prefix + suffix @@ -193,6 +198,7 @@ def create_instance(cls, availability_zone: str, node_config: dict, labels: dict, + is_head_node: bool, wait_for_operation: bool = True) -> Tuple[dict, str]: """Creates a single instance and returns result. @@ -209,6 +215,7 @@ def create_instances( node_config: dict, labels: dict, count: int, + include_head_node: bool, wait_for_operation: bool = True, ) -> List[Tuple[dict, str]]: """Creates multiple instances and returns result. @@ -221,6 +228,7 @@ def create_instances( zone, node_config, labels, + is_head_node=include_head_node and i == 0, wait_for_operation=False) for i in range(count) ] @@ -517,9 +525,9 @@ def set_labels(cls, node_id: str, labels: dict, wait_for_operation: bool = True) -> dict: - response = (cls.load_resource().instances().list( + response = (cls.load_resource().instances().get( project=project_id, - filter=f'name = {node_id}', + instance=node_id, zone=availability_zone, ).execute()) node = response.get('items', [])[0] @@ -607,6 +615,76 @@ def _convert_resources_to_urls( return configuration_dict + # @classmethod + # def create_instances( + # cls, + # cluster_name: str, + # project_id: str, + # zone: str, + # node_config: dict, + # labels: dict, + # count: int, + # include_head_node: bool, + # wait_for_operation: bool = True, + # ) -> List[Tuple[dict, str]]: + # config = cls._convert_resources_to_urls(project_id, zone, + # node_config) + # # removing TPU-specific default key set in config.py + # config.pop("networkConfig", None) + # names = [] + # if include_head_node: + # names.append(_generate_node_name(cluster_name, GCPNodeType.COMPUTE.value, is_head=True)) + # for _ in range(count - 1): + # names.append(_generate_node_name(cluster_name, GCPNodeType.COMPUTE.value, is_head=False)) + # else: + # for _ in range(count): + # names.append(_generate_node_name(cluster_name, GCPNodeType.COMPUTE.value, is_head=False)) + + # labels = dict(config.get("labels", {}), **labels) + + # config.update({ + # "labels": dict( + # labels, **{ + # TAG_RAY_CLUSTER_NAME: cluster_name, + # TAG_SKYPILOT_CLUSTER_NAME: cluster_name + # }), + # }) + # source_instance_template = config.pop("sourceInstanceTemplate", None) + # body = { + # 'count': count, + # 'instanceProperties': config, + # 'sourceInstanceTemplate': source_instance_template, + # 'perInstanceProperties': {n: {} for n in names} + # } + + # # Allow Google Compute Engine instance templates. + # # + # # Config example: + # # + # # ... + # # node_config: + # # sourceInstanceTemplate: global/instanceTemplates/worker-16 + # # machineType: e2-standard-16 + # # ... + # # + # # node_config parameters override matching template parameters, if any. + # # + # # https://cloud.google.com/compute/docs/instance-templates + # # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert + # operation = (cls.load_resource().instances().bulkInsert( + # project=project_id, + # zone=zone, + # body=body, + # ).execute()) + + # if wait_for_operation: + # result = cls.wait_for_operation(operation, project_id, + # zone) + # else: + # result = operation + + # return result, name + @classmethod def create_instance(cls, cluster_name: str, @@ -614,23 +692,32 @@ def create_instance(cls, availability_zone: str, node_config: dict, labels: dict, + is_head_node: bool, wait_for_operation: bool = True) -> Tuple[dict, str]: config = cls._convert_resources_to_urls(project_id, availability_zone, node_config) # removing TPU-specific default key set in config.py config.pop("networkConfig", None) - name = _generate_node_name(cluster_name, GCPNodeType.COMPUTE.value) + name = _generate_node_name(cluster_name, GCPNodeType.COMPUTE.value, + is_head_node) labels = dict(config.get("labels", {}), **labels) - - config.update({ - "labels": dict( - labels, **{ - TAG_RAY_CLUSTER_NAME: cluster_name, - TAG_SKYPILOT_CLUSTER_NAME: cluster_name - }), - "name": name, + labels.update({ + TAG_RAY_CLUSTER_NAME: cluster_name, + TAG_SKYPILOT_CLUSTER_NAME: cluster_name }) + if is_head_node: + labels.update({ + TAG_SKYPILOT_HEAD_NODE: '1', + TAG_RAY_NODE_KIND: 'head', + }) + else: + labels.update({ + TAG_SKYPILOT_HEAD_NODE: '0', + TAG_RAY_NODE_KIND: 'worker', + }) + + config.update({"labels": labels, "name": name}) # Allow Google Compute Engine instance templates. # From 0135eea895e5fd8aa6e8b9483847b6585ac0cec3 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Fri, 20 Oct 2023 14:12:26 -0700 Subject: [PATCH 16/84] support get cluster info --- sky/provision/__init__.py | 7 +++-- sky/provision/aws/instance.py | 7 +++-- sky/provision/gcp/__init__.py | 2 +- sky/provision/gcp/instance.py | 46 +++++++++++++++++++++++++++-- sky/provision/gcp/instance_utils.py | 2 -- sky/provision/provisioner.py | 6 ++-- 6 files changed, 58 insertions(+), 12 deletions(-) diff --git a/sky/provision/__init__.py b/sky/provision/__init__.py index a0c717ed281..8af83f08b85 100644 --- a/sky/provision/__init__.py +++ b/sky/provision/__init__.py @@ -137,7 +137,10 @@ def wait_instances(provider_name: str, region: str, cluster_name_on_cloud: str, @_route_to_cloud_impl -def get_cluster_info(provider_name: str, region: str, - cluster_name_on_cloud: str) -> common.ClusterInfo: +def get_cluster_info( + provider_name: str, + region: str, + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: """Get the metadata of instances in a cluster.""" raise NotImplementedError diff --git a/sky/provision/aws/instance.py b/sky/provision/aws/instance.py index 60e59cf4185..222b405b408 100644 --- a/sky/provision/aws/instance.py +++ b/sky/provision/aws/instance.py @@ -758,9 +758,12 @@ def wait_instances(region: str, cluster_name_on_cloud: str, waiter.wait(WaiterConfig={'Delay': 5, 'MaxAttempts': 120}, Filters=filters) -def get_cluster_info(region: str, - cluster_name_on_cloud: str) -> common.ClusterInfo: +def get_cluster_info( + region: str, + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: """See sky/provision/__init__.py""" + del provider_config # unused ec2 = _default_ec2_resource(region) filters = [ { diff --git a/sky/provision/gcp/__init__.py b/sky/provision/gcp/__init__.py index 6f840c29e9f..12a984457a2 100644 --- a/sky/provision/gcp/__init__.py +++ b/sky/provision/gcp/__init__.py @@ -2,9 +2,9 @@ from sky.provision.gcp.config import bootstrap_instances from sky.provision.gcp.instance import cleanup_ports +from sky.provision.gcp.instance import get_cluster_info from sky.provision.gcp.instance import open_ports from sky.provision.gcp.instance import run_instances from sky.provision.gcp.instance import stop_instances from sky.provision.gcp.instance import terminate_instances from sky.provision.gcp.instance import wait_instances -from sky.provision.gcp.instance import get_cluster_info diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 170d8bd884c..d409d55a103 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -271,11 +271,51 @@ def wait_instances(region: str, cluster_name: str, return -def get_cluster_info(region: str, cluster_name: str) -> common.ClusterInfo: +def get_cluster_info( + region: str, + cluster_name: str, + provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: """See sky/provision/__init__.py""" - raise NotImplementedError + assert provider_config is not None, cluster_name + zone = provider_config['availability_zone'] + project_id = provider_config['project_id'] + label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name} + + handlers: List[Type[instance_utils.GCPInstance]] = [ + instance_utils.GCPComputeInstance + ] + use_tpu_vms = provider_config.get('_has_tpus', False) + if use_tpu_vms: + handlers.append(instance_utils.GCPTPUVMInstance) + + handler_to_instances = _filter_instances( + handlers, + project_id, + zone, + label_filters, + lambda _: ['RUNNING'], + ) + all_instances = [ + i for instances in handler_to_instances.values() for i in instances + ] + + head_instances = _filter_instances( + handlers, + project_id, + zone, + { + **label_filters, TAG_RAY_NODE_KIND: 'head' + }, + lambda _: ['RUNNING'], + ) + head_instance_id = None + for insts in head_instances.values(): + if insts and insts[0]: + head_instance_id = insts[0] + break + return common.ClusterInfo( - instances=instances, + instances=all_instances, head_instance_id=head_instance_id, ) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index c32427b0c7e..9f2d186d0cc 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -562,13 +562,11 @@ def create_node_tag(cls, node_tag = { TAG_SKYPILOT_HEAD_NODE: '1', TAG_RAY_NODE_KIND: 'head', - 'Name': f'sky-{cluster_name}-head', } else: node_tag = { TAG_SKYPILOT_HEAD_NODE: '1', TAG_RAY_NODE_KIND: 'worker', - 'Name': f'sky-{cluster_name}-worker', } cls.set_labels(project_id=project_id, availability_zone=availability_zone, diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index e119d756250..2cb84961c94 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -313,9 +313,12 @@ def _post_provision_setup( cloud_name: str, cluster_name: ClusterName, cluster_yaml: str, provision_record: provision_common.ProvisionRecord, custom_resource: Optional[str]) -> provision_common.ClusterInfo: + config_from_yaml = common_utils.read_yaml(cluster_yaml) + provider_config = config_from_yaml.get('provider') cluster_info = provision.get_cluster_info(cloud_name, provision_record.region, - cluster_name.name_on_cloud) + cluster_name.name_on_cloud, + provider_config=provider_config) if len(cluster_info.instances) > 1: # Only worker nodes have logs in the per-instance log directory. Head @@ -337,7 +340,6 @@ def _post_provision_setup( 'Could not find any head instance.') # TODO(suquark): Move wheel build here in future PRs. - config_from_yaml = common_utils.read_yaml(cluster_yaml) ip_list = cluster_info.get_feasible_ips() ssh_credentials = backend_utils.ssh_credential_from_yaml(cluster_yaml) From a5b7537f0fb8224b8c67479927bd792bf2caddad Mon Sep 17 00:00:00 2001 From: Siyuan Date: Fri, 20 Oct 2023 15:15:14 -0700 Subject: [PATCH 17/84] fix --- sky/provision/gcp/instance.py | 9 ++++--- sky/provision/gcp/instance_utils.py | 39 ++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index d409d55a103..06457d11128 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -295,9 +295,10 @@ def get_cluster_info( label_filters, lambda _: ['RUNNING'], ) - all_instances = [ - i for instances in handler_to_instances.values() for i in instances - ] + instances = {} + for res, insts in handler_to_instances.items(): + for inst in insts: + instances[inst] = res.get_instance_info(project_id, zone, inst) head_instances = _filter_instances( handlers, @@ -315,7 +316,7 @@ def get_cluster_info( break return common.ClusterInfo( - instances=all_instances, + instances=instances, head_instance_id=head_instance_id, ) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 9f2d186d0cc..407597ec01d 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -9,6 +9,7 @@ from sky import sky_logging from sky.adaptors import gcp +from sky.provision import common from sky.provision.gcp.constants import MAX_POLLS from sky.provision.gcp.constants import POLL_INTERVAL from sky.utils import ux_utils @@ -269,6 +270,15 @@ def create_node_tag(cls, wait_for_operation: bool = True) -> str: raise NotImplementedError + @classmethod + def get_instance_info( + cls, + project_id: str, + availability_zone: str, + instance_id: str, + wait_for_operation: bool = True) -> common.InstanceInfo: + raise NotImplementedError + class GCPComputeInstance(GCPInstance): """Instance handler for GCP compute instances.""" @@ -525,12 +535,11 @@ def set_labels(cls, node_id: str, labels: dict, wait_for_operation: bool = True) -> dict: - response = (cls.load_resource().instances().get( + node = cls.load_resource().instances().get( project=project_id, instance=node_id, zone=availability_zone, - ).execute()) - node = response.get('items', [])[0] + ).execute() body = { "labels": dict(node["labels"], **labels), "labelFingerprint": node["labelFingerprint"], @@ -766,6 +775,30 @@ def start_instance(cls, return result + @classmethod + def get_instance_info( + cls, + project_id: str, + availability_zone: str, + instance_id: str, + wait_for_operation: bool = True) -> common.InstanceInfo: + result = cls.load_resource().instances().get( + project=project_id, + zone=availability_zone, + instance=instance_id, + ).execute() + external_ip = (result.get("networkInterfaces", + [{}])[0].get("accessConfigs", + [{}])[0].get("natIP", None)) + internal_ip = result.get("networkInterfaces", [{}])[0].get("networkIP") + + return common.InstanceInfo( + instance_id=instance_id, + internal_ip=internal_ip, + external_ip=external_ip, + tags=result.get('labels', {}), + ) + class GCPTPUVMInstance(GCPInstance): """Instance handler for GCP TPU node.""" From 2bb7438e5a7470e2cedc89d6660f6f47176cab66 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Fri, 20 Oct 2023 15:25:37 -0700 Subject: [PATCH 18/84] update --- sky/provision/gcp/instance.py | 6 +++--- sky/provision/gcp/instance_utils.py | 9 ++++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 06457d11128..c216b66dfe0 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -273,13 +273,13 @@ def wait_instances(region: str, cluster_name: str, def get_cluster_info( region: str, - cluster_name: str, + cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: """See sky/provision/__init__.py""" - assert provider_config is not None, cluster_name + assert provider_config is not None, cluster_name_on_cloud zone = provider_config['availability_zone'] project_id = provider_config['project_id'] - label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name} + label_filters = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} handlers: List[Type[instance_utils.GCPInstance]] = [ instance_utils.GCPComputeInstance diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 407597ec01d..32b17229c51 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -968,13 +968,16 @@ def set_labels(cls, node_id: str, labels: dict, wait_for_operation: bool = True) -> dict: + node = cls.load_resource().projects().locations().nodes().get( + name=node_id, + ) body = { "labels": dict(node["labels"], **labels), } update_mask = "labels" operation = (cls.load_resource().projects().locations().nodes().patch( - name=node["name"], + name=node_id, updateMask=update_mask, body=body, ).execute()) @@ -989,8 +992,12 @@ def set_labels(cls, @classmethod def create_instance(cls, + cluster_name: str, + project_id: str, + availability_zone: str, node_config: dict, labels: dict, + is_head_node: bool, wait_for_operation: bool = True) -> Tuple[dict, str]: raise NotImplementedError From e525f060950453b23ac9de4405f4673dedc5ce0f Mon Sep 17 00:00:00 2001 From: Siyuan Date: Fri, 27 Oct 2023 17:19:52 -0700 Subject: [PATCH 19/84] wait for instance starting --- sky/provision/gcp/instance.py | 11 +++++++++++ sky/provision/gcp/instance_utils.py | 3 +-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index c216b66dfe0..ae25d8ddca1 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -254,6 +254,17 @@ def get_order_key(node): # are started correctly. created_instance_ids = [instance_id for success, instance_id in results] + while True: + # wait until all instances are running + instances = resource.filter( + project_id=project_id, + zone=availability_zone, + label_filters=filter_labels, + status_filters=PENDING_STATUS, + ) + if not instances: + break + return common.ProvisionRecord(provider_name='gcp', region=region, zone=availability_zone, diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 32b17229c51..ffa4d0afe83 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -969,8 +969,7 @@ def set_labels(cls, labels: dict, wait_for_operation: bool = True) -> dict: node = cls.load_resource().projects().locations().nodes().get( - name=node_id, - ) + name=node_id) body = { "labels": dict(node["labels"], **labels), } From 0ae66fc597edf5144e5394d34ae532cd982c85ff Mon Sep 17 00:00:00 2001 From: Siyuan Date: Fri, 3 Nov 2023 10:01:19 -0700 Subject: [PATCH 20/84] rename --- sky/provision/gcp/instance.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index ae25d8ddca1..2157f097f2f 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -101,7 +101,7 @@ def _get_head_instance_id(instances: List) -> Optional[str]: return head_instance_id -def run_instances(region: str, cluster_name: str, +def run_instances(region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionRecord: """See sky/provision/__init__.py""" # NOTE: although google cloud instances have IDs, but they are @@ -128,7 +128,7 @@ def run_instances(region: str, cluster_name: str, raise ValueError(f'Unknown node type {node_type}') PENDING_STATUS = ['PROVISIONING', 'STAGING'] - filter_labels = {TAG_RAY_CLUSTER_NAME: cluster_name} + filter_labels = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} # wait until all stopping instances are stopped/terminated while True: @@ -195,7 +195,7 @@ def get_order_key(node): if head_instance_id is None: if running_instances: head_instance_id = resource.create_node_tag( - cluster_name, + cluster_name_on_cloud, project_id, availability_zone, running_instances[0]['name'], @@ -203,7 +203,7 @@ def get_order_key(node): ) elif pending_instances: head_instance_id = resource.create_node_tag( - cluster_name, + cluster_name_on_cloud, project_id, availability_zone, pending_instances[0]['name'], @@ -212,12 +212,13 @@ def get_order_key(node): # TODO(suquark): Maybe in the future, users could adjust the number # of instances dynamically. Then this case would not be an error. if config.resume_stopped_nodes and len(exist_instances) > config.count: - raise RuntimeError('The number of running/stopped/stopping ' - f'instances combined ({len(exist_instances)}) in ' - f'cluster "{cluster_name}" is greater than the ' - f'number requested by the user ({config.count}). ' - 'This is likely a resource leak. ' - 'Use "sky down" to terminate the cluster.') + raise RuntimeError( + 'The number of running/stopped/stopping ' + f'instances combined ({len(exist_instances)}) in ' + f'cluster "{cluster_name_on_cloud}" is greater than the ' + f'number requested by the user ({config.count}). ' + 'This is likely a resource leak. ' + 'Use "sky down" to terminate the cluster.') to_start_count = (config.count - len(running_instances) - len(pending_instances)) @@ -235,7 +236,7 @@ def get_order_key(node): if head_instance_id is None: head_instance_id = resource.create_node_tag( - cluster_name, + cluster_name_on_cloud, project_id, availability_zone, resumed_instance_ids[0], @@ -243,7 +244,7 @@ def get_order_key(node): ) if to_start_count > 0: - results = resource.create_instances(cluster_name, project_id, + results = resource.create_instances(cluster_name_on_cloud, project_id, availability_zone, config.node_config, labels, to_start_count, @@ -268,13 +269,13 @@ def get_order_key(node): return common.ProvisionRecord(provider_name='gcp', region=region, zone=availability_zone, - cluster_name=cluster_name, + cluster_name=cluster_name_on_cloud, head_instance_id=head_instance_id, resumed_instance_ids=resumed_instance_ids, created_instance_ids=created_instance_ids) -def wait_instances(region: str, cluster_name: str, +def wait_instances(region: str, cluster_name_on_cloud: str, state: Optional[status_lib.ClusterStatus]) -> None: """See sky/provision/__init__.py""" # We already wait for the instances to be running in run_instances. From 084170b5fc4d8f25875959042e103cd840caec40 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Fri, 3 Nov 2023 10:41:01 -0700 Subject: [PATCH 21/84] hide gcp package import --- sky/adaptors/gcp.py | 23 ++++++++++++++++ sky/provision/gcp/config.py | 52 ++++++++++++------------------------- 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/sky/adaptors/gcp.py b/sky/adaptors/gcp.py index 6e611ee1f2b..fd41f7d8b20 100644 --- a/sky/adaptors/gcp.py +++ b/sky/adaptors/gcp.py @@ -2,6 +2,7 @@ # pylint: disable=import-outside-toplevel import functools +import json googleapiclient = None google = None @@ -82,3 +83,25 @@ def credential_error_exception(): """CredentialError exception.""" from google.auth import exceptions return exceptions.DefaultCredentialsError + + +@import_package +def get_credentials(cred_type: str, credentials_field: str): + """Get GCP credentials.""" + from google.oauth2 import service_account + from google.oauth2.credentials import Credentials as OAuthCredentials + + if cred_type == 'service_account': + # If parsing the gcp_credentials failed, then the user likely made a + # mistake in copying the credentials into the config yaml. + try: + service_account_info = json.loads(credentials_field) + except json.decoder.JSONDecodeError: + raise RuntimeError('gcp_credentials found in cluster yaml file but ' + 'formatted improperly.') + credentials = service_account.Credentials.from_service_account_info( + service_account_info) + elif cred_type == 'credentials_token': + # Otherwise the credentials type must be credentials_token. + credentials = OAuthCredentials(credentials_field) + return credentials diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index 7b36f627041..fbb9aeb217a 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -1,15 +1,10 @@ """GCP configuration bootstrapping.""" import copy -import json import logging import time from typing import Dict, List, Set, Tuple -from google.oauth2 import service_account -from google.oauth2.credentials import Credentials as OAuthCredentials -from googleapiclient import discovery -from googleapiclient import errors - +from sky.adaptors import gcp from sky.provision import common from sky.provision.gcp import instance_utils from sky.provision.gcp.constants import FIREWALL_RULES_REQUIRED @@ -104,28 +99,28 @@ def wait_for_compute_global_operation(project_name, operation, compute): def _create_crm(gcp_credentials=None): - return discovery.build('cloudresourcemanager', - 'v1', - credentials=gcp_credentials, - cache_discovery=False) + return gcp.build('cloudresourcemanager', + 'v1', + credentials=gcp_credentials, + cache_discovery=False) def _create_iam(gcp_credentials=None): - return discovery.build('iam', - 'v1', - credentials=gcp_credentials, - cache_discovery=False) + return gcp.build('iam', + 'v1', + credentials=gcp_credentials, + cache_discovery=False) def _create_compute(gcp_credentials=None): - return discovery.build('compute', - 'v1', - credentials=gcp_credentials, - cache_discovery=False) + return gcp.build('compute', + 'v1', + credentials=gcp_credentials, + cache_discovery=False) def _create_tpu(gcp_credentials=None): - return discovery.build( + return gcp.build( 'tpu', TPU_VERSION, credentials=gcp_credentials, @@ -160,20 +155,7 @@ def construct_clients_from_provider_config(provider_config): cred_type = gcp_credentials['type'] credentials_field = gcp_credentials['credentials'] - - if cred_type == 'service_account': - # If parsing the gcp_credentials failed, then the user likely made a - # mistake in copying the credentials into the config yaml. - try: - service_account_info = json.loads(credentials_field) - except json.decoder.JSONDecodeError: - raise RuntimeError('gcp_credentials found in cluster yaml file but ' - 'formatted improperly.') - credentials = service_account.Credentials.from_service_account_info( - service_account_info) - elif cred_type == 'credentials_token': - # Otherwise the credentials type must be credentials_token. - credentials = OAuthCredentials(credentials_field) + credentials = gcp.get_credentials(cred_type, credentials_field) tpu_resource = (_create_tpu(credentials) if provider_config.get( HAS_TPU_PROVIDER_FIELD, False) else None) @@ -675,7 +657,7 @@ def _list_subnets(project_id: str, region: str, compute, filter=None): def _get_project(project_id: str, crm): try: project = crm.projects().get(projectId=project_id).execute() - except errors.HttpError as e: + except gcp.http_error_exception() as e: if e.resp.status != 403: raise project = None @@ -700,7 +682,7 @@ def _get_service_account(account: str, project_id: str, iam): try: service_account = iam.projects().serviceAccounts().get( name=full_name).execute() - except errors.HttpError as e: + except gcp.http_error_exception() as e: if e.resp.status not in [403, 404]: # SkyPilot: added 403, which means the service account doesn't exist, # or not accessible by the current account, which is fine, as we do the From 6940640d8b504cf58e0ac90c3b3a0acf9b8cf0ba Mon Sep 17 00:00:00 2001 From: Siyuan Date: Fri, 3 Nov 2023 11:39:46 -0700 Subject: [PATCH 22/84] fix --- sky/provision/gcp/instance_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index ffa4d0afe83..7f96b746ace 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -42,11 +42,11 @@ def _retry_on_http_exception( retry_interval_s: int = POLL_INTERVAL, ): """Retry a function call n-times for as long as it throws an exception.""" - from googleapiclient.errors import HttpError - - exception = HttpError def dec(func): + from googleapiclient.errors import HttpError + + exception = HttpError @functools.wraps(func) def wrapper(*args, **kwargs): From 45f00c1127793cbc94ebc2dac1d0f07f91ca759f Mon Sep 17 00:00:00 2001 From: Siyuan Date: Fri, 3 Nov 2023 11:44:26 -0700 Subject: [PATCH 23/84] fix --- sky/provision/gcp/instance_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 7f96b746ace..a12e851335f 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -44,12 +44,12 @@ def _retry_on_http_exception( """Retry a function call n-times for as long as it throws an exception.""" def dec(func): - from googleapiclient.errors import HttpError - - exception = HttpError @functools.wraps(func) def wrapper(*args, **kwargs): + from googleapiclient.errors import HttpError + + exception = HttpError def try_catch_exc(): try: From 9b5428f62bc1127bbcc2b3b866801a6ad93f7aa3 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Wed, 15 Nov 2023 12:48:58 -0800 Subject: [PATCH 24/84] update constants --- sky/provision/gcp/config.py | 108 ++++++++++----------------------- sky/provision/gcp/constants.py | 37 +++++++++++ 2 files changed, 70 insertions(+), 75 deletions(-) diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index fbb9aeb217a..32cfd6bdcb1 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -6,62 +6,18 @@ from sky.adaptors import gcp from sky.provision import common +from sky.provision.gcp import constants from sky.provision.gcp import instance_utils -from sky.provision.gcp.constants import FIREWALL_RULES_REQUIRED -from sky.provision.gcp.constants import FIREWALL_RULES_TEMPLATE -from sky.provision.gcp.constants import MAX_POLLS -from sky.provision.gcp.constants import POLL_INTERVAL -from sky.provision.gcp.constants import SKYPILOT_VPC_NAME -from sky.provision.gcp.constants import TPU_MINIMAL_PERMISSIONS -from sky.provision.gcp.constants import VM_MINIMAL_PERMISSIONS -from sky.provision.gcp.constants import VPC_TEMPLATE logger = logging.getLogger(__name__) -VERSION = 'v1' -TPU_VERSION = 'v2alpha' # change once v2 is stable - -RAY = 'ray-autoscaler' -DEFAULT_SERVICE_ACCOUNT_ID = RAY + '-sa-' + VERSION -SERVICE_ACCOUNT_EMAIL_TEMPLATE = '{account_id}@{project_id}.iam.gserviceaccount.com' -DEFAULT_SERVICE_ACCOUNT_CONFIG = { - 'displayName': f'Ray Autoscaler Service Account ({VERSION})', -} - -SKYPILOT = 'skypilot' -SKYPILOT_SERVICE_ACCOUNT_ID = SKYPILOT + '-' + VERSION -SKYPILOT_SERVICE_ACCOUNT_EMAIL_TEMPLATE = ( - '{account_id}@{project_id}.iam.gserviceaccount.com') -SKYPILOT_SERVICE_ACCOUNT_CONFIG = { - 'displayName': f'SkyPilot Service Account ({VERSION})', -} - -# Those roles will be always added. -# NOTE: `serviceAccountUser` allows the head node to create workers with -# a serviceAccount. `roleViewer` allows the head node to run bootstrap_gcp. -DEFAULT_SERVICE_ACCOUNT_ROLES = [ - 'roles/storage.objectAdmin', - 'roles/compute.admin', - 'roles/iam.serviceAccountUser', - 'roles/iam.roleViewer', -] -# Those roles will only be added if there are TPU nodes defined in config. -TPU_SERVICE_ACCOUNT_ROLES = ['roles/tpu.admin'] - -# If there are TPU nodes in config, this field will be set -# to True in config['provider']. -HAS_TPU_PROVIDER_FIELD = '_has_tpus' - -# NOTE: iam.serviceAccountUser allows the Head Node to create worker nodes -# with ServiceAccounts. - def wait_for_crm_operation(operation, crm): """Poll for cloud resource manager operation until finished.""" logger.info('wait_for_crm_operation: ' 'Waiting for operation {} to finish...'.format(operation)) - for _ in range(MAX_POLLS): + for _ in range(constants.MAX_POLLS): result = crm.operations().get(name=operation['name']).execute() if 'error' in result: raise Exception(result['error']) @@ -70,7 +26,7 @@ def wait_for_crm_operation(operation, crm): logger.info('wait_for_crm_operation: Operation done.') break - time.sleep(POLL_INTERVAL) + time.sleep(constants.POLL_INTERVAL) return result @@ -81,7 +37,7 @@ def wait_for_compute_global_operation(project_name, operation, compute): 'Waiting for operation {} to finish...'.format( operation['name'])) - for _ in range(MAX_POLLS): + for _ in range(constants.MAX_POLLS): result = (compute.globalOperations().get( project=project_name, operation=operation['name'], @@ -93,7 +49,7 @@ def wait_for_compute_global_operation(project_name, operation, compute): logger.info('wait_for_compute_global_operation: Operation done.') break - time.sleep(POLL_INTERVAL) + time.sleep(constants.POLL_INTERVAL) return result @@ -122,7 +78,7 @@ def _create_compute(gcp_credentials=None): def _create_tpu(gcp_credentials=None): return gcp.build( 'tpu', - TPU_VERSION, + constants.TPU_VERSION, credentials=gcp_credentials, cache_discovery=False, discoveryServiceUrl='https://tpu.googleapis.com/$discovery/rest', @@ -143,7 +99,7 @@ def construct_clients_from_provider_config(provider_config): 'Falling back to GOOGLE_APPLICATION_CREDENTIALS ' 'environment variable.') tpu_resource = (_create_tpu() if provider_config.get( - HAS_TPU_PROVIDER_FIELD, False) else None) + constants.HAS_TPU_PROVIDER_FIELD, False) else None) # If gcp_credentials is None, then discovery.build will search for # credentials in the local environment. return _create_crm(), _create_iam(), _create_compute(), tpu_resource @@ -158,7 +114,7 @@ def construct_clients_from_provider_config(provider_config): credentials = gcp.get_credentials(cred_type, credentials_field) tpu_resource = (_create_tpu(credentials) if provider_config.get( - HAS_TPU_PROVIDER_FIELD, False) else None) + constants.HAS_TPU_PROVIDER_FIELD, False) else None) return ( _create_crm(credentials), @@ -175,7 +131,7 @@ def bootstrap_instances( # insert that information into the provider config if instance_utils.get_node_type( config.node_config) == instance_utils.GCPNodeType.TPU: - config.provider_config[HAS_TPU_PROVIDER_FIELD] = True + config.provider_config[constants.HAS_TPU_PROVIDER_FIELD] = True crm, iam, compute, tpu = construct_clients_from_provider_config( config.provider_config) @@ -299,17 +255,17 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam): TODO: Allow the name/id of the service account to be configured """ project_id = config.provider_config['project_id'] - email = SKYPILOT_SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( - account_id=SKYPILOT_SERVICE_ACCOUNT_ID, + email = constants.SKYPILOT_SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( + account_id=constants.SKYPILOT_SERVICE_ACCOUNT_ID, project_id=project_id, ) service_account = _get_service_account(email, project_id, iam) - permissions = VM_MINIMAL_PERMISSIONS - roles = DEFAULT_SERVICE_ACCOUNT_ROLES - if config.provider_config.get(HAS_TPU_PROVIDER_FIELD, False): - roles = DEFAULT_SERVICE_ACCOUNT_ROLES + TPU_SERVICE_ACCOUNT_ROLES - permissions = VM_MINIMAL_PERMISSIONS + TPU_MINIMAL_PERMISSIONS + permissions = constants.VM_MINIMAL_PERMISSIONS + roles = constants.DEFAULT_SERVICE_ACCOUNT_ROLES + if config.provider_config.get(constants.HAS_TPU_PROVIDER_FIELD, False): + roles = constants.DEFAULT_SERVICE_ACCOUNT_ROLES + constants.TPU_SERVICE_ACCOUNT_ROLES + permissions = constants.VM_MINIMAL_PERMISSIONS + constants.TPU_MINIMAL_PERMISSIONS satisfied, policy = _is_permission_satisfied(service_account, crm, iam, permissions, roles) @@ -321,8 +277,8 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam): # and the user may not have the permissions to create the # new service account. This is to ensure that the old service # account is still usable. - email = SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( - account_id=DEFAULT_SERVICE_ACCOUNT_ID, + email = constants.SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( + account_id=constants.DEFAULT_SERVICE_ACCOUNT_ID, project_id=project_id, ) logger.info(f'_configure_iam_role: Fallback to service account {email}') @@ -340,12 +296,12 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam): elif service_account is None: logger.info('_configure_iam_role: ' 'Creating new service account {}'.format( - SKYPILOT_SERVICE_ACCOUNT_ID)) + constants.SKYPILOT_SERVICE_ACCOUNT_ID)) # SkyPilot: a GCP user without the permission to create a service # account will fail here. service_account = _create_service_account( - SKYPILOT_SERVICE_ACCOUNT_ID, - SKYPILOT_SERVICE_ACCOUNT_CONFIG, + constants.SKYPILOT_SERVICE_ACCOUNT_ID, + constants.SKYPILOT_SERVICE_ACCOUNT_CONFIG, project_id, iam, ) @@ -383,7 +339,7 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam): def _check_firewall_rules(cluster_name: str, vpc_name: str, project_id: str, compute): """Check if the firewall rules in the VPC are sufficient.""" - required_rules = FIREWALL_RULES_REQUIRED.copy() + required_rules = constants.FIREWALL_RULES_REQUIRED.copy() operation = compute.networks().getEffectiveFirewalls(project=project_id, network=vpc_name) @@ -551,24 +507,26 @@ def get_usable_vpc(cluster_name: str, config: common.ProvisionConfig): break if usable_vpc_name is None: - logger.info(f'Creating a default VPC network, {SKYPILOT_VPC_NAME}...') + logger.info( + f'Creating a default VPC network, {constants.SKYPILOT_VPC_NAME}...') # Create a SkyPilot VPC network if it doesn't exist vpc_list = _list_vpcnets(project_id, compute, - filter=f'name={SKYPILOT_VPC_NAME}') + filter=f'name={constants.SKYPILOT_VPC_NAME}') if len(vpc_list) == 0: - body = VPC_TEMPLATE.copy() - body['name'] = body['name'].format(VPC_NAME=SKYPILOT_VPC_NAME) + body = constants.VPC_TEMPLATE.copy() + body['name'] = body['name'].format( + VPC_NAME=constants.SKYPILOT_VPC_NAME) body['selfLink'] = body['selfLink'].format( - PROJ_ID=project_id, VPC_NAME=SKYPILOT_VPC_NAME) + PROJ_ID=project_id, VPC_NAME=constants.SKYPILOT_VPC_NAME) _create_vpcnet(project_id, compute, body) - _create_rules(project_id, compute, FIREWALL_RULES_TEMPLATE, - SKYPILOT_VPC_NAME) + _create_rules(project_id, compute, constants.FIREWALL_RULES_TEMPLATE, + constants.SKYPILOT_VPC_NAME) - usable_vpc_name = SKYPILOT_VPC_NAME - logger.info(f'A VPC network {SKYPILOT_VPC_NAME} created.') + usable_vpc_name = constants.SKYPILOT_VPC_NAME + logger.info(f'A VPC network {constants.SKYPILOT_VPC_NAME} created.') return usable_vpc_name diff --git a/sky/provision/gcp/constants.py b/sky/provision/gcp/constants.py index 4f70839d407..d50eb06d71e 100644 --- a/sky/provision/gcp/constants.py +++ b/sky/provision/gcp/constants.py @@ -1,3 +1,40 @@ +VERSION = 'v1' +TPU_VERSION = 'v2' # change once v2 is stable + +RAY = 'ray-autoscaler' +DEFAULT_SERVICE_ACCOUNT_ID = RAY + '-sa-' + VERSION +SERVICE_ACCOUNT_EMAIL_TEMPLATE = '{account_id}@{project_id}.iam.gserviceaccount.com' +DEFAULT_SERVICE_ACCOUNT_CONFIG = { + 'displayName': f'Ray Autoscaler Service Account ({VERSION})', +} + +SKYPILOT = 'skypilot' +SKYPILOT_SERVICE_ACCOUNT_ID = SKYPILOT + '-' + VERSION +SKYPILOT_SERVICE_ACCOUNT_EMAIL_TEMPLATE = ( + '{account_id}@{project_id}.iam.gserviceaccount.com') +SKYPILOT_SERVICE_ACCOUNT_CONFIG = { + 'displayName': f'SkyPilot Service Account ({VERSION})', +} + +# Those roles will be always added. +# NOTE: `serviceAccountUser` allows the head node to create workers with +# a serviceAccount. `roleViewer` allows the head node to run bootstrap_gcp. +DEFAULT_SERVICE_ACCOUNT_ROLES = [ + 'roles/storage.objectAdmin', + 'roles/compute.admin', + 'roles/iam.serviceAccountUser', + 'roles/iam.roleViewer', +] +# Those roles will only be added if there are TPU nodes defined in config. +TPU_SERVICE_ACCOUNT_ROLES = ['roles/tpu.admin'] + +# If there are TPU nodes in config, this field will be set +# to True in config['provider']. +HAS_TPU_PROVIDER_FIELD = '_has_tpus' + +# NOTE: iam.serviceAccountUser allows the Head Node to create worker nodes +# with ServiceAccounts. + SKYPILOT_VPC_NAME = 'skypilot-vpc' # Below parameters are from the default VPC on GCP. From c4bed4691e4e3cc9a0ffb526adcc2df0879665ac Mon Sep 17 00:00:00 2001 From: Siyuan Date: Wed, 15 Nov 2023 15:21:56 -0800 Subject: [PATCH 25/84] fix comments --- sky/provision/gcp/instance.py | 33 +++++++++++++++++------------ sky/provision/gcp/instance_utils.py | 6 +++--- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 2157f097f2f..095bffb6b43 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -11,19 +11,6 @@ from sky.provision import common from sky.provision.gcp import instance_utils -# Tag for user defined node types (e.g., m4xl_spot). This is used for multi -# node type clusters. -TAG_RAY_USER_NODE_TYPE = "ray-user-node-type" -# Hash of the node launch config, used to identify out-of-date nodes -TAG_RAY_LAUNCH_CONFIG = "ray-launch-config" -# Tag for autofilled node types for legacy cluster yamls without multi -# node type defined in the cluster configs. -NODE_TYPE_LEGACY_HEAD = "ray-legacy-head-node-type" -NODE_TYPE_LEGACY_WORKER = "ray-legacy-worker-node-type" - -# Tag that reports the current state of the node (e.g. Updating, Up-to-date) -TAG_RAY_NODE_STATUS = "ray-node-status" - logger = sky_logging.init_logger(__name__) MAX_POLLS = 12 @@ -140,6 +127,8 @@ def run_instances(region: str, cluster_name_on_cloud: str, ) if not instances: break + logger.info( + f'Waiting for {len(instances)} instances in STOPPING status') time.sleep(POLL_INTERVAL) exist_instances = resource.filter( @@ -190,7 +179,8 @@ def get_order_key(node): if stopping_instances: raise RuntimeError( - f'Some instances are being stopped during provisioning.') + 'Some instances are being stopped during provisioning. ' + 'Please wait a while and retry.') if head_instance_id is None: if running_instances: @@ -266,6 +256,21 @@ def get_order_key(node): if not instances: break + # Check if the number of running instances is the same as the requested. + instances = resource.filter( + project_id=project_id, + zone=availability_zone, + label_filters=filter_labels, + status_filters=['RUNNING'], + ) + if len(instances) != config.count: + logger.warning('The number of running instances is different from ' + 'the requested number after provisioning ' + f'(requested: {config.count}, ' + f'observed: {len(instances)}). ' + 'This could be some instances failed to start ' + 'or some resource leak.') + return common.ProvisionRecord(provider_name='gcp', region=region, zone=availability_zone, diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index a12e851335f..688ae092082 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -49,16 +49,16 @@ def dec(func): def wrapper(*args, **kwargs): from googleapiclient.errors import HttpError - exception = HttpError + exception_type = HttpError def try_catch_exc(): try: value = func(*args, **kwargs) return value except Exception as e: - if not isinstance(e, exception) or ( + if not isinstance(e, exception_type) or ( regex and not re.search(regex, str(e))): - raise e + raise return e for _ in range(max_retries): From 4f7fb160e6d96363f68e484977e2825fa4acb4b2 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Wed, 15 Nov 2023 15:39:26 -0800 Subject: [PATCH 26/84] remove unused methods --- sky/provision/gcp/instance_utils.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 688ae092082..ed7e124117c 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -1024,15 +1024,6 @@ class GCPNodeType(enum.Enum): COMPUTE = "compute" TPU = "tpu" - @staticmethod - def name_to_type(name: str): - """Provided a node name, determine the type. - - This expects the name to be in format '[NAME]-[UUID]-[TYPE]', - where [TYPE] is either 'compute' or 'tpu'. - """ - return GCPNodeType(name.split("-")[-1]) - def get_node_type(node: dict) -> GCPNodeType: """Returns node type based on the keys in ``node``. From 0647233b1c1132f27b849aa6b0c33abe74e67066 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Wed, 15 Nov 2023 16:20:17 -0800 Subject: [PATCH 27/84] fix comments --- sky/provision/gcp/config.py | 3 +++ sky/provision/gcp/instance_utils.py | 9 ++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index 32cfd6bdcb1..b849ce5e927 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -104,6 +104,9 @@ def construct_clients_from_provider_config(provider_config): # credentials in the local environment. return _create_crm(), _create_iam(), _create_compute(), tpu_resource + # Note: The following code has not been used yet, as we will never set + # `gcp_credentials` in provider_config. + # It will only be used when we allow users to specify their own credeitals. assert ('type' in gcp_credentials ), 'gcp_credentials cluster yaml field missing "type" field.' assert ('credentials' in gcp_credentials diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index ed7e124117c..80a97cec680 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -12,6 +12,7 @@ from sky.provision import common from sky.provision.gcp.constants import MAX_POLLS from sky.provision.gcp.constants import POLL_INTERVAL +from sky.utils import common_utils from sky.utils import ux_utils # Tag uniquely identifying all nodes of a cluster @@ -85,13 +86,15 @@ def _generate_node_name(cluster_name: str, node_suffix: str, The suffix is expected to be one of 'compute' or 'tpu' (as in ``GCPNodeType``). """ - suffix = f'-{uuid.uuid4().hex[:INSTANCE_NAME_UUID_LEN]}-{node_suffix}' + suffix_id = common_utils.base36_encode(uuid.uuid4().hex) + suffix = f'-{suffix_id[:INSTANCE_NAME_UUID_LEN]}-{node_suffix}' if is_head: suffix = f'-head{suffix}' else: suffix = f'-worker{suffix}' - prefix = cluster_name[:INSTANCE_NAME_MAX_LEN - len(suffix)] - return prefix + suffix + node_name = cluster_name + suffix + assert len(node_name) <= INSTANCE_NAME_MAX_LEN, cluster_name + return node_name def instance_to_handler(instance: str): From a4dbcb07c703edca598159b952593e0c8125afac Mon Sep 17 00:00:00 2001 From: Siyuan Date: Thu, 16 Nov 2023 13:58:24 -0800 Subject: [PATCH 28/84] sync 'config' & 'constants' with upstream, Nov 16 --- sky/provision/gcp/config.py | 189 +++++++++++++++++++++++---------- sky/provision/gcp/constants.py | 12 ++- 2 files changed, 144 insertions(+), 57 deletions(-) diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index b849ce5e927..1abfd4295b7 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -2,6 +2,7 @@ import copy import logging import time +import typing from typing import Dict, List, Set, Tuple from sky.adaptors import gcp @@ -11,6 +12,18 @@ logger = logging.getLogger(__name__) +if typing.TYPE_CHECKING: + import google + + +def _skypilot_log_error_and_exit_for_failover(error: str) -> None: + """Logs an message then raises a specific RuntimeError to trigger failover. + Mainly used for handling VPC/subnet errors before nodes are launched. + """ + # NOTE: keep. The backend looks for this to know no nodes are launched. + prefix = "SKYPILOT_ERROR_NO_NODES_LAUNCHED: " + raise RuntimeError(prefix + error) + def wait_for_crm_operation(operation, crm): """Poll for cloud resource manager operation until finished.""" @@ -204,11 +217,15 @@ def _is_permission_satisfied(service_account, crm, iam, required_permissions, for binding in policy['bindings']: if binding['role'] == role: if member_id not in binding['members']: + logger.info( + f"_configure_iam_role: role {role} is not attached to {member_id}..." + ) binding['members'].append(member_id) already_configured = False role_exists = True if not role_exists: + logger.info(f"_configure_iam_role: role {role} does not exist.") already_configured = False policy['bindings'].append({ 'members': [member_id], @@ -477,61 +494,119 @@ def _create_rules(project_id: str, compute, rules, VPC_NAME): wait_for_compute_global_operation(project_id, op, compute) -def get_usable_vpc(cluster_name: str, config: common.ProvisionConfig): - """Return a usable VPC. +def _network_interface_to_vpc_name(network_interface: Dict[str, str]) -> str: + """Returns the VPC name of a network interface.""" + return network_interface["network"].split("/")[-1] - If not found, create a new one with sufficient firewall rules. - """ - project_id = config.provider_config['project_id'] - _, _, compute, _ = construct_clients_from_provider_config( - config.provider_config) - # For backward compatibility, reuse the VPC if the VM is launched. - instance_dict = instance_utils.GCPComputeInstance.filter( - project_id, - config.provider_config['availability_zone'], - label_filters=None, - status_filters=None) - - if instance_dict: - instance_metadata = list(instance_dict.values())[0] - netInterfaces = instance_metadata.get('networkInterfaces', []) - if len(netInterfaces) > 0: - vpc_name = netInterfaces[0]['network'].split('/')[-1] - return vpc_name - - vpcnets_all = _list_vpcnets(project_id, compute) - - usable_vpc_name = None - for vpc in vpcnets_all: - if _check_firewall_rules(cluster_name, vpc['name'], project_id, - compute): - usable_vpc_name = vpc['name'] - break +def get_usable_vpc_and_subnet( + cluster_name: str, + region: str, + config: common.ProvisionConfig, + compute, +) -> Tuple[str, "google.cloud.compute_v1.types.compute.Subnetwork"]: + """Return a usable VPC and the subnet in it. - if usable_vpc_name is None: - logger.info( - f'Creating a default VPC network, {constants.SKYPILOT_VPC_NAME}...') + If config.provider_config['vpc_name'] is set, return the VPC with the name + (errors out if not found). When this field is set, no firewall rules + checking or overrides will take place; it is the user's responsibility to + properly set up the VPC. - # Create a SkyPilot VPC network if it doesn't exist - vpc_list = _list_vpcnets(project_id, - compute, - filter=f'name={constants.SKYPILOT_VPC_NAME}') - if len(vpc_list) == 0: - body = constants.VPC_TEMPLATE.copy() - body['name'] = body['name'].format( - VPC_NAME=constants.SKYPILOT_VPC_NAME) - body['selfLink'] = body['selfLink'].format( - PROJ_ID=project_id, VPC_NAME=constants.SKYPILOT_VPC_NAME) - _create_vpcnet(project_id, compute, body) + If not found, create a new one with sufficient firewall rules. - _create_rules(project_id, compute, constants.FIREWALL_RULES_TEMPLATE, - constants.SKYPILOT_VPC_NAME) + Returns: + vpc_name: The name of the VPC network. + subnet_name: The name of the subnet in the VPC network for the specific + region. - usable_vpc_name = constants.SKYPILOT_VPC_NAME - logger.info(f'A VPC network {constants.SKYPILOT_VPC_NAME} created.') + Raises: + RuntimeError: if the user has specified a VPC name but the VPC is not found. + """ + project_id = config.provider_config['project_id'] - return usable_vpc_name + # For existing cluster, it is ok to return a VPC and subnet not used by + # the cluster, as AWS will ignore them. + # There is a corner case where the multi-node cluster was partially + # launched, launching the cluster again can cause the nodes located on + # different VPCs, if VPCs in the project have changed. It should be fine to + # not handle this special case as we don't want to sacrifice the performance + # for every launch just for this rare case. + + specific_vpc_to_use = config.provider_config.get("vpc_name", None) + if specific_vpc_to_use is not None: + vpcnets_all = _list_vpcnets(project_id, + compute, + filter=f"name={specific_vpc_to_use}") + # On GCP, VPC names are unique, so it'd be 0 or 1 VPC found. + assert ( + len(vpcnets_all) <= 1 + ), f"{len(vpcnets_all)} VPCs found with the same name {specific_vpc_to_use}" + if len(vpcnets_all) == 1: + # Skip checking any firewall rules if the user has specified a VPC. + logger.info(f"Using user-specified VPC {specific_vpc_to_use!r}.") + subnets = _list_subnets(project_id, + region, + compute, + filter=f'(name="{specific_vpc_to_use}")') + if not subnets: + _skypilot_log_error_and_exit_for_failover( + f"No subnet for region {region} found for specified VPC {specific_vpc_to_use!r}. " + f"Check the subnets of VPC {specific_vpc_to_use!r} at https://console.cloud.google.com/networking/networks" + ) + return specific_vpc_to_use, subnets[0] + else: + # VPC with this name not found. Error out and let SkyPilot failover. + _skypilot_log_error_and_exit_for_failover( + f"No VPC with name {specific_vpc_to_use!r} is found. " + "To fix: specify a correct VPC name.") + # Should not reach here. + + subnets_all = _list_subnets(project_id, region, compute) + + # Check if VPC for subnet has sufficient firewall rules. + insufficient_vpcs = set() + for subnet in subnets_all: + vpc_name = _network_interface_to_vpc_name(subnet) + if vpc_name in insufficient_vpcs: + continue + if _check_firewall_rules(cluster_name, vpc_name, project_id, compute): + logger.info( + f"get_usable_vpc: Found a usable VPC network {vpc_name!r}.") + return vpc_name, subnet + else: + insufficient_vpcs.add(vpc_name) + + # No usable VPC found. Try to create one. + logger.info( + f"Creating a default VPC network, {constants.SKYPILOT_VPC_NAME}...") + + # Create a SkyPilot VPC network if it doesn't exist + vpc_list = _list_vpcnets(project_id, + compute, + filter=f"name={constants.SKYPILOT_VPC_NAME}") + if len(vpc_list) == 0: + body = constants.VPC_TEMPLATE.copy() + body["name"] = body["name"].format(VPC_NAME=constants.SKYPILOT_VPC_NAME) + body["selfLink"] = body["selfLink"].format( + PROJ_ID=project_id, VPC_NAME=constants.SKYPILOT_VPC_NAME) + _create_vpcnet(project_id, compute, body) + + _create_rules(project_id, compute, constants.FIREWALL_RULES_TEMPLATE, + constants.SKYPILOT_VPC_NAME) + + usable_vpc_name = constants.SKYPILOT_VPC_NAME + subnets = _list_subnets(project_id, + region, + compute, + filter=f'(name="{usable_vpc_name}")') + if not subnets: + _skypilot_log_error_and_exit_for_failover( + f"No subnet for region {region} found for generated VPC {usable_vpc_name!r}. " + "This is probably due to the region being disabled in the account/project_id." + ) + usable_subnet = subnets[0] + logger.info(f"A VPC network {constants.SKYPILOT_VPC_NAME} created.") + return usable_vpc_name, usable_subnet def _configure_subnet(region: str, cluster_name: str, @@ -546,12 +621,8 @@ def _configure_subnet(region: str, cluster_name: str, return config # SkyPilot: make sure there's a usable VPC - usable_vpc_name = get_usable_vpc(cluster_name, config) - subnets = _list_subnets(config.provider_config['project_id'], - region, - compute, - filter=f'(name="{usable_vpc_name}")') - default_subnet = subnets[0] + _, default_subnet = get_usable_vpc_and_subnet(cluster_name, region, config, + compute) default_interfaces = [{ 'subnetwork': default_subnet['selfLink'], @@ -602,10 +673,16 @@ def _list_vpcnets(project_id: str, compute, filter=None): filter=filter, ).execute()) - return response['items'] if 'items' in response else [] + return (list(sorted(response["items"], key=lambda x: x["name"])) + if "items" in response else []) -def _list_subnets(project_id: str, region: str, compute, filter=None): +def _list_subnets( + project_id: str, + region: str, + compute, + filter=None +) -> List["google.cloud.compute_v1.types.compute.Subnetwork"]: response = (compute.subnetworks().list( project=project_id, region=region, diff --git a/sky/provision/gcp/constants.py b/sky/provision/gcp/constants.py index d50eb06d71e..beed3dad976 100644 --- a/sky/provision/gcp/constants.py +++ b/sky/provision/gcp/constants.py @@ -72,9 +72,12 @@ 'IPProtocol': 'tcp', 'ports': ['22'], }], + # TODO(skypilot): some users reported that this should be relaxed (e.g., + # allowlisting only certain IPs to have ssh access). 'sourceRanges': ['0.0.0.0/0'], }, ] + # Template when creating firewall rules for a new VPC. FIREWALL_RULES_TEMPLATE = [ { @@ -110,6 +113,8 @@ 'IPProtocol': 'tcp', 'ports': ['22'], }], + # TODO(skypilot): some users reported that this should be relaxed (e.g., + # allowlisting only certain IPs to have ssh access). 'sourceRanges': ['0.0.0.0/0'], }, { @@ -127,10 +132,15 @@ ] # A list of permissions required to run SkyPilot on GCP. -# Keep this in sync with https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions.html#gcp # pylint: disable=line-too-long +# Keep this in sync with https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/gcp.html # pylint: disable=line-too-long VM_MINIMAL_PERMISSIONS = [ 'compute.disks.create', 'compute.disks.list', + # TODO(skypilot): some users reported that firewalls changes + # (create/delete/update) should be removed if VPC/firewalls are separately + # set up. It is undesirable for a normal account to have these permissions. + # Note that if these permissions are removed, opening ports (e.g., via + # `resources.ports`) would fail. 'compute.firewalls.create', 'compute.firewalls.delete', 'compute.firewalls.get', From 43bf2e3b3e0e9b997119492381b4039d3554c74b Mon Sep 17 00:00:00 2001 From: Siyuan Date: Thu, 16 Nov 2023 14:38:59 -0800 Subject: [PATCH 29/84] sync 'instace_utils' with the upstream, Nov 16 --- sky/provision/gcp/instance_utils.py | 82 +++++++++++++++++++++++++++-- 1 file changed, 77 insertions(+), 5 deletions(-) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 80a97cec680..c84d9c14c3d 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -48,9 +48,7 @@ def dec(func): @functools.wraps(func) def wrapper(*args, **kwargs): - from googleapiclient.errors import HttpError - - exception_type = HttpError + exception_type = gcp.http_error_exception() def try_catch_exc(): try: @@ -243,6 +241,10 @@ def create_instances( else: results = operations + if "sourceMachineImage" in node_config: + for _, instance_id in results: + cls.resize_disk(project_id, zone, node_config, instance_id) + return results @classmethod @@ -282,6 +284,18 @@ def get_instance_info( wait_for_operation: bool = True) -> common.InstanceInfo: raise NotImplementedError + @classmethod + def resize_disk(cls, + project_id: str, + availability_zone: str, + node_config: dict, + instance_name: str, + wait_for_operation: bool = True) -> dict: + """Resize a Google Cloud disk based on the provided configuration. + Returns the response of resize operation. + """ + raise NotImplementedError + class GCPComputeInstance(GCPInstance): """Instance handler for GCP compute instances.""" @@ -802,6 +816,50 @@ def get_instance_info( tags=result.get('labels', {}), ) + @classmethod + def resize_disk(cls, + project_id: str, + availability_zone: str, + node_config: dict, + instance_name: str, + wait_for_operation: bool = True) -> bool: + """Resize a Google Cloud disk based on the provided configuration.""" + + # Extract the specified disk size from the configuration + new_size_gb = node_config["disks"][0]["initializeParams"]["diskSizeGb"] + + # Fetch the instance details to get the disk name and current disk size + response = (cls.load_resource().instances().get( + project=project_id, + zone=availability_zone, + instance=instance_name, + ).execute()) + disk_name = response["disks"][0]["source"].split("/")[-1] + + try: + # Execute the resize request and return the response + operation = (cls.load_resource().disks().resize( + project=project_id, + zone=availability_zone, + disk=disk_name, + body={ + "sizeGb": str(new_size_gb), + }, + ).execute()) + except gcp.http_error_exception() as e: + # Catch HttpError when provided with invalid value for new disk size. + # Allowing users to create instances with the same size as the image + logger.warning(f"googleapiclient.errors.HttpError: {e.reason}") + return False + + if wait_for_operation: + result = cls.wait_for_operation(operation, project_id, + availability_zone) + else: + result = operation + + return result + class GCPTPUVMInstance(GCPInstance): """Instance handler for GCP TPU node.""" @@ -985,8 +1043,8 @@ def set_labels(cls, ).execute()) if wait_for_operation: - result = cls.wait_for_operation(project_id, availability_zone, - operation) + result = cls.wait_for_operation(operation, project_id, + availability_zone) else: result = operation @@ -1020,6 +1078,20 @@ def start_instance(cls, return result + @classmethod + def resize_disk(cls, + project_id: str, + availability_zone: str, + node_config: dict, + instance_name: str, + wait_for_operation: bool = True) -> dict: + """ + TODO: Implement the feature to attach persistent disks for TPU VMs. + The boot disk of TPU VMs is not resizable, and users need to add a + persistent disk to expand disk capacity. Related issue: #2387 + """ + return False + class GCPNodeType(enum.Enum): """Enum for GCP node types (compute & tpu)""" From 0a867ee2057804cf994fc0af80872389e2308581 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Thu, 16 Nov 2023 16:24:02 -0800 Subject: [PATCH 30/84] fix typing --- sky/provision/gcp/instance_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index c84d9c14c3d..0153009ddb0 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -290,7 +290,7 @@ def resize_disk(cls, availability_zone: str, node_config: dict, instance_name: str, - wait_for_operation: bool = True) -> dict: + wait_for_operation: bool = True) -> bool: """Resize a Google Cloud disk based on the provided configuration. Returns the response of resize operation. """ @@ -1084,7 +1084,7 @@ def resize_disk(cls, availability_zone: str, node_config: dict, instance_name: str, - wait_for_operation: bool = True) -> dict: + wait_for_operation: bool = True) -> bool: """ TODO: Implement the feature to attach persistent disks for TPU VMs. The boot disk of TPU VMs is not resizable, and users need to add a From 9ee76debcac710631b270e96da9438f3368e9368 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Thu, 16 Nov 2023 16:35:45 -0800 Subject: [PATCH 31/84] parallelize provisioning --- sky/provision/gcp/instance.py | 9 ++++++--- sky/provision/gcp/instance_utils.py | 16 +++++++--------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index 095bffb6b43..e720c06cff0 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -1,6 +1,7 @@ """GCP instance provisioning.""" import collections import copy +from multiprocessing import pool import re import time from typing import Any, Callable, Dict, Iterable, List, Optional, Type @@ -312,10 +313,12 @@ def get_cluster_info( label_filters, lambda _: ['RUNNING'], ) - instances = {} + instances: Dict[str, common.InstanceInfo] = {} for res, insts in handler_to_instances.items(): - for inst in insts: - instances[inst] = res.get_instance_info(project_id, zone, inst) + with pool.ThreadPool() as p: + inst_info = p.starmap(res.get_instance_info, + [(project_id, zone, inst) for inst in insts]) + instances.update(zip(insts, inst_info)) head_instances = _filter_instances( handlers, diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 0153009ddb0..c3b5f8f09a1 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -2,6 +2,7 @@ import copy import enum import functools +from multiprocessing import pool import re import time from typing import Any, Dict, List, Optional, Tuple @@ -224,15 +225,12 @@ def create_instances( Returns a list of tuples of (result, node_name). """ - operations = [ - cls.create_instance(cluster_name, - project_id, - zone, - node_config, - labels, - is_head_node=include_head_node and i == 0, - wait_for_operation=False) for i in range(count) - ] + with pool.ThreadPool() as p: + operations = p.starmap( + cls.create_instance, + [(cluster_name, project_id, zone, node_config, labels, + include_head_node and i == 0, False) for i in range(count)], + ) if wait_for_operation: results = [(cls.wait_for_operation(operation, project_id, From 8e584d4e23ef215425a27163a6bbae8d366c83f9 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 17 Nov 2023 06:41:17 +0000 Subject: [PATCH 32/84] Fix TPU node --- sky/backends/cloud_vm_ray_backend.py | 32 +++++++++++++++------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 14d0268da8d..e205ab351f1 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -1641,7 +1641,20 @@ def _retry_zones( config_dict['provision_record'] = provision_record config_dict['resources_vars'] = resources_vars config_dict['handle'] = handle - return config_dict + tpu_name = config_dict.get('tpu_name') + if tpu_name is None: + return config_dict + + # tpu_name will only be set when TPU node (not TPU VM) + # is required. + logger.info( + f'{colorama.Style.BRIGHT}Provisioning TPU node on ' + f'{to_provision.cloud} ' + f'{region.name}{colorama.Style.RESET_ALL}{zone_str}') + + success = self._try_provision_tpu(to_provision, config_dict) + if success: + return config_dict # NOTE: We try to cleanup the cluster even if the previous # cluster does not exist. Also we are fast at @@ -1661,17 +1674,6 @@ def _retry_zones( # NOTE: The code below in the loop should not be reachable # with the new provisioner. - tpu_name = config_dict.get('tpu_name') - if tpu_name is not None: - logger.info( - f'{colorama.Style.BRIGHT}Provisioning TPU on ' - f'{to_provision.cloud} ' - f'{region.name}{colorama.Style.RESET_ALL}{zone_str}') - - success = self._try_provision_tpu(to_provision, config_dict) - if not success: - continue - logging_info = { 'cluster_name': cluster_name, 'region_name': region.name, @@ -1795,8 +1797,8 @@ def _retry_zones( raise exceptions.ResourcesUnavailableError( message, no_failover=is_prev_cluster_healthy) - def _tpu_pod_setup(self, cluster_yaml: str, - cluster_handle: 'backends.CloudVmRayResourceHandle'): + def _tpu_vm_pod_setup(self, cluster_yaml: str, + cluster_handle: 'backends.CloudVmRayResourceHandle'): """Completes setup and start Ray cluster on TPU VM Pod nodes. This is a workaround for Ray Autoscaler where `ray up` does not @@ -2040,7 +2042,7 @@ def need_ray_up( if tpu_utils.is_tpu_vm_pod(resources): logger.info(f'{style.BRIGHT}Setting up TPU VM Pod workers...' f'{style.RESET_ALL}') - self._tpu_pod_setup(cluster_config_file, cluster_handle) + self._tpu_vm_pod_setup(cluster_config_file, cluster_handle) # Only 1 node or head node provisioning failure. if cluster_handle.launched_nodes == 1 and returncode == 0: From d713618f3baaf79b3c9e389cf5027c379256df84 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 17 Nov 2023 19:16:26 +0000 Subject: [PATCH 33/84] Fix TPU NAME env for tpu node --- sky/backends/cloud_vm_ray_backend.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index e205ab351f1..60bfb798f67 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -2973,6 +2973,10 @@ def _provision( # Update launched resources. handle.launched_resources = handle.launched_resources.copy( region=provision_record.region, zone=provision_record.zone) + + if 'tpu_name' in config_dict: + self._set_tpu_name(handle, config_dict['tpu_name']) + self._update_after_cluster_provisioned( handle, to_provision_config.prev_handle, task, prev_cluster_status, handle.external_ips(), @@ -2991,9 +2995,6 @@ def _provision( if 'docker' in config: handle.setup_docker_user(cluster_config_file) - if 'tpu_name' in config_dict: - self._set_tpu_name(handle, config_dict['tpu_name']) - # Get actual zone info and save it into handle. # NOTE: querying zones is expensive, observed 1node GCP >=4s. zone = handle.launched_resources.zone From aebb1977886675e9a23bc2d447043384c69dec16 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Fri, 17 Nov 2023 17:31:44 -0800 Subject: [PATCH 34/84] implement bulk provision --- sky/provision/gcp/instance.py | 16 +- sky/provision/gcp/instance_utils.py | 220 ++++++++++------------------ 2 files changed, 83 insertions(+), 153 deletions(-) diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index e720c06cff0..e518422f31a 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -235,16 +235,14 @@ def get_order_key(node): ) if to_start_count > 0: - results = resource.create_instances(cluster_name_on_cloud, project_id, - availability_zone, - config.node_config, labels, - to_start_count, - head_instance_id is None) + success, created_instance_ids = resource.create_instances( + cluster_name_on_cloud, project_id, availability_zone, + config.node_config, labels, to_start_count, + head_instance_id is None) + if not success: + raise RuntimeError('Failed to launch instances.') if head_instance_id is None: - head_instance_id = results[0][1] - # FIXME: it seems that success could be False sometimes, but the instances - # are started correctly. - created_instance_ids = [instance_id for success, instance_id in results] + head_instance_id = created_instance_ids[0] while True: # wait until all instances are running diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index c3b5f8f09a1..715df0c8f73 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -194,21 +194,6 @@ def add_network_tag_if_not_exist( ) -> None: raise NotImplementedError - @classmethod - def create_instance(cls, - cluster_name: str, - project_id: str, - availability_zone: str, - node_config: dict, - labels: dict, - is_head_node: bool, - wait_for_operation: bool = True) -> Tuple[dict, str]: - """Creates a single instance and returns result. - - Returns a tuple of (result, node_name). - """ - raise NotImplementedError - @classmethod def create_instances( cls, @@ -220,30 +205,12 @@ def create_instances( count: int, include_head_node: bool, wait_for_operation: bool = True, - ) -> List[Tuple[dict, str]]: + ) -> Tuple[bool, List[str]]: """Creates multiple instances and returns result. - Returns a list of tuples of (result, node_name). + Returns a tuple of (result, list[instance_names]). """ - with pool.ThreadPool() as p: - operations = p.starmap( - cls.create_instance, - [(cluster_name, project_id, zone, node_config, labels, - include_head_node and i == 0, False) for i in range(count)], - ) - - if wait_for_operation: - results = [(cls.wait_for_operation(operation, project_id, - zone), node_name) - for operation, node_name in operations] - else: - results = operations - - if "sourceMachineImage" in node_config: - for _, instance_id in results: - cls.resize_disk(project_id, zone, node_config, instance_id) - - return results + raise NotImplementedError @classmethod def start_instance(cls, @@ -589,7 +556,7 @@ def create_node_tag(cls, } else: node_tag = { - TAG_SKYPILOT_HEAD_NODE: '1', + TAG_SKYPILOT_HEAD_NODE: '0', TAG_RAY_NODE_KIND: 'worker', } cls.set_labels(project_id=project_id, @@ -637,109 +604,66 @@ def _convert_resources_to_urls( return configuration_dict - # @classmethod - # def create_instances( - # cls, - # cluster_name: str, - # project_id: str, - # zone: str, - # node_config: dict, - # labels: dict, - # count: int, - # include_head_node: bool, - # wait_for_operation: bool = True, - # ) -> List[Tuple[dict, str]]: - # config = cls._convert_resources_to_urls(project_id, zone, - # node_config) - # # removing TPU-specific default key set in config.py - # config.pop("networkConfig", None) - # names = [] - # if include_head_node: - # names.append(_generate_node_name(cluster_name, GCPNodeType.COMPUTE.value, is_head=True)) - # for _ in range(count - 1): - # names.append(_generate_node_name(cluster_name, GCPNodeType.COMPUTE.value, is_head=False)) - # else: - # for _ in range(count): - # names.append(_generate_node_name(cluster_name, GCPNodeType.COMPUTE.value, is_head=False)) - - # labels = dict(config.get("labels", {}), **labels) - - # config.update({ - # "labels": dict( - # labels, **{ - # TAG_RAY_CLUSTER_NAME: cluster_name, - # TAG_SKYPILOT_CLUSTER_NAME: cluster_name - # }), - # }) - # source_instance_template = config.pop("sourceInstanceTemplate", None) - # body = { - # 'count': count, - # 'instanceProperties': config, - # 'sourceInstanceTemplate': source_instance_template, - # 'perInstanceProperties': {n: {} for n in names} - # } - - # # Allow Google Compute Engine instance templates. - # # - # # Config example: - # # - # # ... - # # node_config: - # # sourceInstanceTemplate: global/instanceTemplates/worker-16 - # # machineType: e2-standard-16 - # # ... - # # - # # node_config parameters override matching template parameters, if any. - # # - # # https://cloud.google.com/compute/docs/instance-templates - # # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert - # operation = (cls.load_resource().instances().bulkInsert( - # project=project_id, - # zone=zone, - # body=body, - # ).execute()) - - # if wait_for_operation: - # result = cls.wait_for_operation(operation, project_id, - # zone) - # else: - # result = operation - - # return result, name - @classmethod - def create_instance(cls, - cluster_name: str, - project_id: str, - availability_zone: str, - node_config: dict, - labels: dict, - is_head_node: bool, - wait_for_operation: bool = True) -> Tuple[dict, str]: - config = cls._convert_resources_to_urls(project_id, availability_zone, - node_config) + def create_instances( + cls, + cluster_name: str, + project_id: str, + zone: str, + node_config: dict, + labels: dict, + count: int, + include_head_node: bool, + wait_for_operation: bool = True, + ) -> Tuple[bool, List[str]]: + # NOTE: The syntax for bulkInsert() is different from insert(). + # bulkInsert expects resource names without prefix. Otherwise + # it causes a 503 error. + + # TODO: We could remove "_convert_resources_to_urls". It is here + # just for possible backward compat. + config = cls._convert_resources_to_urls(project_id, zone, node_config) + + for disk in config.get('disks', []): + disk_type = disk.get('initializeParams', {}).get('diskType') + if disk_type: + disk['initializeParams']['diskType'] = disk_type.rsplit('/', + 1)[-1] + config['machineType'] = config['machineType'].rsplit('/', 1)[-1] + for accelerator in config.get("guestAccelerators", []): + accelerator["acceleratorType"] = accelerator[ + "acceleratorType"].rsplit('/', 1)[-1] + # removing TPU-specific default key set in config.py config.pop("networkConfig", None) - name = _generate_node_name(cluster_name, GCPNodeType.COMPUTE.value, - is_head_node) + + head_tag_needed = [False] * count + if include_head_node: + head_tag_needed[0] = True + + names = [] + for i in range(count): + names.append( + _generate_node_name(cluster_name, + GCPNodeType.COMPUTE.value, + is_head=head_tag_needed[i])) labels = dict(config.get("labels", {}), **labels) - labels.update({ - TAG_RAY_CLUSTER_NAME: cluster_name, - TAG_SKYPILOT_CLUSTER_NAME: cluster_name - }) - if is_head_node: - labels.update({ - TAG_SKYPILOT_HEAD_NODE: '1', - TAG_RAY_NODE_KIND: 'head', - }) - else: - labels.update({ - TAG_SKYPILOT_HEAD_NODE: '0', - TAG_RAY_NODE_KIND: 'worker', - }) - config.update({"labels": labels, "name": name}) + config.update({ + "labels": dict( + labels, **{ + TAG_RAY_CLUSTER_NAME: cluster_name, + TAG_SKYPILOT_CLUSTER_NAME: cluster_name + }), + }) + source_instance_template = config.pop("sourceInstanceTemplate", None) + body = { + 'count': count, + 'instanceProperties': config, + 'sourceInstanceTemplate': source_instance_template, + 'perInstanceProperties': {n: {} for n in names} + } # Allow Google Compute Engine instance templates. # @@ -755,21 +679,29 @@ def create_instance(cls, # # https://cloud.google.com/compute/docs/instance-templates # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert - source_instance_template = config.pop("sourceInstanceTemplate", None) - operation = (cls.load_resource().instances().insert( + operation = (cls.load_resource().instances().bulkInsert( project=project_id, - zone=availability_zone, - sourceInstanceTemplate=source_instance_template, - body=config, + zone=zone, + body=body, ).execute()) if wait_for_operation: - result = cls.wait_for_operation(operation, project_id, - availability_zone) - else: - result = operation + result = cls.load_resource().zoneOperations().wait( + project=project_id, + operation=operation['name'], + zone=zone, + ).execute() + success = result['status'] == 'DONE' + if success: + # assign labels for head node + with pool.ThreadPool() as p: + p.starmap(cls.create_node_tag, + [(cluster_name, project_id, zone, names[i], + head_tag_needed[i]) for i in range(count)]) - return result, name + return success, names + + return operation @classmethod def start_instance(cls, From b5b02466b57c087a6ed59a56c9804a031f1fa3ec Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 20 Nov 2023 20:01:49 +0000 Subject: [PATCH 35/84] refactor selflink --- sky/provision/gcp/instance_utils.py | 44 ++++++++++++++++++----------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 715df0c8f73..cf9d59076c4 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -5,7 +5,7 @@ from multiprocessing import pool import re import time -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import uuid from sky import sky_logging @@ -96,6 +96,18 @@ def _generate_node_name(cluster_name: str, node_suffix: str, return node_name +def selflink_to_name(selflink: str) -> str: + """Converts a selflink to a name. + + Args: + selflink: The selflink to convert. + + Returns: + The name of the resource. + """ + return selflink.rsplit('/', 1)[-1] + + def instance_to_handler(instance: str): instance_type = instance.split('-')[-1] if instance_type == 'compute': @@ -227,7 +239,7 @@ def set_labels(cls, availability_zone: str, node_id: str, labels: dict, - wait_for_operation: bool = True) -> dict: + wait_for_operation: bool = True) -> Union[bool, dict]: raise NotImplementedError @classmethod @@ -406,7 +418,7 @@ def get_vpc_name( ).execute() # Format: projects/PROJECT_ID/global/networks/VPC_NAME vpc_link = response['networkInterfaces'][0]['network'] - return vpc_link.split('/')[-1] + return selflink_to_name(vpc_link) except gcp.http_error_exception() as e: with ux_utils.print_exception_no_traceback(): raise ValueError( @@ -516,7 +528,7 @@ def set_labels(cls, availability_zone: str, node_id: str, labels: dict, - wait_for_operation: bool = True) -> dict: + wait_for_operation: bool = True) -> Union[bool, dict]: node = cls.load_resource().instances().get( project=project_id, instance=node_id, @@ -627,12 +639,12 @@ def create_instances( for disk in config.get('disks', []): disk_type = disk.get('initializeParams', {}).get('diskType') if disk_type: - disk['initializeParams']['diskType'] = disk_type.rsplit('/', - 1)[-1] - config['machineType'] = config['machineType'].rsplit('/', 1)[-1] + disk['initializeParams']['diskType'] = selflink_to_name( + disk_type) + config['machineType'] = selflink_to_name(config['machineType']) for accelerator in config.get("guestAccelerators", []): - accelerator["acceleratorType"] = accelerator[ - "acceleratorType"].rsplit('/', 1)[-1] + accelerator['acceleratorType'] = selflink_to_name( + accelerator['acceleratorType']) # removing TPU-specific default key set in config.py config.pop("networkConfig", None) @@ -708,7 +720,7 @@ def start_instance(cls, node_id: str, project_id: str, zone: str, - wait_for_operation: bool = True) -> dict: + wait_for_operation: bool = True) -> Union[bool, dict]: operation = (cls.load_resource().instances().start( project=project_id, zone=zone, @@ -752,7 +764,7 @@ def resize_disk(cls, availability_zone: str, node_config: dict, instance_name: str, - wait_for_operation: bool = True) -> bool: + wait_for_operation: bool = True) -> Union[bool, dict]: """Resize a Google Cloud disk based on the provided configuration.""" # Extract the specified disk size from the configuration @@ -764,7 +776,7 @@ def resize_disk(cls, zone=availability_zone, instance=instance_name, ).execute()) - disk_name = response["disks"][0]["source"].split("/")[-1] + disk_name = selflink_to_name(response["disks"][0]["source"]) try: # Execute the resize request and return the response @@ -945,7 +957,7 @@ def get_vpc_name( response = cls.load_resource().projects().locations().nodes().get( name=instance).execute() vpc_link = response['networkConfig']['network'] - return vpc_link.split('/')[-1] + return selflink_to_name(vpc_link) except gcp.http_error_exception() as e: with ux_utils.print_exception_no_traceback(): raise ValueError( @@ -958,7 +970,7 @@ def set_labels(cls, availability_zone: str, node_id: str, labels: dict, - wait_for_operation: bool = True) -> dict: + wait_for_operation: bool = True) -> Union[bool, dict]: node = cls.load_resource().projects().locations().nodes().get( name=node_id) body = { @@ -996,7 +1008,7 @@ def start_instance(cls, node_id: str, project_id: str, zone: str, - wait_for_operation: bool = True) -> dict: + wait_for_operation: bool = True) -> Union[bool, dict]: operation = (cls.load_resource().projects().locations().nodes().start( name=node_id).execute()) @@ -1014,7 +1026,7 @@ def resize_disk(cls, availability_zone: str, node_config: dict, instance_name: str, - wait_for_operation: bool = True) -> bool: + wait_for_operation: bool = True) -> Union[bool, dict]: """ TODO: Implement the feature to attach persistent disks for TPU VMs. The boot disk of TPU VMs is not resizable, and users need to add a From b02e70d4024958fdaae978e4ba135c0f9e3e4dff Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 20 Nov 2023 21:02:56 +0000 Subject: [PATCH 36/84] format --- sky/adaptors/gcp.py | 4 +- sky/provision/gcp/config.py | 159 +++++++++++++++------------- sky/provision/gcp/constants.py | 20 ++-- sky/provision/gcp/instance.py | 45 ++++---- sky/provision/gcp/instance_utils.py | 73 ++++++------- 5 files changed, 161 insertions(+), 140 deletions(-) diff --git a/sky/adaptors/gcp.py b/sky/adaptors/gcp.py index fd41f7d8b20..3835d004338 100644 --- a/sky/adaptors/gcp.py +++ b/sky/adaptors/gcp.py @@ -96,9 +96,9 @@ def get_credentials(cred_type: str, credentials_field: str): # mistake in copying the credentials into the config yaml. try: service_account_info = json.loads(credentials_field) - except json.decoder.JSONDecodeError: + except json.decoder.JSONDecodeError as e: raise RuntimeError('gcp_credentials found in cluster yaml file but ' - 'formatted improperly.') + 'formatted improperly.') from e credentials = service_account.Credentials.from_service_account_info( service_account_info) elif cred_type == 'credentials_token': diff --git a/sky/provision/gcp/config.py b/sky/provision/gcp/config.py index 1abfd4295b7..c5b804f864f 100644 --- a/sky/provision/gcp/config.py +++ b/sky/provision/gcp/config.py @@ -3,7 +3,7 @@ import logging import time import typing -from typing import Dict, List, Set, Tuple +from typing import Any, Dict, List, Set, Tuple from sky.adaptors import gcp from sky.provision import common @@ -21,7 +21,7 @@ def _skypilot_log_error_and_exit_for_failover(error: str) -> None: Mainly used for handling VPC/subnet errors before nodes are launched. """ # NOTE: keep. The backend looks for this to know no nodes are launched. - prefix = "SKYPILOT_ERROR_NO_NODES_LAUNCHED: " + prefix = 'SKYPILOT_ERROR_NO_NODES_LAUNCHED: ' raise RuntimeError(prefix + error) @@ -99,9 +99,7 @@ def _create_tpu(gcp_credentials=None): def construct_clients_from_provider_config(provider_config): - """ - Attempt to fetch and parse the JSON GCP credentials from the provider - config yaml file. + """Attempt to fetch and parse the JSON GCP credentials. tpu resource (the last element of the tuple) will be None if `_has_tpus` in provider config is not set or False. @@ -149,7 +147,7 @@ def bootstrap_instances( config.node_config) == instance_utils.GCPNodeType.TPU: config.provider_config[constants.HAS_TPU_PROVIDER_FIELD] = True - crm, iam, compute, tpu = construct_clients_from_provider_config( + crm, iam, compute, _ = construct_clients_from_provider_config( config.provider_config) # Setup a Google Cloud Platform Project. @@ -217,15 +215,14 @@ def _is_permission_satisfied(service_account, crm, iam, required_permissions, for binding in policy['bindings']: if binding['role'] == role: if member_id not in binding['members']: - logger.info( - f"_configure_iam_role: role {role} is not attached to {member_id}..." - ) + logger.info(f'_configure_iam_role: role {role} is not ' + f'attached to {member_id}...') binding['members'].append(member_id) already_configured = False role_exists = True if not role_exists: - logger.info(f"_configure_iam_role: role {role} does not exist.") + logger.info(f'_configure_iam_role: role {role} does not exist.') already_configured = False policy['bindings'].append({ 'members': [member_id], @@ -264,7 +261,7 @@ def _is_permission_satisfied(service_account, crm, iam, required_permissions, return False, policy -def _configure_iam_role(config: common.ProvisionConfig, crm, iam): +def _configure_iam_role(config: common.ProvisionConfig, crm, iam) -> dict: """Setup a gcp service account with IAM roles. Creates a gcp service acconut and binds IAM roles which allow it to control @@ -284,8 +281,10 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam): permissions = constants.VM_MINIMAL_PERMISSIONS roles = constants.DEFAULT_SERVICE_ACCOUNT_ROLES if config.provider_config.get(constants.HAS_TPU_PROVIDER_FIELD, False): - roles = constants.DEFAULT_SERVICE_ACCOUNT_ROLES + constants.TPU_SERVICE_ACCOUNT_ROLES - permissions = constants.VM_MINIMAL_PERMISSIONS + constants.TPU_MINIMAL_PERMISSIONS + roles = (constants.DEFAULT_SERVICE_ACCOUNT_ROLES + + constants.TPU_SERVICE_ACCOUNT_ROLES) + permissions = (constants.VM_MINIMAL_PERMISSIONS + + constants.TPU_MINIMAL_PERMISSIONS) satisfied, policy = _is_permission_satisfied(service_account, crm, iam, permissions, roles) @@ -343,10 +342,12 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam): # account is limited by the IAM rights specified below. 'scopes': ['https://www.googleapis.com/auth/cloud-platform'], } + iam_role: Dict[str, Any] if instance_utils.get_node_type( config.node_config) == instance_utils.GCPNodeType.TPU: - # SKY: The API for TPU VM is slightly different from normal compute instances. - # See https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#Node + # SKY: The API for TPU VM is slightly different from normal compute + # instances. + # See https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#Node # pylint: disable=line-too-long account_dict['scope'] = account_dict['scopes'] account_dict.pop('scopes') iam_role = {'serviceAccount': account_dict} @@ -368,7 +369,8 @@ def _check_firewall_rules(cluster_name: str, vpc_name: str, project_id: str, return False effective_rules = response['firewalls'] - def _merge_and_refine_rule(rules): + def _merge_and_refine_rule( + rules) -> Dict[Tuple[str, str], Dict[str, Set[int]]]: """Returns the reformatted rules from the firewall rules The function translates firewall rules fetched from the cloud provider @@ -378,30 +380,32 @@ def _merge_and_refine_rule(rules): [ { ... - "direction": "INGRESS", - "allowed": [ - {"IPProtocol": "tcp", "ports": ['80', '443']}, - {"IPProtocol": "udp", "ports": ['53']}, + 'direction': 'INGRESS', + 'allowed': [ + {'IPProtocol': 'tcp', 'ports': ['80', '443']}, + {'IPProtocol': 'udp', 'ports': ['53']}, ], - "sourceRanges": ["10.128.0.0/9"], + 'sourceRanges': ['10.128.0.0/9'], }, { ... - "direction": "INGRESS", - "allowed": [{ - "IPProtocol": "tcp", - "ports": ["22"], + 'direction': 'INGRESS', + 'allowed': [{ + 'IPProtocol': 'tcp', + 'ports': ['22'], }], - "sourceRanges": ["0.0.0.0/0"], + 'sourceRanges': ['0.0.0.0/0'], }, ] Returns: - source2rules: Dict[(direction, sourceRanges) -> Dict(protocol -> Set[ports])] - Example { - ("INGRESS", "10.128.0.0/9"): {"tcp": {80, 443}, "udp": {53}}, - ("INGRESS", "0.0.0.0/0"): {"tcp": {22}}, - } + source2rules: Dict[(direction, sourceRanges) -> + Dict(protocol -> Set[ports])] + + Example { + ('INGRESS', '10.128.0.0/9'): {'tcp': {80, 443}, 'udp': {53}}, + ('INGRESS', '0.0.0.0/0'): {'tcp': {22}}, + } """ source2rules: Dict[Tuple[str, str], Dict[str, Set[int]]] = {} source2allowed_list: Dict[Tuple[str, str], List[Dict[str, str]]] = {} @@ -452,13 +456,13 @@ def _merge_and_refine_rule(rules): port_set) return source2rules - effective_rules = _merge_and_refine_rule(effective_rules) - required_rules = _merge_and_refine_rule(required_rules) + effective_rules_map = _merge_and_refine_rule(effective_rules) + required_rules_map = _merge_and_refine_rule(required_rules) - for direction_source, allowed_req in required_rules.items(): - if direction_source not in effective_rules: + for direction_source, allowed_req in required_rules_map.items(): + if direction_source not in effective_rules_map: return False - allowed_eff = effective_rules[direction_source] + allowed_eff = effective_rules_map[direction_source] # Special case: 'all' means allowing all traffic if 'all' in allowed_eff: continue @@ -470,12 +474,12 @@ def _merge_and_refine_rule(rules): return True -def _create_rules(project_id: str, compute, rules, VPC_NAME): +def _create_rules(project_id: str, compute, rules, vpc_name): opertaions = [] for rule in rules: # Query firewall rule by its name (unique in a project). # If the rule already exists, delete it first. - rule_name = rule['name'].format(VPC_NAME=VPC_NAME) + rule_name = rule['name'].format(VPC_NAME=vpc_name) rule_list = _list_firewall_rules(project_id, compute, filter=f'(name={rule_name})') @@ -483,11 +487,11 @@ def _create_rules(project_id: str, compute, rules, VPC_NAME): _delete_firewall_rule(project_id, compute, rule_name) body = rule.copy() - body['name'] = body['name'].format(VPC_NAME=VPC_NAME) + body['name'] = body['name'].format(VPC_NAME=vpc_name) body['network'] = body['network'].format(PROJ_ID=project_id, - VPC_NAME=VPC_NAME) + VPC_NAME=vpc_name) body['selfLink'] = body['selfLink'].format(PROJ_ID=project_id, - VPC_NAME=VPC_NAME) + VPC_NAME=vpc_name) op = compute.firewalls().insert(project=project_id, body=body).execute() opertaions.append(op) for op in opertaions: @@ -496,7 +500,7 @@ def _create_rules(project_id: str, compute, rules, VPC_NAME): def _network_interface_to_vpc_name(network_interface: Dict[str, str]) -> str: """Returns the VPC name of a network interface.""" - return network_interface["network"].split("/")[-1] + return network_interface['network'].split('/')[-1] def get_usable_vpc_and_subnet( @@ -504,7 +508,7 @@ def get_usable_vpc_and_subnet( region: str, config: common.ProvisionConfig, compute, -) -> Tuple[str, "google.cloud.compute_v1.types.compute.Subnetwork"]: +) -> Tuple[str, 'google.cloud.compute_v1.types.compute.Subnetwork']: """Return a usable VPC and the subnet in it. If config.provider_config['vpc_name'] is set, return the VPC with the name @@ -520,7 +524,8 @@ def get_usable_vpc_and_subnet( region. Raises: - RuntimeError: if the user has specified a VPC name but the VPC is not found. + RuntimeError: if the user has specified a VPC name but the VPC is not + found. """ project_id = config.provider_config['project_id'] @@ -532,33 +537,34 @@ def get_usable_vpc_and_subnet( # not handle this special case as we don't want to sacrifice the performance # for every launch just for this rare case. - specific_vpc_to_use = config.provider_config.get("vpc_name", None) + specific_vpc_to_use = config.provider_config.get('vpc_name', None) if specific_vpc_to_use is not None: vpcnets_all = _list_vpcnets(project_id, compute, - filter=f"name={specific_vpc_to_use}") + filter=f'name={specific_vpc_to_use}') # On GCP, VPC names are unique, so it'd be 0 or 1 VPC found. - assert ( - len(vpcnets_all) <= 1 - ), f"{len(vpcnets_all)} VPCs found with the same name {specific_vpc_to_use}" + assert (len(vpcnets_all) <= + 1), (f'{len(vpcnets_all)} VPCs found with the same name ' + f'{specific_vpc_to_use}') if len(vpcnets_all) == 1: # Skip checking any firewall rules if the user has specified a VPC. - logger.info(f"Using user-specified VPC {specific_vpc_to_use!r}.") + logger.info(f'Using user-specified VPC {specific_vpc_to_use!r}.') subnets = _list_subnets(project_id, region, compute, filter=f'(name="{specific_vpc_to_use}")') if not subnets: _skypilot_log_error_and_exit_for_failover( - f"No subnet for region {region} found for specified VPC {specific_vpc_to_use!r}. " - f"Check the subnets of VPC {specific_vpc_to_use!r} at https://console.cloud.google.com/networking/networks" - ) + f'No subnet for region {region} found for specified VPC ' + f'{specific_vpc_to_use!r}. ' + f'Check the subnets of VPC {specific_vpc_to_use!r} at ' + 'https://console.cloud.google.com/networking/networks') return specific_vpc_to_use, subnets[0] else: # VPC with this name not found. Error out and let SkyPilot failover. _skypilot_log_error_and_exit_for_failover( - f"No VPC with name {specific_vpc_to_use!r} is found. " - "To fix: specify a correct VPC name.") + f'No VPC with name {specific_vpc_to_use!r} is found. ' + 'To fix: specify a correct VPC name.') # Should not reach here. subnets_all = _list_subnets(project_id, region, compute) @@ -571,23 +577,23 @@ def get_usable_vpc_and_subnet( continue if _check_firewall_rules(cluster_name, vpc_name, project_id, compute): logger.info( - f"get_usable_vpc: Found a usable VPC network {vpc_name!r}.") + f'get_usable_vpc: Found a usable VPC network {vpc_name!r}.') return vpc_name, subnet else: insufficient_vpcs.add(vpc_name) # No usable VPC found. Try to create one. logger.info( - f"Creating a default VPC network, {constants.SKYPILOT_VPC_NAME}...") + f'Creating a default VPC network, {constants.SKYPILOT_VPC_NAME}...') # Create a SkyPilot VPC network if it doesn't exist vpc_list = _list_vpcnets(project_id, compute, - filter=f"name={constants.SKYPILOT_VPC_NAME}") + filter=f'name={constants.SKYPILOT_VPC_NAME}') if len(vpc_list) == 0: body = constants.VPC_TEMPLATE.copy() - body["name"] = body["name"].format(VPC_NAME=constants.SKYPILOT_VPC_NAME) - body["selfLink"] = body["selfLink"].format( + body['name'] = body['name'].format(VPC_NAME=constants.SKYPILOT_VPC_NAME) + body['selfLink'] = body['selfLink'].format( PROJ_ID=project_id, VPC_NAME=constants.SKYPILOT_VPC_NAME) _create_vpcnet(project_id, compute, body) @@ -601,11 +607,11 @@ def get_usable_vpc_and_subnet( filter=f'(name="{usable_vpc_name}")') if not subnets: _skypilot_log_error_and_exit_for_failover( - f"No subnet for region {region} found for generated VPC {usable_vpc_name!r}. " - "This is probably due to the region being disabled in the account/project_id." - ) + f'No subnet for region {region} found for generated VPC ' + f'{usable_vpc_name!r}. This is probably due to the region being ' + 'disabled in the account/project_id.') usable_subnet = subnets[0] - logger.info(f"A VPC network {constants.SKYPILOT_VPC_NAME} created.") + logger.info(f'A VPC network {constants.SKYPILOT_VPC_NAME} created.') return usable_vpc_name, usable_subnet @@ -652,6 +658,7 @@ def _delete_firewall_rule(project_id: str, compute, name): return response +# pylint: disable=redefined-builtin def _list_firewall_rules(project_id, compute, filter=None): response = (compute.firewalls().list( project=project_id, @@ -667,22 +674,23 @@ def _create_vpcnet(project_id: str, compute, body): return response -def _list_vpcnets(project_id: str, compute, filter=None): +def _list_vpcnets(project_id: str, compute, filter=None): # pylint: disable=redefined-builtin response = (compute.networks().list( project=project_id, filter=filter, ).execute()) - return (list(sorted(response["items"], key=lambda x: x["name"])) - if "items" in response else []) + return (list(sorted(response['items'], key=lambda x: x['name'])) + if 'items' in response else []) def _list_subnets( - project_id: str, - region: str, - compute, - filter=None -) -> List["google.cloud.compute_v1.types.compute.Subnetwork"]: + project_id: str, + region: str, + compute, + # pylint: disable=redefined-builtin + filter=None +) -> List['google.cloud.compute_v1.types.compute.Subnetwork']: response = (compute.subnetworks().list( project=project_id, region=region, @@ -722,9 +730,9 @@ def _get_service_account(account: str, project_id: str, iam): name=full_name).execute() except gcp.http_error_exception() as e: if e.resp.status not in [403, 404]: - # SkyPilot: added 403, which means the service account doesn't exist, - # or not accessible by the current account, which is fine, as we do the - # fallback in the caller. + # SkyPilot: added 403, which means the service account doesn't + # exist, or not accessible by the current account, which is fine, as + # we do the fallback in the caller. raise service_account = None @@ -746,6 +754,7 @@ def _create_service_account(account_id: str, account_config, project_id: str, def _add_iam_policy_binding(service_account, policy, crm, iam): """Add new IAM roles for the service account.""" + del iam project_id = service_account['projectId'] result = (crm.projects().setIamPolicy( diff --git a/sky/provision/gcp/constants.py b/sky/provision/gcp/constants.py index beed3dad976..21efb9d8377 100644 --- a/sky/provision/gcp/constants.py +++ b/sky/provision/gcp/constants.py @@ -1,9 +1,12 @@ +"""Constants used by the GCP provisioner.""" + VERSION = 'v1' TPU_VERSION = 'v2' # change once v2 is stable RAY = 'ray-autoscaler' DEFAULT_SERVICE_ACCOUNT_ID = RAY + '-sa-' + VERSION -SERVICE_ACCOUNT_EMAIL_TEMPLATE = '{account_id}@{project_id}.iam.gserviceaccount.com' +SERVICE_ACCOUNT_EMAIL_TEMPLATE = ( + '{account_id}@{project_id}.iam.gserviceaccount.com') DEFAULT_SERVICE_ACCOUNT_CONFIG = { 'displayName': f'Ray Autoscaler Service Account ({VERSION})', } @@ -39,7 +42,7 @@ # Below parameters are from the default VPC on GCP. # https://cloud.google.com/vpc/docs/firewalls#more_rules_default_vpc -VPC_TEMPLATE = { +VPC_TEMPLATE: dict = { 'name': '{VPC_NAME}', 'selfLink': 'projects/{PROJ_ID}/global/networks/{VPC_NAME}', 'autoCreateSubnetworks': True, @@ -82,9 +85,11 @@ FIREWALL_RULES_TEMPLATE = [ { 'name': '{VPC_NAME}-allow-custom', - 'description': 'Allows connection from any source to any instance on the network using custom protocols.', + 'description': ('Allows connection from any source to any instance on ' + 'the network using custom protocols.'), 'network': 'projects/{PROJ_ID}/global/networks/{VPC_NAME}', - 'selfLink': 'projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-custom', + 'selfLink': + ('projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-custom'), 'direction': 'INGRESS', 'priority': 65534, 'allowed': [ @@ -104,7 +109,9 @@ }, { 'name': '{VPC_NAME}-allow-ssh', - 'description': 'Allows TCP connections from any source to any instance on the network using port 22.', + 'description': + ('Allows TCP connections from any source to any instance on the ' + 'network using port 22.'), 'network': 'projects/{PROJ_ID}/global/networks/{VPC_NAME}', 'selfLink': 'projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-ssh', 'direction': 'INGRESS', @@ -119,7 +126,8 @@ }, { 'name': '{VPC_NAME}-allow-icmp', - 'description': 'Allows ICMP connections from any source to any instance on the network.', + 'description': ('Allows ICMP connections from any source to any ' + 'instance on the network.'), 'network': 'projects/{PROJ_ID}/global/networks/{VPC_NAME}', 'selfLink': 'projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-icmp', 'direction': 'INGRESS', diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index e518422f31a..a27843cfefa 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -94,7 +94,7 @@ def run_instances(region: str, cluster_name_on_cloud: str, """See sky/provision/__init__.py""" # NOTE: although google cloud instances have IDs, but they are # not used for indexing. Instead, we use the instance name. - labels = config.tags # gcp uses "labels" instead of aws "tags" + labels = config.tags # gcp uses 'labels' instead of aws 'tags' labels = dict(sorted(copy.deepcopy(labels).items())) resumed_instance_ids: List[str] = [] created_instance_ids: List[str] = [] @@ -103,19 +103,20 @@ def run_instances(region: str, cluster_name_on_cloud: str, project_id = config.provider_config['project_id'] availability_zone = config.provider_config['availability_zone'] - # SKY: "TERMINATED" for compute VM, "STOPPED" for TPU VM - # "STOPPING" means the VM is being stopped, which needs + # SKY: 'TERMINATED' for compute VM, 'STOPPED' for TPU VM + # 'STOPPING' means the VM is being stopped, which needs # to be included to avoid creating a new VM. + resource: Type[instance_utils.GCPInstance] if node_type == instance_utils.GCPNodeType.COMPUTE: resource = instance_utils.GCPComputeInstance - STOPPED_STATUS = 'TERMINATED' + stopped_status = 'TERMINATED' elif node_type == instance_utils.GCPNodeType.TPU: resource = instance_utils.GCPTPUVMInstance - STOPPED_STATUS = 'STOPPED' + stopped_status = 'STOPPED' else: raise ValueError(f'Unknown node type {node_type}') - PENDING_STATUS = ['PROVISIONING', 'STAGING'] + pending_status = ['PROVISIONING', 'STAGING'] filter_labels = {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} # wait until all stopping instances are stopped/terminated @@ -147,36 +148,36 @@ def run_instances(region: str, cluster_name_on_cloud: str, stopping_instances = [] stopped_instances = [] - # SkyPilot: We try to use the instances with the same matching launch_config first. If - # there is not enough instances with matching launch_config, we then use all the - # instances with the same matching launch_config plus some instances with wrong - # launch_config. + # SkyPilot: We try to use the instances with the same matching launch_config + # first. If there is not enough instances with matching launch_config, we + # then use all the instances with the same matching launch_config plus some + # instances with wrong launch_config. def get_order_key(node): - import datetime + import datetime # pylint: disable=import-outside-toplevel - timestamp = node.get("lastStartTimestamp") + timestamp = node.get('lastStartTimestamp') if timestamp is not None: return datetime.datetime.strptime(timestamp, - "%Y-%m-%dT%H:%M:%S.%f%z") + '%Y-%m-%dT%H:%M:%S.%f%z') return node['id'] for inst in exist_instances: state = inst['status'] - if state in PENDING_STATUS: + if state in pending_status: pending_instances.append(inst) elif state == 'RUNNING': running_instances.append(inst) elif state == 'STOPPING': stopping_instances.append(inst) - elif state == STOPPED_STATUS: + elif state == stopped_status: stopped_instances.append(inst) else: raise RuntimeError(f'Unsupported state "{state}".') - pending_instances.sort(key=lambda n: get_order_key(n), reverse=True) - running_instances.sort(key=lambda n: get_order_key(n), reverse=True) - stopping_instances.sort(key=lambda n: get_order_key(n), reverse=True) - stopped_instances.sort(key=lambda n: get_order_key(n), reverse=True) + pending_instances.sort(key=get_order_key, reverse=True) + running_instances.sort(key=get_order_key, reverse=True) + stopping_instances.sort(key=get_order_key, reverse=True) + stopped_instances.sort(key=get_order_key, reverse=True) if stopping_instances: raise RuntimeError( @@ -250,7 +251,7 @@ def get_order_key(node): project_id=project_id, zone=availability_zone, label_filters=filter_labels, - status_filters=PENDING_STATUS, + status_filters=pending_status, ) if not instances: break @@ -270,6 +271,7 @@ def get_order_key(node): 'This could be some instances failed to start ' 'or some resource leak.') + assert head_instance_id is not None, 'head_instance_id is None' return common.ProvisionRecord(provider_name='gcp', region=region, zone=availability_zone, @@ -282,9 +284,9 @@ def get_order_key(node): def wait_instances(region: str, cluster_name_on_cloud: str, state: Optional[status_lib.ClusterStatus]) -> None: """See sky/provision/__init__.py""" + del region, cluster_name_on_cloud, state # We already wait for the instances to be running in run_instances. # So we don't need to wait here. - return def get_cluster_info( @@ -292,6 +294,7 @@ def get_cluster_info( cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: """See sky/provision/__init__.py""" + del region assert provider_config is not None, cluster_name_on_cloud zone = provider_config['availability_zone'] project_id = provider_config['project_id'] diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index cf9d59076c4..4e472a24930 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -18,9 +18,9 @@ # Tag uniquely identifying all nodes of a cluster TAG_SKYPILOT_CLUSTER_NAME = 'skypilot-cluster-name' -TAG_RAY_CLUSTER_NAME = "ray-cluster-name" +TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' # Tag for the name of the node -TAG_RAY_NODE_NAME = "ray-node-name" +TAG_RAY_NODE_NAME = 'ray-node-name' INSTANCE_NAME_MAX_LEN = 64 INSTANCE_NAME_UUID_LEN = 8 TAG_SKYPILOT_HEAD_NODE = 'skypilot-head-node' @@ -55,7 +55,7 @@ def try_catch_exc(): try: value = func(*args, **kwargs) return value - except Exception as e: + except Exception as e: # pylint: disable=broad-except if not isinstance(e, exception_type) or ( regex and not re.search(regex, str(e))): raise @@ -229,7 +229,7 @@ def start_instance(cls, node_id: str, project_id: str, zone: str, - wait_for_operation: bool = True) -> dict: + wait_for_operation: bool = True) -> Union[bool, dict]: """Start a stopped instance.""" raise NotImplementedError @@ -267,7 +267,7 @@ def resize_disk(cls, availability_zone: str, node_config: dict, instance_name: str, - wait_for_operation: bool = True) -> bool: + wait_for_operation: bool = True) -> Union[bool, dict]: """Resize a Google Cloud disk based on the provided configuration. Returns the response of resize operation. """ @@ -535,8 +535,8 @@ def set_labels(cls, zone=availability_zone, ).execute() body = { - "labels": dict(node["labels"], **labels), - "labelFingerprint": node["labelFingerprint"], + 'labels': dict(node['labels'], **labels), + 'labelFingerprint': node['labelFingerprint'], } operation = (cls.load_resource().instances().setLabels( project=project_id, @@ -596,19 +596,19 @@ def _convert_resources_to_urls( `acceleratorType`. """ configuration_dict = copy.deepcopy(configuration_dict) - existing_machine_type = configuration_dict["machineType"] - if not re.search(".*/machineTypes/.*", existing_machine_type): + existing_machine_type = configuration_dict['machineType'] + if not re.search('.*/machineTypes/.*', existing_machine_type): configuration_dict[ - "machineType"] = "zones/{zone}/machineTypes/{machine_type}".format( + 'machineType'] = 'zones/{zone}/machineTypes/{machine_type}'.format( zone=availability_zone, - machine_type=configuration_dict["machineType"], + machine_type=configuration_dict['machineType'], ) - for accelerator in configuration_dict.get("guestAccelerators", []): - gpu_type = accelerator["acceleratorType"] - if not re.search(".*/acceleratorTypes/.*", gpu_type): + for accelerator in configuration_dict.get('guestAccelerators', []): + gpu_type = accelerator['acceleratorType'] + if not re.search('.*/acceleratorTypes/.*', gpu_type): accelerator[ - "acceleratorType"] = "projects/{project}/zones/{zone}/acceleratorTypes/{accelerator}".format( # noqa: E501 + 'acceleratorType'] = 'projects/{project}/zones/{zone}/acceleratorTypes/{accelerator}'.format( # noqa: E501 project=project_id, zone=availability_zone, accelerator=gpu_type, @@ -632,7 +632,7 @@ def create_instances( # bulkInsert expects resource names without prefix. Otherwise # it causes a 503 error. - # TODO: We could remove "_convert_resources_to_urls". It is here + # TODO: We could remove '_convert_resources_to_urls'. It is here # just for possible backward compat. config = cls._convert_resources_to_urls(project_id, zone, node_config) @@ -642,12 +642,12 @@ def create_instances( disk['initializeParams']['diskType'] = selflink_to_name( disk_type) config['machineType'] = selflink_to_name(config['machineType']) - for accelerator in config.get("guestAccelerators", []): + for accelerator in config.get('guestAccelerators', []): accelerator['acceleratorType'] = selflink_to_name( accelerator['acceleratorType']) # removing TPU-specific default key set in config.py - config.pop("networkConfig", None) + config.pop('networkConfig', None) head_tag_needed = [False] * count if include_head_node: @@ -660,16 +660,16 @@ def create_instances( GCPNodeType.COMPUTE.value, is_head=head_tag_needed[i])) - labels = dict(config.get("labels", {}), **labels) + labels = dict(config.get('labels', {}), **labels) config.update({ - "labels": dict( + 'labels': dict( labels, **{ TAG_RAY_CLUSTER_NAME: cluster_name, TAG_SKYPILOT_CLUSTER_NAME: cluster_name }), }) - source_instance_template = config.pop("sourceInstanceTemplate", None) + source_instance_template = config.pop('sourceInstanceTemplate', None) body = { 'count': count, 'instanceProperties': config, @@ -746,10 +746,10 @@ def get_instance_info( zone=availability_zone, instance=instance_id, ).execute() - external_ip = (result.get("networkInterfaces", - [{}])[0].get("accessConfigs", - [{}])[0].get("natIP", None)) - internal_ip = result.get("networkInterfaces", [{}])[0].get("networkIP") + external_ip = (result.get('networkInterfaces', + [{}])[0].get('accessConfigs', + [{}])[0].get('natIP', None)) + internal_ip = result.get('networkInterfaces', [{}])[0].get('networkIP') return common.InstanceInfo( instance_id=instance_id, @@ -768,7 +768,7 @@ def resize_disk(cls, """Resize a Google Cloud disk based on the provided configuration.""" # Extract the specified disk size from the configuration - new_size_gb = node_config["disks"][0]["initializeParams"]["diskSizeGb"] + new_size_gb = node_config['disks'][0]['initializeParams']['diskSizeGb'] # Fetch the instance details to get the disk name and current disk size response = (cls.load_resource().instances().get( @@ -776,7 +776,7 @@ def resize_disk(cls, zone=availability_zone, instance=instance_name, ).execute()) - disk_name = selflink_to_name(response["disks"][0]["source"]) + disk_name = selflink_to_name(response['disks'][0]['source']) try: # Execute the resize request and return the response @@ -785,13 +785,13 @@ def resize_disk(cls, zone=availability_zone, disk=disk_name, body={ - "sizeGb": str(new_size_gb), + 'sizeGb': str(new_size_gb), }, ).execute()) except gcp.http_error_exception() as e: # Catch HttpError when provided with invalid value for new disk size. # Allowing users to create instances with the same size as the image - logger.warning(f"googleapiclient.errors.HttpError: {e.reason}") + logger.warning(f'googleapiclient.errors.HttpError: {e.reason}') return False if wait_for_operation: @@ -964,7 +964,7 @@ def get_vpc_name( f'Failed to get VPC name for instance {instance}') from e @classmethod - @_retry_on_http_exception("unable to queue the operation") + @_retry_on_http_exception('unable to queue the operation') def set_labels(cls, project_id: str, availability_zone: str, @@ -974,9 +974,9 @@ def set_labels(cls, node = cls.load_resource().projects().locations().nodes().get( name=node_id) body = { - "labels": dict(node["labels"], **labels), + 'labels': dict(node['labels'], **labels), } - update_mask = "labels" + update_mask = 'labels' operation = (cls.load_resource().projects().locations().nodes().patch( name=node_id, @@ -1012,7 +1012,7 @@ def start_instance(cls, operation = (cls.load_resource().projects().locations().nodes().start( name=node_id).execute()) - # FIXME: original implementation has the "max_polls=MAX_POLLS" option. + # FIXME: original implementation has the 'max_polls=MAX_POLLS' option. if wait_for_operation: result = cls.wait_for_operation(operation, project_id, zone) else: @@ -1027,7 +1027,8 @@ def resize_disk(cls, node_config: dict, instance_name: str, wait_for_operation: bool = True) -> Union[bool, dict]: - """ + """Resize the disk a machine image with a different size is used. + TODO: Implement the feature to attach persistent disks for TPU VMs. The boot disk of TPU VMs is not resizable, and users need to add a persistent disk to expand disk capacity. Related issue: #2387 @@ -1038,8 +1039,8 @@ def resize_disk(cls, class GCPNodeType(enum.Enum): """Enum for GCP node types (compute & tpu)""" - COMPUTE = "compute" - TPU = "tpu" + COMPUTE = 'compute' + TPU = 'tpu' def get_node_type(node: dict) -> GCPNodeType: From 01ba45f1166173f604792e88419c731e5a88272a Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Tue, 21 Nov 2023 11:48:22 +0000 Subject: [PATCH 37/84] reduce the sleep time for autostop --- tests/test_smoke.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 686ed24d1e8..700762d3546 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -1599,7 +1599,7 @@ def test_autostop(generic_cloud: str): f'sky status | grep {name} | grep "1m"', # Ensure the cluster is not stopped early. - 'sleep 45', + 'sleep 40', f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', # Ensure the cluster is STOPPED. @@ -1653,7 +1653,7 @@ def test_autodown(generic_cloud: str): # Ensure autostop is set. f'sky status | grep {name} | grep "1m (down)"', # Ensure the cluster is not terminated early. - 'sleep 45', + 'sleep 40', f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', # Ensure the cluster is terminated. 'sleep 200', From d90c22b8c347789ccb48ad29a0d0b278271acf90 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 26 Nov 2023 07:32:18 +0000 Subject: [PATCH 38/84] provisioner version refactoring --- sky/backends/cloud_vm_ray_backend.py | 6 ++---- sky/clouds/aws.py | 2 ++ sky/clouds/cloud.py | 9 +++++++++ sky/clouds/gcp.py | 2 ++ sky/skylet/events.py | 4 +++- 5 files changed, 18 insertions(+), 5 deletions(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 60bfb798f67..affae4e1946 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -1611,10 +1611,8 @@ def _retry_zones( global_user_state.set_owner_identity_for_cluster( cluster_name, cloud_user_identity) - if isinstance( - to_provision.cloud, - (clouds.AWS, - clouds.GCP)) and not tpu_utils.is_tpu_vm(to_provision): + if (to_provision.cloud.PROVISIONER_VERSION >= 2 and + not tpu_utils.is_tpu_vm(to_provision)): # Use the new provisioner for AWS. # TODO (suquark): Gradually move the other clouds to # the new provisioner once they are ready. diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index fbe3947a9d6..ec9db79963f 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -113,6 +113,8 @@ class AWS(clouds.Cloud): 'https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-quickstart.html' # pylint: disable=line-too-long ) + PROVISIONER_VERSION = 2 + @classmethod def _cloud_unsupported_features( cls) -> Dict[clouds.CloudImplementationFeatures, str]: diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index 625694f8780..fd2d3f01354 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -60,6 +60,8 @@ class Cloud: _REPR = '' _DEFAULT_DISK_TIER = 'medium' + PROVISIONER_VERSION = 1 + @classmethod def _cloud_unsupported_features( cls) -> Dict[CloudImplementationFeatures, str]: @@ -679,3 +681,10 @@ def delete_image(cls, image_id: str, region: Optional[str]) -> None: def __repr__(self): return self._REPR + + def __setstate__(self, state): + state.pop('PROVISIONER_VERSION', None) + self.__dict__.update(state) + # Make sure the provisioner version is always the latest. + # pylint: disable=invalid-name + self.PROVISIONER_VERSION = type(self).PROVISIONER_VERSION diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index edb67a94f3e..05a979c4fcf 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -225,6 +225,8 @@ class GCP(clouds.Cloud): 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#google-cloud-platform-gcp' # pylint: disable=line-too-long ) + PROVISIONER_VERSION = 2 + def __init__(self): super().__init__() diff --git a/sky/skylet/events.py b/sky/skylet/events.py index ade43104048..12b8d28a3af 100644 --- a/sky/skylet/events.py +++ b/sky/skylet/events.py @@ -13,6 +13,7 @@ from sky import sky_logging from sky.backends import backend_utils from sky.backends import cloud_vm_ray_backend +from sky.clouds import cloud_registry from sky.serve import serve_utils from sky.skylet import autostop_lib from sky.skylet import job_lib @@ -149,7 +150,8 @@ def _stop_cluster(self, autostop_config): provider_module) assert provider_search is not None, config provider_name = provider_search.group(1).lower() - if provider_name in ('aws', 'gcp'): + if (cloud_registry.CLOUD_REGISTRY.from_str( + provider_name).PROVISIONER_VERSION >= 2): logger.info('Using new provisioner to stop the cluster.') self._stop_cluster_with_new_provisioner(autostop_config, config, provider_name) From 2df5d394c02cce15ecc6234a6a3be6f55f07456f Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 27 Nov 2023 00:02:02 +0000 Subject: [PATCH 39/84] refactor --- sky/backends/cloud_vm_ray_backend.py | 5 +++-- sky/clouds/__init__.py | 1 + sky/clouds/aws.py | 2 +- sky/clouds/azure.py | 2 ++ sky/clouds/cloud.py | 18 +++++++++++++++++- sky/clouds/gcp.py | 2 +- sky/skylet/events.py | 4 +++- 7 files changed, 28 insertions(+), 6 deletions(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index affae4e1946..322912ad15d 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -1611,7 +1611,8 @@ def _retry_zones( global_user_state.set_owner_identity_for_cluster( cluster_name, cloud_user_identity) - if (to_provision.cloud.PROVISIONER_VERSION >= 2 and + if (to_provision.cloud.PROVISIONER_VERSION + == clouds.ProvisionerVersion.SKYPILOT and not tpu_utils.is_tpu_vm(to_provision)): # Use the new provisioner for AWS. # TODO (suquark): Gradually move the other clouds to @@ -3957,7 +3958,7 @@ def teardown_no_lock(self, stderr = '' # Use the new provisioner for AWS. - if isinstance(cloud, (clouds.AWS, clouds.GCP)): + if (cloud.PROVISIONER_VERSION >= clouds.ProvisionerVersion.RAY_PROVISIONER_SKYPILOT_TERMINATOR): # Stop the ray autoscaler first to avoid the head node trying to # re-launch the worker nodes, during the termination of the # cluster. diff --git a/sky/clouds/__init__.py b/sky/clouds/__init__.py index d3d8aab0d9f..0fc4f8dc35e 100644 --- a/sky/clouds/__init__.py +++ b/sky/clouds/__init__.py @@ -1,6 +1,7 @@ """Clouds in Sky.""" from sky.clouds.cloud import Cloud from sky.clouds.cloud import CloudImplementationFeatures +from sky.clouds.cloud import ProvisionerVersion from sky.clouds.cloud import Region from sky.clouds.cloud import Zone from sky.clouds.cloud_registry import CLOUD_REGISTRY diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index ec9db79963f..cfb9be87846 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -113,7 +113,7 @@ class AWS(clouds.Cloud): 'https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-quickstart.html' # pylint: disable=line-too-long ) - PROVISIONER_VERSION = 2 + PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT @classmethod def _cloud_unsupported_features( diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index 41f6862d9df..e649bd8a368 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -60,6 +60,8 @@ class Azure(clouds.Cloud): _INDENT_PREFIX = ' ' * 4 + PROVISIONER_VERSION = clouds.ProvisionerVersion.RAY_PROVISIONER_SKYPILOT_TERMINATOR + @classmethod def _cloud_unsupported_features( cls) -> Dict[clouds.CloudImplementationFeatures, str]: diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index fd2d3f01354..d9e0b6c196e 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -54,13 +54,29 @@ class Zone(collections.namedtuple('Zone', ['name'])): region: Region +class ProvisionerVersion(enum.Enum): + """The version of the provisioner. + + 1: ray node provider based implementation + 2: ray node provider for provisioning and SkyPilot provisioner for + stopping and termination + 3: SkyPilot provisioner for both provisioning and stopping + """ + RAY_AUTOSCALER = 1 + RAY_PROVISIONER_SKYPILOT_TERMINATOR = 2 + SKYPILOT = 3 + + def __ge__(self, other): + return self.value >= other.value + + class Cloud: """A cloud provider.""" _REPR = '' _DEFAULT_DISK_TIER = 'medium' - PROVISIONER_VERSION = 1 + PROVISIONER_VERSION = ProvisionerVersion.RAY_AUTOSCALER @classmethod def _cloud_unsupported_features( diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 05a979c4fcf..939ead0c97d 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -225,7 +225,7 @@ class GCP(clouds.Cloud): 'https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#google-cloud-platform-gcp' # pylint: disable=line-too-long ) - PROVISIONER_VERSION = 2 + PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT def __init__(self): super().__init__() diff --git a/sky/skylet/events.py b/sky/skylet/events.py index 12b8d28a3af..b7d960672e1 100644 --- a/sky/skylet/events.py +++ b/sky/skylet/events.py @@ -10,6 +10,7 @@ import psutil import yaml +from sky import clouds from sky import sky_logging from sky.backends import backend_utils from sky.backends import cloud_vm_ray_backend @@ -151,7 +152,8 @@ def _stop_cluster(self, autostop_config): assert provider_search is not None, config provider_name = provider_search.group(1).lower() if (cloud_registry.CLOUD_REGISTRY.from_str( - provider_name).PROVISIONER_VERSION >= 2): + provider_name).PROVISIONER_VERSION >= clouds. + ProvisionerVersion.RAY_PROVISIONER_SKYPILOT_TERMINATOR): logger.info('Using new provisioner to stop the cluster.') self._stop_cluster_with_new_provisioner(autostop_config, config, provider_name) From 06fbad6b0914efc5784efd8969ed7776e7cb6fbd Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 27 Nov 2023 00:03:58 +0000 Subject: [PATCH 40/84] Add logging --- sky/backends/cloud_vm_ray_backend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 322912ad15d..9d1b39f4ee5 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -3959,6 +3959,8 @@ def teardown_no_lock(self, # Use the new provisioner for AWS. if (cloud.PROVISIONER_VERSION >= clouds.ProvisionerVersion.RAY_PROVISIONER_SKYPILOT_TERMINATOR): + logger.debug(f'Provisioner version: {cloud.PROVISIONER_VERSION} ' + 'using new provisioner for teardown.') # Stop the ray autoscaler first to avoid the head node trying to # re-launch the worker nodes, during the termination of the # cluster. From aee06ddb3b2cdcca1c54709e8f1ef70635d2f105 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 27 Nov 2023 00:11:09 +0000 Subject: [PATCH 41/84] avoid saving the provisioner version --- sky/clouds/cloud.py | 8 +++----- sky/clouds/gcp.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index d9e0b6c196e..d10ba12e07e 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -698,9 +698,7 @@ def delete_image(cls, image_id: str, region: Optional[str]) -> None: def __repr__(self): return self._REPR - def __setstate__(self, state): + def __getstate__(self): + state = self.__dict__.copy() state.pop('PROVISIONER_VERSION', None) - self.__dict__.update(state) - # Make sure the provisioner version is always the latest. - # pylint: disable=invalid-name - self.PROVISIONER_VERSION = type(self).PROVISIONER_VERSION + return state diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 939ead0c97d..0a574e7ab36 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -1231,7 +1231,7 @@ def delete_image(cls, image_id: str, region: Optional[str]) -> None: stream_logs=True) def __getstate__(self) -> Dict[str, Any]: - state = self.__dict__.copy() + state = super().__getstate__() # We should avoid saving third-party object to the state, as it may # cause unpickling error when the third-party API is updated. state.pop('_list_reservations_cache', None) From 918a3c64d4de273f52e01feecd0d68affdffb998 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 27 Nov 2023 00:37:49 +0000 Subject: [PATCH 42/84] format --- sky/backends/cloud_vm_ray_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 9d1b39f4ee5..75556133b25 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -3958,7 +3958,8 @@ def teardown_no_lock(self, stderr = '' # Use the new provisioner for AWS. - if (cloud.PROVISIONER_VERSION >= clouds.ProvisionerVersion.RAY_PROVISIONER_SKYPILOT_TERMINATOR): + if (cloud.PROVISIONER_VERSION >= + clouds.ProvisionerVersion.RAY_PROVISIONER_SKYPILOT_TERMINATOR): logger.debug(f'Provisioner version: {cloud.PROVISIONER_VERSION} ' 'using new provisioner for teardown.') # Stop the ray autoscaler first to avoid the head node trying to From 61bec6a4cad0ee5f6c7de3492b57a3f3e95f05a4 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 27 Nov 2023 00:59:21 +0000 Subject: [PATCH 43/84] format --- sky/backends/cloud_vm_ray_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 75556133b25..4e2f8ea9a8f 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -4226,7 +4226,8 @@ def post_teardown_cleanup(self, # provision_lib.supports(cloud, 'cleanup_ports') # so that our backend do not need to know the specific details # for different clouds. - if isinstance(cloud, (clouds.AWS, clouds.GCP, clouds.Azure)): + if (cloud.PROVISIONER_VERSION >= clouds.ProvisionerVersion. + RAY_PROVISIONER_SKYPILOT_TERMINATOR): provision_lib.cleanup_ports(repr(cloud), cluster_name_on_cloud, config['provider']) From 0f07964f2bc3d9c2bef51810549e5d627c529312 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 27 Nov 2023 02:17:53 +0000 Subject: [PATCH 44/84] Fix scheduling field in config --- sky/provision/gcp/instance_utils.py | 30 +++++++++++++++++++++++------ sky/templates/gcp-ray.yml.j2 | 18 +++++++++++------ 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 4e472a24930..4e9952b1573 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -636,6 +636,17 @@ def create_instances( # just for possible backward compat. config = cls._convert_resources_to_urls(project_id, zone, node_config) + if 'scheduling' in config and isinstance(config['scheduling'], list): + # For backeward compatibility: converting the list of dictionaries + # to a dictionary due to the use of deprecated API. + # [{'preemptible': True}, {'onHostMaintenance': 'TERMINATE'}] + # to {'preemptible': True, 'onHostMaintenance': 'TERMINATE'} + config['scheduling'] = { + k: v + for d in config['scheduling'] + for k, v in d.items() + } + for disk in config.get('disks', []): disk_type = disk.get('initializeParams', {}).get('diskType') if disk_type: @@ -669,6 +680,7 @@ def create_instances( TAG_SKYPILOT_CLUSTER_NAME: cluster_name }), }) + source_instance_template = config.pop('sourceInstanceTemplate', None) body = { 'count': count, @@ -691,11 +703,15 @@ def create_instances( # # https://cloud.google.com/compute/docs/instance-templates # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert - operation = (cls.load_resource().instances().bulkInsert( - project=project_id, - zone=zone, - body=body, - ).execute()) + try: + operation = cls.load_resource().instances().bulkInsert( + project=project_id, + zone=zone, + body=body, + ).execute() + except gcp.http_error_exception() as e: + logger.warning(f'googleapiclient.errors.HttpError: {e}') + return False, names if wait_for_operation: result = cls.load_resource().zoneOperations().wait( @@ -710,7 +726,9 @@ def create_instances( p.starmap(cls.create_node_tag, [(cluster_name, project_id, zone, names[i], head_tag_needed[i]) for i in range(count)]) - + else: + # Print out the error message + logger.warning(f'Failed to create instances: {result["error"]}') return success, names return operation diff --git a/sky/templates/gcp-ray.yml.j2 b/sky/templates/gcp-ray.yml.j2 index ca73199e317..5264dc362c2 100644 --- a/sky/templates/gcp-ray.yml.j2 +++ b/sky/templates/gcp-ray.yml.j2 @@ -117,12 +117,14 @@ available_node_types: - key: install-nvidia-driver value: "True" {%- endif %} + {%- if use_spot or gpu is not none %} scheduling: {%- if use_spot %} - - preemptible: true + preemptible: true {%- endif %} {%- if gpu is not none %} - - onHostMaintenance: TERMINATE # Required for GPU-attached VMs. + onHostMaintenance: TERMINATE # Required for GPU-attached VMs. + {%- endif %} {%- endif %} {%- endif %} {% if num_nodes - 1 - num_specific_reserved_workers > 0 %} @@ -178,12 +180,14 @@ available_node_types: - key: install-nvidia-driver value: "True" {%- endif %} + {%- if use_spot or gpu is not none %} scheduling: {%- if use_spot %} - - preemptible: true + preemptible: true {%- endif %} {%- if gpu is not none %} - - onHostMaintenance: TERMINATE # Required for GPU-attached VMs. + onHostMaintenance: TERMINATE # Required for GPU-attached VMs. + {%- endif %} {%- endif %} {%- endif %} {%- endif %} @@ -245,12 +249,14 @@ available_node_types: - key: install-nvidia-driver value: "True" {%- endif %} + {%- if use_spot or gpu is not none %} scheduling: {%- if use_spot %} - - preemptible: true + preemptible: true {%- endif %} {%- if gpu is not none %} - - onHostMaintenance: TERMINATE # Required for GPU-attached VMs. + onHostMaintenance: TERMINATE # Required for GPU-attached VMs. + {%- endif %} {%- endif %} {%- endif %} {%- endif %} From 279b301c1841fc85411b6cda4322f329318d03a2 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 27 Nov 2023 02:18:59 +0000 Subject: [PATCH 45/84] format --- sky/provision/gcp/instance_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index 4e9952b1573..bd05eff3b19 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -642,9 +642,7 @@ def create_instances( # [{'preemptible': True}, {'onHostMaintenance': 'TERMINATE'}] # to {'preemptible': True, 'onHostMaintenance': 'TERMINATE'} config['scheduling'] = { - k: v - for d in config['scheduling'] - for k, v in d.items() + k: v for d in config['scheduling'] for k, v in d.items() } for disk in config.get('disks', []): From 625c51038e6684b6f65de8dc5bd96dc4b5a7bb88 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 27 Nov 2023 02:34:20 +0000 Subject: [PATCH 46/84] fix public key content --- sky/templates/gcp-ray.yml.j2 | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sky/templates/gcp-ray.yml.j2 b/sky/templates/gcp-ray.yml.j2 index 5264dc362c2..17cd96811a5 100644 --- a/sky/templates/gcp-ray.yml.j2 +++ b/sky/templates/gcp-ray.yml.j2 @@ -79,7 +79,7 @@ available_node_types: # TPU VM's metadata has different format than normal VMs. # After replacing the variables, this will become username:ssh_public_key_content. # This is a specific syntax required by GCP https://cloud.google.com/compute/docs/connect/add-ssh-keys - ssh-keys: | + ssh-keys: |- skypilot:ssh_user:skypilot:ssh_public_key_content {%- if use_spot %} schedulingConfig: @@ -111,7 +111,7 @@ available_node_types: - key: ssh-keys # After replacing the variables, this will become username:ssh_public_key_content. # This is a specific syntax required by GCP https://cloud.google.com/compute/docs/connect/add-ssh-keys - value: | + value: |- skypilot:ssh_user:skypilot:ssh_public_key_content {%- if gpu is not none %} - key: install-nvidia-driver @@ -142,7 +142,7 @@ available_node_types: # TPU VM's metadata has different format than normal VMs. # After replacing the variables, this will become username:ssh_public_key_content. # This is a specific syntax required by GCP https://cloud.google.com/compute/docs/connect/add-ssh-keys - ssh-keys: | + ssh-keys: |- skypilot:ssh_user:skypilot:ssh_public_key_content {%- if use_spot %} schedulingConfig: @@ -174,7 +174,7 @@ available_node_types: - key: ssh-keys # After replacing the variables, this will become username:ssh_public_key_content. # This is a specific syntax required by GCP https://cloud.google.com/compute/docs/connect/add-ssh-keys - value: | + value: |- skypilot:ssh_user:skypilot:ssh_public_key_content {%- if gpu is not none %} - key: install-nvidia-driver @@ -211,7 +211,7 @@ available_node_types: # TPU VM's metadata has different format than normal VMs. # After replacing the variables, this will become username:ssh_public_key_content. # This is a specific syntax required by GCP https://cloud.google.com/compute/docs/connect/add-ssh-keys - ssh-keys: | + ssh-keys: |- skypilot:ssh_user:skypilot:ssh_public_key_content {%- if use_spot %} schedulingConfig: @@ -243,7 +243,7 @@ available_node_types: - key: ssh-keys # After replacing the variables, this will become username:ssh_public_key_content. # This is a specific syntax required by GCP https://cloud.google.com/compute/docs/connect/add-ssh-keys - value: | + value: |- skypilot:ssh_user:skypilot:ssh_public_key_content {%- if gpu is not none %} - key: install-nvidia-driver From ebe92f557eb14e6842062b55362077e4f85f1815 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 27 Nov 2023 02:57:56 +0000 Subject: [PATCH 47/84] Fix provisioner version for azure --- sky/clouds/azure.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index e649bd8a368..fc05fedd8f4 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -60,7 +60,7 @@ class Azure(clouds.Cloud): _INDENT_PREFIX = ' ' * 4 - PROVISIONER_VERSION = clouds.ProvisionerVersion.RAY_PROVISIONER_SKYPILOT_TERMINATOR + PROVISIONER_VERSION = clouds.ProvisionerVersion.RAY_AUTOSCALER @classmethod def _cloud_unsupported_features( From 54a87ef8e114f7c0047234f5750eeb564152bfc6 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 27 Nov 2023 03:35:57 +0000 Subject: [PATCH 48/84] Use ray port from head node for workers --- sky/provision/instance_setup.py | 9 ++++++--- sky/provision/provisioner.py | 13 +++++++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/sky/provision/instance_setup.py b/sky/provision/instance_setup.py index 9a21f6ed9ca..eb4d1248ed3 100644 --- a/sky/provision/instance_setup.py +++ b/sky/provision/instance_setup.py @@ -37,7 +37,9 @@ _RAY_PORT_COMMAND = ( 'RAY_PORT=$(python -c "from sky.skylet import job_lib; ' - 'print(job_lib.get_ray_port())" 2> /dev/null || echo 6379)') + 'print(job_lib.get_ray_port())" 2> /dev/null || echo 6379);' + 'python -c "from sky.utils import common_utils; ' + 'print(common_utils.encode_payload({\'ray_port\': $RAY_PORT}))"') # Command that calls `ray status` with SkyPilot's Ray port set. RAY_STATUS_WITH_SKY_RAY_PORT_COMMAND = ( @@ -241,6 +243,7 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], @_auto_retry def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, custom_resource: Optional[str], + ray_port: int, cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, Any]) -> None: """Start Ray on the worker nodes.""" @@ -269,14 +272,14 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, cmd = (f'unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY; ' 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 ' f'ray start --disable-usage-stats {ray_options} || exit 1;' + - _RAY_PRLIMIT + _DUMP_RAY_PORTS) + _RAY_PRLIMIT) if no_restart: # We do not use ray status to check whether ray is running, because # on worker node, if the user started their own ray cluster, ray status # will return 0, i.e., we don't know skypilot's ray cluster is running. # Instead, we check whether the raylet process is running on gcs address # that is connected to the head with the correct port. - cmd = (f'{_RAY_PORT_COMMAND}; ps aux | grep "ray/raylet/raylet" | ' + cmd = (f'RAY_PORT={ray_port}; ps aux | grep "ray/raylet/raylet" | ' f'grep "gcs-address={head_private_ip}:${{RAY_PORT}}" || ' f'{{ {cmd}; }}') else: diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index 2cb84961c94..6815b90ffe8 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -22,6 +22,7 @@ from sky.provision import instance_setup from sky.provision import logging as provision_logging from sky.provision import metadata_utils +from sky.skylet import constants from sky.utils import command_runner from sky.utils import common_utils from sky.utils import rich_utils @@ -415,13 +416,16 @@ def _post_provision_setup( if not provision_record.is_instance_just_booted( head_instance.instance_id): # Check if head node Ray is alive - returncode = head_runner.run( + returncode, stdout, _ = head_runner.run( instance_setup.RAY_STATUS_WITH_SKY_RAY_PORT_COMMAND, - stream_logs=False) + stream_logs=False, + require_outputs=True) if returncode: logger.info('Ray cluster on head is not up. Restarting...') + ray_port = constants.SKY_REMOTE_RAY_PORT else: logger.debug('Ray cluster on head is up.') + ray_port = common_utils.decode_payload(stdout)['ray_port'] full_ray_setup = bool(returncode) if full_ray_setup: @@ -446,6 +450,11 @@ def _post_provision_setup( cluster_name.name_on_cloud, no_restart=not full_ray_setup, custom_resource=custom_resource, + # Pass the ray_port to worker nodes for backward compatibilirt + # as in some existing clusters the ray_port is not dumped with + # instance_setup._DUMP_RAY_PORTS. We should use the ray_port + # from the head node for worker nodes. + ray_port=ray_port, cluster_info=cluster_info, ssh_credentials=ssh_credentials) From b1896ac4c4d40d3561e8bb09358b78a2f40b0904 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 27 Nov 2023 03:39:04 +0000 Subject: [PATCH 49/84] format --- sky/provision/instance_setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sky/provision/instance_setup.py b/sky/provision/instance_setup.py index eb4d1248ed3..7961c1f17b3 100644 --- a/sky/provision/instance_setup.py +++ b/sky/provision/instance_setup.py @@ -242,8 +242,7 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], @_log_start_end @_auto_retry def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, - custom_resource: Optional[str], - ray_port: int, + custom_resource: Optional[str], ray_port: int, cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, Any]) -> None: """Start Ray on the worker nodes.""" From 15e193d2544d1d266426e99b625faa523b2f75cb Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 27 Nov 2023 03:53:32 +0000 Subject: [PATCH 50/84] fix ray_port --- sky/provision/provisioner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index 6815b90ffe8..36e9c77d511 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -413,6 +413,7 @@ def _post_provision_setup( status.update( runtime_preparation_str.format(step=3, step_name='runtime')) full_ray_setup = True + ray_port = constants.SKY_REMOTE_RAY_PORT if not provision_record.is_instance_just_booted( head_instance.instance_id): # Check if head node Ray is alive @@ -422,7 +423,6 @@ def _post_provision_setup( require_outputs=True) if returncode: logger.info('Ray cluster on head is not up. Restarting...') - ray_port = constants.SKY_REMOTE_RAY_PORT else: logger.debug('Ray cluster on head is up.') ray_port = common_utils.decode_payload(stdout)['ray_port'] From 504b7e79071463fab577db336c6f54d5292b4f9e Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 27 Nov 2023 06:07:41 +0000 Subject: [PATCH 51/84] fix smoke tests --- examples/job_queue/job.yaml | 2 +- sky/provision/instance_setup.py | 2 +- tests/test_smoke.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/job_queue/job.yaml b/examples/job_queue/job.yaml index c54e3d9a173..aa9c3502247 100644 --- a/examples/job_queue/job.yaml +++ b/examples/job_queue/job.yaml @@ -17,7 +17,7 @@ setup: | run: | timestamp=$(date +%s) conda env list - for i in {1..120}; do + for i in {1..140}; do echo "$timestamp $i" sleep 1 done diff --git a/sky/provision/instance_setup.py b/sky/provision/instance_setup.py index 7961c1f17b3..54b3f34b6ac 100644 --- a/sky/provision/instance_setup.py +++ b/sky/provision/instance_setup.py @@ -280,7 +280,7 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, # that is connected to the head with the correct port. cmd = (f'RAY_PORT={ray_port}; ps aux | grep "ray/raylet/raylet" | ' f'grep "gcs-address={head_private_ip}:${{RAY_PORT}}" || ' - f'{{ {cmd}; }}') + f'{{ {cmd} }}') else: cmd = 'ray stop; ' + cmd diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 700762d3546..d359e6b5522 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -1599,7 +1599,7 @@ def test_autostop(generic_cloud: str): f'sky status | grep {name} | grep "1m"', # Ensure the cluster is not stopped early. - 'sleep 40', + 'sleep 35', f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', # Ensure the cluster is STOPPED. @@ -1653,7 +1653,7 @@ def test_autodown(generic_cloud: str): # Ensure autostop is set. f'sky status | grep {name} | grep "1m (down)"', # Ensure the cluster is not terminated early. - 'sleep 40', + 'sleep 35', f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', # Ensure the cluster is terminated. 'sleep 200', From b6ba235617f05db20e6eb15801fe9e5c1afcc431 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 27 Nov 2023 06:17:12 +0000 Subject: [PATCH 52/84] shorter sleep time --- tests/test_smoke.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_smoke.py b/tests/test_smoke.py index d359e6b5522..05dc1abce1a 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -1599,7 +1599,7 @@ def test_autostop(generic_cloud: str): f'sky status | grep {name} | grep "1m"', # Ensure the cluster is not stopped early. - 'sleep 35', + 'sleep 30', f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', # Ensure the cluster is STOPPED. @@ -1653,7 +1653,7 @@ def test_autodown(generic_cloud: str): # Ensure autostop is set. f'sky status | grep {name} | grep "1m (down)"', # Ensure the cluster is not terminated early. - 'sleep 35', + 'sleep 30', f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', # Ensure the cluster is terminated. 'sleep 200', From 497c438667a8e8e7f85b0de1b4be38e5b6f7b127 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 27 Nov 2023 06:58:10 +0000 Subject: [PATCH 53/84] refactor status refresh version --- sky/backends/backend_utils.py | 3 ++- sky/clouds/__init__.py | 3 +++ sky/clouds/aws.py | 1 + sky/clouds/cloud.py | 15 +++++++++++++++ 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 2ff5220d875..bde6fa866fb 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -1872,7 +1872,8 @@ def _query_cluster_status_via_cloud_api( # Query the cloud provider. # TODO(suquark): move implementations of more clouds here - if isinstance(handle.launched_resources.cloud, clouds.AWS): + if (handle.launched_resources.cloud.STATUS_VERSION >= + clouds.StatusVersion.SKYPILOT): cloud_name = repr(handle.launched_resources.cloud) try: node_status_dict = provision_lib.query_instances( diff --git a/sky/clouds/__init__.py b/sky/clouds/__init__.py index 0fc4f8dc35e..36d843e267e 100644 --- a/sky/clouds/__init__.py +++ b/sky/clouds/__init__.py @@ -3,6 +3,7 @@ from sky.clouds.cloud import CloudImplementationFeatures from sky.clouds.cloud import ProvisionerVersion from sky.clouds.cloud import Region +from sky.clouds.cloud import StatusVersion from sky.clouds.cloud import Zone from sky.clouds.cloud_registry import CLOUD_REGISTRY @@ -33,4 +34,6 @@ 'Region', 'Zone', 'CLOUD_REGISTRY', + 'ProvisionerVersion', + 'StatusVersion', ] diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index cfb9be87846..84232073341 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -114,6 +114,7 @@ class AWS(clouds.Cloud): ) PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT + STATUS_VERSION = clouds.StatusVersion.SKYPILOT @classmethod def _cloud_unsupported_features( diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index d10ba12e07e..eee0bd4cb6f 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -70,6 +70,19 @@ def __ge__(self, other): return self.value >= other.value +class StatusVersion(enum.Enum): + """The version of the status query. + + 1: cloud-CLI based implementation + 2: SkyPilot provisioner based implementation + """ + CLOUD_CLI = 1 + SKYPILOT = 2 + + def __ge__(self, other): + return self.value >= other.value + + class Cloud: """A cloud provider.""" @@ -77,6 +90,7 @@ class Cloud: _DEFAULT_DISK_TIER = 'medium' PROVISIONER_VERSION = ProvisionerVersion.RAY_AUTOSCALER + STATUS_VERSION = StatusVersion.CLOUD_CLI @classmethod def _cloud_unsupported_features( @@ -701,4 +715,5 @@ def __repr__(self): def __getstate__(self): state = self.__dict__.copy() state.pop('PROVISIONER_VERSION', None) + state.pop('STATUS_VERSION', None) return state From 21a36924d9c47439e45566a2fbb2604b394cdf31 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 1 Dec 2023 06:43:17 +0000 Subject: [PATCH 54/84] Use new provisioner to launch runpod to avoid issue with ray autoscaler on head Co-authored-by: Justin Merrell --- sky/__init__.py | 2 + sky/adaptors/runpod.py | 29 +++ sky/authentication.py | 15 ++ sky/backends/backend_utils.py | 2 + sky/backends/cloud_vm_ray_backend.py | 41 +++ sky/clouds/__init__.py | 2 + sky/clouds/runpod.py | 258 +++++++++++++++++++ sky/clouds/service_catalog/runpod_catalog.py | 113 ++++++++ sky/provision/__init__.py | 1 + sky/provision/common.py | 11 + sky/provision/instance_setup.py | 12 +- sky/provision/provisioner.py | 19 +- sky/provision/runpod/__init__.py | 10 + sky/provision/runpod/config.py | 11 + sky/provision/runpod/instance.py | 207 +++++++++++++++ sky/provision/runpod/utils.py | 180 +++++++++++++ sky/setup_files/setup.py | 4 +- sky/templates/runpod-ray.yml.j2 | 115 +++++++++ 18 files changed, 1020 insertions(+), 12 deletions(-) create mode 100644 sky/adaptors/runpod.py create mode 100644 sky/clouds/runpod.py create mode 100644 sky/clouds/service_catalog/runpod_catalog.py create mode 100644 sky/provision/runpod/__init__.py create mode 100644 sky/provision/runpod/config.py create mode 100644 sky/provision/runpod/instance.py create mode 100644 sky/provision/runpod/utils.py create mode 100644 sky/templates/runpod-ray.yml.j2 diff --git a/sky/__init__.py b/sky/__init__.py index b27de4a5c3f..a673669774a 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -82,6 +82,7 @@ def get_git_commit(): Local = clouds.Local Kubernetes = clouds.Kubernetes OCI = clouds.OCI +RunPod = clouds.RunPod optimize = Optimizer.optimize __all__ = [ @@ -94,6 +95,7 @@ def get_git_commit(): 'Lambda', 'Local', 'OCI', + 'RunPod', 'SCP', 'Optimizer', 'OptimizeTarget', diff --git a/sky/adaptors/runpod.py b/sky/adaptors/runpod.py new file mode 100644 index 00000000000..2f4699f80bf --- /dev/null +++ b/sky/adaptors/runpod.py @@ -0,0 +1,29 @@ +"""RunPod cloud adaptor.""" + +import functools + +_runpod_sdk = None + + +def import_package(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + global _runpod_sdk + if _runpod_sdk is None: + try: + import runpod as _runpod # pylint: disable=import-outside-toplevel + _runpod_sdk = _runpod + except ImportError: + raise ImportError( + 'Fail to import dependencies for runpod.' + 'Try pip install "skypilot[runpod]"') from None + return func(*args, **kwargs) + + return wrapper + + +@import_package +def runpod(): + """Return the runpod package.""" + return _runpod_sdk diff --git a/sky/authentication.py b/sky/authentication.py index efce12defa4..b52f62c44e5 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -41,6 +41,7 @@ from sky import skypilot_config from sky.adaptors import gcp from sky.adaptors import ibm +from sky.adaptors import runpod from sky.clouds.utils import lambda_utils from sky.utils import common_utils from sky.utils import kubernetes_enums @@ -449,3 +450,17 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: config['auth']['ssh_proxy_command'] = ssh_proxy_cmd return config + + +# ---------------------------------- RunPod ---------------------------------- # +def setup_runpod_authentication(config: Dict[str, Any]) -> Dict[str, Any]: + """Sets up SSH authentication for RunPod. + - Generates a new SSH key pair if one does not exist. + - Adds the public SSH key to the user's RunPod account. + """ + _, public_key_path = get_or_generate_keys() + with open(public_key_path, 'r', encoding='UTF-8') as pub_key_file: + public_key = pub_key_file.read().strip() + runpod.runpod().cli.groups.ssh.functions.add_ssh_key(public_key) + + return configure_ssh_info(config) diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 4d8b2923b2d..c3e66456ba2 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -1181,6 +1181,8 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str): config = auth.setup_kubernetes_authentication(config) elif isinstance(cloud, clouds.IBM): config = auth.setup_ibm_authentication(config) + elif isinstance(cloud, clouds.RunPod): + config = auth.setup_runpod_authentication(config) else: assert isinstance(cloud, clouds.Local), cloud # Local cluster case, authentication is already filled by the user diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 2eb2cd10b3e..49e23cff37a 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -147,6 +147,7 @@ def _get_cluster_config_template(cloud): clouds.Local: 'local-ray.yml.j2', clouds.SCP: 'scp-ray.yml.j2', clouds.OCI: 'oci-ray.yml.j2', + clouds.RunPod: 'runpod-ray.yml.j2', clouds.Kubernetes: 'kubernetes-ray.yml.j2', } return cloud_to_template[type(cloud)] @@ -1170,6 +1171,36 @@ def _update_blocklist_on_oci_error( self._blocked_resources.add( launchable_resources.copy(zone=zone.name)) + def _update_blocklist_on_runpod_error( + self, launchable_resources: 'resources_lib.Resources', + region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + stdout: str, stderr: str): + del zones # Unused. + style = colorama.Style + stdout_splits = stdout.split('\n') + stderr_splits = stderr.split('\n') + errors = [ + s.strip() + for s in stdout_splits + stderr_splits + if any(err in s.strip() + for err in ['runpod.error.QueryError:', 'RunPodError:']) + ] + if not errors: + logger.info('====== stdout ======') + for s in stdout_splits: + print(s) + logger.info('====== stderr ======') + for s in stderr_splits: + print(s) + with ux_utils.print_exception_no_traceback(): + raise RuntimeError('Errors occurred during provision; ' + 'check logs above.') + + logger.warning(f'Got error(s) in {region.name}:') + messages = '\n\t'.join(errors) + logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') + self._blocked_resources.add(launchable_resources.copy(zone=None)) + def _update_blocklist_on_error( self, launchable_resources: 'resources_lib.Resources', region: 'clouds.Region', zones: Optional[List['clouds.Zone']], @@ -1209,6 +1240,7 @@ def _update_blocklist_on_error( clouds.Local: self._update_blocklist_on_local_error, clouds.Kubernetes: self._update_blocklist_on_kubernetes_error, clouds.OCI: self._update_blocklist_on_oci_error, + clouds.RunPod: self._update_blocklist_on_runpod_error, } cloud = launchable_resources.cloud cloud_type = type(cloud) @@ -2451,6 +2483,15 @@ def update_ssh_ports(self, max_attempts: int = 1) -> None: Use this method to use any cloud-specific port fetching logic. """ del max_attempts # Unused. + if isinstance(self.launched_resources.cloud, clouds.RunPod): + cluster_info = provision_lib.get_cluster_info( + str(self.launched_resources.cloud).lower(), + region=self.launched_resources.region, + cluster_name_on_cloud=self.cluster_name_on_cloud, + provider_config=None) + self.stable_ssh_ports = cluster_info.get_ssh_ports() + return + head_ssh_port = 22 self.stable_ssh_ports = ([head_ssh_port] + [22] * (self.num_node_ips - 1)) diff --git a/sky/clouds/__init__.py b/sky/clouds/__init__.py index 36d843e267e..0860d26b47f 100644 --- a/sky/clouds/__init__.py +++ b/sky/clouds/__init__.py @@ -17,6 +17,7 @@ from sky.clouds.lambda_cloud import Lambda from sky.clouds.local import Local from sky.clouds.oci import OCI +from sky.clouds.runpod import RunPod from sky.clouds.scp import SCP __all__ = [ @@ -28,6 +29,7 @@ 'Lambda', 'Local', 'SCP', + 'RunPod', 'OCI', 'Kubernetes', 'CloudImplementationFeatures', diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py new file mode 100644 index 00000000000..49254d9cd4b --- /dev/null +++ b/sky/clouds/runpod.py @@ -0,0 +1,258 @@ +""" RunPod Cloud. """ + +import json +import typing +from typing import Dict, Iterator, List, Optional, Tuple + +from sky import clouds +from sky.clouds import service_catalog + +if typing.TYPE_CHECKING: + from sky import resources as resources_lib + +_CREDENTIAL_FILES = [ + 'config.toml', +] + + +@clouds.CLOUD_REGISTRY.register +class RunPod(clouds.Cloud): + """ RunPod GPU Cloud + + _REPR | The string representation for the RunPod GPU cloud object. + """ + _REPR = 'RunPod' + _CLOUD_UNSUPPORTED_FEATURES = { + clouds.CloudImplementationFeatures.AUTOSTOP: 'Stopping not supported.', + clouds.CloudImplementationFeatures.MULTI_NODE: 'Multi-node unsupported.', # pylint: disable=line-too-long + } + _MAX_CLUSTER_NAME_LEN_LIMIT = 120 + _regions: List[clouds.Region] = [] + + PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT + STATUS_VERSION = clouds.StatusVersion.SKYPILOT + + @classmethod + def _cloud_unsupported_features( + cls) -> Dict[clouds.CloudImplementationFeatures, str]: + return cls._CLOUD_UNSUPPORTED_FEATURES + + @classmethod + def _max_cluster_name_length(cls) -> Optional[int]: + return cls._MAX_CLUSTER_NAME_LEN_LIMIT + + @classmethod + def regions_with_offering(cls, instance_type: str, + accelerators: Optional[Dict[str, int]], + use_spot: bool, region: Optional[str], + zone: Optional[str]) -> List[clouds.Region]: + assert zone is None, 'RunPod does not support zones.' + del accelerators, zone # unused + if use_spot: + return [] + else: + regions = service_catalog.get_region_zones_for_instance_type( + instance_type, use_spot, 'runpod') + + if region is not None: + regions = [r for r in regions if r.name == region] + return regions + + @classmethod + def get_vcpus_mem_from_instance_type( + cls, + instance_type: str, + ) -> Tuple[Optional[float], Optional[float]]: + return service_catalog.get_vcpus_mem_from_instance_type(instance_type, + clouds='runpod') + + @classmethod + def zones_provision_loop( + cls, + *, + region: str, + num_nodes: int, + instance_type: str, + accelerators: Optional[Dict[str, int]] = None, + use_spot: bool = False, + ) -> Iterator[None]: + del num_nodes # unused + regions = cls.regions_with_offering(instance_type, + accelerators, + use_spot, + region=region, + zone=None) + for r in regions: + assert r.zones is None, r + yield r.zones + + def instance_type_to_hourly_cost(self, + instance_type: str, + use_spot: bool, + region: Optional[str] = None, + zone: Optional[str] = None) -> float: + return service_catalog.get_hourly_cost(instance_type, + use_spot=use_spot, + region=region, + zone=zone, + clouds='runpod') + + def accelerators_to_hourly_cost(self, + accelerators: Dict[str, int], + use_spot: bool, + region: Optional[str] = None, + zone: Optional[str] = None) -> float: + """Returns the hourly cost of the accelerators, in dollars/hour.""" + del accelerators, use_spot, region, zone # unused + return 0.0 # RunPod includes accelerators in the hourly cost. + + def get_egress_cost(self, num_gigabytes: float) -> float: + return 0.0 + + def __repr__(self): + return 'RunPod' + + def is_same_cloud(self, other: clouds.Cloud) -> bool: + # Returns true if the two clouds are the same cloud type. + return isinstance(other, RunPod) + + @classmethod + def get_default_instance_type( + cls, + cpus: Optional[str] = None, + memory: Optional[str] = None, + disk_tier: Optional[str] = None) -> Optional[str]: + """Returns the default instance type for RunPod.""" + return service_catalog.get_default_instance_type(cpus=cpus, + memory=memory, + disk_tier=disk_tier, + clouds='runpod') + + @classmethod + def get_accelerators_from_instance_type( + cls, instance_type: str) -> Optional[Dict[str, int]]: + return service_catalog.get_accelerators_from_instance_type( + instance_type, clouds='runpod') + + @classmethod + def get_zone_shell_cmd(cls) -> Optional[str]: + return None + + def make_deploy_resources_variables( + self, resources: 'resources_lib.Resources', + cluster_name_on_cloud: str, region: 'clouds.Region', + zones: Optional[List['clouds.Zone']]) -> Dict[str, Optional[str]]: + del zones + + r = resources + acc_dict = self.get_accelerators_from_instance_type(r.instance_type) + if acc_dict is not None: + custom_resources = json.dumps(acc_dict, separators=(',', ':')) + else: + custom_resources = None + + return { + 'instance_type': resources.instance_type, + 'custom_resources': custom_resources, + 'region': region.name, + } + + def _get_feasible_launchable_resources( + self, resources: 'resources_lib.Resources'): + """Returns a list of feasible resources for the given resources.""" + if resources.use_spot: + return ([], []) + if resources.instance_type is not None: + assert resources.is_launchable(), resources + resources = resources.copy(accelerators=None) + return ([resources], []) + + def _make(instance_list): + resource_list = [] + for instance_type in instance_list: + r = resources.copy( + cloud=RunPod(), + instance_type=instance_type, + accelerators=None, + cpus=None, + ) + resource_list.append(r) + return resource_list + + # Currently, handle a filter on accelerators only. + accelerators = resources.accelerators + if accelerators is None: + # Return a default instance type + default_instance_type = RunPod.get_default_instance_type( + cpus=resources.cpus) + if default_instance_type is None: + return ([], []) + else: + return (_make([default_instance_type]), []) + + assert len(accelerators) == 1, resources + acc, acc_count = list(accelerators.items())[0] + (instance_list, fuzzy_candidate_list + ) = service_catalog.get_instance_type_for_accelerator( + acc, + acc_count, + use_spot=resources.use_spot, + cpus=resources.cpus, + region=resources.region, + zone=resources.zone, + clouds='runpod') + if instance_list is None: + return ([], fuzzy_candidate_list) + return (_make(instance_list), fuzzy_candidate_list) + + @classmethod + def check_credentials(cls) -> Tuple[bool, Optional[str]]: + """ Verify that the user has valid credentials for RunPod. """ + try: + import runpod # pylint: disable=import-outside-toplevel + valid, error = runpod.check_credentials() + + if not valid: + return False, ( + f'{error} \n' # First line is indented by 4 spaces + ' Credentials can be set up by running: \n' + f' $ pip install runpod \n' + f' $ runpod store_api_key \n' + ' For more information, see https://docs.runpod.io/docs/skypilot' # pylint: disable=line-too-long + ) + + return True, None + + except ImportError: + return False, ( + 'Failed to import runpod.' + 'To install, run: "pip install runpod" or "pip install sky[runpod]"' # pylint: disable=line-too-long + ) + + def get_credential_file_mounts(self) -> Dict[str, str]: + return { + f'~/.runpod/{filename}': f'~/.runpod/{filename}' + for filename in _CREDENTIAL_FILES + } + + @classmethod + def get_current_user_identity(cls) -> Optional[List[str]]: + # NOTE: used for very advanced SkyPilot functionality + # Can implement later if desired + return None + + def instance_type_exists(self, instance_type: str) -> bool: + return service_catalog.instance_type_exists(instance_type, 'runpod') + + def validate_region_zone(self, region: Optional[str], zone: Optional[str]): + return service_catalog.validate_region_zone(region, + zone, + clouds='runpod') + + def accelerator_in_region_or_zone(self, + accelerator: str, + acc_count: int, + region: Optional[str] = None, + zone: Optional[str] = None) -> bool: + return service_catalog.accelerator_in_region_or_zone( + accelerator, acc_count, region, zone, 'runpod') diff --git a/sky/clouds/service_catalog/runpod_catalog.py b/sky/clouds/service_catalog/runpod_catalog.py new file mode 100644 index 00000000000..aa8ea5539af --- /dev/null +++ b/sky/clouds/service_catalog/runpod_catalog.py @@ -0,0 +1,113 @@ +""" RunPod | Catalog + +This module loads the service catalog file and can be used to +quarry instance types and pricing information for RunPod. +""" + +import typing +from typing import Dict, List, Optional, Tuple + +from sky.clouds.service_catalog import common +from sky.utils import ux_utils + +if typing.TYPE_CHECKING: + from sky.clouds import cloud + +_df = common.read_catalog('runpod/vms.csv') + + +def instance_type_exists(instance_type: str) -> bool: + return common.instance_type_exists_impl(_df, instance_type) + + +def validate_region_zone( + region: Optional[str], + zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]: + if zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('RunPod does not support zones.') + return common.validate_region_zone_impl('runpod', _df, region, zone) + + +def accelerator_in_region_or_zone(acc_name: str, + acc_count: int, + region: Optional[str] = None, + zone: Optional[str] = None) -> bool: + if zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('RunPod does not support zones.') + return common.accelerator_in_region_or_zone_impl(_df, acc_name, acc_count, + region, zone) + + +def get_hourly_cost(instance_type: str, + use_spot: bool = False, + region: Optional[str] = None, + zone: Optional[str] = None) -> float: + """Returns the cost, or the cheapest cost among all zones for spot.""" + assert not use_spot, 'FluffyCloud does not support spot.' + if zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('RunPod does not support zones.') + return common.get_hourly_cost_impl(_df, instance_type, use_spot, region, + zone) + + +def get_vcpus_mem_from_instance_type( + instance_type: str) -> Tuple[Optional[float], Optional[float]]: + return common.get_vcpus_mem_from_instance_type_impl(_df, instance_type) + + +def get_default_instance_type(cpus: Optional[str] = None, + memory: Optional[str] = None, + disk_tier: Optional[str] = None) -> Optional[str]: + del disk_tier # RunPod does not support disk tiers. + # NOTE: After expanding catalog to multiple entries, you may + # want to specify a default instance type or family. + return common.get_instance_type_for_cpus_mem_impl(_df, cpus, memory) + + +def get_accelerators_from_instance_type( + instance_type: str) -> Optional[Dict[str, int]]: + return common.get_accelerators_from_instance_type_impl(_df, instance_type) + + +def get_instance_type_for_accelerator( + acc_name: str, + acc_count: int, + cpus: Optional[str] = None, + memory: Optional[str] = None, + use_spot: bool = False, + region: Optional[str] = None, + zone: Optional[str] = None) -> Tuple[Optional[List[str]], List[str]]: + """Returns a list of instance types that have the given accelerator.""" + if zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('RunPod does not support zones.') + return common.get_instance_type_for_accelerator_impl(df=_df, + acc_name=acc_name, + acc_count=acc_count, + cpus=cpus, + memory=memory, + use_spot=use_spot, + region=region, + zone=zone) + + +def get_region_zones_for_instance_type(instance_type: str, + use_spot: bool) -> List['cloud.Region']: + df = _df[_df['InstanceType'] == instance_type] + return common.get_region_zones(df, use_spot) + + +def list_accelerators( + gpus_only: bool, + name_filter: Optional[str], + region_filter: Optional[str], + quantity_filter: Optional[int], + case_sensitive: bool = True +) -> Dict[str, List[common.InstanceTypeInfo]]: + """Returns all instance types in RunPod offering GPUs.""" + return common.list_accelerators_impl('RunPodCloud', _df, gpus_only, + name_filter, region_filter, + quantity_filter, case_sensitive) diff --git a/sky/provision/__init__.py b/sky/provision/__init__.py index 8af83f08b85..20f0b86ba50 100644 --- a/sky/provision/__init__.py +++ b/sky/provision/__init__.py @@ -16,6 +16,7 @@ from sky.provision import azure from sky.provision import common from sky.provision import gcp +from sky.provision import runpod logger = sky_logging.init_logger(__name__) diff --git a/sky/provision/common.py b/sky/provision/common.py index 6e796666f06..678d50a2cd5 100644 --- a/sky/provision/common.py +++ b/sky/provision/common.py @@ -70,6 +70,7 @@ class InstanceInfo: internal_ip: str external_ip: Optional[str] tags: Dict[str, str] + ssh_port: int = 22 def get_feasible_ip(self) -> str: """Get the most feasible IPs of the instance. This function returns @@ -143,3 +144,13 @@ def get_head_instance(self) -> Optional[InstanceInfo]: if self.head_instance_id not in self.instances: raise ValueError('Head instance ID not in the cluster metadata.') return self.instances[self.head_instance_id] + + def get_ssh_ports(self) -> List[int]: + """Get the SSH port of all the instances.""" + head_node_port, other_ports = [], [] + for instance in self.instances.values(): + if instance.instance_id == self.head_instance_id: + head_node_port.append(instance.ssh_port) + else: + other_ports.append(instance.ssh_port) + return head_node_port + other_ports diff --git a/sky/provision/instance_setup.py b/sky/provision/instance_setup.py index 54b3f34b6ac..8dfc4579e89 100644 --- a/sky/provision/instance_setup.py +++ b/sky/provision/instance_setup.py @@ -102,7 +102,7 @@ def _parallel_ssh_with_cache(func, cluster_name: str, stage_name: str, results = [] for instance_id, metadata in cluster_info.instances.items(): runner = command_runner.SSHCommandRunner(metadata.get_feasible_ip(), - port=22, + port=metadata.ssh_port, **ssh_credentials) wrapper = metadata_utils.cache_func(cluster_name, instance_id, stage_name, digest) @@ -198,8 +198,9 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], ssh_credentials: Dict[str, Any]) -> None: """Start Ray on the head node.""" ip_list = cluster_info.get_feasible_ips() + port_list = cluster_info.get_ssh_ports() ssh_runner = command_runner.SSHCommandRunner(ip_list[0], - port=22, + port=port_list[0], **ssh_credentials) assert cluster_info.head_instance_id is not None, (cluster_name, cluster_info) @@ -251,7 +252,9 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, _hint_worker_log_path(cluster_name, cluster_info, 'ray_cluster') ip_list = cluster_info.get_feasible_ips() ssh_runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list[1:], port_list=None, **ssh_credentials) + ip_list[1:], + port_list=cluster_info.get_ssh_ports()[1:], + **ssh_credentials) worker_ids = [ instance_id for instance_id in cluster_info.instances if instance_id != cluster_info.head_instance_id @@ -319,8 +322,9 @@ def start_skylet_on_head_node(cluster_name: str, """Start skylet on the head node.""" del cluster_name ip_list = cluster_info.get_feasible_ips() + port_list = cluster_info.get_ssh_ports() ssh_runner = command_runner.SSHCommandRunner(ip_list[0], - port=22, + port=port_list[0], **ssh_credentials) assert cluster_info.head_instance_id is not None, cluster_info log_path_abs = str(provision_logging.get_log_path()) diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index a479c2f10d4..07393d3574f 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -196,6 +196,7 @@ def teardown_cluster(cloud_name: str, cluster_name: ClusterName, def _ssh_probe_command(ip: str, + ssh_port: int, ssh_user: str, ssh_private_key: str, ssh_proxy_command: Optional[str] = None) -> List[str]: @@ -207,6 +208,8 @@ def _ssh_probe_command(ip: str, '-i', ssh_private_key, f'{ssh_user}@{ip}', + '-p', + str(ssh_port), '-o', 'StrictHostKeyChecking=no', '-o', @@ -239,13 +242,14 @@ def _shlex_join(command: List[str]) -> str: def _wait_ssh_connection_direct( ip: str, + ssh_port: int, ssh_user: str, ssh_private_key: str, ssh_control_name: Optional[str] = None, ssh_proxy_command: Optional[str] = None) -> bool: assert ssh_proxy_command is None, 'SSH proxy command is not supported.' try: - with socket.create_connection((ip, 22), timeout=1) as s: + with socket.create_connection((ip, ssh_port), timeout=1) as s: if s.recv(100).startswith(b'SSH'): # Wait for SSH being actually ready, otherwise we may get the # following error: @@ -259,7 +263,7 @@ def _wait_ssh_connection_direct( pass except Exception: # pylint: disable=broad-except pass - command = _ssh_probe_command(ip, ssh_user, ssh_private_key, + command = _ssh_probe_command(ip, ssh_port, ssh_user, ssh_private_key, ssh_proxy_command) logger.debug(f'Waiting for SSH to {ip}. Try: ' f'{_shlex_join(command)}') @@ -268,12 +272,13 @@ def _wait_ssh_connection_direct( def _wait_ssh_connection_indirect( ip: str, + ssh_port: int, ssh_user: str, ssh_private_key: str, ssh_control_name: Optional[str] = None, ssh_proxy_command: Optional[str] = None) -> bool: del ssh_control_name - command = _ssh_probe_command(ip, ssh_user, ssh_private_key, + command = _ssh_probe_command(ip, ssh_port, ssh_user, ssh_private_key, ssh_proxy_command) proc = subprocess.run(command, shell=False, @@ -299,14 +304,17 @@ def wait_for_ssh(cluster_info: provision_common.ClusterInfo, # See https://github.com/skypilot-org/skypilot/pull/1512 waiter = _wait_ssh_connection_indirect ip_list = cluster_info.get_feasible_ips() + port_list = cluster_info.get_ssh_ports() timeout = 60 * 10 # 10-min maximum timeout start = time.time() # use a queue for SSH querying ips = collections.deque(ip_list) + ssh_ports = collections.deque(port_list) while ips: ip = ips.popleft() - if not waiter(ip, **ssh_credentials): + ssh_port = ssh_ports.popleft() + if not waiter(ip, ssh_port, **ssh_credentials): ips.append(ip) if time.time() - start > timeout: with ux_utils.print_exception_no_traceback(): @@ -348,6 +356,7 @@ def _post_provision_setup( # TODO(suquark): Move wheel build here in future PRs. ip_list = cluster_info.get_feasible_ips() + port_list = cluster_info.get_ssh_ports() ssh_credentials = backend_utils.ssh_credential_from_yaml(cluster_yaml) # TODO(suquark): Handle TPU VMs when dealing with GCP later. @@ -413,7 +422,7 @@ def _post_provision_setup( cluster_info, ssh_credentials) head_runner = command_runner.SSHCommandRunner(ip_list[0], - port=22, + port=port_list[0], **ssh_credentials) status.update( diff --git a/sky/provision/runpod/__init__.py b/sky/provision/runpod/__init__.py new file mode 100644 index 00000000000..59297d7a3b0 --- /dev/null +++ b/sky/provision/runpod/__init__.py @@ -0,0 +1,10 @@ +"""GCP provisioner for SkyPilot.""" + +from sky.provision.runpod.config import bootstrap_instances +from sky.provision.runpod.instance import cleanup_ports +from sky.provision.runpod.instance import get_cluster_info +from sky.provision.runpod.instance import query_instances +from sky.provision.runpod.instance import run_instances +from sky.provision.runpod.instance import stop_instances +from sky.provision.runpod.instance import terminate_instances +from sky.provision.runpod.instance import wait_instances diff --git a/sky/provision/runpod/config.py b/sky/provision/runpod/config.py new file mode 100644 index 00000000000..f0d6ca3488d --- /dev/null +++ b/sky/provision/runpod/config.py @@ -0,0 +1,11 @@ +"""Runpod configuration bootstrapping.""" + +from sky.provision import common + + +def bootstrap_instances( + region: str, cluster_name: str, + config: common.ProvisionConfig) -> common.ProvisionConfig: + """Bootstraps instances for the given cluster.""" + del region, cluster_name # unused + return config diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py new file mode 100644 index 00000000000..6fa7fff1c93 --- /dev/null +++ b/sky/provision/runpod/instance.py @@ -0,0 +1,207 @@ +"""GCP instance provisioning.""" +import os +import time +from typing import Any, Dict, List, Optional + +from sky import sky_logging +from sky import status_lib +from sky.provision import common +from sky.provision.runpod import utils +from sky.utils import command_runner +from sky.utils import subprocess_utils + +POLL_INTERVAL = 5 +PRIVATE_SSH_KEY_PATH = '~/.ssh/sky-key' + +_GET_INTERNAL_IP_CMD = r'ip -4 -br addr show | grep UP | grep -Eo "(10\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)|172\.(1[6-9]|2[0-9][0-9]?|3[0-1]))\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"' # pylint: disable=line-too-long + +logger = sky_logging.init_logger(__name__) + + +def _filter_instances(cluster_name_on_cloud: str, + status_filters: Optional[List[str]]) -> Dict[str, Any]: + + def _get_internal_ip(node: Dict[str, Any]): + # TODO(ewzeng): cache internal ips in metadata file to reduce + # ssh overhead. + if node['ip'] is None: + node['internal_ip'] = None + return + runner = command_runner.SSHCommandRunner( + node['ip'], + 'root', + os.path.expanduser(PRIVATE_SSH_KEY_PATH), + port=node['ssh_port']) + rc, stdout, stderr = runner.run(_GET_INTERNAL_IP_CMD, + require_outputs=True, + stream_logs=False) + subprocess_utils.handle_returncode( + rc, + _GET_INTERNAL_IP_CMD, + 'Failed get obtain private IP from node', + stderr=stdout + stderr) + node['internal_ip'] = stdout.strip() + + instances = utils.list_instances() + possible_names = [ + f'{cluster_name_on_cloud}-head', f'{cluster_name_on_cloud}-worker' + ] + + filtered_nodes = {} + for instance_id, instance in instances.items(): + if status_filters is not None and instance[ + 'status'] not in status_filters: + continue + if instance.get('name') in possible_names: + filtered_nodes[instance_id] = instance + subprocess_utils.run_in_parallel(_get_internal_ip, + list(filtered_nodes.values())) + return filtered_nodes + + +def _get_head_instance_id(instances: Dict[str, Any]) -> Optional[str]: + head_instance_id = None + for inst_id, inst in instances.items(): + if inst['name'].endswith('-head'): + head_instance_id = inst_id + break + return head_instance_id + + +def run_instances(region: str, cluster_name_on_cloud: str, + config: common.ProvisionConfig) -> common.ProvisionRecord: + """Runs instances for the given cluster.""" + + pending_status = ['CREATED', 'RESTARTING'] + + while True: + instances = _filter_instances(cluster_name_on_cloud, pending_status) + if not instances: + break + logger.info(f'Waiting for {len(instances)} instances to be ready.') + time.sleep(POLL_INTERVAL) + exist_instances = _filter_instances(cluster_name_on_cloud, ['RUNNING']) + head_instance_id = _get_head_instance_id(exist_instances) + + to_start_count = config.count - len(exist_instances) + if to_start_count < 0: + raise RuntimeError( + f'Cluster {cluster_name_on_cloud} already has ' + f'{len(exist_instances)} nodes, but {config.count} are required.') + if to_start_count == 0: + if head_instance_id is None: + raise RuntimeError( + f'Cluster {cluster_name_on_cloud} has no head node.') + logger.info(f'Cluster {cluster_name_on_cloud} already has ' + f'{len(exist_instances)} nodes, no need to start more.') + return common.ProvisionRecord(provider_name='runpod', + cluster_name=cluster_name_on_cloud, + region=region, + zone=None, + head_instance_id=head_instance_id, + resumed_instance_ids=[], + created_instance_ids=[]) + + created_instance_ids = [] + for _ in range(to_start_count): + node_type = 'head' if head_instance_id is None else 'worker' + instance_id = utils.launch( + name=f'{cluster_name_on_cloud}-{node_type}', + instance_type=config.node_config['InstanceType'], + region=region) + logger.info(f'Launched instance {instance_id}.') + created_instance_ids.append(instance_id) + if head_instance_id is None: + head_instance_id = instance_id + assert head_instance_id is not None, 'head_instance_id should not be None' + return common.ProvisionRecord(provider_name='runpod', + cluster_name=cluster_name_on_cloud, + region=region, + zone=None, + head_instance_id=head_instance_id, + resumed_instance_ids=[], + created_instance_ids=created_instance_ids) + + +def wait_instances(region: str, cluster_name_on_cloud: str, + state: Optional[status_lib.ClusterStatus]) -> None: + del region, cluster_name_on_cloud, state + + +def stop_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + worker_only: bool = False, +) -> None: + raise NotImplementedError() + + +def terminate_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + worker_only: bool = False, +) -> None: + """See sky/provision/__init__.py""" + assert provider_config is not None, (cluster_name_on_cloud, provider_config) + instances = _filter_instances(cluster_name_on_cloud, None) + for inst_id, inst in instances.items(): + if worker_only and inst['name'].endswith('-head'): + continue + utils.remove(inst_id) + + +def get_cluster_info( + region: str, + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: + del region, provider_config + nodes = _filter_instances(cluster_name_on_cloud, ['RUNNING']) + instances: Dict[str, common.InstanceInfo] = {} + head_instance_id = None + for node_id, node_info in nodes.items(): + instances[node_id] = common.InstanceInfo( + instance_id=node_id, + internal_ip=node_info['internal_ip'], + external_ip=node_info['ip'], + ssh_port=node_info['ssh_port'], + tags={}, + ) + if node_info['name'].endswith('-head'): + head_instance_id = node_id + + return common.ClusterInfo( + instances=instances, + head_instance_id=head_instance_id, + ) + + +def query_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + non_terminated_only: bool = True, +) -> Dict[str, Optional[status_lib.ClusterStatus]]: + """See sky/provision/__init__.py""" + assert provider_config is not None, (cluster_name_on_cloud, provider_config) + instances = _filter_instances(cluster_name_on_cloud, None) + + status_map = { + 'CREATED': status_lib.ClusterStatus.INIT, + 'RESTARTING': status_lib.ClusterStatus.INIT, + 'PAUSED': status_lib.ClusterStatus.INIT, + 'RUNNING': status_lib.ClusterStatus.UP, + } + statuses: Dict[str, Optional[status_lib.ClusterStatus]] = {} + for inst_id, inst in instances.items(): + status = status_map[inst['status']] + if non_terminated_only and status is None: + continue + statuses[inst_id] = status + return statuses + + +def cleanup_ports( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, +) -> None: + del cluster_name_on_cloud, provider_config + pass diff --git a/sky/provision/runpod/utils.py b/sky/provision/runpod/utils.py new file mode 100644 index 00000000000..12a3ef581b7 --- /dev/null +++ b/sky/provision/runpod/utils.py @@ -0,0 +1,180 @@ +"""RunPod library wrapper for SkyPilot.""" + +import json +import os +from pathlib import Path +import time +from typing import Dict, Optional + +from sky import sky_logging +from sky.adaptors import runpod +from sky.skylet import constants +from sky.utils import common_utils + +logger = sky_logging.init_logger(__name__) + +GPU_NAME_MAP = { + 'A100-80GB': 'NVIDIA A100 80GB PCIe', + 'A100-40GB': 'NVIDIA A100-PCIE-40GB', + 'A100-80GB-SXM4': 'NVIDIA A100-SXM4-80GB', + 'A30': 'NVIDIA A30', + 'A40': 'NVIDIA A40', + 'RTX3070': 'NVIDIA GeForce RTX 3070', + 'RTX3080': 'NVIDIA GeForce RTX 3080', + 'RTX3080Ti': 'NVIDIA GeForce RTX 3080 Ti', + 'RTX3090': 'NVIDIA GeForce RTX 3090', + 'RTX3090Ti': 'NVIDIA GeForce RTX 3090 Ti', + 'RTX4070Ti': 'NVIDIA GeForce RTX 4070 Ti', + 'RTX4080': 'NVIDIA GeForce RTX 4080', + 'RTX4090': 'NVIDIA GeForce RTX 4090', + 'H100-80GB-HBM3': 'NVIDIA H100 80GB HBM3', + 'H100-PCIe': 'NVIDIA H100 PCIe', + 'L4': 'NVIDIA L4', + 'L40': 'NVIDIA L40', + 'RTX4000-Ada-SFF': 'NVIDIA RTX 4000 SFF Ada Generation', + 'RTX6000-Ada': 'NVIDIA RTX 6000 Ada Generation', + 'RTXA4000': 'NVIDIA RTX A4000', + 'RTXA4500': 'NVIDIA RTX A4500', + 'RTXA5000': 'NVIDIA RTX A5000', + 'RTXA6000': 'NVIDIA RTX A6000', + 'RTX5000': 'Quadro RTX 5000', + 'V100-16GB-FHHL': 'Tesla V100-FHHL-16GB', + 'V100-16GB-SXM2': 'V100-SXM2-16GB', + 'RTXA2000': 'NVIDIA RTX A2000', + 'V100-16GB-PCIe': 'Tesla V100-PCIE-16GB' +} + + +def retry(func): + """Decorator to retry a function.""" + + def wrapper(*args, **kwargs): + """Wrapper for retrying a function.""" + cnt = 0 + while True: + try: + return func(*args, **kwargs) + except runpod.runpod().error.QueryError as e: + if cnt >= 3: + raise + logger.warning('Retrying for exception: ' + f'{common_utils.format_exception(e)}.') + time.sleep(1) + + return wrapper + + +def get_set_tags(instance_id: str, new_tags: Optional[Dict]) -> Dict: + """Gets the tags for the given instance. + - Creates the tag file if it doesn't exist. + - Returns the tags for the given instance. + - If tags are provided, sets the tags for the given instance. + """ + tag_file_path = os.path.expanduser('~/.runpod/skypilot_tags.json') + + # Ensure the tag file exists, create it if it doesn't. + if not os.path.exists(tag_file_path): + Path(os.path.dirname(tag_file_path)).mkdir(parents=True, exist_ok=True) + with open(tag_file_path, 'w', encoding='UTF-8') as tag_file: + json.dump({}, tag_file, indent=4) + + # Read existing tags + with open(tag_file_path, 'r', encoding='UTF-8') as tag_file: + tags = json.load(tag_file) + + if tags is None: + tags = {} + + # If new_tags is provided, update the tags for the instance + if new_tags: + instance_tags = tags.get(instance_id, {}) + instance_tags.update(new_tags) + tags[instance_id] = instance_tags + with open(tag_file_path, 'w', encoding='UTF-8') as tag_file: + json.dump(tags, tag_file, indent=4) + + return tags.get(instance_id, {}) + + +@retry +def list_instances(): + """Lists instances associated with API key.""" + instances = runpod.runpod().get_pods() + + instance_list = {} + for instance in instances: + instance_list[instance['id']] = {} + + instance_list[instance['id']]['status'] = instance['desiredStatus'] + instance_list[instance['id']]['name'] = instance['name'] + + if instance['desiredStatus'] == 'RUNNING' and instance.get('runtime'): + for port in instance['runtime']['ports']: + if port['privatePort'] == 22 and port['isIpPublic']: + instance_list[instance['id']]['ip'] = port['ip'] + instance_list[ + instance['id']]['ssh_port'] = port['publicPort'] + + instance_list[instance['id']]['tags'] = get_set_tags( + instance['id'], None) + + return instance_list + + +def launch(name: str, instance_type: str, region: str): + """Launches an instance with the given parameters. + + Converts the instance_type to the RunPod GPU name, finds the specs for the + GPU, and launches the instance. + """ + gpu_type = GPU_NAME_MAP[instance_type.split('_')[1]] + gpu_quantity = int(instance_type.split('_')[0].replace('x', '')) + cloud_type = instance_type.split('_')[2] + + gpu_specs = runpod.runpod().get_gpu(gpu_type) + + new_instance = runpod.runpod().create_pod( + name=name, + image_name='runpod/base:0.0.2', + gpu_type_id=gpu_type, + cloud_type=cloud_type, + container_disk_in_gb=50, + min_vcpu_count=4 * gpu_quantity, + min_memory_in_gb=gpu_specs['memoryInGb'] * gpu_quantity, + country_code=region, + ports=(f'22/tcp,' + f'{constants.SKY_REMOTE_RAY_DASHBOARD_PORT}/http,' + f'{constants.SKY_REMOTE_RAY_PORT}/tcp'), + support_public_ip=True, + ) + + return new_instance['id'] + + +def set_tags(instance_id: str, tags: Dict): + """Sets the tags for the given instance.""" + get_set_tags(instance_id, tags) + + +@retry +def remove(instance_id: str): + """Terminates the given instance.""" + runpod.runpod().terminate_pod(instance_id) + + +def get_ssh_ports(cluster_name): + """Gets the SSH ports for the given cluster.""" + logger.debug(f'Getting SSH ports for cluster {cluster_name}.') + + instances = list_instances() + possible_names = [f'{cluster_name}-head', f'{cluster_name}-worker'] + + ssh_ports = [] + + for instance in instances.values(): + if instance['name'] in possible_names: + ssh_ports.append(instance['ssh_port']) + assert ssh_ports, ( + f'Could not find any instances for cluster {cluster_name}.') + + return ssh_ports diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py index 5cbc801b2e8..ebecdcfe0d5 100644 --- a/sky/setup_files/setup.py +++ b/sky/setup_files/setup.py @@ -234,13 +234,11 @@ def parse_readme(readme: str) -> str: 'oci': ['oci'] + local_ray, 'kubernetes': ['kubernetes'] + local_ray, 'remote': remote, + 'runpod': ['runpod>=1.3.7'] } extras_require['all'] = sum(extras_require.values(), []) -# Install aws requirements by default, as it is the most common cloud provider, -# and the installation is quick. -install_requires += extras_require['aws'] long_description = '' readme_filepath = 'README.md' diff --git a/sky/templates/runpod-ray.yml.j2 b/sky/templates/runpod-ray.yml.j2 new file mode 100644 index 00000000000..846d801a982 --- /dev/null +++ b/sky/templates/runpod-ray.yml.j2 @@ -0,0 +1,115 @@ +cluster_name: {{cluster_name_on_cloud}} + +# The maximum number of workers nodes to launch in addition to the head node. +max_workers: {{num_nodes - 1}} +upscaling_speed: {{num_nodes - 1}} +idle_timeout_minutes: 60 + +provider: + type: external + module: sky.skylet.providers.runpod.RunPodNodeProvider + region: "{{region}}" + disable_launch_config_check: true + +auth: + ssh_user: root + ssh_private_key: {{ssh_private_key}} + +available_node_types: + ray_head_default: + resources: {} + node_config: + InstanceType: {{instance_type}} +{% if num_nodes > 1 %} + ray_worker_default: + min_workers: {{num_nodes - 1}} + max_workers: {{num_nodes - 1}} + resources: {} + node_config: + InstanceType: {{instance_type}} +{%- endif %} + +head_node_type: ray_head_default + +# Format: `REMOTE_PATH : LOCAL_PATH` +file_mounts: { + "{{sky_ray_yaml_remote_path}}": "{{sky_ray_yaml_local_path}}", + "{{sky_remote_path}}/{{sky_wheel_hash}}": "{{sky_local_path}}", +{%- for remote_path, local_path in credentials.items() %} + "{{remote_path}}": "{{local_path}}", +{%- endfor %} +} + +rsync_exclude: [] + +initialization_commands: [] + +# List of shell commands to run to set up nodes. +# NOTE: these are very performance-sensitive. Each new item opens/closes an SSH +# connection, which is expensive. Try your best to co-locate commands into fewer +# items! +# +# Increment the following for catching performance bugs easier: +# current num items (num SSH connections): 1 +setup_commands: + # Disable `unattended-upgrades` to prevent apt-get from hanging. It should be called at the beginning before the process started to avoid being blocked. (This is a temporary fix.) + # Create ~/.ssh/config file in case the file does not exist in the image. + # Line 'rm ..': there is another installation of pip. + # Line 'sudo bash ..': set the ulimit as suggested by ray docs for performance. https://docs.ray.io/en/latest/cluster/vms/user-guides/large-cluster-best-practices.html#system-configuration + # Line 'sudo grep ..': set the number of threads per process to unlimited to avoid ray job submit stucking issue when the number of running ray jobs increase. + # Line 'mkdir -p ..': disable host key check + # Line 'python3 -c ..': patch the buggy ray files and enable `-o allow_other` option for `goofys` + - sudo systemctl stop unattended-upgrades || true; + sudo systemctl disable unattended-upgrades || true; + sudo sed -i 's/Unattended-Upgrade "1"/Unattended-Upgrade "0"/g' /etc/apt/apt.conf.d/20auto-upgrades || true; + sudo kill -9 `sudo lsof /var/lib/dpkg/lock-frontend | awk '{print $2}' | tail -n 1` || true; + sudo pkill -9 apt-get; + sudo pkill -9 dpkg; + sudo dpkg --configure -a; + mkdir -p ~/.ssh; touch ~/.ssh/config; + {{ conda_installation_commands }} + (type -a python | grep -q python3) || echo 'alias python=python3' >> ~/.bashrc; + (type -a pip | grep -q pip3) || echo 'alias pip=pip3' >> ~/.bashrc; + source ~/.bashrc; + (pip3 list | grep ray | grep {{ray_version}} 2>&1 > /dev/null || pip3 install -U ray[default]=={{ray_version}}) && mkdir -p ~/sky_workdir && mkdir -p ~/.sky/sky_app && touch ~/.sudo_as_admin_successful; + (pip3 list | grep skypilot && [ "$(cat {{sky_remote_path}}/current_sky_wheel_hash)" == "{{sky_wheel_hash}}" ]) || (pip3 uninstall skypilot -y; pip3 install "$(echo {{sky_remote_path}}/{{sky_wheel_hash}}/skypilot-{{sky_version}}*.whl)[runpod,remote]" && echo "{{sky_wheel_hash}}" > {{sky_remote_path}}/current_sky_wheel_hash || exit 1); + sudo bash -c 'rm -rf /etc/security/limits.d; echo "* soft nofile 1048576" >> /etc/security/limits.conf; echo "* hard nofile 1048576" >> /etc/security/limits.conf'; + sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || (sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf'); sudo systemctl set-property user-$(id -u $(whoami)).slice TasksMax=infinity; sudo systemctl daemon-reload; + mkdir -p ~/.ssh; (grep -Pzo -q "Host \*\n StrictHostKeyChecking no" ~/.ssh/config) || printf "Host *\n StrictHostKeyChecking no\n" >> ~/.ssh/config; + python3 -c "from sky.skylet.ray_patches import patch; patch()" || exit 1; + [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); + +# Command to start ray on the head node. You don't need to change this. +# NOTE: these are very performance-sensitive. Each new item opens/closes an SSH +# connection, which is expensive. Try your best to co-locate commands into fewer +# items! The same comment applies for worker_start_ray_commands. +# +# Increment the following for catching performance bugs easier: +# current num items (num SSH connections): 1 +head_start_ray_commands: + # NOTE: --disable-usage-stats in `ray start` saves 10 seconds of idle wait. + # Line "which prlimit ..": increase the limit of the number of open files for the raylet process, as the `ulimit` may not take effect at this point, because it requires + # all the sessions to be reloaded. This is a workaround. + - export SKYPILOT_NUM_GPUS=0 && which nvidia-smi > /dev/null && SKYPILOT_NUM_GPUS=$(nvidia-smi --query-gpu=index,name --format=csv,noheader | wc -l); + ray stop; RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 ray start --disable-usage-stats --head --port={{ray_port}} --dashboard-port={{ray_dashboard_port}} --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml {{"--resources='%s'" % custom_resources if custom_resources}} --num-gpus=$SKYPILOT_NUM_GPUS --temp-dir {{ray_temp_dir}} || exit 1; + which prlimit && for id in $(pgrep -f raylet/raylet); do sudo prlimit --nofile=1048576:1048576 --pid=$id || true; done; + {{dump_port_command}}; + +# Worker commands are needed for TPU VM Pods +{%- if num_nodes > 1 or tpu_vm %} +worker_start_ray_commands: + - SKYPILOT_NUM_GPUS=0 && which nvidia-smi > /dev/null && SKYPILOT_NUM_GPUS=$(nvidia-smi --query-gpu=index,name --format=csv,noheader | wc -l); + ray stop; RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 ray start --disable-usage-stats --address=$RAY_HEAD_IP:{{ray_port}} --object-manager-port=8076 {{"--resources='%s'" % custom_resources if custom_resources}} --num-gpus=$SKYPILOT_NUM_GPUS --temp-dir {{ray_temp_dir}} || exit 1; + which prlimit && for id in $(pgrep -f raylet/raylet); do sudo prlimit --nofile=1048576:1048576 --pid=$id || true; done; +{%- else %} +worker_start_ray_commands: [] +{%- endif %} + +head_node: {} +worker_nodes: {} + +# These fields are required for external cloud providers. +head_setup_commands: [] +worker_setup_commands: [] +cluster_synced_files: [] +file_mounts_sync_continuously: False From 533afb0575536aac4f26540fa3708aecf1b09d4c Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 1 Dec 2023 06:58:34 +0000 Subject: [PATCH 55/84] Add wait for the instances to be ready --- sky/provision/runpod/instance.py | 17 ++++++++++++++++- sky/setup_files/setup.py | 1 - 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py index 6fa7fff1c93..274a0dac119 100644 --- a/sky/provision/runpod/instance.py +++ b/sky/provision/runpod/instance.py @@ -24,7 +24,8 @@ def _filter_instances(cluster_name_on_cloud: str, def _get_internal_ip(node: Dict[str, Any]): # TODO(ewzeng): cache internal ips in metadata file to reduce # ssh overhead. - if node['ip'] is None: + if node.get('ip') is None: + node['ip'] = None node['internal_ip'] = None return runner = command_runner.SSHCommandRunner( @@ -113,6 +114,20 @@ def run_instances(region: str, cluster_name_on_cloud: str, created_instance_ids.append(instance_id) if head_instance_id is None: head_instance_id = instance_id + + # Wait for instances to be ready. + while True: + instances = _filter_instances(cluster_name_on_cloud, ['RUNNING']) + ready_instance_cnt = 0 + for instance_id, instance in instances.items(): + if instance.get('ssh_port') is not None: + ready_instance_cnt += 1 + if ready_instance_cnt == config.count: + break + + logger.info('Waiting for instances to be ready ' + f'({len(instances)}/{config.count}).') + time.sleep(POLL_INTERVAL) assert head_instance_id is not None, 'head_instance_id should not be None' return common.ProvisionRecord(provider_name='runpod', cluster_name=cluster_name_on_cloud, diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py index ebecdcfe0d5..3ecbeacedf2 100644 --- a/sky/setup_files/setup.py +++ b/sky/setup_files/setup.py @@ -239,7 +239,6 @@ def parse_readme(readme: str) -> str: extras_require['all'] = sum(extras_require.values(), []) - long_description = '' readme_filepath = 'README.md' # When sky/backends/wheel_utils.py builds wheels, it will not contain the From 389bd214a140db11f84d9edc82e316eb213d0f4c Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 1 Dec 2023 07:02:20 +0000 Subject: [PATCH 56/84] fix setup --- sky/clouds/runpod.py | 2 +- sky/setup_files/setup.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py index 49254d9cd4b..53841e41da9 100644 --- a/sky/clouds/runpod.py +++ b/sky/clouds/runpod.py @@ -24,7 +24,7 @@ class RunPod(clouds.Cloud): _REPR = 'RunPod' _CLOUD_UNSUPPORTED_FEATURES = { clouds.CloudImplementationFeatures.AUTOSTOP: 'Stopping not supported.', - clouds.CloudImplementationFeatures.MULTI_NODE: 'Multi-node unsupported.', # pylint: disable=line-too-long + clouds.CloudImplementationFeatures.STOP: 'Stopping not supported.', } _MAX_CLUSTER_NAME_LEN_LIMIT = 120 _regions: List[clouds.Region] = [] diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py index 3ecbeacedf2..854e9eb4238 100644 --- a/sky/setup_files/setup.py +++ b/sky/setup_files/setup.py @@ -135,9 +135,7 @@ def parse_readme(readme: str) -> str: 'cachetools', # NOTE: ray requires click>=7.0. 'click >= 7.0', - # NOTE: required by awscli. To avoid ray automatically installing - # the latest version. - 'colorama < 0.4.5', + 'colorama', 'cryptography', # Jinja has a bug in older versions because of the lack of pinning # the version of the underlying markupsafe package. See: @@ -207,6 +205,9 @@ def parse_readme(readme: str) -> str: 'awscli>=1.27.10', 'botocore>=1.29.10', 'boto3>=1.26.1', + # NOTE: required by awscli. To avoid ray automatically installing + # the latest version. + 'colorama < 0.4.5', ] extras_require: Dict[str, List[str]] = { 'aws': aws_dependencies, From b797e3bb0fa5bea34f4ebb32fd81963e91c7d80f Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 1 Dec 2023 07:26:19 +0000 Subject: [PATCH 57/84] Retry and give for getting internal IP --- sky/provision/runpod/instance.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py index 274a0dac119..acccf4d2aed 100644 --- a/sky/provision/runpod/instance.py +++ b/sky/provision/runpod/instance.py @@ -33,9 +33,23 @@ def _get_internal_ip(node: Dict[str, Any]): 'root', os.path.expanduser(PRIVATE_SSH_KEY_PATH), port=node['ssh_port']) - rc, stdout, stderr = runner.run(_GET_INTERNAL_IP_CMD, - require_outputs=True, - stream_logs=False) + retry_cnt = 0 + while True: + rc, stdout, stderr = runner.run(_GET_INTERNAL_IP_CMD, + require_outputs=True, + stream_logs=False) + if not rc or rc != 255: + break + if retry_cnt >= 3: + if rc != 255: + break + # If we fail to connect the node for 3 times: + # 1. The node is terminated. + # 2. We are on the same node as the node we are trying to + # connect to, and runpod does not allow ssh to itself. + node['internal_ip'] = None + return + time.sleep(1) subprocess_utils.handle_returncode( rc, _GET_INTERNAL_IP_CMD, From 4ef8a591adda293414846358d7900d70d7f448ea Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 1 Dec 2023 07:27:49 +0000 Subject: [PATCH 58/84] comment --- sky/provision/runpod/instance.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py index acccf4d2aed..b9aed2a53b9 100644 --- a/sky/provision/runpod/instance.py +++ b/sky/provision/runpod/instance.py @@ -43,10 +43,11 @@ def _get_internal_ip(node: Dict[str, Any]): if retry_cnt >= 3: if rc != 255: break - # If we fail to connect the node for 3 times: + # If we fail to connect the node for 3 times, it is likely that: # 1. The node is terminated. # 2. We are on the same node as the node we are trying to # connect to, and runpod does not allow ssh to itself. + # In both cases, we can safely set the internal ip to None. node['internal_ip'] = None return time.sleep(1) From 184c1dec7c430659adef84112a1c9388ea8f5bec Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 1 Dec 2023 08:48:04 +0000 Subject: [PATCH 59/84] Remove internal IP --- sky/backends/cloud_vm_ray_backend.py | 30 ------------ sky/provision/runpod/instance.py | 73 +++++++--------------------- sky/provision/runpod/utils.py | 14 +++--- sky/templates/runpod-ray.yml.j2 | 9 +--- 4 files changed, 27 insertions(+), 99 deletions(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 49e23cff37a..82ae663da64 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -1171,35 +1171,6 @@ def _update_blocklist_on_oci_error( self._blocked_resources.add( launchable_resources.copy(zone=zone.name)) - def _update_blocklist_on_runpod_error( - self, launchable_resources: 'resources_lib.Resources', - region: 'clouds.Region', zones: Optional[List['clouds.Zone']], - stdout: str, stderr: str): - del zones # Unused. - style = colorama.Style - stdout_splits = stdout.split('\n') - stderr_splits = stderr.split('\n') - errors = [ - s.strip() - for s in stdout_splits + stderr_splits - if any(err in s.strip() - for err in ['runpod.error.QueryError:', 'RunPodError:']) - ] - if not errors: - logger.info('====== stdout ======') - for s in stdout_splits: - print(s) - logger.info('====== stderr ======') - for s in stderr_splits: - print(s) - with ux_utils.print_exception_no_traceback(): - raise RuntimeError('Errors occurred during provision; ' - 'check logs above.') - - logger.warning(f'Got error(s) in {region.name}:') - messages = '\n\t'.join(errors) - logger.warning(f'{style.DIM}\t{messages}{style.RESET_ALL}') - self._blocked_resources.add(launchable_resources.copy(zone=None)) def _update_blocklist_on_error( self, launchable_resources: 'resources_lib.Resources', @@ -1240,7 +1211,6 @@ def _update_blocklist_on_error( clouds.Local: self._update_blocklist_on_local_error, clouds.Kubernetes: self._update_blocklist_on_kubernetes_error, clouds.OCI: self._update_blocklist_on_oci_error, - clouds.RunPod: self._update_blocklist_on_runpod_error, } cloud = launchable_resources.cloud cloud_type = type(cloud) diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py index b9aed2a53b9..fd6cff3ed1a 100644 --- a/sky/provision/runpod/instance.py +++ b/sky/provision/runpod/instance.py @@ -1,5 +1,4 @@ """GCP instance provisioning.""" -import os import time from typing import Any, Dict, List, Optional @@ -7,13 +6,9 @@ from sky import status_lib from sky.provision import common from sky.provision.runpod import utils -from sky.utils import command_runner -from sky.utils import subprocess_utils +from sky.utils import common_utils POLL_INTERVAL = 5 -PRIVATE_SSH_KEY_PATH = '~/.ssh/sky-key' - -_GET_INTERNAL_IP_CMD = r'ip -4 -br addr show | grep UP | grep -Eo "(10\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)|172\.(1[6-9]|2[0-9][0-9]?|3[0-1]))\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"' # pylint: disable=line-too-long logger = sky_logging.init_logger(__name__) @@ -21,43 +16,6 @@ def _filter_instances(cluster_name_on_cloud: str, status_filters: Optional[List[str]]) -> Dict[str, Any]: - def _get_internal_ip(node: Dict[str, Any]): - # TODO(ewzeng): cache internal ips in metadata file to reduce - # ssh overhead. - if node.get('ip') is None: - node['ip'] = None - node['internal_ip'] = None - return - runner = command_runner.SSHCommandRunner( - node['ip'], - 'root', - os.path.expanduser(PRIVATE_SSH_KEY_PATH), - port=node['ssh_port']) - retry_cnt = 0 - while True: - rc, stdout, stderr = runner.run(_GET_INTERNAL_IP_CMD, - require_outputs=True, - stream_logs=False) - if not rc or rc != 255: - break - if retry_cnt >= 3: - if rc != 255: - break - # If we fail to connect the node for 3 times, it is likely that: - # 1. The node is terminated. - # 2. We are on the same node as the node we are trying to - # connect to, and runpod does not allow ssh to itself. - # In both cases, we can safely set the internal ip to None. - node['internal_ip'] = None - return - time.sleep(1) - subprocess_utils.handle_returncode( - rc, - _GET_INTERNAL_IP_CMD, - 'Failed get obtain private IP from node', - stderr=stdout + stderr) - node['internal_ip'] = stdout.strip() - instances = utils.list_instances() possible_names = [ f'{cluster_name_on_cloud}-head', f'{cluster_name_on_cloud}-worker' @@ -65,13 +23,11 @@ def _get_internal_ip(node: Dict[str, Any]): filtered_nodes = {} for instance_id, instance in instances.items(): - if status_filters is not None and instance[ - 'status'] not in status_filters: + if (status_filters is not None and + instance['status'] not in status_filters): continue if instance.get('name') in possible_names: filtered_nodes[instance_id] = instance - subprocess_utils.run_in_parallel(_get_internal_ip, - list(filtered_nodes.values())) return filtered_nodes @@ -121,10 +77,15 @@ def run_instances(region: str, cluster_name_on_cloud: str, created_instance_ids = [] for _ in range(to_start_count): node_type = 'head' if head_instance_id is None else 'worker' - instance_id = utils.launch( - name=f'{cluster_name_on_cloud}-{node_type}', - instance_type=config.node_config['InstanceType'], - region=region) + try: + instance_id = utils.launch( + name=f'{cluster_name_on_cloud}-{node_type}', + instance_type=config.node_config['InstanceType'], + region=region, + disk_size=config.node_config['DiskSize']) + except Exception as e: # pylint: disable=broad-except + logger.warning(f'run_instances error: {e}') + raise logger.info(f'Launched instance {instance_id}.') created_instance_ids.append(instance_id) if head_instance_id is None: @@ -172,11 +133,14 @@ def terminate_instances( worker_only: bool = False, ) -> None: """See sky/provision/__init__.py""" - assert provider_config is not None, (cluster_name_on_cloud, provider_config) + del provider_config instances = _filter_instances(cluster_name_on_cloud, None) for inst_id, inst in instances.items(): + logger.info(f'Terminating instance {inst_id}.' + f'{inst}') if worker_only and inst['name'].endswith('-head'): continue + logger.info(f'Start {inst_id}: {inst}') utils.remove(inst_id) @@ -191,8 +155,8 @@ def get_cluster_info( for node_id, node_info in nodes.items(): instances[node_id] = common.InstanceInfo( instance_id=node_id, - internal_ip=node_info['internal_ip'], - external_ip=node_info['ip'], + internal_ip=node_info['external_ip'], + external_ip=node_info['external_ip'], ssh_port=node_info['ssh_port'], tags={}, ) @@ -234,4 +198,3 @@ def cleanup_ports( provider_config: Optional[Dict[str, Any]] = None, ) -> None: del cluster_name_on_cloud, provider_config - pass diff --git a/sky/provision/runpod/utils.py b/sky/provision/runpod/utils.py index 12a3ef581b7..baa11aafa4d 100644 --- a/sky/provision/runpod/utils.py +++ b/sky/provision/runpod/utils.py @@ -96,7 +96,7 @@ def get_set_tags(instance_id: str, new_tags: Optional[Dict]) -> Dict: return tags.get(instance_id, {}) -@retry +# @retry def list_instances(): """Lists instances associated with API key.""" instances = runpod.runpod().get_pods() @@ -111,9 +111,11 @@ def list_instances(): if instance['desiredStatus'] == 'RUNNING' and instance.get('runtime'): for port in instance['runtime']['ports']: if port['privatePort'] == 22 and port['isIpPublic']: - instance_list[instance['id']]['ip'] = port['ip'] + instance_list[instance['id']]['external_ip'] = port['ip'] instance_list[ instance['id']]['ssh_port'] = port['publicPort'] + elif not port['isIpPublic']: + instance_list[instance['id']]['internal_ip'] = port['ip'] instance_list[instance['id']]['tags'] = get_set_tags( instance['id'], None) @@ -121,7 +123,7 @@ def list_instances(): return instance_list -def launch(name: str, instance_type: str, region: str): +def launch(name: str, instance_type: str, region: str, disk_size: int): """Launches an instance with the given parameters. Converts the instance_type to the RunPod GPU name, finds the specs for the @@ -138,13 +140,13 @@ def launch(name: str, instance_type: str, region: str): image_name='runpod/base:0.0.2', gpu_type_id=gpu_type, cloud_type=cloud_type, - container_disk_in_gb=50, + container_disk_in_gb=disk_size, min_vcpu_count=4 * gpu_quantity, min_memory_in_gb=gpu_specs['memoryInGb'] * gpu_quantity, country_code=region, ports=(f'22/tcp,' f'{constants.SKY_REMOTE_RAY_DASHBOARD_PORT}/http,' - f'{constants.SKY_REMOTE_RAY_PORT}/tcp'), + f'{constants.SKY_REMOTE_RAY_PORT}/http'), support_public_ip=True, ) @@ -156,7 +158,7 @@ def set_tags(instance_id: str, tags: Dict): get_set_tags(instance_id, tags) -@retry +# @retry def remove(instance_id: str): """Terminates the given instance.""" runpod.runpod().terminate_pod(instance_id) diff --git a/sky/templates/runpod-ray.yml.j2 b/sky/templates/runpod-ray.yml.j2 index 846d801a982..182834dd979 100644 --- a/sky/templates/runpod-ray.yml.j2 +++ b/sky/templates/runpod-ray.yml.j2 @@ -20,14 +20,7 @@ available_node_types: resources: {} node_config: InstanceType: {{instance_type}} -{% if num_nodes > 1 %} - ray_worker_default: - min_workers: {{num_nodes - 1}} - max_workers: {{num_nodes - 1}} - resources: {} - node_config: - InstanceType: {{instance_type}} -{%- endif %} + DiskSize: {{disk_size}} head_node_type: ray_head_default From 8162e84ee44ae37b1126c71a35c8c802765b1ccf Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 1 Dec 2023 08:53:51 +0000 Subject: [PATCH 60/84] use external IP TODO: use external ray port --- sky/provision/runpod/instance.py | 2 ++ sky/provision/runpod/utils.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py index fd6cff3ed1a..d36a43e7e2e 100644 --- a/sky/provision/runpod/instance.py +++ b/sky/provision/runpod/instance.py @@ -155,6 +155,8 @@ def get_cluster_info( for node_id, node_info in nodes.items(): instances[node_id] = common.InstanceInfo( instance_id=node_id, + # We have to use external IP as we are not able to connect to the + # ray cluster using internal IP. internal_ip=node_info['external_ip'], external_ip=node_info['external_ip'], ssh_port=node_info['ssh_port'], diff --git a/sky/provision/runpod/utils.py b/sky/provision/runpod/utils.py index baa11aafa4d..089807d5f97 100644 --- a/sky/provision/runpod/utils.py +++ b/sky/provision/runpod/utils.py @@ -145,7 +145,7 @@ def launch(name: str, instance_type: str, region: str, disk_size: int): min_memory_in_gb=gpu_specs['memoryInGb'] * gpu_quantity, country_code=region, ports=(f'22/tcp,' - f'{constants.SKY_REMOTE_RAY_DASHBOARD_PORT}/http,' + f'{constants.SKY_REMOTE_RAY_DASHBOARD_PORT}/tcp,' f'{constants.SKY_REMOTE_RAY_PORT}/http'), support_public_ip=True, ) From eba6787023eb457b02d6b795f92a2f71ddff7cf1 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 1 Dec 2023 09:06:24 +0000 Subject: [PATCH 61/84] fix ssh port --- sky/provision/provisioner.py | 2 +- sky/provision/runpod/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index 07393d3574f..a621ff7200b 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -255,7 +255,7 @@ def _wait_ssh_connection_direct( # following error: # "System is booting up. Unprivileged users are not permitted to # log in yet". - return _wait_ssh_connection_indirect(ip, ssh_user, + return _wait_ssh_connection_indirect(ip, ssh_port, ssh_user, ssh_private_key, ssh_control_name, ssh_proxy_command) diff --git a/sky/provision/runpod/utils.py b/sky/provision/runpod/utils.py index 089807d5f97..545248721c0 100644 --- a/sky/provision/runpod/utils.py +++ b/sky/provision/runpod/utils.py @@ -145,8 +145,8 @@ def launch(name: str, instance_type: str, region: str, disk_size: int): min_memory_in_gb=gpu_specs['memoryInGb'] * gpu_quantity, country_code=region, ports=(f'22/tcp,' - f'{constants.SKY_REMOTE_RAY_DASHBOARD_PORT}/tcp,' - f'{constants.SKY_REMOTE_RAY_PORT}/http'), + f'{constants.SKY_REMOTE_RAY_DASHBOARD_PORT}/http,' + f'{constants.SKY_REMOTE_RAY_PORT}/tcp'), support_public_ip=True, ) From 9bbf3c25121c1c6c399ce27742d78e596ab2b945 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 1 Dec 2023 10:24:31 +0000 Subject: [PATCH 62/84] Unsupported feature --- sky/backends/cloud_vm_ray_backend.py | 1 - sky/clouds/runpod.py | 5 +++++ sky/clouds/service_catalog/runpod_catalog.py | 1 - sky/provision/runpod/instance.py | 5 +---- sky/provision/runpod/utils.py | 4 +--- 5 files changed, 7 insertions(+), 9 deletions(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 82ae663da64..f2c26c81289 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -1171,7 +1171,6 @@ def _update_blocklist_on_oci_error( self._blocked_resources.add( launchable_resources.copy(zone=zone.name)) - def _update_blocklist_on_error( self, launchable_resources: 'resources_lib.Resources', region: 'clouds.Region', zones: Optional[List['clouds.Zone']], diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py index 53841e41da9..e5de5493565 100644 --- a/sky/clouds/runpod.py +++ b/sky/clouds/runpod.py @@ -25,6 +25,11 @@ class RunPod(clouds.Cloud): _CLOUD_UNSUPPORTED_FEATURES = { clouds.CloudImplementationFeatures.AUTOSTOP: 'Stopping not supported.', clouds.CloudImplementationFeatures.STOP: 'Stopping not supported.', + clouds.CloudImplementationFeatures.SPOT_INSTANCE: + ('Spot is not supported, as runpod API does not implement spot .'), + clouds.CloudImplementationFeatures.MULTI_NODE: + ('Multi-node not supported yet, as the interconnection among nodes ' + 'are non-trival on RunPod.'), } _MAX_CLUSTER_NAME_LEN_LIMIT = 120 _regions: List[clouds.Region] = [] diff --git a/sky/clouds/service_catalog/runpod_catalog.py b/sky/clouds/service_catalog/runpod_catalog.py index aa8ea5539af..f8e7b926e92 100644 --- a/sky/clouds/service_catalog/runpod_catalog.py +++ b/sky/clouds/service_catalog/runpod_catalog.py @@ -45,7 +45,6 @@ def get_hourly_cost(instance_type: str, region: Optional[str] = None, zone: Optional[str] = None) -> float: """Returns the cost, or the cheapest cost among all zones for spot.""" - assert not use_spot, 'FluffyCloud does not support spot.' if zone is not None: with ux_utils.print_exception_no_traceback(): raise ValueError('RunPod does not support zones.') diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py index d36a43e7e2e..f009faeb9b9 100644 --- a/sky/provision/runpod/instance.py +++ b/sky/provision/runpod/instance.py @@ -6,7 +6,6 @@ from sky import status_lib from sky.provision import common from sky.provision.runpod import utils -from sky.utils import common_utils POLL_INTERVAL = 5 @@ -155,9 +154,7 @@ def get_cluster_info( for node_id, node_info in nodes.items(): instances[node_id] = common.InstanceInfo( instance_id=node_id, - # We have to use external IP as we are not able to connect to the - # ray cluster using internal IP. - internal_ip=node_info['external_ip'], + internal_ip=node_info['internal_ip'], external_ip=node_info['external_ip'], ssh_port=node_info['ssh_port'], tags={}, diff --git a/sky/provision/runpod/utils.py b/sky/provision/runpod/utils.py index 545248721c0..decc3ac2389 100644 --- a/sky/provision/runpod/utils.py +++ b/sky/provision/runpod/utils.py @@ -96,7 +96,6 @@ def get_set_tags(instance_id: str, new_tags: Optional[Dict]) -> Dict: return tags.get(instance_id, {}) -# @retry def list_instances(): """Lists instances associated with API key.""" instances = runpod.runpod().get_pods() @@ -146,7 +145,7 @@ def launch(name: str, instance_type: str, region: str, disk_size: int): country_code=region, ports=(f'22/tcp,' f'{constants.SKY_REMOTE_RAY_DASHBOARD_PORT}/http,' - f'{constants.SKY_REMOTE_RAY_PORT}/tcp'), + f'{constants.SKY_REMOTE_RAY_PORT}/http'), support_public_ip=True, ) @@ -158,7 +157,6 @@ def set_tags(instance_id: str, tags: Dict): get_set_tags(instance_id, tags) -# @retry def remove(instance_id: str): """Terminates the given instance.""" runpod.runpod().terminate_pod(instance_id) From 472f1f61723eb4cfa627059278a925bd245a8f7e Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 1 Dec 2023 10:25:52 +0000 Subject: [PATCH 63/84] typo --- sky/clouds/runpod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py index e5de5493565..82ba4762a5a 100644 --- a/sky/clouds/runpod.py +++ b/sky/clouds/runpod.py @@ -29,7 +29,7 @@ class RunPod(clouds.Cloud): ('Spot is not supported, as runpod API does not implement spot .'), clouds.CloudImplementationFeatures.MULTI_NODE: ('Multi-node not supported yet, as the interconnection among nodes ' - 'are non-trival on RunPod.'), + 'are non-trivial on RunPod.'), } _MAX_CLUSTER_NAME_LEN_LIMIT = 120 _regions: List[clouds.Region] = [] From 0b582aede7ffee73ba84e73d7f38139d7ef8f680 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 1 Jan 2024 02:35:00 +0000 Subject: [PATCH 64/84] fix ssh ports --- sky/provision/common.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/sky/provision/common.py b/sky/provision/common.py index 27d8c31ed0e..10384ab2f50 100644 --- a/sky/provision/common.py +++ b/sky/provision/common.py @@ -169,10 +169,9 @@ def get_feasible_ips(self, force_internal_ips: bool = False) -> List[str]: def get_ssh_ports(self) -> List[int]: """Get the SSH port of all the instances.""" - head_node_port, other_ports = [], [] - for instance in self.instances.values(): - if instance.instance_id == self.head_instance_id: - head_node_port.append(instance.ssh_port) - else: - other_ports.append(instance.ssh_port) - return head_node_port + other_ports + head_node = self.get_head_instance() + head_node_port = [head_node.ssh_port] + + worker_nodes = self.get_worker_instances() + worker_node_ports = [instance.ssh_port for instance in worker_nodes] + return head_node_port + worker_node_ports From 106eefa13096e4399e46f1803797b927b577eba6 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 1 Jan 2024 02:35:50 +0000 Subject: [PATCH 65/84] rename var --- sky/provision/common.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sky/provision/common.py b/sky/provision/common.py index 10384ab2f50..8d5b0dac60e 100644 --- a/sky/provision/common.py +++ b/sky/provision/common.py @@ -124,16 +124,16 @@ def ip_tuples(self) -> List[Tuple[str, Optional[str]]]: Returns: A list of tuples (internal_ip, external_ip) of all instances. """ - head_node = self.get_head_instance() - if head_node is None: - head_node_ip = [] + head_instance = self.get_head_instance() + if head_instance is None: + head_instance_ip = [] else: - head_node_ip = [(head_node.internal_ip, head_node.external_ip)] + head_instance_ip = [(head_instance.internal_ip, head_instance.external_ip)] other_ips = [] for instance in self.get_worker_instances(): pair = (instance.internal_ip, instance.external_ip) other_ips.append(pair) - return head_node_ip + other_ips + return head_instance_ip + other_ips def has_external_ips(self) -> bool: """True if the cluster has external IP.""" @@ -169,9 +169,9 @@ def get_feasible_ips(self, force_internal_ips: bool = False) -> List[str]: def get_ssh_ports(self) -> List[int]: """Get the SSH port of all the instances.""" - head_node = self.get_head_instance() - head_node_port = [head_node.ssh_port] + head_instance = self.get_head_instance() + head_instance_port = [head_instance.ssh_port] - worker_nodes = self.get_worker_instances() - worker_node_ports = [instance.ssh_port for instance in worker_nodes] - return head_node_port + worker_node_ports + worker_instances = self.get_worker_instances() + worker_instance_ports = [instance.ssh_port for instance in worker_instances] + return head_instance_port + worker_instance_ports From 8e8501f3abee070d240d1728a262ca78205a9761 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Tue, 2 Jan 2024 11:20:59 +0000 Subject: [PATCH 66/84] format --- sky/clouds/runpod.py | 1 - sky/provision/common.py | 8 ++++++-- sky/provision/instance_setup.py | 8 ++++++-- sky/provision/runpod/instance.py | 18 ++++++++++-------- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py index 82ba4762a5a..84f64ce347c 100644 --- a/sky/clouds/runpod.py +++ b/sky/clouds/runpod.py @@ -23,7 +23,6 @@ class RunPod(clouds.Cloud): """ _REPR = 'RunPod' _CLOUD_UNSUPPORTED_FEATURES = { - clouds.CloudImplementationFeatures.AUTOSTOP: 'Stopping not supported.', clouds.CloudImplementationFeatures.STOP: 'Stopping not supported.', clouds.CloudImplementationFeatures.SPOT_INSTANCE: ('Spot is not supported, as runpod API does not implement spot .'), diff --git a/sky/provision/common.py b/sky/provision/common.py index 8d5b0dac60e..66d214516d6 100644 --- a/sky/provision/common.py +++ b/sky/provision/common.py @@ -128,7 +128,8 @@ def ip_tuples(self) -> List[Tuple[str, Optional[str]]]: if head_instance is None: head_instance_ip = [] else: - head_instance_ip = [(head_instance.internal_ip, head_instance.external_ip)] + head_instance_ip = [(head_instance.internal_ip, + head_instance.external_ip)] other_ips = [] for instance in self.get_worker_instances(): pair = (instance.internal_ip, instance.external_ip) @@ -170,8 +171,11 @@ def get_feasible_ips(self, force_internal_ips: bool = False) -> List[str]: def get_ssh_ports(self) -> List[int]: """Get the SSH port of all the instances.""" head_instance = self.get_head_instance() + assert head_instance is not None, self head_instance_port = [head_instance.ssh_port] worker_instances = self.get_worker_instances() - worker_instance_ports = [instance.ssh_port for instance in worker_instances] + worker_instance_ports = [ + instance.ssh_port for instance in worker_instances + ] return head_instance_port + worker_instance_ports diff --git a/sky/provision/instance_setup.py b/sky/provision/instance_setup.py index 071a63dbde3..9f84ada199e 100644 --- a/sky/provision/instance_setup.py +++ b/sky/provision/instance_setup.py @@ -106,7 +106,9 @@ def _parallel_ssh_with_cache(func, cluster_name: str, stage_name: str, for i, metadata in enumerate(metadatas): cache_id = f'{instance_id}-{i}' runner = command_runner.SSHCommandRunner( - metadata.get_feasible_ip(), port=metadata.ssh_port, **ssh_credentials) + metadata.get_feasible_ip(), + port=metadata.ssh_port, + **ssh_credentials) wrapper = metadata_utils.cache_func(cluster_name, cache_id, stage_name, digest) if (cluster_info.head_instance_id == instance_id and i == 0): @@ -255,7 +257,9 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, _hint_worker_log_path(cluster_name, cluster_info, 'ray_cluster') ip_list = cluster_info.get_feasible_ips() ssh_runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list[1:], port_list=cluster_info.get_ssh_ports()[1:], **ssh_credentials) + ip_list[1:], + port_list=cluster_info.get_ssh_ports()[1:], + **ssh_credentials) worker_instances = cluster_info.get_worker_instances() cache_ids = [] prev_instance_id = None diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py index f009faeb9b9..b11f2e5a91e 100644 --- a/sky/provision/runpod/instance.py +++ b/sky/provision/runpod/instance.py @@ -149,16 +149,18 @@ def get_cluster_info( provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: del region, provider_config nodes = _filter_instances(cluster_name_on_cloud, ['RUNNING']) - instances: Dict[str, common.InstanceInfo] = {} + instances: Dict[str, List[common.InstanceInfo]] = {} head_instance_id = None for node_id, node_info in nodes.items(): - instances[node_id] = common.InstanceInfo( - instance_id=node_id, - internal_ip=node_info['internal_ip'], - external_ip=node_info['external_ip'], - ssh_port=node_info['ssh_port'], - tags={}, - ) + instances[node_id] = [ + common.InstanceInfo( + instance_id=node_id, + internal_ip=node_info['internal_ip'], + external_ip=node_info['external_ip'], + ssh_port=node_info['ssh_port'], + tags={}, + ) + ] if node_info['name'].endswith('-head'): head_instance_id = node_id From 1e92cfde25a02c483fb5542c111f15ed9c39f34b Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Tue, 2 Jan 2024 11:33:20 +0000 Subject: [PATCH 67/84] Fix cloud unsupported resources --- sky/clouds/runpod.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py index 84f64ce347c..0e5ccdde71a 100644 --- a/sky/clouds/runpod.py +++ b/sky/clouds/runpod.py @@ -37,8 +37,19 @@ class RunPod(clouds.Cloud): STATUS_VERSION = clouds.StatusVersion.SKYPILOT @classmethod - def _cloud_unsupported_features( - cls) -> Dict[clouds.CloudImplementationFeatures, str]: + def _unsupported_features_for_resources( + cls, resources: 'resources_lib.Resources' + ) -> Dict[clouds.CloudImplementationFeatures, str]: + """The features not supported based on the resources provided. + + This method is used by check_features_are_supported() to check if the + cloud implementation supports all the requested features. + + Returns: + A dict of {feature: reason} for the features not supported by the + cloud implementation. + """ + del resources # unused return cls._CLOUD_UNSUPPORTED_FEATURES @classmethod From fa07c7252cc5b95e61d8f139392f938d5a1cc3ae Mon Sep 17 00:00:00 2001 From: Doyoung Kim <34902420+landscapepainter@users.noreply.github.com> Date: Fri, 5 Jan 2024 10:08:26 -0800 Subject: [PATCH 68/84] Runpod update name mapping (#2945) --- sky/clouds/service_catalog/__init__.py | 2 +- sky/provision/runpod/utils.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index 524a8fdf6ce..01b9bc7ff56 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -15,7 +15,7 @@ CloudFilter = Optional[Union[List[str], str]] ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci', - 'kubernetes') + 'kubernetes', 'runpod') def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs): diff --git a/sky/provision/runpod/utils.py b/sky/provision/runpod/utils.py index decc3ac2389..93700927339 100644 --- a/sky/provision/runpod/utils.py +++ b/sky/provision/runpod/utils.py @@ -16,7 +16,7 @@ GPU_NAME_MAP = { 'A100-80GB': 'NVIDIA A100 80GB PCIe', 'A100-40GB': 'NVIDIA A100-PCIE-40GB', - 'A100-80GB-SXM4': 'NVIDIA A100-SXM4-80GB', + 'A100-80GB-SXM': 'NVIDIA A100-SXM4-80GB', 'A30': 'NVIDIA A30', 'A40': 'NVIDIA A40', 'RTX3070': 'NVIDIA GeForce RTX 3070', @@ -27,11 +27,14 @@ 'RTX4070Ti': 'NVIDIA GeForce RTX 4070 Ti', 'RTX4080': 'NVIDIA GeForce RTX 4080', 'RTX4090': 'NVIDIA GeForce RTX 4090', - 'H100-80GB-HBM3': 'NVIDIA H100 80GB HBM3', - 'H100-PCIe': 'NVIDIA H100 PCIe', + # Following instance is displayed as SXM at the console + # but the ID from the API appears as HBM + 'H100-SXM': 'NVIDIA H100 80GB HBM3', + 'H100': 'NVIDIA H100 PCIe', 'L4': 'NVIDIA L4', 'L40': 'NVIDIA L40', 'RTX4000-Ada-SFF': 'NVIDIA RTX 4000 SFF Ada Generation', + 'RTX4000-Ada': 'NVIDIA RTX 4000 Ada Generation', 'RTX6000-Ada': 'NVIDIA RTX 6000 Ada Generation', 'RTXA4000': 'NVIDIA RTX A4000', 'RTXA4500': 'NVIDIA RTX A4500', From 3489df51bbb10852a26278c6dedbe296952240ba Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 7 Jan 2024 08:36:20 +0000 Subject: [PATCH 69/84] Avoid using GpuInfo --- sky/clouds/service_catalog/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py index a84ad9f55c7..afa58730d4c 100644 --- a/sky/clouds/service_catalog/common.py +++ b/sky/clouds/service_catalog/common.py @@ -463,7 +463,7 @@ def list_accelerators_impl( instance types offered by this cloud. """ if gpus_only: - df = df[~df['GpuInfo'].isna()] + df = df[~df['AcceleratorName'].isna()] df = df.copy() # avoid column assignment warning try: From 18591010eeb4ade8256179c4f7740df57272f27f Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 7 Jan 2024 17:31:10 +0000 Subject: [PATCH 70/84] fix all_regions --- sky/clouds/service_catalog/runpod_catalog.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sky/clouds/service_catalog/runpod_catalog.py b/sky/clouds/service_catalog/runpod_catalog.py index f8e7b926e92..122b4c3b2d3 100644 --- a/sky/clouds/service_catalog/runpod_catalog.py +++ b/sky/clouds/service_catalog/runpod_catalog.py @@ -100,13 +100,15 @@ def get_region_zones_for_instance_type(instance_type: str, def list_accelerators( - gpus_only: bool, - name_filter: Optional[str], - region_filter: Optional[str], - quantity_filter: Optional[int], - case_sensitive: bool = True + gpus_only: bool, + name_filter: Optional[str], + region_filter: Optional[str], + quantity_filter: Optional[int], + case_sensitive: bool = True, + all_regions: bool = False, ) -> Dict[str, List[common.InstanceTypeInfo]]: """Returns all instance types in RunPod offering GPUs.""" return common.list_accelerators_impl('RunPodCloud', _df, gpus_only, name_filter, region_filter, - quantity_filter, case_sensitive) + quantity_filter, case_sensitive, + all_regions) From ed955df1052f1443d1e78fda313889e5386f078f Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 7 Jan 2024 17:48:18 +0000 Subject: [PATCH 71/84] Fix runpod list accelerators --- sky/clouds/service_catalog/runpod_catalog.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/clouds/service_catalog/runpod_catalog.py b/sky/clouds/service_catalog/runpod_catalog.py index 122b4c3b2d3..90b64c3d1b0 100644 --- a/sky/clouds/service_catalog/runpod_catalog.py +++ b/sky/clouds/service_catalog/runpod_catalog.py @@ -108,7 +108,7 @@ def list_accelerators( all_regions: bool = False, ) -> Dict[str, List[common.InstanceTypeInfo]]: """Returns all instance types in RunPod offering GPUs.""" - return common.list_accelerators_impl('RunPodCloud', _df, gpus_only, + return common.list_accelerators_impl('RunPod', _df, gpus_only, name_filter, region_filter, quantity_filter, case_sensitive, all_regions) From 045bab6364d4e8cfeb4d053a1f700c7745df683c Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sun, 7 Jan 2024 17:50:16 +0000 Subject: [PATCH 72/84] format --- sky/clouds/service_catalog/runpod_catalog.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sky/clouds/service_catalog/runpod_catalog.py b/sky/clouds/service_catalog/runpod_catalog.py index 90b64c3d1b0..811ea9a1eab 100644 --- a/sky/clouds/service_catalog/runpod_catalog.py +++ b/sky/clouds/service_catalog/runpod_catalog.py @@ -108,7 +108,6 @@ def list_accelerators( all_regions: bool = False, ) -> Dict[str, List[common.InstanceTypeInfo]]: """Returns all instance types in RunPod offering GPUs.""" - return common.list_accelerators_impl('RunPod', _df, gpus_only, - name_filter, region_filter, - quantity_filter, case_sensitive, - all_regions) + return common.list_accelerators_impl('RunPod', _df, gpus_only, name_filter, + region_filter, quantity_filter, + case_sensitive, all_regions) From 9630832fd64c087087bada4eecc0c97fecff3f5e Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 8 Jan 2024 01:21:44 +0000 Subject: [PATCH 73/84] revert to GpuInfo --- sky/clouds/service_catalog/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py index afa58730d4c..a84ad9f55c7 100644 --- a/sky/clouds/service_catalog/common.py +++ b/sky/clouds/service_catalog/common.py @@ -463,7 +463,7 @@ def list_accelerators_impl( instance types offered by this cloud. """ if gpus_only: - df = df[~df['AcceleratorName'].isna()] + df = df[~df['GpuInfo'].isna()] df = df.copy() # avoid column assignment warning try: From 97165a467d8447011311264770cd8b2897fbd26f Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 8 Jan 2024 20:31:03 +0000 Subject: [PATCH 74/84] Fix get_feasible_launchable_resources --- sky/clouds/runpod.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py index 0e5ccdde71a..0f159e181b5 100644 --- a/sky/clouds/runpod.py +++ b/sky/clouds/runpod.py @@ -199,7 +199,9 @@ def _make(instance_list): if accelerators is None: # Return a default instance type default_instance_type = RunPod.get_default_instance_type( - cpus=resources.cpus) + cpus=resources.cpus, + memory=resources.memory, + disk_tier=resources.disk_tier) if default_instance_type is None: return ([], []) else: From f527545e00f6f85fd41267b952cd9ac0ca445630 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 8 Jan 2024 20:55:46 +0000 Subject: [PATCH 75/84] Add error --- sky/clouds/service_catalog/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py index a84ad9f55c7..d72abb634ee 100644 --- a/sky/clouds/service_catalog/common.py +++ b/sky/clouds/service_catalog/common.py @@ -470,7 +470,7 @@ def list_accelerators_impl( gpu_info_df = df['GpuInfo'].apply(ast.literal_eval) df['DeviceMemoryGiB'] = gpu_info_df.apply( lambda row: row['Gpus'][0]['MemoryInfo']['SizeInMiB']) / 1024.0 - except ValueError: + except (ValueError, SyntaxError): # TODO(zongheng,woosuk): GCP/Azure catalogs do not have well-formed # GpuInfo fields. So the above will throw: # ValueError: malformed node or string: <_ast.Name object at ..> From 680beca99df6b33543bcd4edcb225d25b61b5392 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 8 Jan 2024 22:36:47 +0000 Subject: [PATCH 76/84] Fix optimizer random_dag for feature check --- tests/test_optimizer_random_dag.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/test_optimizer_random_dag.py b/tests/test_optimizer_random_dag.py index 5c828ba3130..6681f061d6f 100644 --- a/tests/test_optimizer_random_dag.py +++ b/tests/test_optimizer_random_dag.py @@ -1,11 +1,13 @@ import copy import random +import sys import numpy as np import pandas as pd import sky from sky import clouds +from sky import exceptions from sky.clouds import service_catalog ALL_INSTANCE_TYPE_INFOS = sum( @@ -57,8 +59,8 @@ def generate_random_dag( op.set_outputs('CLOUD', random.randint(0, max_data_size)) num_candidates = random.randint(1, max_num_candidate_resources) - candidate_instance_types = random.choices(ALL_INSTANCE_TYPE_INFOS, - k=num_candidates) + candidate_instance_types = random.choices( + ALL_INSTANCE_TYPE_INFOS, k=len(ALL_INSTANCE_TYPE_INFOS)) candidate_resources = set() for candidate in candidate_instance_types: @@ -80,7 +82,18 @@ def generate_random_dag( accelerators={ candidate.accelerator_name: candidate.accelerator_count }) + requested_features = set() + if op.num_nodes > 1: + requested_features.add( + clouds.CloudImplementationFeatures.MULTI_NODE) + try: + resources.cloud.check_features_are_supported( + resources, requested_features) + except exceptions.NotSupportedError: + continue candidate_resources.add(resources) + if len(candidate_resources) >= num_candidates: + break op.set_resources(candidate_resources) return dag @@ -121,7 +134,7 @@ def _optimize_by_brute_force(tasks, plan): resources_stack.pop() _optimize_by_brute_force(topo_order, {}) - print(final_plan) + print(final_plan, file=sys.stderr) return min_objective @@ -140,6 +153,9 @@ def compare_optimization_results(dag: sky.Dag, minimize_cost: bool): objective = sky.Optimizer._compute_total_time(dag.get_graph(), dag.tasks, optimizer_plan) + print('=== optimizer plan ===', file=sys.stderr) + print(optimizer_plan, file=sys.stderr) + print('=== brute force ===', file=sys.stderr) min_objective = find_min_objective(copy_dag, minimize_cost) assert abs(objective - min_objective) < 5e-2 From 8193931c67568c736164b195f9c29638c9d81eca Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 13 Jan 2024 01:11:04 +0000 Subject: [PATCH 77/84] address comments --- sky/clouds/runpod.py | 21 ++++--- sky/clouds/service_catalog/runpod_catalog.py | 2 +- sky/provision/common.py | 1 + sky/provision/provisioner.py | 1 + sky/provision/runpod/instance.py | 31 ++++++---- sky/provision/runpod/utils.py | 65 ++++---------------- sky/templates/runpod-ray.yml.j2 | 2 +- 7 files changed, 45 insertions(+), 78 deletions(-) diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py index 0f159e181b5..c362ab39d04 100644 --- a/sky/clouds/runpod.py +++ b/sky/clouds/runpod.py @@ -25,10 +25,15 @@ class RunPod(clouds.Cloud): _CLOUD_UNSUPPORTED_FEATURES = { clouds.CloudImplementationFeatures.STOP: 'Stopping not supported.', clouds.CloudImplementationFeatures.SPOT_INSTANCE: - ('Spot is not supported, as runpod API does not implement spot .'), + ('Spot is not supported, as runpod API does not implement spot.'), clouds.CloudImplementationFeatures.MULTI_NODE: ('Multi-node not supported yet, as the interconnection among nodes ' 'are non-trivial on RunPod.'), + clouds.CloudImplementationFeatures.OPEN_PORTS: + ('Opening ports is not ' + 'supported yet on RunPod.'), + clouds.CloudImplementationFeatures.CUSTOM_DISK_TIER: + ('Customizing disk tier is not supported yet on RunPod.') } _MAX_CLUSTER_NAME_LEN_LIMIT = 120 _regions: List[clouds.Region] = [] @@ -124,9 +129,6 @@ def accelerators_to_hourly_cost(self, def get_egress_cost(self, num_gigabytes: float) -> float: return 0.0 - def __repr__(self): - return 'RunPod' - def is_same_cloud(self, other: clouds.Cloud) -> bool: # Returns true if the two clouds are the same cloud type. return isinstance(other, RunPod) @@ -157,7 +159,7 @@ def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', cluster_name_on_cloud: str, region: 'clouds.Region', zones: Optional[List['clouds.Zone']]) -> Dict[str, Optional[str]]: - del zones + del zones # unused r = resources acc_dict = self.get_accelerators_from_instance_type(r.instance_type) @@ -173,7 +175,8 @@ def make_deploy_resources_variables( } def _get_feasible_launchable_resources( - self, resources: 'resources_lib.Resources'): + self, resources: 'resources_lib.Resources' + ) -> Tuple[List['resources_lib.Resources'], List[str]]: """Returns a list of feasible resources for the given resources.""" if resources.use_spot: return ([], []) @@ -241,10 +244,8 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]: return True, None except ImportError: - return False, ( - 'Failed to import runpod.' - 'To install, run: "pip install runpod" or "pip install sky[runpod]"' # pylint: disable=line-too-long - ) + return False, ('Failed to import runpod.' + 'To install, run: pip install skypilot[runpod]') def get_credential_file_mounts(self) -> Dict[str, str]: return { diff --git a/sky/clouds/service_catalog/runpod_catalog.py b/sky/clouds/service_catalog/runpod_catalog.py index 811ea9a1eab..bb23b85c832 100644 --- a/sky/clouds/service_catalog/runpod_catalog.py +++ b/sky/clouds/service_catalog/runpod_catalog.py @@ -1,7 +1,7 @@ """ RunPod | Catalog This module loads the service catalog file and can be used to -quarry instance types and pricing information for RunPod. +query instance types and pricing information for RunPod. """ import typing diff --git a/sky/provision/common.py b/sky/provision/common.py index db35848c31f..6074549883b 100644 --- a/sky/provision/common.py +++ b/sky/provision/common.py @@ -184,6 +184,7 @@ def get_ssh_ports(self) -> List[int]: ] return head_instance_port + worker_instance_ports + class Endpoint: """Base class for endpoints.""" pass diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index 6afbfafb063..5a662ac1076 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -159,6 +159,7 @@ def bulk_provision( logger.debug( 'Provision config:\n' f'{json.dumps(dataclasses.asdict(bootstrap_config), indent=2)}') + raise RuntimeError('test') return _bulk_provision(cloud, region, zones, cluster_name, bootstrap_config) except Exception: # pylint: disable=broad-except diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py index b11f2e5a91e..f7ac263eae8 100644 --- a/sky/provision/runpod/instance.py +++ b/sky/provision/runpod/instance.py @@ -1,4 +1,4 @@ -"""GCP instance provisioning.""" +"""RunPod instance provisioning.""" import time from typing import Any, Dict, List, Optional @@ -6,6 +6,7 @@ from sky import status_lib from sky.provision import common from sky.provision.runpod import utils +from sky.utils import common_utils POLL_INTERVAL = 5 @@ -20,14 +21,14 @@ def _filter_instances(cluster_name_on_cloud: str, f'{cluster_name_on_cloud}-head', f'{cluster_name_on_cloud}-worker' ] - filtered_nodes = {} + filtered_instances = {} for instance_id, instance in instances.items(): if (status_filters is not None and instance['status'] not in status_filters): continue if instance.get('name') in possible_names: - filtered_nodes[instance_id] = instance - return filtered_nodes + filtered_instances[instance_id] = instance + return filtered_instances def _get_head_instance_id(instances: Dict[str, Any]) -> Optional[str]: @@ -97,11 +98,11 @@ def run_instances(region: str, cluster_name_on_cloud: str, for instance_id, instance in instances.items(): if instance.get('ssh_port') is not None: ready_instance_cnt += 1 + logger.info('Waiting for instances to be ready: ' + f'({ready_instance_cnt}/{config.count}).') if ready_instance_cnt == config.count: break - logger.info('Waiting for instances to be ready ' - f'({len(instances)}/{config.count}).') time.sleep(POLL_INTERVAL) assert head_instance_id is not None, 'head_instance_id should not be None' return common.ProvisionRecord(provider_name='runpod', @@ -132,26 +133,30 @@ def terminate_instances( worker_only: bool = False, ) -> None: """See sky/provision/__init__.py""" - del provider_config + del provider_config # unused instances = _filter_instances(cluster_name_on_cloud, None) for inst_id, inst in instances.items(): - logger.info(f'Terminating instance {inst_id}.' - f'{inst}') + logger.info(f'Terminating instance {inst_id}: {inst}') if worker_only and inst['name'].endswith('-head'): continue logger.info(f'Start {inst_id}: {inst}') - utils.remove(inst_id) + try: + utils.remove(inst_id) + except Exception as e: + raise RuntimeError( + f'Failed to terminate instance {inst_id}: ' + f'{common_utils.format_exception(e, use_bracket=False)}') def get_cluster_info( region: str, cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: - del region, provider_config - nodes = _filter_instances(cluster_name_on_cloud, ['RUNNING']) + del region, provider_config # unused + running_instances = _filter_instances(cluster_name_on_cloud, ['RUNNING']) instances: Dict[str, List[common.InstanceInfo]] = {} head_instance_id = None - for node_id, node_info in nodes.items(): + for node_id, node_info in running_instances.items(): instances[node_id] = [ common.InstanceInfo( instance_id=node_id, diff --git a/sky/provision/runpod/utils.py b/sky/provision/runpod/utils.py index 93700927339..cafcab08976 100644 --- a/sky/provision/runpod/utils.py +++ b/sky/provision/runpod/utils.py @@ -1,10 +1,7 @@ """RunPod library wrapper for SkyPilot.""" -import json -import os -from pathlib import Path import time -from typing import Dict, Optional +from typing import Dict from sky import sky_logging from sky.adaptors import runpod @@ -67,65 +64,32 @@ def wrapper(*args, **kwargs): return wrapper -def get_set_tags(instance_id: str, new_tags: Optional[Dict]) -> Dict: - """Gets the tags for the given instance. - - Creates the tag file if it doesn't exist. - - Returns the tags for the given instance. - - If tags are provided, sets the tags for the given instance. - """ - tag_file_path = os.path.expanduser('~/.runpod/skypilot_tags.json') - - # Ensure the tag file exists, create it if it doesn't. - if not os.path.exists(tag_file_path): - Path(os.path.dirname(tag_file_path)).mkdir(parents=True, exist_ok=True) - with open(tag_file_path, 'w', encoding='UTF-8') as tag_file: - json.dump({}, tag_file, indent=4) - - # Read existing tags - with open(tag_file_path, 'r', encoding='UTF-8') as tag_file: - tags = json.load(tag_file) - - if tags is None: - tags = {} - - # If new_tags is provided, update the tags for the instance - if new_tags: - instance_tags = tags.get(instance_id, {}) - instance_tags.update(new_tags) - tags[instance_id] = instance_tags - with open(tag_file_path, 'w', encoding='UTF-8') as tag_file: - json.dump(tags, tag_file, indent=4) - - return tags.get(instance_id, {}) - - -def list_instances(): +def list_instances() -> Dict[str, dict]: """Lists instances associated with API key.""" instances = runpod.runpod().get_pods() - instance_list = {} + instance_dict: Dict[str, dict] = {} for instance in instances: - instance_list[instance['id']] = {} + info = {} - instance_list[instance['id']]['status'] = instance['desiredStatus'] - instance_list[instance['id']]['name'] = instance['name'] + info['status'] = instance['desiredStatus'] + info['name'] = instance['name'] if instance['desiredStatus'] == 'RUNNING' and instance.get('runtime'): for port in instance['runtime']['ports']: if port['privatePort'] == 22 and port['isIpPublic']: - instance_list[instance['id']]['external_ip'] = port['ip'] - instance_list[ + info['external_ip'] = port['ip'] + instance_dict[ instance['id']]['ssh_port'] = port['publicPort'] elif not port['isIpPublic']: - instance_list[instance['id']]['internal_ip'] = port['ip'] + info['internal_ip'] = port['ip'] - instance_list[instance['id']]['tags'] = get_set_tags( - instance['id'], None) + instance_dict[instance['id']] = info - return instance_list + return instance_dict -def launch(name: str, instance_type: str, region: str, disk_size: int): +def launch(name: str, instance_type: str, region: str, disk_size: int) -> str: """Launches an instance with the given parameters. Converts the instance_type to the RunPod GPU name, finds the specs for the @@ -155,11 +119,6 @@ def launch(name: str, instance_type: str, region: str, disk_size: int): return new_instance['id'] -def set_tags(instance_id: str, tags: Dict): - """Sets the tags for the given instance.""" - get_set_tags(instance_id, tags) - - def remove(instance_id: str): """Terminates the given instance.""" runpod.runpod().terminate_pod(instance_id) diff --git a/sky/templates/runpod-ray.yml.j2 b/sky/templates/runpod-ray.yml.j2 index 182834dd979..a8350ae265a 100644 --- a/sky/templates/runpod-ray.yml.j2 +++ b/sky/templates/runpod-ray.yml.j2 @@ -7,7 +7,7 @@ idle_timeout_minutes: 60 provider: type: external - module: sky.skylet.providers.runpod.RunPodNodeProvider + module: sky.provision.runpod region: "{{region}}" disable_launch_config_check: true From e5631f315de8eb93fcff359a0941466fc0b2e227 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 13 Jan 2024 01:12:23 +0000 Subject: [PATCH 78/84] remove test code --- sky/provision/provisioner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index 5a662ac1076..6afbfafb063 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -159,7 +159,6 @@ def bulk_provision( logger.debug( 'Provision config:\n' f'{json.dumps(dataclasses.asdict(bootstrap_config), indent=2)}') - raise RuntimeError('test') return _bulk_provision(cloud, region, zones, cluster_name, bootstrap_config) except Exception: # pylint: disable=broad-except From 7399d53a0b6c066155a97c5342d4bfc5b82ce746 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 13 Jan 2024 01:16:17 +0000 Subject: [PATCH 79/84] format --- sky/provision/runpod/instance.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py index f7ac263eae8..9608b8a1941 100644 --- a/sky/provision/runpod/instance.py +++ b/sky/provision/runpod/instance.py @@ -7,6 +7,7 @@ from sky.provision import common from sky.provision.runpod import utils from sky.utils import common_utils +from sky.utils import ux_utils POLL_INTERVAL = 5 @@ -143,9 +144,11 @@ def terminate_instances( try: utils.remove(inst_id) except Exception as e: - raise RuntimeError( - f'Failed to terminate instance {inst_id}: ' - f'{common_utils.format_exception(e, use_bracket=False)}') + with ux_utils.print_exception_no_traceback(): + raise RuntimeError( + f'Failed to terminate instance {inst_id}: ' + f'{common_utils.format_exception(e, use_bracket=False)}' + ) from e def get_cluster_info( From 9342d2039bd73994143154731be004b3907f3958 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 13 Jan 2024 01:20:41 +0000 Subject: [PATCH 80/84] Add type hints --- sky/provision/runpod/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sky/provision/runpod/utils.py b/sky/provision/runpod/utils.py index cafcab08976..910a6624527 100644 --- a/sky/provision/runpod/utils.py +++ b/sky/provision/runpod/utils.py @@ -1,7 +1,7 @@ """RunPod library wrapper for SkyPilot.""" import time -from typing import Dict +from typing import Dict, List from sky import sky_logging from sky.adaptors import runpod @@ -119,12 +119,12 @@ def launch(name: str, instance_type: str, region: str, disk_size: int) -> str: return new_instance['id'] -def remove(instance_id: str): +def remove(instance_id: str) -> None: """Terminates the given instance.""" runpod.runpod().terminate_pod(instance_id) -def get_ssh_ports(cluster_name): +def get_ssh_ports(cluster_name) -> List[int]: """Gets the SSH ports for the given cluster.""" logger.debug(f'Getting SSH ports for cluster {cluster_name}.') From 19553807de8ef426a499266a78d9f4fdc3841407 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 13 Jan 2024 01:21:03 +0000 Subject: [PATCH 81/84] format --- sky/provision/runpod/instance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py index 9608b8a1941..b893e73baf4 100644 --- a/sky/provision/runpod/instance.py +++ b/sky/provision/runpod/instance.py @@ -143,7 +143,7 @@ def terminate_instances( logger.info(f'Start {inst_id}: {inst}') try: utils.remove(inst_id) - except Exception as e: + except Exception as e: # pylint: disable=broad-except with ux_utils.print_exception_no_traceback(): raise RuntimeError( f'Failed to terminate instance {inst_id}: ' From 8b48f32db11e0da8182e6b8c963b309e97c73ab1 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 13 Jan 2024 01:26:03 +0000 Subject: [PATCH 82/84] format --- sky/provision/runpod/instance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py index b893e73baf4..c51d8b9ab3f 100644 --- a/sky/provision/runpod/instance.py +++ b/sky/provision/runpod/instance.py @@ -143,7 +143,7 @@ def terminate_instances( logger.info(f'Start {inst_id}: {inst}') try: utils.remove(inst_id) - except Exception as e: # pylint: disable=broad-except + except Exception as e: # pylint: disable=broad-except with ux_utils.print_exception_no_traceback(): raise RuntimeError( f'Failed to terminate instance {inst_id}: ' From 07498d92990f4caac635a4014ff4e257c2d63b4f Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 13 Jan 2024 01:35:24 +0000 Subject: [PATCH 83/84] fix keyerror --- sky/provision/runpod/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sky/provision/runpod/utils.py b/sky/provision/runpod/utils.py index 910a6624527..3040b0e329b 100644 --- a/sky/provision/runpod/utils.py +++ b/sky/provision/runpod/utils.py @@ -79,8 +79,7 @@ def list_instances() -> Dict[str, dict]: for port in instance['runtime']['ports']: if port['privatePort'] == 22 and port['isIpPublic']: info['external_ip'] = port['ip'] - instance_dict[ - instance['id']]['ssh_port'] = port['publicPort'] + info['ssh_port'] = port['publicPort'] elif not port['isIpPublic']: info['internal_ip'] = port['ip'] From b11446e3c37fc6b8632efced49d628f638383136 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Sat, 13 Jan 2024 03:56:27 +0000 Subject: [PATCH 84/84] Address comments --- sky/clouds/runpod.py | 4 +--- sky/provision/runpod/instance.py | 16 +++++++------- sky/provision/runpod/utils.py | 6 +++--- sky/templates/runpod-ray.yml.j2 | 36 ++------------------------------ 4 files changed, 14 insertions(+), 48 deletions(-) diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py index c362ab39d04..f07bd18ee46 100644 --- a/sky/clouds/runpod.py +++ b/sky/clouds/runpod.py @@ -178,8 +178,6 @@ def _get_feasible_launchable_resources( self, resources: 'resources_lib.Resources' ) -> Tuple[List['resources_lib.Resources'], List[str]]: """Returns a list of feasible resources for the given resources.""" - if resources.use_spot: - return ([], []) if resources.instance_type is not None: assert resources.is_launchable(), resources resources = resources.copy(accelerators=None) @@ -244,7 +242,7 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]: return True, None except ImportError: - return False, ('Failed to import runpod.' + return False, ('Failed to import runpod. ' 'To install, run: pip install skypilot[runpod]') def get_credential_file_mounts(self) -> Dict[str, str]: diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py index c51d8b9ab3f..9f3a1d92886 100644 --- a/sky/provision/runpod/instance.py +++ b/sky/provision/runpod/instance.py @@ -159,18 +159,18 @@ def get_cluster_info( running_instances = _filter_instances(cluster_name_on_cloud, ['RUNNING']) instances: Dict[str, List[common.InstanceInfo]] = {} head_instance_id = None - for node_id, node_info in running_instances.items(): - instances[node_id] = [ + for instance_id, instance_info in running_instances.items(): + instances[instance_id] = [ common.InstanceInfo( - instance_id=node_id, - internal_ip=node_info['internal_ip'], - external_ip=node_info['external_ip'], - ssh_port=node_info['ssh_port'], + instance_id=instance_id, + internal_ip=instance_info['internal_ip'], + external_ip=instance_info['external_ip'], + ssh_port=instance_info['ssh_port'], tags={}, ) ] - if node_info['name'].endswith('-head'): - head_instance_id = node_id + if instance_info['name'].endswith('-head'): + head_instance_id = instance_id return common.ClusterInfo( instances=instances, diff --git a/sky/provision/runpod/utils.py b/sky/provision/runpod/utils.py index 3040b0e329b..00b24aee0a8 100644 --- a/sky/provision/runpod/utils.py +++ b/sky/provision/runpod/utils.py @@ -1,7 +1,7 @@ """RunPod library wrapper for SkyPilot.""" import time -from typing import Dict, List +from typing import Any, Dict, List from sky import sky_logging from sky.adaptors import runpod @@ -64,11 +64,11 @@ def wrapper(*args, **kwargs): return wrapper -def list_instances() -> Dict[str, dict]: +def list_instances() -> Dict[str, Dict[str, Any]]: """Lists instances associated with API key.""" instances = runpod.runpod().get_pods() - instance_dict: Dict[str, dict] = {} + instance_dict: Dict[str, Dict[str, Any]] = {} for instance in instances: info = {} diff --git a/sky/templates/runpod-ray.yml.j2 b/sky/templates/runpod-ray.yml.j2 index a8350ae265a..fa3598e429e 100644 --- a/sky/templates/runpod-ray.yml.j2 +++ b/sky/templates/runpod-ray.yml.j2 @@ -72,37 +72,5 @@ setup_commands: python3 -c "from sky.skylet.ray_patches import patch; patch()" || exit 1; [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); -# Command to start ray on the head node. You don't need to change this. -# NOTE: these are very performance-sensitive. Each new item opens/closes an SSH -# connection, which is expensive. Try your best to co-locate commands into fewer -# items! The same comment applies for worker_start_ray_commands. -# -# Increment the following for catching performance bugs easier: -# current num items (num SSH connections): 1 -head_start_ray_commands: - # NOTE: --disable-usage-stats in `ray start` saves 10 seconds of idle wait. - # Line "which prlimit ..": increase the limit of the number of open files for the raylet process, as the `ulimit` may not take effect at this point, because it requires - # all the sessions to be reloaded. This is a workaround. - - export SKYPILOT_NUM_GPUS=0 && which nvidia-smi > /dev/null && SKYPILOT_NUM_GPUS=$(nvidia-smi --query-gpu=index,name --format=csv,noheader | wc -l); - ray stop; RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 ray start --disable-usage-stats --head --port={{ray_port}} --dashboard-port={{ray_dashboard_port}} --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml {{"--resources='%s'" % custom_resources if custom_resources}} --num-gpus=$SKYPILOT_NUM_GPUS --temp-dir {{ray_temp_dir}} || exit 1; - which prlimit && for id in $(pgrep -f raylet/raylet); do sudo prlimit --nofile=1048576:1048576 --pid=$id || true; done; - {{dump_port_command}}; - -# Worker commands are needed for TPU VM Pods -{%- if num_nodes > 1 or tpu_vm %} -worker_start_ray_commands: - - SKYPILOT_NUM_GPUS=0 && which nvidia-smi > /dev/null && SKYPILOT_NUM_GPUS=$(nvidia-smi --query-gpu=index,name --format=csv,noheader | wc -l); - ray stop; RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 ray start --disable-usage-stats --address=$RAY_HEAD_IP:{{ray_port}} --object-manager-port=8076 {{"--resources='%s'" % custom_resources if custom_resources}} --num-gpus=$SKYPILOT_NUM_GPUS --temp-dir {{ray_temp_dir}} || exit 1; - which prlimit && for id in $(pgrep -f raylet/raylet); do sudo prlimit --nofile=1048576:1048576 --pid=$id || true; done; -{%- else %} -worker_start_ray_commands: [] -{%- endif %} - -head_node: {} -worker_nodes: {} - -# These fields are required for external cloud providers. -head_setup_commands: [] -worker_setup_commands: [] -cluster_synced_files: [] -file_mounts_sync_continuously: False +# Command to start ray clusters are now placed in `sky.provision.instance_setup`. +# We do not need to list it here anymore.