diff --git a/sky/provision/lambda_cloud/config.py b/sky/provision/lambda_cloud/config.py index 3d4cf64e2981..3066e7747fd9 100644 --- a/sky/provision/lambda_cloud/config.py +++ b/sky/provision/lambda_cloud/config.py @@ -4,7 +4,7 @@ def bootstrap_instances( - region: str, cluster_name: str, config: common.ProvisionConfig -) -> common.ProvisionConfig: + region: str, cluster_name: str, + config: common.ProvisionConfig) -> common.ProvisionConfig: del region, cluster_name # unused return config diff --git a/sky/provision/lambda_cloud/instance.py b/sky/provision/lambda_cloud/instance.py index 0c35b676a2bd..0617c4f583e5 100644 --- a/sky/provision/lambda_cloud/instance.py +++ b/sky/provision/lambda_cloud/instance.py @@ -26,9 +26,8 @@ def _get_lambda_client(): return _lambda_client -def _filter_instances( - cluster_name_on_cloud: str, status_filters: Optional[List[str]] -) -> Dict[str, Any]: +def _filter_instances(cluster_name_on_cloud: str, + status_filters: Optional[List[str]]) -> Dict[str, Any]: lambda_client = _get_lambda_client() instances = lambda_client.list_instances() possible_names = [ @@ -38,7 +37,8 @@ def _filter_instances( filtered_instances = {} for instance in instances: - if status_filters is not None and instance['status'] not in status_filters: + if status_filters is not None and instance[ + 'status'] not in status_filters: continue if instance.get('name') in possible_names: filtered_instances[instance['id']] = instance @@ -65,9 +65,8 @@ def _get_ssh_key_name(prefix: str = '') -> str: return name -def run_instances( - region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig -) -> common.ProvisionRecord: +def run_instances(region: str, cluster_name_on_cloud: str, + config: common.ProvisionConfig) -> common.ProvisionRecord: """Runs instances for the given cluster""" lambda_client = _get_lambda_client() pending_status = ['booting'] @@ -87,7 +86,8 @@ def run_instances( ) if to_start_count == 0: if head_instance_id is None: - raise RuntimeError(f'Cluster {cluster_name_on_cloud} has no head node.') + raise RuntimeError( + f'Cluster {cluster_name_on_cloud} has no head node.') logger.info( f'Cluster {cluster_name_on_cloud} already has {len(exist_instances)} nodes, no need to start more.' ) @@ -145,9 +145,8 @@ def run_instances( ) -def wait_instances( - region: str, cluster_name_on_cloud: str, state: Optional[status_lib.ClusterStatus] -) -> None: +def wait_instances(region: str, cluster_name_on_cloud: str, + state: Optional[status_lib.ClusterStatus]) -> None: del region, cluster_name_on_cloud, state # Unused. @@ -156,7 +155,8 @@ def stop_instances( provider_config: Optional[Dict[str, Any]] = None, worker_only: bool = False, ) -> None: - raise NotImplementedError('stop_instances is not supported for Lambda Cloud') + raise NotImplementedError( + 'stop_instances is not supported for Lambda Cloud') def terminate_instances( @@ -196,7 +196,8 @@ def get_cluster_info( instances[instance_id] = [ common.InstanceInfo( instance_id=instance_id, - internal_ip="", # TODO (kmushegi): check if this is correct; external ip is preferred + internal_ip= + "", # TODO (kmushegi): check if this is correct; external ip is preferred external_ip=instance_info["ip"], ssh_port=22, tags={}, diff --git a/sky/provision/lambda_cloud/lambda_utils.py b/sky/provision/lambda_cloud/lambda_utils.py index 3b78b45bf6b1..9d12523d7400 100644 --- a/sky/provision/lambda_cloud/lambda_utils.py +++ b/sky/provision/lambda_cloud/lambda_utils.py @@ -92,17 +92,16 @@ def raise_lambda_error(response: requests.Response) -> None: raise LambdaCloudError( 'Response cannot be parsed into JSON. Status ' f'code: {status_code}; reason: {response.reason}; ' - f'content: {response.text}' - ) from e + f'content: {response.text}') from e raise LambdaCloudError(f'{code}: {message}') -def _try_request_with_backoff( - method: str, url: str, headers: Dict[str, str], data: Optional[str] = None -): - backoff = common_utils.Backoff( - initial_backoff=INITIAL_BACKOFF_SECONDS, max_backoff_factor=MAX_BACKOFF_FACTOR - ) +def _try_request_with_backoff(method: str, + url: str, + headers: Dict[str, str], + data: Optional[str] = None): + backoff = common_utils.Backoff(initial_backoff=INITIAL_BACKOFF_SECONDS, + max_backoff_factor=MAX_BACKOFF_FACTOR) for i in range(MAX_ATTEMPTS): if method == 'get': response = requests.get(url, headers=headers) @@ -147,35 +146,28 @@ def create_instances( # launch requests are rate limited at ~1 request every 10 seconds. # So don't use launch requests to check availability. # See https://docs.lambdalabs.com/cloud/rate-limiting/ for more. - available_regions = self.list_catalog()[instance_type][ - 'regions_with_capacity_available' - ] + available_regions = self.list_catalog( + )[instance_type]['regions_with_capacity_available'] available_regions = [reg['name'] for reg in available_regions] if region not in available_regions: if len(available_regions) > 0: aval_reg = ' '.join(available_regions) else: aval_reg = 'None' - raise LambdaCloudError( - ( - 'instance-operations/launch/' - 'insufficient-capacity: Not enough ' - 'capacity to fulfill launch request. ' - 'Regions with capacity available: ' - f'{aval_reg}' - ) - ) + raise LambdaCloudError(('instance-operations/launch/' + 'insufficient-capacity: Not enough ' + 'capacity to fulfill launch request. ' + 'Regions with capacity available: ' + f'{aval_reg}')) # Try to launch instance - data = json.dumps( - { - 'region_name': region, - 'instance_type_name': instance_type, - 'ssh_key_names': [ssh_key_name], - 'quantity': quantity, - 'name': name, - } - ) + data = json.dumps({ + 'region_name': region, + 'instance_type_name': instance_type, + 'ssh_key_names': [ssh_key_name], + 'quantity': quantity, + 'name': name, + }) response = _try_request_with_backoff( 'post', f'{API_ENDPOINT}/instance-operations/launch', @@ -186,9 +178,9 @@ def create_instances( def remove_instances(self, *instance_ids: str) -> Dict[str, Any]: """Terminate instances.""" - data = json.dumps( - {'instance_ids': [instance_ids[0]]} # TODO(ewzeng) don't hardcode - ) + data = json.dumps({'instance_ids': [instance_ids[0]] + } # TODO(ewzeng) don't hardcode + ) response = _try_request_with_backoff( 'post', f'{API_ENDPOINT}/instance-operations/terminate', @@ -199,19 +191,20 @@ def remove_instances(self, *instance_ids: str) -> Dict[str, Any]: def list_instances(self) -> List[Dict[str, Any]]: """List existing instances.""" - response = _try_request_with_backoff( - 'get', f'{API_ENDPOINT}/instances', headers=self.headers - ) + response = _try_request_with_backoff('get', + f'{API_ENDPOINT}/instances', + headers=self.headers) return response.json().get('data', []) def list_ssh_keys(self) -> List[Dict[str, str]]: """List ssh keys.""" - response = _try_request_with_backoff( - 'get', f'{API_ENDPOINT}/ssh-keys', headers=self.headers - ) + response = _try_request_with_backoff('get', + f'{API_ENDPOINT}/ssh-keys', + headers=self.headers) return response.json().get('data', []) - def get_unique_ssh_key_name(self, prefix: str, pub_key: str) -> Tuple[str, bool]: + def get_unique_ssh_key_name(self, prefix: str, + pub_key: str) -> Tuple[str, bool]: """Returns a ssh key name with the given prefix. If no names have given prefix, return prefix. If pub_key exists and @@ -221,7 +214,8 @@ def get_unique_ssh_key_name(self, prefix: str, pub_key: str) -> Tuple[str, bool] The second return value is True iff the returned name already exists. """ candidate_keys = [ - k for k in self.list_ssh_keys() if k.get('name', '').startswith(prefix) + k for k in self.list_ssh_keys() + if k.get('name', '').startswith(prefix) ] # Prefix not found @@ -234,24 +228,22 @@ def get_unique_ssh_key_name(self, prefix: str, pub_key: str) -> Tuple[str, bool] if key_info.get('public_key', '').strip() == pub_key.strip(): # Pub key already exists. Use strip to avoid whitespace diffs. return name, True - if ( - len(name) > len(prefix) + 1 - and name[len(prefix)] == '-' - and name[len(prefix) + 1 :].isdigit() - ): - suffix_digits.append(int(name[len(prefix) + 1 :])) + if (len(name) > len(prefix) + 1 and name[len(prefix)] == '-' and + name[len(prefix) + 1:].isdigit()): + suffix_digits.append(int(name[len(prefix) + 1:])) return f'{prefix}-{max(suffix_digits) + 1}', False def register_ssh_key(self, name: str, pub_key: str) -> None: """Register ssh key with Lambda.""" data = json.dumps({'name': name, 'public_key': pub_key}) - _try_request_with_backoff( - 'post', f'{API_ENDPOINT}/ssh-keys', data=data, headers=self.headers - ) + _try_request_with_backoff('post', + f'{API_ENDPOINT}/ssh-keys', + data=data, + headers=self.headers) def list_catalog(self) -> Dict[str, Any]: """List offered instances and their availability.""" - response = _try_request_with_backoff( - 'get', f'{API_ENDPOINT}/instance-types', headers=self.headers - ) + response = _try_request_with_backoff('get', + f'{API_ENDPOINT}/instance-types', + headers=self.headers) return response.json().get('data', [])