Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lambda] Lambda Cloud SkyPilot provisioner #3865

Merged
2 changes: 1 addition & 1 deletion sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
from sky.adaptors import ibm
from sky.adaptors import kubernetes
from sky.adaptors import runpod
from sky.clouds.utils import lambda_utils
from sky.provision.fluidstack import fluidstack_utils
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.provision.lambda_cloud import lambda_utils
from sky.utils import common_utils
from sky.utils import kubernetes_enums
from sky.utils import subprocess_utils
Expand Down
5 changes: 4 additions & 1 deletion sky/clouds/lambda_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sky import clouds
from sky import status_lib
from sky.clouds import service_catalog
from sky.clouds.utils import lambda_utils
from sky.provision.lambda_cloud import lambda_utils
from sky.utils import resources_utils

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -48,6 +48,9 @@ class Lambda(clouds.Cloud):
clouds.CloudImplementationFeatures.HOST_CONTROLLERS: f'Host controllers are not supported in {_REPR}.',
}

PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT
STATUS_VERSION = clouds.StatusVersion.SKYPILOT

@classmethod
def _unsupported_features_for_resources(
cls, resources: 'resources_lib.Resources'
Expand Down
3 changes: 3 additions & 0 deletions sky/provision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sky.provision import fluidstack
from sky.provision import gcp
from sky.provision import kubernetes
from sky.provision import lambda_cloud
from sky.provision import runpod
from sky.provision import vsphere
from sky.utils import command_runner
Expand All @@ -39,6 +40,8 @@ def _wrapper(*args, **kwargs):
provider_name = kwargs.pop('provider_name')

module_name = provider_name.lower()
if module_name == 'lambda':
module_name = 'lambda_cloud'
module = globals().get(module_name)
assert module is not None, f'Unknown provider: {module_name}'

Expand Down
11 changes: 11 additions & 0 deletions sky/provision/lambda_cloud/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Lambda provisioner for SkyPilot."""

from sky.provision.lambda_cloud.config import bootstrap_instances
from sky.provision.lambda_cloud.instance import cleanup_ports
from sky.provision.lambda_cloud.instance import get_cluster_info
from sky.provision.lambda_cloud.instance import open_ports
from sky.provision.lambda_cloud.instance import query_instances
from sky.provision.lambda_cloud.instance import run_instances
from sky.provision.lambda_cloud.instance import stop_instances
from sky.provision.lambda_cloud.instance import terminate_instances
from sky.provision.lambda_cloud.instance import wait_instances
10 changes: 10 additions & 0 deletions sky/provision/lambda_cloud/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Lambda Cloud configuration bootstrapping"""

from sky.provision import common


def bootstrap_instances(
region: str, cluster_name: str,
config: common.ProvisionConfig) -> common.ProvisionConfig:
del region, cluster_name # unused
return config
261 changes: 261 additions & 0 deletions sky/provision/lambda_cloud/instance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
"""Lambda instance provisioning."""

import time
from typing import Any, Dict, List, Optional

from sky import authentication as auth
from sky import sky_logging
from sky import status_lib
from sky.provision import common
import sky.provision.lambda_cloud.lambda_utils as lambda_utils
from sky.utils import common_utils
from sky.utils import ux_utils

POLL_INTERVAL = 1

logger = sky_logging.init_logger(__name__)
_lambda_client = None


def _get_lambda_client():
global _lambda_client
if _lambda_client is None:
_lambda_client = lambda_utils.LambdaCloudClient()
return _lambda_client


def _filter_instances(
cluster_name_on_cloud: str,
status_filters: Optional[List[str]]) -> Dict[str, Dict[str, Any]]:
lambda_client = _get_lambda_client()
instances = lambda_client.list_instances()
possible_names = [
f'{cluster_name_on_cloud}-head',
f'{cluster_name_on_cloud}-worker',
]

filtered_instances = {}
for instance in instances:
if (status_filters is not None and
instance['status'] not in status_filters):
continue
if instance.get('name') in possible_names:
filtered_instances[instance['id']] = instance
return filtered_instances


def _get_head_instance_id(instances: Dict[str, Any]) -> Optional[str]:
head_instance_id = None
for instance_id, instance in instances.items():
if instance['name'].endswith('-head'):
head_instance_id = instance_id
break
return head_instance_id


def _get_ssh_key_name(prefix: str = '') -> str:
lambda_client = _get_lambda_client()
_, public_key_path = auth.get_or_generate_keys()
with open(public_key_path, 'r', encoding='utf-8') as f:
public_key = f.read()
name, exists = lambda_client.get_unique_ssh_key_name(prefix, public_key)
if not exists:
raise lambda_utils.LambdaCloudError('SSH key not found')
return name


def run_instances(region: str, cluster_name_on_cloud: str,
config: common.ProvisionConfig) -> common.ProvisionRecord:
"""Runs instances for the given cluster"""
lambda_client = _get_lambda_client()
pending_status = ['booting']
while True:
instances = _filter_instances(cluster_name_on_cloud, pending_status)
if not instances:
break
logger.info(f'Waiting for {len(instances)} instances to be ready.')
time.sleep(POLL_INTERVAL)
exist_instances = _filter_instances(cluster_name_on_cloud, ['active'])
head_instance_id = _get_head_instance_id(exist_instances)

to_start_count = config.count - len(exist_instances)
if to_start_count < 0:
raise RuntimeError(
f'Cluster {cluster_name_on_cloud} already has '
f'{len(exist_instances)} nodes, but {config.count} are required.')
if to_start_count == 0:
if head_instance_id is None:
raise RuntimeError(
f'Cluster {cluster_name_on_cloud} has no head node.')
Comment on lines +87 to +89
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Instead of error out, can we just patch one of the nodes to make it head?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have bandwidth this week to implement this, if someone wants to take a stab at it pls do cause I like the idea

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Michaelvll I'm trying to give this a stab but just realized that there might be different runtime on worker node and head node (e.g. ray config) which will involve complex pattern recognition on the node runtime. Also, the only scenario i can think of that will trigger this case is user manually change the instance name on the console, which is a rare case, and seems like runpod's implementation takes the similar approach of directly raising it. Do you think we could leave this to another PR?

if head_instance_id is None:
raise RuntimeError(
f'Cluster {cluster_name_on_cloud} has no head node.')

Just filed an issue for that: #4087

logger.info(f'Cluster {cluster_name_on_cloud} already has '
f'{len(exist_instances)} nodes, no need to start more.')
return common.ProvisionRecord(
provider_name='lambda',
cluster_name=cluster_name_on_cloud,
region=region,
zone=None,
head_instance_id=head_instance_id,
resumed_instance_ids=[],
created_instance_ids=[],
)

created_instance_ids = []
ssh_key_name = _get_ssh_key_name()

def launch_nodes(node_type: str, quantity: int) -> List[str]:
try:
instance_ids = lambda_client.create_instances(
instance_type=config.node_config['InstanceType'],
region=region,
name=f'{cluster_name_on_cloud}-{node_type}',
quantity=quantity,
ssh_key_name=ssh_key_name,
)
logger.info(f'Launched {len(instance_ids)} {node_type} node(s), '
f'instance_ids: {instance_ids}')
return instance_ids
except Exception as e:
logger.warning(f'run_instances error: {e}')
raise

if head_instance_id is None:
instance_ids = launch_nodes('head', 1)
assert len(instance_ids) == 1
created_instance_ids.append(instance_ids[0])
head_instance_id = instance_ids[0]

assert head_instance_id is not None, 'head_instance_id should not be None'

worker_node_count = to_start_count - 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
worker_node_count = to_start_count - 1
worker_node_count = to_start_count - 1

this should only minus one if the head instance id is none?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm we assert above that head instance id is not none?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. I mean this

image

logically it is possible that our head instance is provisioned and this time we only wants to add some workers. in this case, the to_start_count should be equal to worker_node_count. Though practically this wont appear in our system but we should do it for future expansion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so are you suggesting smth like

if head_instance_id is None:
    worker_node_count = to_start_count - 1
else:
    worker_node_count = to_start_count

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See this comment: #3865 (comment)

kmushegi marked this conversation as resolved.
Show resolved Hide resolved
if worker_node_count > 0:
instance_ids = launch_nodes('worker', worker_node_count)
created_instance_ids.extend(instance_ids)

while True:
instances = _filter_instances(cluster_name_on_cloud, ['active'])
if len(instances) == config.count:
break

time.sleep(POLL_INTERVAL)

return common.ProvisionRecord(
provider_name='lambda',
cluster_name=cluster_name_on_cloud,
region=region,
zone=None,
head_instance_id=head_instance_id,
resumed_instance_ids=[],
kmushegi marked this conversation as resolved.
Show resolved Hide resolved
created_instance_ids=created_instance_ids,
)


def wait_instances(region: str, cluster_name_on_cloud: str,
state: Optional[status_lib.ClusterStatus]) -> None:
del region, cluster_name_on_cloud, state # Unused.


def stop_instances(
cluster_name_on_cloud: str,
provider_config: Optional[Dict[str, Any]] = None,
worker_only: bool = False,
) -> None:
raise NotImplementedError(
'stop_instances is not supported for Lambda Cloud')


def terminate_instances(
cluster_name_on_cloud: str,
provider_config: Optional[Dict[str, Any]] = None,
worker_only: bool = False,
) -> None:
"""See sky/provision/__init__.py"""
del provider_config
lambda_client = _get_lambda_client()
instances = _filter_instances(cluster_name_on_cloud, None)

instance_ids_to_terminate = []
for instance_id, instance in instances.items():
if worker_only and not instance['name'].endswith('-worker'):
continue
instance_ids_to_terminate.append(instance_id)

try:
logger.debug(
f'Terminating instances {", ".join(instance_ids_to_terminate)}')
lambda_client.remove_instances(instance_ids_to_terminate)
except Exception as e: # pylint: disable=broad-except
with ux_utils.print_exception_no_traceback():
raise RuntimeError(
f'Failed to terminate instances {instance_ids_to_terminate}: '
f'{common_utils.format_exception(e, use_bracket=False)}') from e


def get_cluster_info(
region: str,
cluster_name_on_cloud: str,
provider_config: Optional[Dict[str, Any]] = None,
) -> common.ClusterInfo:
del region # unused
running_instances = _filter_instances(cluster_name_on_cloud, ['active'])
instances: Dict[str, List[common.InstanceInfo]] = {}
head_instance_id = None
for instance_id, instance_info in running_instances.items():
instances[instance_id] = [
common.InstanceInfo(
instance_id=instance_id,
internal_ip=instance_info['private_ip'],
external_ip=instance_info['ip'],
ssh_port=22,
tags={},
kmushegi marked this conversation as resolved.
Show resolved Hide resolved
)
]
if instance_info['name'].endswith('-head'):
head_instance_id = instance_id

return common.ClusterInfo(
instances=instances,
head_instance_id=head_instance_id,
provider_name='lambda',
provider_config=provider_config,
)


def query_instances(
cluster_name_on_cloud: str,
provider_config: Optional[Dict[str, Any]] = None,
non_terminated_only: bool = True,
) -> Dict[str, Optional[status_lib.ClusterStatus]]:
"""See sky/provision/__init__.py"""
assert provider_config is not None, (cluster_name_on_cloud, provider_config)
instances = _filter_instances(cluster_name_on_cloud, None)

status_map = {
'booting': status_lib.ClusterStatus.INIT,
'active': status_lib.ClusterStatus.UP,
'unhealthy': status_lib.ClusterStatus.INIT,
'terminating': status_lib.ClusterStatus.INIT,
}
statuses: Dict[str, Optional[status_lib.ClusterStatus]] = {}
for instance_id, instance in instances.items():
status = status_map.get(instance['status'])
if non_terminated_only and status is None:
continue
statuses[instance_id] = status
cblmemo marked this conversation as resolved.
Show resolved Hide resolved
return statuses


def open_ports(
cluster_name_on_cloud: str,
ports: List[str],
provider_config: Optional[Dict[str, Any]] = None,
) -> None:
raise NotImplementedError('open_ports is not supported for Lambda Cloud')


def cleanup_ports(
cluster_name_on_cloud: str,
ports: List[str],
provider_config: Optional[Dict[str, Any]] = None,
) -> None:
"""See sky/provision/__init__.py"""
del cluster_name_on_cloud, ports, provider_config # Unused.
Loading
Loading