diff --git a/sky/authentication.py b/sky/authentication.py index eb51aad02ad..41a7d02dfb7 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -43,9 +43,9 @@ from sky.adaptors import ibm from sky.adaptors import kubernetes from sky.adaptors import runpod -from sky.clouds.utils import lambda_utils from sky.provision.fluidstack import fluidstack_utils from sky.provision.kubernetes import utils as kubernetes_utils +from sky.provision.lambda_cloud import lambda_utils from sky.utils import common_utils from sky.utils import kubernetes_enums from sky.utils import subprocess_utils diff --git a/sky/clouds/lambda_cloud.py b/sky/clouds/lambda_cloud.py index d3d20fbd41a..d2573ebbb29 100644 --- a/sky/clouds/lambda_cloud.py +++ b/sky/clouds/lambda_cloud.py @@ -8,7 +8,7 @@ from sky import clouds from sky import status_lib from sky.clouds import service_catalog -from sky.clouds.utils import lambda_utils +from sky.provision.lambda_cloud import lambda_utils from sky.utils import resources_utils if typing.TYPE_CHECKING: @@ -48,6 +48,9 @@ class Lambda(clouds.Cloud): clouds.CloudImplementationFeatures.HOST_CONTROLLERS: f'Host controllers are not supported in {_REPR}.', } + PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT + STATUS_VERSION = clouds.StatusVersion.SKYPILOT + @classmethod def _unsupported_features_for_resources( cls, resources: 'resources_lib.Resources' diff --git a/sky/provision/__init__.py b/sky/provision/__init__.py index 41d985ade41..bbe92b68c3a 100644 --- a/sky/provision/__init__.py +++ b/sky/provision/__init__.py @@ -19,6 +19,7 @@ from sky.provision import fluidstack from sky.provision import gcp from sky.provision import kubernetes +from sky.provision import lambda_cloud from sky.provision import runpod from sky.provision import vsphere from sky.utils import command_runner @@ -39,6 +40,8 @@ def _wrapper(*args, **kwargs): provider_name = kwargs.pop('provider_name') module_name = provider_name.lower() + if module_name == 'lambda': + module_name = 'lambda_cloud' module = globals().get(module_name) assert module is not None, f'Unknown provider: {module_name}' diff --git a/sky/provision/lambda_cloud/__init__.py b/sky/provision/lambda_cloud/__init__.py new file mode 100644 index 00000000000..4992df4531b --- /dev/null +++ b/sky/provision/lambda_cloud/__init__.py @@ -0,0 +1,11 @@ +"""Lambda provisioner for SkyPilot.""" + +from sky.provision.lambda_cloud.config import bootstrap_instances +from sky.provision.lambda_cloud.instance import cleanup_ports +from sky.provision.lambda_cloud.instance import get_cluster_info +from sky.provision.lambda_cloud.instance import open_ports +from sky.provision.lambda_cloud.instance import query_instances +from sky.provision.lambda_cloud.instance import run_instances +from sky.provision.lambda_cloud.instance import stop_instances +from sky.provision.lambda_cloud.instance import terminate_instances +from sky.provision.lambda_cloud.instance import wait_instances diff --git a/sky/provision/lambda_cloud/config.py b/sky/provision/lambda_cloud/config.py new file mode 100644 index 00000000000..3066e7747fd --- /dev/null +++ b/sky/provision/lambda_cloud/config.py @@ -0,0 +1,10 @@ +"""Lambda Cloud configuration bootstrapping""" + +from sky.provision import common + + +def bootstrap_instances( + region: str, cluster_name: str, + config: common.ProvisionConfig) -> common.ProvisionConfig: + del region, cluster_name # unused + return config diff --git a/sky/provision/lambda_cloud/instance.py b/sky/provision/lambda_cloud/instance.py new file mode 100644 index 00000000000..d10c36496ab --- /dev/null +++ b/sky/provision/lambda_cloud/instance.py @@ -0,0 +1,261 @@ +"""Lambda instance provisioning.""" + +import time +from typing import Any, Dict, List, Optional + +from sky import authentication as auth +from sky import sky_logging +from sky import status_lib +from sky.provision import common +import sky.provision.lambda_cloud.lambda_utils as lambda_utils +from sky.utils import common_utils +from sky.utils import ux_utils + +POLL_INTERVAL = 1 + +logger = sky_logging.init_logger(__name__) +_lambda_client = None + + +def _get_lambda_client(): + global _lambda_client + if _lambda_client is None: + _lambda_client = lambda_utils.LambdaCloudClient() + return _lambda_client + + +def _filter_instances( + cluster_name_on_cloud: str, + status_filters: Optional[List[str]]) -> Dict[str, Dict[str, Any]]: + lambda_client = _get_lambda_client() + instances = lambda_client.list_instances() + possible_names = [ + f'{cluster_name_on_cloud}-head', + f'{cluster_name_on_cloud}-worker', + ] + + filtered_instances = {} + for instance in instances: + if (status_filters is not None and + instance['status'] not in status_filters): + continue + if instance.get('name') in possible_names: + filtered_instances[instance['id']] = instance + return filtered_instances + + +def _get_head_instance_id(instances: Dict[str, Any]) -> Optional[str]: + head_instance_id = None + for instance_id, instance in instances.items(): + if instance['name'].endswith('-head'): + head_instance_id = instance_id + break + return head_instance_id + + +def _get_ssh_key_name(prefix: str = '') -> str: + lambda_client = _get_lambda_client() + _, public_key_path = auth.get_or_generate_keys() + with open(public_key_path, 'r', encoding='utf-8') as f: + public_key = f.read() + name, exists = lambda_client.get_unique_ssh_key_name(prefix, public_key) + if not exists: + raise lambda_utils.LambdaCloudError('SSH key not found') + return name + + +def run_instances(region: str, cluster_name_on_cloud: str, + config: common.ProvisionConfig) -> common.ProvisionRecord: + """Runs instances for the given cluster""" + lambda_client = _get_lambda_client() + pending_status = ['booting'] + while True: + instances = _filter_instances(cluster_name_on_cloud, pending_status) + if not instances: + break + logger.info(f'Waiting for {len(instances)} instances to be ready.') + time.sleep(POLL_INTERVAL) + exist_instances = _filter_instances(cluster_name_on_cloud, ['active']) + head_instance_id = _get_head_instance_id(exist_instances) + + to_start_count = config.count - len(exist_instances) + if to_start_count < 0: + raise RuntimeError( + f'Cluster {cluster_name_on_cloud} already has ' + f'{len(exist_instances)} nodes, but {config.count} are required.') + if to_start_count == 0: + if head_instance_id is None: + raise RuntimeError( + f'Cluster {cluster_name_on_cloud} has no head node.') + logger.info(f'Cluster {cluster_name_on_cloud} already has ' + f'{len(exist_instances)} nodes, no need to start more.') + return common.ProvisionRecord( + provider_name='lambda', + cluster_name=cluster_name_on_cloud, + region=region, + zone=None, + head_instance_id=head_instance_id, + resumed_instance_ids=[], + created_instance_ids=[], + ) + + created_instance_ids = [] + ssh_key_name = _get_ssh_key_name() + + def launch_nodes(node_type: str, quantity: int) -> List[str]: + try: + instance_ids = lambda_client.create_instances( + instance_type=config.node_config['InstanceType'], + region=region, + name=f'{cluster_name_on_cloud}-{node_type}', + quantity=quantity, + ssh_key_name=ssh_key_name, + ) + logger.info(f'Launched {len(instance_ids)} {node_type} node(s), ' + f'instance_ids: {instance_ids}') + return instance_ids + except Exception as e: + logger.warning(f'run_instances error: {e}') + raise + + if head_instance_id is None: + instance_ids = launch_nodes('head', 1) + assert len(instance_ids) == 1 + created_instance_ids.append(instance_ids[0]) + head_instance_id = instance_ids[0] + + assert head_instance_id is not None, 'head_instance_id should not be None' + + worker_node_count = to_start_count - 1 + if worker_node_count > 0: + instance_ids = launch_nodes('worker', worker_node_count) + created_instance_ids.extend(instance_ids) + + while True: + instances = _filter_instances(cluster_name_on_cloud, ['active']) + if len(instances) == config.count: + break + + time.sleep(POLL_INTERVAL) + + return common.ProvisionRecord( + provider_name='lambda', + cluster_name=cluster_name_on_cloud, + region=region, + zone=None, + head_instance_id=head_instance_id, + resumed_instance_ids=[], + created_instance_ids=created_instance_ids, + ) + + +def wait_instances(region: str, cluster_name_on_cloud: str, + state: Optional[status_lib.ClusterStatus]) -> None: + del region, cluster_name_on_cloud, state # Unused. + + +def stop_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + worker_only: bool = False, +) -> None: + raise NotImplementedError( + 'stop_instances is not supported for Lambda Cloud') + + +def terminate_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + worker_only: bool = False, +) -> None: + """See sky/provision/__init__.py""" + del provider_config + lambda_client = _get_lambda_client() + instances = _filter_instances(cluster_name_on_cloud, None) + + instance_ids_to_terminate = [] + for instance_id, instance in instances.items(): + if worker_only and not instance['name'].endswith('-worker'): + continue + instance_ids_to_terminate.append(instance_id) + + try: + logger.debug( + f'Terminating instances {", ".join(instance_ids_to_terminate)}') + lambda_client.remove_instances(instance_ids_to_terminate) + except Exception as e: # pylint: disable=broad-except + with ux_utils.print_exception_no_traceback(): + raise RuntimeError( + f'Failed to terminate instances {instance_ids_to_terminate}: ' + f'{common_utils.format_exception(e, use_bracket=False)}') from e + + +def get_cluster_info( + region: str, + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, +) -> common.ClusterInfo: + del region # unused + running_instances = _filter_instances(cluster_name_on_cloud, ['active']) + instances: Dict[str, List[common.InstanceInfo]] = {} + head_instance_id = None + for instance_id, instance_info in running_instances.items(): + instances[instance_id] = [ + common.InstanceInfo( + instance_id=instance_id, + internal_ip=instance_info['private_ip'], + external_ip=instance_info['ip'], + ssh_port=22, + tags={}, + ) + ] + if instance_info['name'].endswith('-head'): + head_instance_id = instance_id + + return common.ClusterInfo( + instances=instances, + head_instance_id=head_instance_id, + provider_name='lambda', + provider_config=provider_config, + ) + + +def query_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + non_terminated_only: bool = True, +) -> Dict[str, Optional[status_lib.ClusterStatus]]: + """See sky/provision/__init__.py""" + assert provider_config is not None, (cluster_name_on_cloud, provider_config) + instances = _filter_instances(cluster_name_on_cloud, None) + + status_map = { + 'booting': status_lib.ClusterStatus.INIT, + 'active': status_lib.ClusterStatus.UP, + 'unhealthy': status_lib.ClusterStatus.INIT, + 'terminating': status_lib.ClusterStatus.INIT, + } + statuses: Dict[str, Optional[status_lib.ClusterStatus]] = {} + for instance_id, instance in instances.items(): + status = status_map.get(instance['status']) + if non_terminated_only and status is None: + continue + statuses[instance_id] = status + return statuses + + +def open_ports( + cluster_name_on_cloud: str, + ports: List[str], + provider_config: Optional[Dict[str, Any]] = None, +) -> None: + raise NotImplementedError('open_ports is not supported for Lambda Cloud') + + +def cleanup_ports( + cluster_name_on_cloud: str, + ports: List[str], + provider_config: Optional[Dict[str, Any]] = None, +) -> None: + """See sky/provision/__init__.py""" + del cluster_name_on_cloud, ports, provider_config # Unused. diff --git a/sky/clouds/utils/lambda_utils.py b/sky/provision/lambda_cloud/lambda_utils.py similarity index 92% rename from sky/clouds/utils/lambda_utils.py rename to sky/provision/lambda_cloud/lambda_utils.py index 61c4b33ebe9..339919e80e7 100644 --- a/sky/clouds/utils/lambda_utils.py +++ b/sky/provision/lambda_cloud/lambda_utils.py @@ -1,4 +1,5 @@ """Lambda Cloud helper functions.""" + import json import os import time @@ -76,7 +77,7 @@ def refresh(self, instance_ids: List[str]) -> None: def raise_lambda_error(response: requests.Response) -> None: - """Raise LambdaCloudError if appropriate. """ + """Raise LambdaCloudError if appropriate.""" status_code = response.status_code if status_code == 200: return @@ -131,20 +132,22 @@ def __init__(self) -> None: self.api_key = self._credentials['api_key'] self.headers = {'Authorization': f'Bearer {self.api_key}'} - def create_instances(self, - instance_type: str = 'gpu_1x_a100_sxm4', - region: str = 'us-east-1', - quantity: int = 1, - name: str = '', - ssh_key_name: str = '') -> List[str]: + def create_instances( + self, + instance_type: str = 'gpu_1x_a100_sxm4', + region: str = 'us-east-1', + quantity: int = 1, + name: str = '', + ssh_key_name: str = '', + ) -> List[str]: """Launch new instances.""" # Optimization: # Most API requests are rate limited at ~1 request every second but # launch requests are rate limited at ~1 request every 10 seconds. # So don't use launch requests to check availability. # See https://docs.lambdalabs.com/cloud/rate-limiting/ for more. - available_regions = self.list_catalog()[instance_type]\ - ['regions_with_capacity_available'] + available_regions = (self.list_catalog()[instance_type] + ['regions_with_capacity_available']) available_regions = [reg['name'] for reg in available_regions] if region not in available_regions: if len(available_regions) > 0: @@ -163,27 +166,25 @@ def create_instances(self, 'instance_type_name': instance_type, 'ssh_key_names': [ssh_key_name], 'quantity': quantity, - 'name': name + 'name': name, }) response = _try_request_with_backoff( 'post', f'{API_ENDPOINT}/instance-operations/launch', data=data, - headers=self.headers) + headers=self.headers, + ) return response.json().get('data', []).get('instance_ids', []) - def remove_instances(self, *instance_ids: str) -> Dict[str, Any]: + def remove_instances(self, instance_ids: List[str]) -> Dict[str, Any]: """Terminate instances.""" - data = json.dumps({ - 'instance_ids': [ - instance_ids[0] # TODO(ewzeng) don't hardcode - ] - }) + data = json.dumps({'instance_ids': instance_ids}) response = _try_request_with_backoff( 'post', f'{API_ENDPOINT}/instance-operations/terminate', data=data, - headers=self.headers) + headers=self.headers, + ) return response.json().get('data', []).get('terminated_instances', []) def list_instances(self) -> List[Dict[str, Any]]: diff --git a/sky/setup_files/MANIFEST.in b/sky/setup_files/MANIFEST.in index 54ab3b55a32..0cd93f485e0 100644 --- a/sky/setup_files/MANIFEST.in +++ b/sky/setup_files/MANIFEST.in @@ -6,7 +6,6 @@ include sky/setup_files/* include sky/skylet/*.sh include sky/skylet/LICENSE include sky/skylet/providers/ibm/* -include sky/skylet/providers/lambda_cloud/* include sky/skylet/providers/oci/* include sky/skylet/providers/scp/* include sky/skylet/providers/*.py diff --git a/sky/skylet/providers/lambda_cloud/__init__.py b/sky/skylet/providers/lambda_cloud/__init__.py deleted file mode 100644 index 64dac295eb5..00000000000 --- a/sky/skylet/providers/lambda_cloud/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Lambda Cloud node provider""" -from sky.skylet.providers.lambda_cloud.node_provider import LambdaNodeProvider diff --git a/sky/skylet/providers/lambda_cloud/node_provider.py b/sky/skylet/providers/lambda_cloud/node_provider.py deleted file mode 100644 index 557afe75568..00000000000 --- a/sky/skylet/providers/lambda_cloud/node_provider.py +++ /dev/null @@ -1,320 +0,0 @@ -import logging -import os -from threading import RLock -import time -from typing import Any, Dict, List, Optional - -from ray.autoscaler.node_provider import NodeProvider -from ray.autoscaler.tags import NODE_KIND_HEAD -from ray.autoscaler.tags import NODE_KIND_WORKER -from ray.autoscaler.tags import STATUS_UP_TO_DATE -from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME -from ray.autoscaler.tags import TAG_RAY_NODE_KIND -from ray.autoscaler.tags import TAG_RAY_NODE_NAME -from ray.autoscaler.tags import TAG_RAY_NODE_STATUS -from ray.autoscaler.tags import TAG_RAY_USER_NODE_TYPE - -from sky import authentication as auth -from sky.clouds.utils import lambda_utils -from sky.utils import command_runner -from sky.utils import common_utils -from sky.utils import subprocess_utils -from sky.utils import ux_utils - -_TAG_PATH_PREFIX = '~/.sky/generated/lambda_cloud/metadata' -_REMOTE_SSH_KEY_NAME = '~/.lambda_cloud/ssh_key_name' -_REMOTE_RAY_SSH_KEY = '~/ray_bootstrap_key.pem' -_REMOTE_RAY_YAML = '~/ray_bootstrap_config.yaml' -_GET_INTERNAL_IP_CMD = 's=$(ip -4 -br addr show | grep UP); echo "$s"; echo "$s" | grep -Eo "(10\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)|172\.(1[6-9]|2[0-9]|3[0-1])|104\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"' - -logger = logging.getLogger(__name__) - - -def synchronized(f): - - def wrapper(self, *args, **kwargs): - self.lock.acquire() - try: - return f(self, *args, **kwargs) - finally: - self.lock.release() - - return wrapper - - -class LambdaNodeProvider(NodeProvider): - """Node Provider for Lambda Cloud. - - This provider assumes Lambda Cloud credentials are set. - """ - - def __init__(self, provider_config: Dict[str, Any], - cluster_name: str) -> None: - NodeProvider.__init__(self, provider_config, cluster_name) - self.lock = RLock() - self.lambda_client = lambda_utils.LambdaCloudClient() - self.cached_nodes: Dict[str, Dict[str, Any]] = {} - self.metadata = lambda_utils.Metadata(_TAG_PATH_PREFIX, cluster_name) - self.ssh_key_path = os.path.expanduser(auth.PRIVATE_SSH_KEY_PATH) - - def _get_ssh_key_name(prefix: str) -> str: - public_key_path = os.path.expanduser(auth.PUBLIC_SSH_KEY_PATH) - with open(public_key_path, 'r') as f: - public_key = f.read() - name, exists = self.lambda_client.get_unique_ssh_key_name( - prefix, public_key) - if not exists: - raise lambda_utils.LambdaCloudError('SSH key not found') - return name - - ray_yaml_path = os.path.expanduser(_REMOTE_RAY_YAML) - self.on_head = (os.path.exists(ray_yaml_path) and - common_utils.read_yaml(ray_yaml_path)['cluster_name'] - == cluster_name) - - if self.on_head: - self.ssh_key_path = os.path.expanduser(_REMOTE_RAY_SSH_KEY) - ssh_key_name_path = os.path.expanduser(_REMOTE_SSH_KEY_NAME) - if os.path.exists(ssh_key_name_path): - with open(ssh_key_name_path, 'r') as f: - self.ssh_key_name = f.read() - else: - # At this point, `~/.ssh/sky-key.pub` contains the public - # key used to launch this cluster. Use it to determine - # ssh key name and store the name in _REMOTE_SSH_KEY_NAME. - # Note: this case only runs during cluster launch, so it is - # not possible for ~/.ssh/sky-key.pub to already be regenerated - # by the user. - self.ssh_key_name = _get_ssh_key_name('') - with open(ssh_key_name_path, 'w', encoding='utf-8') as f: - f.write(self.ssh_key_name) - else: - # On local - self.ssh_key_name = _get_ssh_key_name( - f'sky-key-{common_utils.get_user_hash()}') - - def _guess_and_add_missing_tags(self, vms: List[Dict[str, Any]]) -> None: - """Adds missing vms to local tag file and guesses their tags.""" - for node in vms: - if self.metadata.get(node['id']) is not None: - pass - elif node['name'] == f'{self.cluster_name}-head': - self.metadata.set( - node['id'], { - 'tags': { - TAG_RAY_CLUSTER_NAME: self.cluster_name, - TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, - TAG_RAY_NODE_KIND: NODE_KIND_HEAD, - TAG_RAY_USER_NODE_TYPE: 'ray_head_default', - TAG_RAY_NODE_NAME: f'ray-{self.cluster_name}-head', - } - }) - elif node['name'] == f'{self.cluster_name}-worker': - self.metadata.set( - node['id'], { - 'tags': { - TAG_RAY_CLUSTER_NAME: self.cluster_name, - TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, - TAG_RAY_NODE_KIND: NODE_KIND_WORKER, - TAG_RAY_USER_NODE_TYPE: 'ray_worker_default', - TAG_RAY_NODE_NAME: f'ray-{self.cluster_name}-worker', - } - }) - - def _list_instances_in_cluster(self) -> List[Dict[str, Any]]: - """List running instances in cluster.""" - vms = self.lambda_client.list_instances() - possible_names = [ - f'{self.cluster_name}-head', f'{self.cluster_name}-worker' - ] - return [node for node in vms if node.get('name') in possible_names] - - @synchronized - def _get_filtered_nodes(self, tag_filters: Dict[str, - str]) -> Dict[str, Any]: - - def _extract_metadata(vm: Dict[str, Any]) -> Dict[str, Any]: - metadata = {'id': vm['id'], 'status': vm['status'], 'tags': {}} - instance_info = self.metadata.get(vm['id']) - if instance_info is not None: - metadata['tags'] = instance_info['tags'] - metadata['external_ip'] = vm.get('ip') - return metadata - - def _match_tags(vm: Dict[str, Any]): - vm_info = self.metadata.get(vm['id']) - tags = {} if vm_info is None else vm_info['tags'] - for k, v in tag_filters.items(): - if tags.get(k) != v: - return False - return True - - def _get_internal_ip(node: Dict[str, Any]): - # TODO(ewzeng): cache internal ips in metadata file to reduce - # ssh overhead. - if node['external_ip'] is None or node['status'] != 'active': - node['internal_ip'] = None - return - runner = command_runner.SSHCommandRunner( - node=(node['external_ip'], 22), - ssh_user='ubuntu', - ssh_private_key=self.ssh_key_path) - rc, stdout, stderr = runner.run(_GET_INTERNAL_IP_CMD, - require_outputs=True, - stream_logs=False) - subprocess_utils.handle_returncode( - rc, - _GET_INTERNAL_IP_CMD, - 'Failed get obtain private IP from node', - stderr=stdout + stderr) - node['internal_ip'] = stdout.strip() - - vms = self._list_instances_in_cluster() - self.metadata.refresh([node['id'] for node in vms]) - self._guess_and_add_missing_tags(vms) - nodes = [_extract_metadata(vm) for vm in filter(_match_tags, vms)] - nodes = [ - node for node in nodes - if node['status'] not in ['terminating', 'terminated'] - ] - subprocess_utils.run_in_parallel(_get_internal_ip, nodes) - self.cached_nodes = {node['id']: node for node in nodes} - return self.cached_nodes - - def non_terminated_nodes(self, tag_filters: Dict[str, str]) -> List[str]: - """Return a list of node ids filtered by the specified tags dict. - - This list must not include terminated nodes. For performance reasons, - providers are allowed to cache the result of a call to - non_terminated_nodes() to serve single-node queries - (e.g. is_running(node_id)). This means that non_terminated_nodes() must - be called again to refresh results. - - Examples: - >>> provider.non_terminated_nodes({TAG_RAY_NODE_KIND: "worker"}) - ["node-1", "node-2"] - """ - nodes = self._get_filtered_nodes(tag_filters=tag_filters) - return [k for k, _ in nodes.items()] - - def is_running(self, node_id: str) -> bool: - """Return whether the specified node is running.""" - return self._get_cached_node(node_id=node_id) is not None - - def is_terminated(self, node_id: str) -> bool: - """Return whether the specified node is terminated.""" - return self._get_cached_node(node_id=node_id) is None - - def node_tags(self, node_id: str) -> Dict[str, str]: - """Returns the tags of the given node (string dict).""" - node = self._get_cached_node(node_id=node_id) - if node is None: - return {} - return node['tags'] - - def external_ip(self, node_id: str) -> Optional[str]: - """Returns the external ip of the given node.""" - node = self._get_cached_node(node_id=node_id) - if node is None: - return None - ip = node.get('external_ip') - with ux_utils.print_exception_no_traceback(): - if ip is None: - raise lambda_utils.LambdaCloudError( - 'A node ip address was not found. Either ' - '(1) Lambda Cloud has internally errored, or ' - '(2) the cluster is still booting. ' - 'You can manually terminate the cluster on the ' - 'Lambda Cloud console or (in case 2) wait for ' - 'booting to finish (~2 minutes).') - return ip - - def internal_ip(self, node_id: str) -> Optional[str]: - """Returns the internal ip (Ray ip) of the given node.""" - node = self._get_cached_node(node_id=node_id) - if node is None: - return None - ip = node.get('internal_ip') - with ux_utils.print_exception_no_traceback(): - if ip is None: - raise lambda_utils.LambdaCloudError( - 'A node ip address was not found. Either ' - '(1) Lambda Cloud has internally errored, or ' - '(2) the cluster is still booting. ' - 'You can manually terminate the cluster on the ' - 'Lambda Cloud console or (in case 2) wait for ' - 'booting to finish (~2 minutes).') - return ip - - def create_node(self, node_config: Dict[str, Any], tags: Dict[str, str], - count: int) -> None: - """Creates a number of nodes within the namespace.""" - # Get tags - config_tags = node_config.get('tags', {}).copy() - config_tags.update(tags) - config_tags[TAG_RAY_CLUSTER_NAME] = self.cluster_name - - # Create nodes - instance_type = node_config['InstanceType'] - region = self.provider_config['region'] - - if config_tags[TAG_RAY_NODE_KIND] == NODE_KIND_HEAD: - name = f'{self.cluster_name}-head' - # Occasionally, the head node will continue running for a short - # period after termination. This can lead to the following bug: - # 1. Head node autodowns but continues running. - # 2. The next autodown event is triggered, which executes ray up. - # 3. Head node stops running. - # In this case, a new head node is created after the cluster has - # terminated. We avoid this with the following check: - if self.on_head: - raise lambda_utils.LambdaCloudError('Head already exists.') - else: - name = f'{self.cluster_name}-worker' - - # Lambda launch api only supports launching one node at a time, - # so we do a loop. Remove loop when launch api allows quantity > 1 - booting_list = [] - for _ in range(count): - vm_id = self.lambda_client.create_instances( - instance_type=instance_type, - region=region, - quantity=1, - name=name, - ssh_key_name=self.ssh_key_name)[0] - self.metadata.set(vm_id, {'tags': config_tags}) - booting_list.append(vm_id) - time.sleep(10) # Avoid api rate limits - - # Wait for nodes to finish booting - while True: - vms = self._list_instances_in_cluster() - for vm_id in booting_list.copy(): - for vm in vms: - if vm['id'] == vm_id and vm['status'] == 'active': - booting_list.remove(vm_id) - if len(booting_list) == 0: - return - time.sleep(10) - - @synchronized - def set_node_tags(self, node_id: str, tags: Dict[str, str]) -> None: - """Sets the tag values (string dict) for the specified node.""" - node = self._get_node(node_id) - assert node is not None, node_id - node['tags'].update(tags) - self.metadata.set(node_id, {'tags': node['tags']}) - - def terminate_node(self, node_id: str) -> None: - """Terminates the specified node.""" - self.lambda_client.remove_instances(node_id) - self.metadata.set(node_id, None) - - def _get_node(self, node_id: str) -> Optional[Dict[str, Any]]: - self._get_filtered_nodes({}) # Side effect: updates cache - return self.cached_nodes.get(node_id, None) - - def _get_cached_node(self, node_id: str) -> Optional[Dict[str, Any]]: - if node_id in self.cached_nodes: - return self.cached_nodes[node_id] - return self._get_node(node_id=node_id) diff --git a/sky/templates/lambda-ray.yml.j2 b/sky/templates/lambda-ray.yml.j2 index 6b6d94cfb3c..c4b8dba1a9f 100644 --- a/sky/templates/lambda-ray.yml.j2 +++ b/sky/templates/lambda-ray.yml.j2 @@ -7,7 +7,7 @@ idle_timeout_minutes: 60 provider: type: external - module: sky.skylet.providers.lambda_cloud.LambdaNodeProvider + module: sky.provision.lambda region: {{region}} # Disable launch config check for worker nodes as it can cause resource # leakage. @@ -25,14 +25,6 @@ available_node_types: resources: {} node_config: InstanceType: {{instance_type}} -{% if num_nodes > 1 %} - ray_worker_default: - min_workers: {{num_nodes - 1}} - max_workers: {{num_nodes - 1}} - resources: {} - node_config: - InstanceType: {{instance_type}} -{%- endif %} head_node_type: ray_head_default @@ -64,7 +56,10 @@ setup_commands: # Line 'sudo grep ..': set the number of threads per process to unlimited to avoid ray job submit stucking issue when the number of running ray jobs increase. # Line 'mkdir -p ..': disable host key check # Line 'python3 -c ..': patch the buggy ray files and enable `-o allow_other` option for `goofys` - - sudo systemctl stop unattended-upgrades || true; + - {%- for initial_setup_command in initial_setup_commands %} + {{ initial_setup_command }} + {%- endfor %} + sudo systemctl stop unattended-upgrades || true; sudo systemctl disable unattended-upgrades || true; sudo sed -i 's/Unattended-Upgrade "1"/Unattended-Upgrade "0"/g' /etc/apt/apt.conf.d/20auto-upgrades || true; sudo kill -9 `sudo lsof /var/lib/dpkg/lock-frontend | awk '{print $2}' | tail -n 1` || true; @@ -81,31 +76,5 @@ setup_commands: mkdir -p ~/.ssh; (grep -Pzo -q "Host \*\n StrictHostKeyChecking no" ~/.ssh/config) || printf "Host *\n StrictHostKeyChecking no\n" >> ~/.ssh/config; [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); -# Command to start ray on the head node. You don't need to change this. -# NOTE: these are very performance-sensitive. Each new item opens/closes an SSH -# connection, which is expensive. Try your best to co-locate commands into fewer -# items! The same comment applies for worker_start_ray_commands. -# -# Increment the following for catching performance bugs easier: -# current num items (num SSH connections): 2 -head_start_ray_commands: - - {{ sky_activate_python_env }}; {{ sky_ray_cmd }} stop; RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 {{ sky_ray_cmd }} start --disable-usage-stats --head --port={{ray_port}} --min-worker-port 11002 --dashboard-port={{ray_dashboard_port}} --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml {{"--resources='%s'" % custom_resources if custom_resources}} --temp-dir {{ray_temp_dir}} || exit 1; - which prlimit && for id in $(pgrep -f raylet/raylet); do sudo prlimit --nofile=1048576:1048576 --pid=$id || true; done; - {{dump_port_command}}; {{ray_head_wait_initialized_command}} - -{%- if num_nodes > 1 %} -worker_start_ray_commands: - - {{ sky_activate_python_env }}; {{ sky_ray_cmd }} stop; RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 {{ sky_ray_cmd }} start --disable-usage-stats --address=$RAY_HEAD_IP:{{ray_port}} --min-worker-port 11002 --object-manager-port=8076 {{"--resources='%s'" % custom_resources if custom_resources}} --temp-dir {{ray_temp_dir}} || exit 1; - which prlimit && for id in $(pgrep -f raylet/raylet); do sudo prlimit --nofile=1048576:1048576 --pid=$id || true; done; -{%- else %} -worker_start_ray_commands: [] -{%- endif %} - -head_node: {} -worker_nodes: {} - -# These fields are required for external cloud providers. -head_setup_commands: [] -worker_setup_commands: [] -cluster_synced_files: [] -file_mounts_sync_continuously: False +# Command to start ray clusters are now placed in `sky.provision.instance_setup`. +# We do not need to list it here anymore.