From 2ac6aa1b9949e09b354f8655f4b5b41081e649ca Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 12 Jan 2024 20:36:28 -0800 Subject: [PATCH] New provisioner for RunPod (#2829) * init * remove ray * update config * update * update * update * complete bootstrapping * add start instance * fix * fix * fix * update * wait stopping instances * support normal gcp tpus first * fix gcp * support get cluster info * fix * update * wait for instance starting * rename * hide gcp package import * fix * fix * update constants * fix comments * remove unused methods * fix comments * sync 'config' & 'constants' with upstream, Nov 16 * sync 'instace_utils' with the upstream, Nov 16 * fix typing * parallelize provisioning * Fix TPU node * Fix TPU NAME env for tpu node * implement bulk provision * refactor selflink * format * reduce the sleep time for autostop * provisioner version refactoring * refactor * Add logging * avoid saving the provisioner version * format * format * Fix scheduling field in config * format * fix public key content * Fix provisioner version for azure * Use ray port from head node for workers * format * fix ray_port * fix smoke tests * shorter sleep time * refactor status refresh version * Use new provisioner to launch runpod to avoid issue with ray autoscaler on head Co-authored-by: Justin Merrell * Add wait for the instances to be ready * fix setup * Retry and give for getting internal IP * comment * Remove internal IP * use external IP TODO: use external ray port * fix ssh port * Unsupported feature * typo * fix ssh ports * rename var * format * Fix cloud unsupported resources * Runpod update name mapping (#2945) * Avoid using GpuInfo * fix all_regions * Fix runpod list accelerators * format * revert to GpuInfo * Fix get_feasible_launchable_resources * Add error * Fix optimizer random_dag for feature check * address comments * remove test code * format * Add type hints * format * format * fix keyerror * Address comments --------- Co-authored-by: Siyuan Co-authored-by: Doyoung Kim <34902420+landscapepainter@users.noreply.github.com> --- sky/__init__.py | 2 + sky/adaptors/runpod.py | 29 ++ sky/authentication.py | 15 + sky/backends/backend_utils.py | 2 + sky/backends/cloud_vm_ray_backend.py | 10 + sky/clouds/__init__.py | 2 + sky/clouds/runpod.py | 274 +++++++++++++++++++ sky/clouds/service_catalog/__init__.py | 2 +- sky/clouds/service_catalog/common.py | 2 +- sky/clouds/service_catalog/runpod_catalog.py | 113 ++++++++ sky/provision/__init__.py | 1 + sky/provision/common.py | 24 +- sky/provision/instance_setup.py | 14 +- sky/provision/provisioner.py | 21 +- sky/provision/runpod/__init__.py | 10 + sky/provision/runpod/config.py | 11 + sky/provision/runpod/instance.py | 209 ++++++++++++++ sky/provision/runpod/utils.py | 141 ++++++++++ sky/setup_files/setup.py | 8 +- sky/templates/runpod-ray.yml.j2 | 76 +++++ tests/test_optimizer_random_dag.py | 22 +- 21 files changed, 965 insertions(+), 23 deletions(-) create mode 100644 sky/adaptors/runpod.py create mode 100644 sky/clouds/runpod.py create mode 100644 sky/clouds/service_catalog/runpod_catalog.py create mode 100644 sky/provision/runpod/__init__.py create mode 100644 sky/provision/runpod/config.py create mode 100644 sky/provision/runpod/instance.py create mode 100644 sky/provision/runpod/utils.py create mode 100644 sky/templates/runpod-ray.yml.j2 diff --git a/sky/__init__.py b/sky/__init__.py index b27de4a5c3ff..a673669774ad 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -82,6 +82,7 @@ def get_git_commit(): Local = clouds.Local Kubernetes = clouds.Kubernetes OCI = clouds.OCI +RunPod = clouds.RunPod optimize = Optimizer.optimize __all__ = [ @@ -94,6 +95,7 @@ def get_git_commit(): 'Lambda', 'Local', 'OCI', + 'RunPod', 'SCP', 'Optimizer', 'OptimizeTarget', diff --git a/sky/adaptors/runpod.py b/sky/adaptors/runpod.py new file mode 100644 index 000000000000..2f4699f80bf6 --- /dev/null +++ b/sky/adaptors/runpod.py @@ -0,0 +1,29 @@ +"""RunPod cloud adaptor.""" + +import functools + +_runpod_sdk = None + + +def import_package(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + global _runpod_sdk + if _runpod_sdk is None: + try: + import runpod as _runpod # pylint: disable=import-outside-toplevel + _runpod_sdk = _runpod + except ImportError: + raise ImportError( + 'Fail to import dependencies for runpod.' + 'Try pip install "skypilot[runpod]"') from None + return func(*args, **kwargs) + + return wrapper + + +@import_package +def runpod(): + """Return the runpod package.""" + return _runpod_sdk diff --git a/sky/authentication.py b/sky/authentication.py index efce12defa49..b52f62c44e59 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -41,6 +41,7 @@ from sky import skypilot_config from sky.adaptors import gcp from sky.adaptors import ibm +from sky.adaptors import runpod from sky.clouds.utils import lambda_utils from sky.utils import common_utils from sky.utils import kubernetes_enums @@ -449,3 +450,17 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: config['auth']['ssh_proxy_command'] = ssh_proxy_cmd return config + + +# ---------------------------------- RunPod ---------------------------------- # +def setup_runpod_authentication(config: Dict[str, Any]) -> Dict[str, Any]: + """Sets up SSH authentication for RunPod. + - Generates a new SSH key pair if one does not exist. + - Adds the public SSH key to the user's RunPod account. + """ + _, public_key_path = get_or_generate_keys() + with open(public_key_path, 'r', encoding='UTF-8') as pub_key_file: + public_key = pub_key_file.read().strip() + runpod.runpod().cli.groups.ssh.functions.add_ssh_key(public_key) + + return configure_ssh_info(config) diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 41880f18fdb7..76c77dcdb0b2 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -1006,6 +1006,8 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str): config = auth.setup_kubernetes_authentication(config) elif isinstance(cloud, clouds.IBM): config = auth.setup_ibm_authentication(config) + elif isinstance(cloud, clouds.RunPod): + config = auth.setup_runpod_authentication(config) else: assert isinstance(cloud, clouds.Local), cloud # Local cluster case, authentication is already filled by the user diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index e8f76e2ec844..2403ede12319 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -148,6 +148,7 @@ def _get_cluster_config_template(cloud): clouds.Local: 'local-ray.yml.j2', clouds.SCP: 'scp-ray.yml.j2', clouds.OCI: 'oci-ray.yml.j2', + clouds.RunPod: 'runpod-ray.yml.j2', clouds.Kubernetes: 'kubernetes-ray.yml.j2', } return cloud_to_template[type(cloud)] @@ -2291,6 +2292,15 @@ def update_ssh_ports(self, max_attempts: int = 1) -> None: Use this method to use any cloud-specific port fetching logic. """ del max_attempts # Unused. + if isinstance(self.launched_resources.cloud, clouds.RunPod): + cluster_info = provision_lib.get_cluster_info( + str(self.launched_resources.cloud).lower(), + region=self.launched_resources.region, + cluster_name_on_cloud=self.cluster_name_on_cloud, + provider_config=None) + self.stable_ssh_ports = cluster_info.get_ssh_ports() + return + head_ssh_port = 22 self.stable_ssh_ports = ( [head_ssh_port] + [22] * diff --git a/sky/clouds/__init__.py b/sky/clouds/__init__.py index 36d843e267ea..0860d26b47fe 100644 --- a/sky/clouds/__init__.py +++ b/sky/clouds/__init__.py @@ -17,6 +17,7 @@ from sky.clouds.lambda_cloud import Lambda from sky.clouds.local import Local from sky.clouds.oci import OCI +from sky.clouds.runpod import RunPod from sky.clouds.scp import SCP __all__ = [ @@ -28,6 +29,7 @@ 'Lambda', 'Local', 'SCP', + 'RunPod', 'OCI', 'Kubernetes', 'CloudImplementationFeatures', diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py new file mode 100644 index 000000000000..f07bd18ee465 --- /dev/null +++ b/sky/clouds/runpod.py @@ -0,0 +1,274 @@ +""" RunPod Cloud. """ + +import json +import typing +from typing import Dict, Iterator, List, Optional, Tuple + +from sky import clouds +from sky.clouds import service_catalog + +if typing.TYPE_CHECKING: + from sky import resources as resources_lib + +_CREDENTIAL_FILES = [ + 'config.toml', +] + + +@clouds.CLOUD_REGISTRY.register +class RunPod(clouds.Cloud): + """ RunPod GPU Cloud + + _REPR | The string representation for the RunPod GPU cloud object. + """ + _REPR = 'RunPod' + _CLOUD_UNSUPPORTED_FEATURES = { + clouds.CloudImplementationFeatures.STOP: 'Stopping not supported.', + clouds.CloudImplementationFeatures.SPOT_INSTANCE: + ('Spot is not supported, as runpod API does not implement spot.'), + clouds.CloudImplementationFeatures.MULTI_NODE: + ('Multi-node not supported yet, as the interconnection among nodes ' + 'are non-trivial on RunPod.'), + clouds.CloudImplementationFeatures.OPEN_PORTS: + ('Opening ports is not ' + 'supported yet on RunPod.'), + clouds.CloudImplementationFeatures.CUSTOM_DISK_TIER: + ('Customizing disk tier is not supported yet on RunPod.') + } + _MAX_CLUSTER_NAME_LEN_LIMIT = 120 + _regions: List[clouds.Region] = [] + + PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT + STATUS_VERSION = clouds.StatusVersion.SKYPILOT + + @classmethod + def _unsupported_features_for_resources( + cls, resources: 'resources_lib.Resources' + ) -> Dict[clouds.CloudImplementationFeatures, str]: + """The features not supported based on the resources provided. + + This method is used by check_features_are_supported() to check if the + cloud implementation supports all the requested features. + + Returns: + A dict of {feature: reason} for the features not supported by the + cloud implementation. + """ + del resources # unused + return cls._CLOUD_UNSUPPORTED_FEATURES + + @classmethod + def _max_cluster_name_length(cls) -> Optional[int]: + return cls._MAX_CLUSTER_NAME_LEN_LIMIT + + @classmethod + def regions_with_offering(cls, instance_type: str, + accelerators: Optional[Dict[str, int]], + use_spot: bool, region: Optional[str], + zone: Optional[str]) -> List[clouds.Region]: + assert zone is None, 'RunPod does not support zones.' + del accelerators, zone # unused + if use_spot: + return [] + else: + regions = service_catalog.get_region_zones_for_instance_type( + instance_type, use_spot, 'runpod') + + if region is not None: + regions = [r for r in regions if r.name == region] + return regions + + @classmethod + def get_vcpus_mem_from_instance_type( + cls, + instance_type: str, + ) -> Tuple[Optional[float], Optional[float]]: + return service_catalog.get_vcpus_mem_from_instance_type(instance_type, + clouds='runpod') + + @classmethod + def zones_provision_loop( + cls, + *, + region: str, + num_nodes: int, + instance_type: str, + accelerators: Optional[Dict[str, int]] = None, + use_spot: bool = False, + ) -> Iterator[None]: + del num_nodes # unused + regions = cls.regions_with_offering(instance_type, + accelerators, + use_spot, + region=region, + zone=None) + for r in regions: + assert r.zones is None, r + yield r.zones + + def instance_type_to_hourly_cost(self, + instance_type: str, + use_spot: bool, + region: Optional[str] = None, + zone: Optional[str] = None) -> float: + return service_catalog.get_hourly_cost(instance_type, + use_spot=use_spot, + region=region, + zone=zone, + clouds='runpod') + + def accelerators_to_hourly_cost(self, + accelerators: Dict[str, int], + use_spot: bool, + region: Optional[str] = None, + zone: Optional[str] = None) -> float: + """Returns the hourly cost of the accelerators, in dollars/hour.""" + del accelerators, use_spot, region, zone # unused + return 0.0 # RunPod includes accelerators in the hourly cost. + + def get_egress_cost(self, num_gigabytes: float) -> float: + return 0.0 + + def is_same_cloud(self, other: clouds.Cloud) -> bool: + # Returns true if the two clouds are the same cloud type. + return isinstance(other, RunPod) + + @classmethod + def get_default_instance_type( + cls, + cpus: Optional[str] = None, + memory: Optional[str] = None, + disk_tier: Optional[str] = None) -> Optional[str]: + """Returns the default instance type for RunPod.""" + return service_catalog.get_default_instance_type(cpus=cpus, + memory=memory, + disk_tier=disk_tier, + clouds='runpod') + + @classmethod + def get_accelerators_from_instance_type( + cls, instance_type: str) -> Optional[Dict[str, int]]: + return service_catalog.get_accelerators_from_instance_type( + instance_type, clouds='runpod') + + @classmethod + def get_zone_shell_cmd(cls) -> Optional[str]: + return None + + def make_deploy_resources_variables( + self, resources: 'resources_lib.Resources', + cluster_name_on_cloud: str, region: 'clouds.Region', + zones: Optional[List['clouds.Zone']]) -> Dict[str, Optional[str]]: + del zones # unused + + r = resources + acc_dict = self.get_accelerators_from_instance_type(r.instance_type) + if acc_dict is not None: + custom_resources = json.dumps(acc_dict, separators=(',', ':')) + else: + custom_resources = None + + return { + 'instance_type': resources.instance_type, + 'custom_resources': custom_resources, + 'region': region.name, + } + + def _get_feasible_launchable_resources( + self, resources: 'resources_lib.Resources' + ) -> Tuple[List['resources_lib.Resources'], List[str]]: + """Returns a list of feasible resources for the given resources.""" + if resources.instance_type is not None: + assert resources.is_launchable(), resources + resources = resources.copy(accelerators=None) + return ([resources], []) + + def _make(instance_list): + resource_list = [] + for instance_type in instance_list: + r = resources.copy( + cloud=RunPod(), + instance_type=instance_type, + accelerators=None, + cpus=None, + ) + resource_list.append(r) + return resource_list + + # Currently, handle a filter on accelerators only. + accelerators = resources.accelerators + if accelerators is None: + # Return a default instance type + default_instance_type = RunPod.get_default_instance_type( + cpus=resources.cpus, + memory=resources.memory, + disk_tier=resources.disk_tier) + if default_instance_type is None: + return ([], []) + else: + return (_make([default_instance_type]), []) + + assert len(accelerators) == 1, resources + acc, acc_count = list(accelerators.items())[0] + (instance_list, fuzzy_candidate_list + ) = service_catalog.get_instance_type_for_accelerator( + acc, + acc_count, + use_spot=resources.use_spot, + cpus=resources.cpus, + region=resources.region, + zone=resources.zone, + clouds='runpod') + if instance_list is None: + return ([], fuzzy_candidate_list) + return (_make(instance_list), fuzzy_candidate_list) + + @classmethod + def check_credentials(cls) -> Tuple[bool, Optional[str]]: + """ Verify that the user has valid credentials for RunPod. """ + try: + import runpod # pylint: disable=import-outside-toplevel + valid, error = runpod.check_credentials() + + if not valid: + return False, ( + f'{error} \n' # First line is indented by 4 spaces + ' Credentials can be set up by running: \n' + f' $ pip install runpod \n' + f' $ runpod store_api_key \n' + ' For more information, see https://docs.runpod.io/docs/skypilot' # pylint: disable=line-too-long + ) + + return True, None + + except ImportError: + return False, ('Failed to import runpod. ' + 'To install, run: pip install skypilot[runpod]') + + def get_credential_file_mounts(self) -> Dict[str, str]: + return { + f'~/.runpod/{filename}': f'~/.runpod/{filename}' + for filename in _CREDENTIAL_FILES + } + + @classmethod + def get_current_user_identity(cls) -> Optional[List[str]]: + # NOTE: used for very advanced SkyPilot functionality + # Can implement later if desired + return None + + def instance_type_exists(self, instance_type: str) -> bool: + return service_catalog.instance_type_exists(instance_type, 'runpod') + + def validate_region_zone(self, region: Optional[str], zone: Optional[str]): + return service_catalog.validate_region_zone(region, + zone, + clouds='runpod') + + def accelerator_in_region_or_zone(self, + accelerator: str, + acc_count: int, + region: Optional[str] = None, + zone: Optional[str] = None) -> bool: + return service_catalog.accelerator_in_region_or_zone( + accelerator, acc_count, region, zone, 'runpod') diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index 524a8fdf6ce8..01b9bc7ff560 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -15,7 +15,7 @@ CloudFilter = Optional[Union[List[str], str]] ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci', - 'kubernetes') + 'kubernetes', 'runpod') def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs): diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py index a84ad9f55c7c..d72abb634ee8 100644 --- a/sky/clouds/service_catalog/common.py +++ b/sky/clouds/service_catalog/common.py @@ -470,7 +470,7 @@ def list_accelerators_impl( gpu_info_df = df['GpuInfo'].apply(ast.literal_eval) df['DeviceMemoryGiB'] = gpu_info_df.apply( lambda row: row['Gpus'][0]['MemoryInfo']['SizeInMiB']) / 1024.0 - except ValueError: + except (ValueError, SyntaxError): # TODO(zongheng,woosuk): GCP/Azure catalogs do not have well-formed # GpuInfo fields. So the above will throw: # ValueError: malformed node or string: <_ast.Name object at ..> diff --git a/sky/clouds/service_catalog/runpod_catalog.py b/sky/clouds/service_catalog/runpod_catalog.py new file mode 100644 index 000000000000..bb23b85c832c --- /dev/null +++ b/sky/clouds/service_catalog/runpod_catalog.py @@ -0,0 +1,113 @@ +""" RunPod | Catalog + +This module loads the service catalog file and can be used to +query instance types and pricing information for RunPod. +""" + +import typing +from typing import Dict, List, Optional, Tuple + +from sky.clouds.service_catalog import common +from sky.utils import ux_utils + +if typing.TYPE_CHECKING: + from sky.clouds import cloud + +_df = common.read_catalog('runpod/vms.csv') + + +def instance_type_exists(instance_type: str) -> bool: + return common.instance_type_exists_impl(_df, instance_type) + + +def validate_region_zone( + region: Optional[str], + zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]: + if zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('RunPod does not support zones.') + return common.validate_region_zone_impl('runpod', _df, region, zone) + + +def accelerator_in_region_or_zone(acc_name: str, + acc_count: int, + region: Optional[str] = None, + zone: Optional[str] = None) -> bool: + if zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('RunPod does not support zones.') + return common.accelerator_in_region_or_zone_impl(_df, acc_name, acc_count, + region, zone) + + +def get_hourly_cost(instance_type: str, + use_spot: bool = False, + region: Optional[str] = None, + zone: Optional[str] = None) -> float: + """Returns the cost, or the cheapest cost among all zones for spot.""" + if zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('RunPod does not support zones.') + return common.get_hourly_cost_impl(_df, instance_type, use_spot, region, + zone) + + +def get_vcpus_mem_from_instance_type( + instance_type: str) -> Tuple[Optional[float], Optional[float]]: + return common.get_vcpus_mem_from_instance_type_impl(_df, instance_type) + + +def get_default_instance_type(cpus: Optional[str] = None, + memory: Optional[str] = None, + disk_tier: Optional[str] = None) -> Optional[str]: + del disk_tier # RunPod does not support disk tiers. + # NOTE: After expanding catalog to multiple entries, you may + # want to specify a default instance type or family. + return common.get_instance_type_for_cpus_mem_impl(_df, cpus, memory) + + +def get_accelerators_from_instance_type( + instance_type: str) -> Optional[Dict[str, int]]: + return common.get_accelerators_from_instance_type_impl(_df, instance_type) + + +def get_instance_type_for_accelerator( + acc_name: str, + acc_count: int, + cpus: Optional[str] = None, + memory: Optional[str] = None, + use_spot: bool = False, + region: Optional[str] = None, + zone: Optional[str] = None) -> Tuple[Optional[List[str]], List[str]]: + """Returns a list of instance types that have the given accelerator.""" + if zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('RunPod does not support zones.') + return common.get_instance_type_for_accelerator_impl(df=_df, + acc_name=acc_name, + acc_count=acc_count, + cpus=cpus, + memory=memory, + use_spot=use_spot, + region=region, + zone=zone) + + +def get_region_zones_for_instance_type(instance_type: str, + use_spot: bool) -> List['cloud.Region']: + df = _df[_df['InstanceType'] == instance_type] + return common.get_region_zones(df, use_spot) + + +def list_accelerators( + gpus_only: bool, + name_filter: Optional[str], + region_filter: Optional[str], + quantity_filter: Optional[int], + case_sensitive: bool = True, + all_regions: bool = False, +) -> Dict[str, List[common.InstanceTypeInfo]]: + """Returns all instance types in RunPod offering GPUs.""" + return common.list_accelerators_impl('RunPod', _df, gpus_only, name_filter, + region_filter, quantity_filter, + case_sensitive, all_regions) diff --git a/sky/provision/__init__.py b/sky/provision/__init__.py index 0a35bfebabfe..6a107cc5f27d 100644 --- a/sky/provision/__init__.py +++ b/sky/provision/__init__.py @@ -17,6 +17,7 @@ from sky.provision import common from sky.provision import gcp from sky.provision import kubernetes +from sky.provision import runpod logger = sky_logging.init_logger(__name__) diff --git a/sky/provision/common.py b/sky/provision/common.py index 6dd82e5868b8..6074549883b6 100644 --- a/sky/provision/common.py +++ b/sky/provision/common.py @@ -79,6 +79,7 @@ class InstanceInfo: internal_ip: str external_ip: Optional[str] tags: Dict[str, str] + ssh_port: int = 22 def get_feasible_ip(self) -> str: """Get the most feasible IPs of the instance. This function returns @@ -127,16 +128,17 @@ def ip_tuples(self) -> List[Tuple[str, Optional[str]]]: Returns: A list of tuples (internal_ip, external_ip) of all instances. """ - head_node = self.get_head_instance() - if head_node is None: - head_node_ip = [] + head_instance = self.get_head_instance() + if head_instance is None: + head_instance_ip = [] else: - head_node_ip = [(head_node.internal_ip, head_node.external_ip)] + head_instance_ip = [(head_instance.internal_ip, + head_instance.external_ip)] other_ips = [] for instance in self.get_worker_instances(): pair = (instance.internal_ip, instance.external_ip) other_ips.append(pair) - return head_node_ip + other_ips + return head_instance_ip + other_ips def has_external_ips(self) -> bool: """True if the cluster has external IP.""" @@ -170,6 +172,18 @@ def get_feasible_ips(self, force_internal_ips: bool = False) -> List[str]: """Get external IPs if they exist, otherwise get internal ones.""" return self._get_ips(not self.has_external_ips() or force_internal_ips) + def get_ssh_ports(self) -> List[int]: + """Get the SSH port of all the instances.""" + head_instance = self.get_head_instance() + assert head_instance is not None, self + head_instance_port = [head_instance.ssh_port] + + worker_instances = self.get_worker_instances() + worker_instance_ports = [ + instance.ssh_port for instance in worker_instances + ] + return head_instance_port + worker_instance_ports + class Endpoint: """Base class for endpoints.""" diff --git a/sky/provision/instance_setup.py b/sky/provision/instance_setup.py index 68ed6665b35a..9f84ada199e5 100644 --- a/sky/provision/instance_setup.py +++ b/sky/provision/instance_setup.py @@ -106,7 +106,9 @@ def _parallel_ssh_with_cache(func, cluster_name: str, stage_name: str, for i, metadata in enumerate(metadatas): cache_id = f'{instance_id}-{i}' runner = command_runner.SSHCommandRunner( - metadata.get_feasible_ip(), port=22, **ssh_credentials) + metadata.get_feasible_ip(), + port=metadata.ssh_port, + **ssh_credentials) wrapper = metadata_utils.cache_func(cluster_name, cache_id, stage_name, digest) if (cluster_info.head_instance_id == instance_id and i == 0): @@ -201,8 +203,9 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], ssh_credentials: Dict[str, Any]) -> None: """Start Ray on the head node.""" ip_list = cluster_info.get_feasible_ips() + port_list = cluster_info.get_ssh_ports() ssh_runner = command_runner.SSHCommandRunner(ip_list[0], - port=22, + port=port_list[0], **ssh_credentials) assert cluster_info.head_instance_id is not None, (cluster_name, cluster_info) @@ -254,7 +257,9 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, _hint_worker_log_path(cluster_name, cluster_info, 'ray_cluster') ip_list = cluster_info.get_feasible_ips() ssh_runners = command_runner.SSHCommandRunner.make_runner_list( - ip_list[1:], port_list=None, **ssh_credentials) + ip_list[1:], + port_list=cluster_info.get_ssh_ports()[1:], + **ssh_credentials) worker_instances = cluster_info.get_worker_instances() cache_ids = [] prev_instance_id = None @@ -329,8 +334,9 @@ def start_skylet_on_head_node(cluster_name: str, """Start skylet on the head node.""" del cluster_name ip_list = cluster_info.get_feasible_ips() + port_list = cluster_info.get_ssh_ports() ssh_runner = command_runner.SSHCommandRunner(ip_list[0], - port=22, + port=port_list[0], **ssh_credentials) assert cluster_info.head_instance_id is not None, cluster_info log_path_abs = str(provision_logging.get_log_path()) diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index 0a3c5af5d9a8..6afbfafb0635 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -197,6 +197,7 @@ def teardown_cluster(cloud_name: str, cluster_name: ClusterName, def _ssh_probe_command(ip: str, + ssh_port: int, ssh_user: str, ssh_private_key: str, ssh_proxy_command: Optional[str] = None) -> List[str]: @@ -208,6 +209,8 @@ def _ssh_probe_command(ip: str, '-i', ssh_private_key, f'{ssh_user}@{ip}', + '-p', + str(ssh_port), '-o', 'StrictHostKeyChecking=no', '-o', @@ -240,19 +243,20 @@ def _shlex_join(command: List[str]) -> str: def _wait_ssh_connection_direct( ip: str, + ssh_port: int, ssh_user: str, ssh_private_key: str, ssh_control_name: Optional[str] = None, ssh_proxy_command: Optional[str] = None) -> bool: assert ssh_proxy_command is None, 'SSH proxy command is not supported.' try: - with socket.create_connection((ip, 22), timeout=1) as s: + with socket.create_connection((ip, ssh_port), timeout=1) as s: if s.recv(100).startswith(b'SSH'): # Wait for SSH being actually ready, otherwise we may get the # following error: # "System is booting up. Unprivileged users are not permitted to # log in yet". - return _wait_ssh_connection_indirect(ip, ssh_user, + return _wait_ssh_connection_indirect(ip, ssh_port, ssh_user, ssh_private_key, ssh_control_name, ssh_proxy_command) @@ -260,7 +264,7 @@ def _wait_ssh_connection_direct( pass except Exception: # pylint: disable=broad-except pass - command = _ssh_probe_command(ip, ssh_user, ssh_private_key, + command = _ssh_probe_command(ip, ssh_port, ssh_user, ssh_private_key, ssh_proxy_command) logger.debug(f'Waiting for SSH to {ip}. Try: ' f'{_shlex_join(command)}') @@ -269,12 +273,13 @@ def _wait_ssh_connection_direct( def _wait_ssh_connection_indirect( ip: str, + ssh_port: int, ssh_user: str, ssh_private_key: str, ssh_control_name: Optional[str] = None, ssh_proxy_command: Optional[str] = None) -> bool: del ssh_control_name - command = _ssh_probe_command(ip, ssh_user, ssh_private_key, + command = _ssh_probe_command(ip, ssh_port, ssh_user, ssh_private_key, ssh_proxy_command) proc = subprocess.run(command, shell=False, @@ -300,14 +305,17 @@ def wait_for_ssh(cluster_info: provision_common.ClusterInfo, # See https://github.com/skypilot-org/skypilot/pull/1512 waiter = _wait_ssh_connection_indirect ip_list = cluster_info.get_feasible_ips() + port_list = cluster_info.get_ssh_ports() timeout = 60 * 10 # 10-min maximum timeout start = time.time() # use a queue for SSH querying ips = collections.deque(ip_list) + ssh_ports = collections.deque(port_list) while ips: ip = ips.popleft() - if not waiter(ip, **ssh_credentials): + ssh_port = ssh_ports.popleft() + if not waiter(ip, ssh_port, **ssh_credentials): ips.append(ip) if time.time() - start > timeout: with ux_utils.print_exception_no_traceback(): @@ -349,6 +357,7 @@ def _post_provision_setup( # TODO(suquark): Move wheel build here in future PRs. ip_list = cluster_info.get_feasible_ips() + port_list = cluster_info.get_ssh_ports() ssh_credentials = backend_utils.ssh_credential_from_yaml(cluster_yaml) with rich_utils.safe_status( @@ -407,7 +416,7 @@ def _post_provision_setup( cluster_info, ssh_credentials) head_runner = command_runner.SSHCommandRunner(ip_list[0], - port=22, + port=port_list[0], **ssh_credentials) status.update( diff --git a/sky/provision/runpod/__init__.py b/sky/provision/runpod/__init__.py new file mode 100644 index 000000000000..59297d7a3b01 --- /dev/null +++ b/sky/provision/runpod/__init__.py @@ -0,0 +1,10 @@ +"""GCP provisioner for SkyPilot.""" + +from sky.provision.runpod.config import bootstrap_instances +from sky.provision.runpod.instance import cleanup_ports +from sky.provision.runpod.instance import get_cluster_info +from sky.provision.runpod.instance import query_instances +from sky.provision.runpod.instance import run_instances +from sky.provision.runpod.instance import stop_instances +from sky.provision.runpod.instance import terminate_instances +from sky.provision.runpod.instance import wait_instances diff --git a/sky/provision/runpod/config.py b/sky/provision/runpod/config.py new file mode 100644 index 000000000000..f0d6ca3488de --- /dev/null +++ b/sky/provision/runpod/config.py @@ -0,0 +1,11 @@ +"""Runpod configuration bootstrapping.""" + +from sky.provision import common + + +def bootstrap_instances( + region: str, cluster_name: str, + config: common.ProvisionConfig) -> common.ProvisionConfig: + """Bootstraps instances for the given cluster.""" + del region, cluster_name # unused + return config diff --git a/sky/provision/runpod/instance.py b/sky/provision/runpod/instance.py new file mode 100644 index 000000000000..9f3a1d928862 --- /dev/null +++ b/sky/provision/runpod/instance.py @@ -0,0 +1,209 @@ +"""RunPod instance provisioning.""" +import time +from typing import Any, Dict, List, Optional + +from sky import sky_logging +from sky import status_lib +from sky.provision import common +from sky.provision.runpod import utils +from sky.utils import common_utils +from sky.utils import ux_utils + +POLL_INTERVAL = 5 + +logger = sky_logging.init_logger(__name__) + + +def _filter_instances(cluster_name_on_cloud: str, + status_filters: Optional[List[str]]) -> Dict[str, Any]: + + instances = utils.list_instances() + possible_names = [ + f'{cluster_name_on_cloud}-head', f'{cluster_name_on_cloud}-worker' + ] + + filtered_instances = {} + for instance_id, instance in instances.items(): + if (status_filters is not None and + instance['status'] not in status_filters): + continue + if instance.get('name') in possible_names: + filtered_instances[instance_id] = instance + return filtered_instances + + +def _get_head_instance_id(instances: Dict[str, Any]) -> Optional[str]: + head_instance_id = None + for inst_id, inst in instances.items(): + if inst['name'].endswith('-head'): + head_instance_id = inst_id + break + return head_instance_id + + +def run_instances(region: str, cluster_name_on_cloud: str, + config: common.ProvisionConfig) -> common.ProvisionRecord: + """Runs instances for the given cluster.""" + + pending_status = ['CREATED', 'RESTARTING'] + + while True: + instances = _filter_instances(cluster_name_on_cloud, pending_status) + if not instances: + break + logger.info(f'Waiting for {len(instances)} instances to be ready.') + time.sleep(POLL_INTERVAL) + exist_instances = _filter_instances(cluster_name_on_cloud, ['RUNNING']) + head_instance_id = _get_head_instance_id(exist_instances) + + to_start_count = config.count - len(exist_instances) + if to_start_count < 0: + raise RuntimeError( + f'Cluster {cluster_name_on_cloud} already has ' + f'{len(exist_instances)} nodes, but {config.count} are required.') + if to_start_count == 0: + if head_instance_id is None: + raise RuntimeError( + f'Cluster {cluster_name_on_cloud} has no head node.') + logger.info(f'Cluster {cluster_name_on_cloud} already has ' + f'{len(exist_instances)} nodes, no need to start more.') + return common.ProvisionRecord(provider_name='runpod', + cluster_name=cluster_name_on_cloud, + region=region, + zone=None, + head_instance_id=head_instance_id, + resumed_instance_ids=[], + created_instance_ids=[]) + + created_instance_ids = [] + for _ in range(to_start_count): + node_type = 'head' if head_instance_id is None else 'worker' + try: + instance_id = utils.launch( + name=f'{cluster_name_on_cloud}-{node_type}', + instance_type=config.node_config['InstanceType'], + region=region, + disk_size=config.node_config['DiskSize']) + except Exception as e: # pylint: disable=broad-except + logger.warning(f'run_instances error: {e}') + raise + logger.info(f'Launched instance {instance_id}.') + created_instance_ids.append(instance_id) + if head_instance_id is None: + head_instance_id = instance_id + + # Wait for instances to be ready. + while True: + instances = _filter_instances(cluster_name_on_cloud, ['RUNNING']) + ready_instance_cnt = 0 + for instance_id, instance in instances.items(): + if instance.get('ssh_port') is not None: + ready_instance_cnt += 1 + logger.info('Waiting for instances to be ready: ' + f'({ready_instance_cnt}/{config.count}).') + if ready_instance_cnt == config.count: + break + + time.sleep(POLL_INTERVAL) + assert head_instance_id is not None, 'head_instance_id should not be None' + return common.ProvisionRecord(provider_name='runpod', + cluster_name=cluster_name_on_cloud, + region=region, + zone=None, + head_instance_id=head_instance_id, + resumed_instance_ids=[], + created_instance_ids=created_instance_ids) + + +def wait_instances(region: str, cluster_name_on_cloud: str, + state: Optional[status_lib.ClusterStatus]) -> None: + del region, cluster_name_on_cloud, state + + +def stop_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + worker_only: bool = False, +) -> None: + raise NotImplementedError() + + +def terminate_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + worker_only: bool = False, +) -> None: + """See sky/provision/__init__.py""" + del provider_config # unused + instances = _filter_instances(cluster_name_on_cloud, None) + for inst_id, inst in instances.items(): + logger.info(f'Terminating instance {inst_id}: {inst}') + if worker_only and inst['name'].endswith('-head'): + continue + logger.info(f'Start {inst_id}: {inst}') + try: + utils.remove(inst_id) + except Exception as e: # pylint: disable=broad-except + with ux_utils.print_exception_no_traceback(): + raise RuntimeError( + f'Failed to terminate instance {inst_id}: ' + f'{common_utils.format_exception(e, use_bracket=False)}' + ) from e + + +def get_cluster_info( + region: str, + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: + del region, provider_config # unused + running_instances = _filter_instances(cluster_name_on_cloud, ['RUNNING']) + instances: Dict[str, List[common.InstanceInfo]] = {} + head_instance_id = None + for instance_id, instance_info in running_instances.items(): + instances[instance_id] = [ + common.InstanceInfo( + instance_id=instance_id, + internal_ip=instance_info['internal_ip'], + external_ip=instance_info['external_ip'], + ssh_port=instance_info['ssh_port'], + tags={}, + ) + ] + if instance_info['name'].endswith('-head'): + head_instance_id = instance_id + + return common.ClusterInfo( + instances=instances, + head_instance_id=head_instance_id, + ) + + +def query_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + non_terminated_only: bool = True, +) -> Dict[str, Optional[status_lib.ClusterStatus]]: + """See sky/provision/__init__.py""" + assert provider_config is not None, (cluster_name_on_cloud, provider_config) + instances = _filter_instances(cluster_name_on_cloud, None) + + status_map = { + 'CREATED': status_lib.ClusterStatus.INIT, + 'RESTARTING': status_lib.ClusterStatus.INIT, + 'PAUSED': status_lib.ClusterStatus.INIT, + 'RUNNING': status_lib.ClusterStatus.UP, + } + statuses: Dict[str, Optional[status_lib.ClusterStatus]] = {} + for inst_id, inst in instances.items(): + status = status_map[inst['status']] + if non_terminated_only and status is None: + continue + statuses[inst_id] = status + return statuses + + +def cleanup_ports( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, +) -> None: + del cluster_name_on_cloud, provider_config diff --git a/sky/provision/runpod/utils.py b/sky/provision/runpod/utils.py new file mode 100644 index 000000000000..00b24aee0a86 --- /dev/null +++ b/sky/provision/runpod/utils.py @@ -0,0 +1,141 @@ +"""RunPod library wrapper for SkyPilot.""" + +import time +from typing import Any, Dict, List + +from sky import sky_logging +from sky.adaptors import runpod +from sky.skylet import constants +from sky.utils import common_utils + +logger = sky_logging.init_logger(__name__) + +GPU_NAME_MAP = { + 'A100-80GB': 'NVIDIA A100 80GB PCIe', + 'A100-40GB': 'NVIDIA A100-PCIE-40GB', + 'A100-80GB-SXM': 'NVIDIA A100-SXM4-80GB', + 'A30': 'NVIDIA A30', + 'A40': 'NVIDIA A40', + 'RTX3070': 'NVIDIA GeForce RTX 3070', + 'RTX3080': 'NVIDIA GeForce RTX 3080', + 'RTX3080Ti': 'NVIDIA GeForce RTX 3080 Ti', + 'RTX3090': 'NVIDIA GeForce RTX 3090', + 'RTX3090Ti': 'NVIDIA GeForce RTX 3090 Ti', + 'RTX4070Ti': 'NVIDIA GeForce RTX 4070 Ti', + 'RTX4080': 'NVIDIA GeForce RTX 4080', + 'RTX4090': 'NVIDIA GeForce RTX 4090', + # Following instance is displayed as SXM at the console + # but the ID from the API appears as HBM + 'H100-SXM': 'NVIDIA H100 80GB HBM3', + 'H100': 'NVIDIA H100 PCIe', + 'L4': 'NVIDIA L4', + 'L40': 'NVIDIA L40', + 'RTX4000-Ada-SFF': 'NVIDIA RTX 4000 SFF Ada Generation', + 'RTX4000-Ada': 'NVIDIA RTX 4000 Ada Generation', + 'RTX6000-Ada': 'NVIDIA RTX 6000 Ada Generation', + 'RTXA4000': 'NVIDIA RTX A4000', + 'RTXA4500': 'NVIDIA RTX A4500', + 'RTXA5000': 'NVIDIA RTX A5000', + 'RTXA6000': 'NVIDIA RTX A6000', + 'RTX5000': 'Quadro RTX 5000', + 'V100-16GB-FHHL': 'Tesla V100-FHHL-16GB', + 'V100-16GB-SXM2': 'V100-SXM2-16GB', + 'RTXA2000': 'NVIDIA RTX A2000', + 'V100-16GB-PCIe': 'Tesla V100-PCIE-16GB' +} + + +def retry(func): + """Decorator to retry a function.""" + + def wrapper(*args, **kwargs): + """Wrapper for retrying a function.""" + cnt = 0 + while True: + try: + return func(*args, **kwargs) + except runpod.runpod().error.QueryError as e: + if cnt >= 3: + raise + logger.warning('Retrying for exception: ' + f'{common_utils.format_exception(e)}.') + time.sleep(1) + + return wrapper + + +def list_instances() -> Dict[str, Dict[str, Any]]: + """Lists instances associated with API key.""" + instances = runpod.runpod().get_pods() + + instance_dict: Dict[str, Dict[str, Any]] = {} + for instance in instances: + info = {} + + info['status'] = instance['desiredStatus'] + info['name'] = instance['name'] + + if instance['desiredStatus'] == 'RUNNING' and instance.get('runtime'): + for port in instance['runtime']['ports']: + if port['privatePort'] == 22 and port['isIpPublic']: + info['external_ip'] = port['ip'] + info['ssh_port'] = port['publicPort'] + elif not port['isIpPublic']: + info['internal_ip'] = port['ip'] + + instance_dict[instance['id']] = info + + return instance_dict + + +def launch(name: str, instance_type: str, region: str, disk_size: int) -> str: + """Launches an instance with the given parameters. + + Converts the instance_type to the RunPod GPU name, finds the specs for the + GPU, and launches the instance. + """ + gpu_type = GPU_NAME_MAP[instance_type.split('_')[1]] + gpu_quantity = int(instance_type.split('_')[0].replace('x', '')) + cloud_type = instance_type.split('_')[2] + + gpu_specs = runpod.runpod().get_gpu(gpu_type) + + new_instance = runpod.runpod().create_pod( + name=name, + image_name='runpod/base:0.0.2', + gpu_type_id=gpu_type, + cloud_type=cloud_type, + container_disk_in_gb=disk_size, + min_vcpu_count=4 * gpu_quantity, + min_memory_in_gb=gpu_specs['memoryInGb'] * gpu_quantity, + country_code=region, + ports=(f'22/tcp,' + f'{constants.SKY_REMOTE_RAY_DASHBOARD_PORT}/http,' + f'{constants.SKY_REMOTE_RAY_PORT}/http'), + support_public_ip=True, + ) + + return new_instance['id'] + + +def remove(instance_id: str) -> None: + """Terminates the given instance.""" + runpod.runpod().terminate_pod(instance_id) + + +def get_ssh_ports(cluster_name) -> List[int]: + """Gets the SSH ports for the given cluster.""" + logger.debug(f'Getting SSH ports for cluster {cluster_name}.') + + instances = list_instances() + possible_names = [f'{cluster_name}-head', f'{cluster_name}-worker'] + + ssh_ports = [] + + for instance in instances.values(): + if instance['name'] in possible_names: + ssh_ports.append(instance['ssh_port']) + assert ssh_ports, ( + f'Could not find any instances for cluster {cluster_name}.') + + return ssh_ports diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py index b0389cd73daf..c54187aa658d 100644 --- a/sky/setup_files/setup.py +++ b/sky/setup_files/setup.py @@ -135,9 +135,7 @@ def parse_readme(readme: str) -> str: 'cachetools', # NOTE: ray requires click>=7.0. 'click >= 7.0', - # NOTE: required by awscli. To avoid ray automatically installing - # the latest version. - 'colorama < 0.4.5', + 'colorama', 'cryptography', # Jinja has a bug in older versions because of the lack of pinning # the version of the underlying markupsafe package. See: @@ -207,6 +205,9 @@ def parse_readme(readme: str) -> str: 'awscli>=1.27.10', 'botocore>=1.29.10', 'boto3>=1.26.1', + # NOTE: required by awscli. To avoid ray automatically installing + # the latest version. + 'colorama < 0.4.5', ] extras_require: Dict[str, List[str]] = { 'aws': aws_dependencies, @@ -233,6 +234,7 @@ def parse_readme(readme: str) -> str: 'oci': ['oci'] + local_ray, 'kubernetes': ['kubernetes>=20.0.0'] + local_ray, 'remote': remote, + 'runpod': ['runpod>=1.3.7'] } extras_require['all'] = sum(extras_require.values(), []) diff --git a/sky/templates/runpod-ray.yml.j2 b/sky/templates/runpod-ray.yml.j2 new file mode 100644 index 000000000000..fa3598e429e5 --- /dev/null +++ b/sky/templates/runpod-ray.yml.j2 @@ -0,0 +1,76 @@ +cluster_name: {{cluster_name_on_cloud}} + +# The maximum number of workers nodes to launch in addition to the head node. +max_workers: {{num_nodes - 1}} +upscaling_speed: {{num_nodes - 1}} +idle_timeout_minutes: 60 + +provider: + type: external + module: sky.provision.runpod + region: "{{region}}" + disable_launch_config_check: true + +auth: + ssh_user: root + ssh_private_key: {{ssh_private_key}} + +available_node_types: + ray_head_default: + resources: {} + node_config: + InstanceType: {{instance_type}} + DiskSize: {{disk_size}} + +head_node_type: ray_head_default + +# Format: `REMOTE_PATH : LOCAL_PATH` +file_mounts: { + "{{sky_ray_yaml_remote_path}}": "{{sky_ray_yaml_local_path}}", + "{{sky_remote_path}}/{{sky_wheel_hash}}": "{{sky_local_path}}", +{%- for remote_path, local_path in credentials.items() %} + "{{remote_path}}": "{{local_path}}", +{%- endfor %} +} + +rsync_exclude: [] + +initialization_commands: [] + +# List of shell commands to run to set up nodes. +# NOTE: these are very performance-sensitive. Each new item opens/closes an SSH +# connection, which is expensive. Try your best to co-locate commands into fewer +# items! +# +# Increment the following for catching performance bugs easier: +# current num items (num SSH connections): 1 +setup_commands: + # Disable `unattended-upgrades` to prevent apt-get from hanging. It should be called at the beginning before the process started to avoid being blocked. (This is a temporary fix.) + # Create ~/.ssh/config file in case the file does not exist in the image. + # Line 'rm ..': there is another installation of pip. + # Line 'sudo bash ..': set the ulimit as suggested by ray docs for performance. https://docs.ray.io/en/latest/cluster/vms/user-guides/large-cluster-best-practices.html#system-configuration + # Line 'sudo grep ..': set the number of threads per process to unlimited to avoid ray job submit stucking issue when the number of running ray jobs increase. + # Line 'mkdir -p ..': disable host key check + # Line 'python3 -c ..': patch the buggy ray files and enable `-o allow_other` option for `goofys` + - sudo systemctl stop unattended-upgrades || true; + sudo systemctl disable unattended-upgrades || true; + sudo sed -i 's/Unattended-Upgrade "1"/Unattended-Upgrade "0"/g' /etc/apt/apt.conf.d/20auto-upgrades || true; + sudo kill -9 `sudo lsof /var/lib/dpkg/lock-frontend | awk '{print $2}' | tail -n 1` || true; + sudo pkill -9 apt-get; + sudo pkill -9 dpkg; + sudo dpkg --configure -a; + mkdir -p ~/.ssh; touch ~/.ssh/config; + {{ conda_installation_commands }} + (type -a python | grep -q python3) || echo 'alias python=python3' >> ~/.bashrc; + (type -a pip | grep -q pip3) || echo 'alias pip=pip3' >> ~/.bashrc; + source ~/.bashrc; + (pip3 list | grep ray | grep {{ray_version}} 2>&1 > /dev/null || pip3 install -U ray[default]=={{ray_version}}) && mkdir -p ~/sky_workdir && mkdir -p ~/.sky/sky_app && touch ~/.sudo_as_admin_successful; + (pip3 list | grep skypilot && [ "$(cat {{sky_remote_path}}/current_sky_wheel_hash)" == "{{sky_wheel_hash}}" ]) || (pip3 uninstall skypilot -y; pip3 install "$(echo {{sky_remote_path}}/{{sky_wheel_hash}}/skypilot-{{sky_version}}*.whl)[runpod,remote]" && echo "{{sky_wheel_hash}}" > {{sky_remote_path}}/current_sky_wheel_hash || exit 1); + sudo bash -c 'rm -rf /etc/security/limits.d; echo "* soft nofile 1048576" >> /etc/security/limits.conf; echo "* hard nofile 1048576" >> /etc/security/limits.conf'; + sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || (sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf'); sudo systemctl set-property user-$(id -u $(whoami)).slice TasksMax=infinity; sudo systemctl daemon-reload; + mkdir -p ~/.ssh; (grep -Pzo -q "Host \*\n StrictHostKeyChecking no" ~/.ssh/config) || printf "Host *\n StrictHostKeyChecking no\n" >> ~/.ssh/config; + python3 -c "from sky.skylet.ray_patches import patch; patch()" || exit 1; + [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); + +# Command to start ray clusters are now placed in `sky.provision.instance_setup`. +# We do not need to list it here anymore. diff --git a/tests/test_optimizer_random_dag.py b/tests/test_optimizer_random_dag.py index 5c828ba3130f..6681f061d6ff 100644 --- a/tests/test_optimizer_random_dag.py +++ b/tests/test_optimizer_random_dag.py @@ -1,11 +1,13 @@ import copy import random +import sys import numpy as np import pandas as pd import sky from sky import clouds +from sky import exceptions from sky.clouds import service_catalog ALL_INSTANCE_TYPE_INFOS = sum( @@ -57,8 +59,8 @@ def generate_random_dag( op.set_outputs('CLOUD', random.randint(0, max_data_size)) num_candidates = random.randint(1, max_num_candidate_resources) - candidate_instance_types = random.choices(ALL_INSTANCE_TYPE_INFOS, - k=num_candidates) + candidate_instance_types = random.choices( + ALL_INSTANCE_TYPE_INFOS, k=len(ALL_INSTANCE_TYPE_INFOS)) candidate_resources = set() for candidate in candidate_instance_types: @@ -80,7 +82,18 @@ def generate_random_dag( accelerators={ candidate.accelerator_name: candidate.accelerator_count }) + requested_features = set() + if op.num_nodes > 1: + requested_features.add( + clouds.CloudImplementationFeatures.MULTI_NODE) + try: + resources.cloud.check_features_are_supported( + resources, requested_features) + except exceptions.NotSupportedError: + continue candidate_resources.add(resources) + if len(candidate_resources) >= num_candidates: + break op.set_resources(candidate_resources) return dag @@ -121,7 +134,7 @@ def _optimize_by_brute_force(tasks, plan): resources_stack.pop() _optimize_by_brute_force(topo_order, {}) - print(final_plan) + print(final_plan, file=sys.stderr) return min_objective @@ -140,6 +153,9 @@ def compare_optimization_results(dag: sky.Dag, minimize_cost: bool): objective = sky.Optimizer._compute_total_time(dag.get_graph(), dag.tasks, optimizer_plan) + print('=== optimizer plan ===', file=sys.stderr) + print(optimizer_plan, file=sys.stderr) + print('=== brute force ===', file=sys.stderr) min_objective = find_min_objective(copy_dag, minimize_cost) assert abs(objective - min_objective) < 5e-2