Skip to content

Commit

Permalink
feat: address cblmemo reviews and other reviews + make multi-node wor…
Browse files Browse the repository at this point in the history
…k again
  • Loading branch information
kmushegi committed Sep 11, 2024
1 parent 78bf722 commit a7d4351
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 49 deletions.
2 changes: 2 additions & 0 deletions sky/provision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def _wrapper(*args, **kwargs):
provider_name = kwargs.pop('provider_name')

module_name = provider_name.lower()
if module_name == 'lambda':
module_name = 'lambda_cloud'
module = globals().get(module_name)
assert module is not None, f'Unknown provider: {module_name}'

Expand Down
94 changes: 50 additions & 44 deletions sky/provision/lambda_cloud/instance.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
"""Lambda instance provisioning."""

import os
import time
from typing import Any, Dict, List, Optional

from sky import authentication as auth
from sky import sky_logging
from sky import status_lib
from sky.provision import common
from sky.provision.lambda_cloud.lambda_utils import LambdaCloudClient
from sky.provision.lambda_cloud.lambda_utils import LambdaCloudError
import sky.provision.lambda_cloud.lambda_utils as lambda_utils
from sky.utils import common_utils
from sky.utils import ux_utils

POLL_INTERVAL = 5
POLL_INTERVAL = 1

logger = sky_logging.init_logger(__name__)
_lambda_client = None
Expand All @@ -22,7 +20,7 @@
def _get_lambda_client():
global _lambda_client
if _lambda_client is None:
_lambda_client = LambdaCloudClient()
_lambda_client = lambda_utils.LambdaCloudClient()
return _lambda_client


