-
Notifications
You must be signed in to change notification settings - Fork 531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Lambda] Lambda Cloud SkyPilot provisioner #3865
Changes from all commits
22d2a65
9c0e204
eca0df1
f47bf31
eb2c76d
86d6a3f
099a67b
a2e31a5
ec3e815
0bc4509
c612df8
0025a3a
f8eb44c
21a3475
b1dd794
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
"""Lambda provisioner for SkyPilot.""" | ||
|
||
from sky.provision.lambda_cloud.config import bootstrap_instances | ||
from sky.provision.lambda_cloud.instance import cleanup_ports | ||
from sky.provision.lambda_cloud.instance import get_cluster_info | ||
from sky.provision.lambda_cloud.instance import open_ports | ||
from sky.provision.lambda_cloud.instance import query_instances | ||
from sky.provision.lambda_cloud.instance import run_instances | ||
from sky.provision.lambda_cloud.instance import stop_instances | ||
from sky.provision.lambda_cloud.instance import terminate_instances | ||
from sky.provision.lambda_cloud.instance import wait_instances |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""Lambda Cloud configuration bootstrapping""" | ||
|
||
from sky.provision import common | ||
|
||
|
||
def bootstrap_instances( | ||
region: str, cluster_name: str, | ||
config: common.ProvisionConfig) -> common.ProvisionConfig: | ||
del region, cluster_name # unused | ||
return config |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,261 @@ | ||||||
"""Lambda instance provisioning.""" | ||||||
|
||||||
import time | ||||||
from typing import Any, Dict, List, Optional | ||||||
|
||||||
from sky import authentication as auth | ||||||
from sky import sky_logging | ||||||
from sky import status_lib | ||||||
from sky.provision import common | ||||||
import sky.provision.lambda_cloud.lambda_utils as lambda_utils | ||||||
from sky.utils import common_utils | ||||||
from sky.utils import ux_utils | ||||||
|
||||||
POLL_INTERVAL = 1 | ||||||
|
||||||
logger = sky_logging.init_logger(__name__) | ||||||
_lambda_client = None | ||||||
|
||||||
|
||||||
def _get_lambda_client(): | ||||||
global _lambda_client | ||||||
if _lambda_client is None: | ||||||
_lambda_client = lambda_utils.LambdaCloudClient() | ||||||
return _lambda_client | ||||||
|
||||||
|
||||||
def _filter_instances( | ||||||
cluster_name_on_cloud: str, | ||||||
status_filters: Optional[List[str]]) -> Dict[str, Dict[str, Any]]: | ||||||
lambda_client = _get_lambda_client() | ||||||
instances = lambda_client.list_instances() | ||||||
possible_names = [ | ||||||
f'{cluster_name_on_cloud}-head', | ||||||
f'{cluster_name_on_cloud}-worker', | ||||||
] | ||||||
|
||||||
filtered_instances = {} | ||||||
for instance in instances: | ||||||
if (status_filters is not None and | ||||||
instance['status'] not in status_filters): | ||||||
continue | ||||||
if instance.get('name') in possible_names: | ||||||
filtered_instances[instance['id']] = instance | ||||||
return filtered_instances | ||||||
|
||||||
|
||||||
def _get_head_instance_id(instances: Dict[str, Any]) -> Optional[str]: | ||||||
head_instance_id = None | ||||||
for instance_id, instance in instances.items(): | ||||||
if instance['name'].endswith('-head'): | ||||||
head_instance_id = instance_id | ||||||
break | ||||||
return head_instance_id | ||||||
|
||||||
|
||||||
def _get_ssh_key_name(prefix: str = '') -> str: | ||||||
lambda_client = _get_lambda_client() | ||||||
_, public_key_path = auth.get_or_generate_keys() | ||||||
with open(public_key_path, 'r', encoding='utf-8') as f: | ||||||
public_key = f.read() | ||||||
name, exists = lambda_client.get_unique_ssh_key_name(prefix, public_key) | ||||||
if not exists: | ||||||
raise lambda_utils.LambdaCloudError('SSH key not found') | ||||||
return name | ||||||
|
||||||
|
||||||
def run_instances(region: str, cluster_name_on_cloud: str, | ||||||
config: common.ProvisionConfig) -> common.ProvisionRecord: | ||||||
"""Runs instances for the given cluster""" | ||||||
lambda_client = _get_lambda_client() | ||||||
pending_status = ['booting'] | ||||||
while True: | ||||||
instances = _filter_instances(cluster_name_on_cloud, pending_status) | ||||||
if not instances: | ||||||
break | ||||||
logger.info(f'Waiting for {len(instances)} instances to be ready.') | ||||||
time.sleep(POLL_INTERVAL) | ||||||
exist_instances = _filter_instances(cluster_name_on_cloud, ['active']) | ||||||
head_instance_id = _get_head_instance_id(exist_instances) | ||||||
|
||||||
to_start_count = config.count - len(exist_instances) | ||||||
if to_start_count < 0: | ||||||
raise RuntimeError( | ||||||
f'Cluster {cluster_name_on_cloud} already has ' | ||||||
f'{len(exist_instances)} nodes, but {config.count} are required.') | ||||||
if to_start_count == 0: | ||||||
if head_instance_id is None: | ||||||
raise RuntimeError( | ||||||
f'Cluster {cluster_name_on_cloud} has no head node.') | ||||||
logger.info(f'Cluster {cluster_name_on_cloud} already has ' | ||||||
f'{len(exist_instances)} nodes, no need to start more.') | ||||||
return common.ProvisionRecord( | ||||||
provider_name='lambda', | ||||||
cluster_name=cluster_name_on_cloud, | ||||||
region=region, | ||||||
zone=None, | ||||||
head_instance_id=head_instance_id, | ||||||
resumed_instance_ids=[], | ||||||
created_instance_ids=[], | ||||||
) | ||||||
|
||||||
created_instance_ids = [] | ||||||
ssh_key_name = _get_ssh_key_name() | ||||||
|
||||||
def launch_nodes(node_type: str, quantity: int) -> List[str]: | ||||||
try: | ||||||
instance_ids = lambda_client.create_instances( | ||||||
instance_type=config.node_config['InstanceType'], | ||||||
region=region, | ||||||
name=f'{cluster_name_on_cloud}-{node_type}', | ||||||
quantity=quantity, | ||||||
ssh_key_name=ssh_key_name, | ||||||
) | ||||||
logger.info(f'Launched {len(instance_ids)} {node_type} node(s), ' | ||||||
f'instance_ids: {instance_ids}') | ||||||
return instance_ids | ||||||
except Exception as e: | ||||||
logger.warning(f'run_instances error: {e}') | ||||||
raise | ||||||
|
||||||
if head_instance_id is None: | ||||||
instance_ids = launch_nodes('head', 1) | ||||||
assert len(instance_ids) == 1 | ||||||
created_instance_ids.append(instance_ids[0]) | ||||||
head_instance_id = instance_ids[0] | ||||||
|
||||||
assert head_instance_id is not None, 'head_instance_id should not be None' | ||||||
|
||||||
worker_node_count = to_start_count - 1 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
this should only minus one if the head instance id is none? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm we assert above that head instance id is not none? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the confusion. I mean this logically it is possible that our head instance is provisioned and this time we only wants to add some workers. in this case, the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so are you suggesting smth like if head_instance_id is None:
worker_node_count = to_start_count - 1
else:
worker_node_count = to_start_count There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See this comment: #3865 (comment)
kmushegi marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
if worker_node_count > 0: | ||||||
instance_ids = launch_nodes('worker', worker_node_count) | ||||||
created_instance_ids.extend(instance_ids) | ||||||
|
||||||
while True: | ||||||
instances = _filter_instances(cluster_name_on_cloud, ['active']) | ||||||
if len(instances) == config.count: | ||||||
break | ||||||
|
||||||
time.sleep(POLL_INTERVAL) | ||||||
|
||||||
return common.ProvisionRecord( | ||||||
provider_name='lambda', | ||||||
cluster_name=cluster_name_on_cloud, | ||||||
region=region, | ||||||
zone=None, | ||||||
head_instance_id=head_instance_id, | ||||||
resumed_instance_ids=[], | ||||||
kmushegi marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
created_instance_ids=created_instance_ids, | ||||||
) | ||||||
|
||||||
|
||||||
def wait_instances(region: str, cluster_name_on_cloud: str, | ||||||
state: Optional[status_lib.ClusterStatus]) -> None: | ||||||
del region, cluster_name_on_cloud, state # Unused. | ||||||
|
||||||
|
||||||
def stop_instances( | ||||||
cluster_name_on_cloud: str, | ||||||
provider_config: Optional[Dict[str, Any]] = None, | ||||||
worker_only: bool = False, | ||||||
) -> None: | ||||||
raise NotImplementedError( | ||||||
'stop_instances is not supported for Lambda Cloud') | ||||||
|
||||||
|
||||||
def terminate_instances( | ||||||
cluster_name_on_cloud: str, | ||||||
provider_config: Optional[Dict[str, Any]] = None, | ||||||
worker_only: bool = False, | ||||||
) -> None: | ||||||
"""See sky/provision/__init__.py""" | ||||||
del provider_config | ||||||
lambda_client = _get_lambda_client() | ||||||
instances = _filter_instances(cluster_name_on_cloud, None) | ||||||
|
||||||
instance_ids_to_terminate = [] | ||||||
for instance_id, instance in instances.items(): | ||||||
if worker_only and not instance['name'].endswith('-worker'): | ||||||
continue | ||||||
instance_ids_to_terminate.append(instance_id) | ||||||
|
||||||
try: | ||||||
logger.debug( | ||||||
f'Terminating instances {", ".join(instance_ids_to_terminate)}') | ||||||
lambda_client.remove_instances(instance_ids_to_terminate) | ||||||
except Exception as e: # pylint: disable=broad-except | ||||||
with ux_utils.print_exception_no_traceback(): | ||||||
raise RuntimeError( | ||||||
f'Failed to terminate instances {instance_ids_to_terminate}: ' | ||||||
f'{common_utils.format_exception(e, use_bracket=False)}') from e | ||||||
|
||||||
|
||||||
def get_cluster_info( | ||||||
region: str, | ||||||
cluster_name_on_cloud: str, | ||||||
provider_config: Optional[Dict[str, Any]] = None, | ||||||
) -> common.ClusterInfo: | ||||||
del region # unused | ||||||
running_instances = _filter_instances(cluster_name_on_cloud, ['active']) | ||||||
instances: Dict[str, List[common.InstanceInfo]] = {} | ||||||
head_instance_id = None | ||||||
for instance_id, instance_info in running_instances.items(): | ||||||
instances[instance_id] = [ | ||||||
common.InstanceInfo( | ||||||
instance_id=instance_id, | ||||||
internal_ip=instance_info['private_ip'], | ||||||
external_ip=instance_info['ip'], | ||||||
ssh_port=22, | ||||||
tags={}, | ||||||
kmushegi marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
) | ||||||
] | ||||||
if instance_info['name'].endswith('-head'): | ||||||
head_instance_id = instance_id | ||||||
|
||||||
return common.ClusterInfo( | ||||||
instances=instances, | ||||||
head_instance_id=head_instance_id, | ||||||
provider_name='lambda', | ||||||
provider_config=provider_config, | ||||||
) | ||||||
|
||||||
|
||||||
def query_instances( | ||||||
cluster_name_on_cloud: str, | ||||||
provider_config: Optional[Dict[str, Any]] = None, | ||||||
non_terminated_only: bool = True, | ||||||
) -> Dict[str, Optional[status_lib.ClusterStatus]]: | ||||||
"""See sky/provision/__init__.py""" | ||||||
assert provider_config is not None, (cluster_name_on_cloud, provider_config) | ||||||
instances = _filter_instances(cluster_name_on_cloud, None) | ||||||
|
||||||
status_map = { | ||||||
'booting': status_lib.ClusterStatus.INIT, | ||||||
'active': status_lib.ClusterStatus.UP, | ||||||
'unhealthy': status_lib.ClusterStatus.INIT, | ||||||
'terminating': status_lib.ClusterStatus.INIT, | ||||||
} | ||||||
statuses: Dict[str, Optional[status_lib.ClusterStatus]] = {} | ||||||
for instance_id, instance in instances.items(): | ||||||
status = status_map.get(instance['status']) | ||||||
if non_terminated_only and status is None: | ||||||
continue | ||||||
statuses[instance_id] = status | ||||||
cblmemo marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
return statuses | ||||||
|
||||||
|
||||||
def open_ports( | ||||||
cluster_name_on_cloud: str, | ||||||
ports: List[str], | ||||||
provider_config: Optional[Dict[str, Any]] = None, | ||||||
) -> None: | ||||||
raise NotImplementedError('open_ports is not supported for Lambda Cloud') | ||||||
|
||||||
|
||||||
def cleanup_ports( | ||||||
cluster_name_on_cloud: str, | ||||||
ports: List[str], | ||||||
provider_config: Optional[Dict[str, Any]] = None, | ||||||
) -> None: | ||||||
"""See sky/provision/__init__.py""" | ||||||
del cluster_name_on_cloud, ports, provider_config # Unused. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Instead of error out, can we just patch one of the nodes to make it head?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have bandwidth this week to implement this, if someone wants to take a stab at it pls do cause I like the idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Michaelvll I'm trying to give this a stab but just realized that there might be different runtime on worker node and head node (e.g. ray config) which will involve complex pattern recognition on the node runtime. Also, the only scenario i can think of that will trigger this case is user manually change the instance name on the console, which is a rare case, and seems like runpod's implementation takes the similar approach of directly raising it. Do you think we could leave this to another PR?
skypilot/sky/provision/runpod/instance.py
Lines 68 to 70 in a4e2fcd
Just filed an issue for that: #4087