Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UX] warning before launching jobs/serve when using a reauth required credentials #4479

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,42 @@ def _restore_block(new_block: Dict[str, Any], old_block: Dict[str, Any]):
return common_utils.dump_yaml_str(new_config)


def get_expirable_clouds(
enabled_clouds: Sequence[clouds.Cloud]) -> List[clouds.Cloud]:
"""Returns a list of clouds that use local credentials and whose credentials can expire.

This function checks each cloud in the provided sequence to determine if it uses local credentials
and if its credentials can expire. If both conditions are met, the cloud is added to the list of
expirable clouds.

Args:
enabled_clouds (Sequence[clouds.Cloud]): A sequence of cloud objects to check.

Returns:
list[clouds.Cloud]: A list of cloud objects that use local credentials and whose credentials can expire.
"""
expirable_clouds = []
local_credentials_value = schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value
for cloud in enabled_clouds:
remote_identities = skypilot_config.get_nested(
(str(cloud).lower(), 'remote_identity'), None)
if remote_identities is None:
remote_identities = schemas.get_default_remote_identity(
str(cloud).lower())

local_credential_expiring = cloud.can_credential_expire()
if isinstance(remote_identities, str):
if remote_identities == local_credentials_value and local_credential_expiring:
expirable_clouds.append(cloud)
elif isinstance(remote_identities, list):
for profile in remote_identities:
if list(profile.values(
))[0] == local_credentials_value and local_credential_expiring:
expirable_clouds.append(cloud)
break
return expirable_clouds


# TODO: too many things happening here - leaky abstraction. Refactor.
@timeline.event
def write_cluster_config(
Expand Down Expand Up @@ -771,9 +807,7 @@ def write_cluster_config(
if (remote_identity_config ==
schemas.RemoteIdentityOptions.NO_UPLOAD.value):
excluded_clouds.add(cloud_obj)

credentials = sky_check.get_cloud_credential_file_mounts(excluded_clouds)

auth_config = {'ssh_private_key': auth.PRIVATE_SSH_KEY_PATH}
region_name = resources_vars.get('region')

Expand Down
18 changes: 18 additions & 0 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
from typing import (Any, Callable, Dict, Iterable, List, Optional, Set, Tuple,
Union)

import click
import colorama
import filelock

import sky
from sky import backends
from sky import check as sky_check
from sky import cloud_stores
from sky import clouds
from sky import exceptions
Expand Down Expand Up @@ -1996,6 +1998,22 @@ def provision_with_retries(
skip_unnecessary_provisioning else None)

failover_history: List[Exception] = list()
# When jobs controller/server using the local credentials which are
# expiring it may cause the cluster to be leaked. So, checking the
# enabled clouds and expiring credentials and warning the user to use
# the credentials that never expire or a service account.
if task.is_controller_task():
enabled_clouds = sky_check.get_cached_enabled_clouds_or_refresh()
expirable_clouds = backend_utils.get_expirable_clouds(
enabled_clouds)

if len(expirable_clouds) > 0:
warnings = (f'\nWarning: Expiring credentials detected for '
f'{expirable_clouds}. Clusters may be leaked if '
f'the credentials expire while jobs are running. '
f'It is recommended to use credentials that never'
f' expire or a service account.')
click.secho(warnings, fg='yellow')

# Retrying launchable resources.
while True:
Expand Down
23 changes: 23 additions & 0 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,23 @@ class AWSIdentityType(enum.Enum):
# region us-east-1 config-file ~/.aws/config
SHARED_CREDENTIALS_FILE = 'shared-credentials-file'

def can_credential_expire(self) -> bool:
"""Check if the AWS identity type can expire.

SSO,IAM_ROLE and CONTAINER_ROLE are temporary credentials and refreshed
automatically. ENV and SHARED_CREDENTIALS_FILE are short-lived
credentials without refresh.
IAM ROLE:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
SSO/Container-role refresh token:
https://docs.aws.amazon.com/solutions/latest/dea-api/auth-refreshtoken.html
"""
# TODO(hong): Add a check for the expiration of the temporary
expirable_types = {
AWSIdentityType.ENV, AWSIdentityType.SHARED_CREDENTIALS_FILE
}
return self in expirable_types


@clouds.CLOUD_REGISTRY.register
class AWS(clouds.Cloud):
Expand Down Expand Up @@ -812,6 +829,12 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
if os.path.exists(os.path.expanduser(f'~/.aws/{filename}'))
}

@functools.lru_cache(maxsize=1)
def can_credential_expire(self) -> bool:
identity_type = self._current_identity_type()
return identity_type is not None and identity_type.can_credential_expire(
)

def instance_type_exists(self, instance_type):
return service_catalog.instance_type_exists(instance_type, clouds='aws')

Expand Down
4 changes: 4 additions & 0 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,10 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
"""
raise NotImplementedError

def can_credential_expire(self) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
def can_credential_expire(self) -> bool:
def can_credentials_expire(self) -> bool:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checks the active credential(only one), the original make sense I think.

"""Returns whether the cloud credential can expire."""
return False

@classmethod
def get_image_size(cls, image_id: str, region: Optional[str]) -> float:
"""Check the image size from the cloud.
Expand Down
9 changes: 9 additions & 0 deletions sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ class GCPIdentityType(enum.Enum):

SHARED_CREDENTIALS_FILE = ''

def can_credential_expire(self) -> bool:
return self == GCPIdentityType.SHARED_CREDENTIALS_FILE


@clouds.CLOUD_REGISTRY.register
class GCP(clouds.Cloud):
Expand Down Expand Up @@ -863,6 +866,12 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
pass
return credentials

@functools.lru_cache(maxsize=1)
def can_credential_expire(self) -> bool:
identity_type = self._get_identity_type()
return identity_type is not None and identity_type.can_credential_expire(
)

@classmethod
def _get_identity_type(cls) -> Optional[GCPIdentityType]:
try:
Expand Down
Loading