Expand All @@ -37,8 +35,8 @@ def _filter_instances(cluster_name_on_cloud: str,

filtered_instances = {}
for instance in instances:
if status_filters is not None and instance[
'status'] not in status_filters:
if (status_filters is not None and
instance['status'] not in status_filters):
continue
if instance.get('name') in possible_names:
filtered_instances[instance['id']] = instance
Expand All @@ -56,12 +54,12 @@ def _get_head_instance_id(instances: Dict[str, Any]) -> Optional[str]:

def _get_ssh_key_name(prefix: str = '') -> str:
lambda_client = _get_lambda_client()
public_key_path = os.path.expanduser(auth.PUBLIC_SSH_KEY_PATH)
_, public_key_path = auth.get_or_generate_keys()
with open(public_key_path, 'r', encoding='utf-8') as f:
public_key = f.read()
name, exists = lambda_client.get_unique_ssh_key_name(prefix, public_key)
if not exists:
raise LambdaCloudError('SSH key not found')
raise lambda_utils.LambdaCloudError('SSH key not found')
return name


Expand Down Expand Up @@ -102,37 +100,45 @@ def run_instances(region: str, cluster_name_on_cloud: str,

created_instance_ids = []
ssh_key_name = _get_ssh_key_name()
for _ in range(to_start_count):
node_type = 'head' if head_instance_id is None else 'worker'

def launch_nodes(node_type: str, quantity: int):
try:
instance_id = lambda_client.create_instances(
instance_ids = lambda_client.create_instances(
instance_type=config.node_config['InstanceType'],
region=region,
name=f'{cluster_name_on_cloud}-{node_type}',
quantity=1,
quantity=quantity,
ssh_key_name=ssh_key_name,
)[0]
)
logger.info(f'Launched {len(instance_ids)} {node_type} node(s), '
f'instance_ids: {instance_ids}')
return instance_ids
except Exception as e:
logger.warning(f'run_instances error: {e}')
raise
logger.info(f'Launched instance {instance_id}')
created_instance_ids.append(instance_id)
if head_instance_id is None:
head_instance_id = instance_id
time.sleep(10) # Avoid api rate limits

if head_instance_id is None:
instance_ids = launch_nodes('head', 1)
if len(instance_ids) != 1:
raise RuntimeError(
f"Expected exactly one instance, got {len(instance_ids)}")
created_instance_ids.append(instance_ids[0])
head_instance_id = instance_ids[0]

assert head_instance_id is not None, 'head_instance_id should not be None'

worker_node_count = to_start_count - 1
if worker_node_count > 0:
instance_ids = launch_nodes('worker', worker_node_count)
created_instance_ids.extend(instance_ids)

while True:
instances = _filter_instances(cluster_name_on_cloud, ['active'])
ready_instance_cnt = 0
for instance_id, instance in instances.items():
if instance.get('status') == 'active':
ready_instance_cnt += 1
if ready_instance_cnt == config.count:
if len(instances) == config.count:
break

time.sleep(POLL_INTERVAL)

assert head_instance_id is not None, 'head_instance_id should not be None'
return common.ProvisionRecord(
provider_name='lambda',
cluster_name=cluster_name_on_cloud,
Expand Down Expand Up @@ -167,21 +173,22 @@ def terminate_instances(
del provider_config
lambda_client = _get_lambda_client()
instances = _filter_instances(cluster_name_on_cloud, None)
# TODO (kmushegi): terminate all instances together since
# remove_instances accepts multiple instance ids

instance_ids_to_terminate = []
for instance_id, instance in instances.items():
logger.debug(f'Terminating instance {instance_id}: {instance}')
if worker_only and not instance['name'].endswith('-worker'):
continue
try:
lambda_client.remove_instances(instance_id)
logger.info(f'Terminated instance {instance_id}')
except Exception as e: # pylint: disable=broad-except
with ux_utils.print_exception_no_traceback():
raise RuntimeError(
f'Failed to terminate instance {instance_id}: '
f'{common_utils.format_exception(e, use_bracket=False)}'
) from e
instance_ids_to_terminate.append(instance_id)

try:
logger.debug(
f'Terminating instances {", ".join(instance_ids_to_terminate)}')
lambda_client.remove_instances(*instance_ids_to_terminate)
except Exception as e: # pylint: disable=broad-except
with ux_utils.print_exception_no_traceback():
raise RuntimeError(
f'Failed to terminate instances {instance_ids_to_terminate}: '
f'{common_utils.format_exception(e, use_bracket=False)}') from e


def get_cluster_info(
Expand All @@ -197,10 +204,8 @@ def get_cluster_info(
instances[instance_id] = [
common.InstanceInfo(
instance_id=instance_id,
# TODO (kmushegi): check if this is correct;
# external ip is preferred
internal_ip='',
external_ip=instance_info['ip'],
internal_ip=instance_info["private_ip"],
external_ip=instance_info["ip"],
ssh_port=22,
tags={},
)
Expand All @@ -213,6 +218,9 @@ def get_cluster_info(
head_instance_id=head_instance_id,
provider_name='lambda',
provider_config=provider_config,
custom_ray_options={
'use_external_ip': True,
},
)


Expand Down Expand Up @@ -246,9 +254,7 @@ def open_ports(
ports: List[str],
provider_config: Optional[Dict[str, Any]] = None,
) -> None:
# Do nothing because we assume the user has updated their firewall
# rules to allow these ports
del cluster_name_on_cloud, ports, provider_config # Unused.
raise NotImplementedError()


def cleanup_ports(
Expand Down
8 changes: 3 additions & 5 deletions sky/provision/lambda_cloud/lambda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def create_instances(
# launch requests are rate limited at ~1 request every 10 seconds.
# So don't use launch requests to check availability.
# See https://docs.lambdalabs.com/cloud/rate-limiting/ for more.
available_regions = self.list_catalog(
)[instance_type]['regions_with_capacity_available']
available_regions = (self.list_catalog()[instance_type]
['regions_with_capacity_available'])
available_regions = [reg['name'] for reg in available_regions]
if region not in available_regions:
if len(available_regions) > 0:
Expand Down Expand Up @@ -178,9 +178,7 @@ def create_instances(

def remove_instances(self, *instance_ids: str) -> Dict[str, Any]:
"""Terminate instances."""
data = json.dumps({'instance_ids': [instance_ids[0]]
} # TODO(ewzeng) don't hardcode
)
data = json.dumps({'instance_ids': list(instance_ids)})
response = _try_request_with_backoff(
'post',
f'{API_ENDPOINT}/instance-operations/terminate',
Expand Down

0 comments on commit a7d4351

Please sign in to comment.