Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UX] warning before launching jobs/serve when using a reauth required credentials #4479

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3737,6 +3737,17 @@ def jobs_launch(

common_utils.check_cluster_name_is_valid(name)

reauth_needed_clouds = [
resource.cloud for task in dag.tasks for resource in task.resources if
resource.cloud is not None and resource.cloud.can_credential_expire()
]
if len(reauth_needed_clouds) > 0:
prompt = (
f'Launching jobs with cloud(s) {reauth_needed_clouds} may lead to '
'jobs being out of control. It is recommended to use credentials '
'that never expire or a service account. Proceed?')
romilbhardwaj marked this conversation as resolved.
Show resolved Hide resolved
click.confirm(prompt, default=False, abort=True, show_default=True)
romilbhardwaj marked this conversation as resolved.
Show resolved Hide resolved

managed_jobs.launch(dag,
name,
detach_run=detach_run,
Expand Down Expand Up @@ -4223,6 +4234,17 @@ def serve_up(
if prompt is not None:
click.confirm(prompt, default=True, abort=True, show_default=True)

reauth_needed_clouds = [
resource.cloud for task in dag.tasks for resource in task.resources if
resource.cloud is not None and resource.cloud.can_credential_expire()
]
if len(reauth_needed_clouds) > 0:
prompt = (
f'Launching jobs with cloud(s) {reauth_needed_clouds} may lead to '
'jobs being out of control. It is recommended to use credentials '
'that never expire or a service account. Proceed?')
click.confirm(prompt, default=False, abort=True, show_default=True)

serve_lib.up(task, service_name)


Expand Down
13 changes: 13 additions & 0 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ class AWSIdentityType(enum.Enum):
# region us-east-1 config-file ~/.aws/config
SHARED_CREDENTIALS_FILE = 'shared-credentials-file'

def can_credential_expire(self) -> bool:
"""Check if the AWS identity type can expire."""
expirable_types = {
AWSIdentityType.SSO, AWSIdentityType.ENV, AWSIdentityType.IAM_ROLE,
weih1121 marked this conversation as resolved.
Show resolved Hide resolved
AWSIdentityType.CONTAINER_ROLE
}
return self in expirable_types


@clouds.CLOUD_REGISTRY.register
class AWS(clouds.Cloud):
Expand Down Expand Up @@ -812,6 +820,11 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
if os.path.exists(os.path.expanduser(f'~/.aws/{filename}'))
}

def can_credential_expire(self) -> bool:
identity_type = self._current_identity_type()
return identity_type is not None and identity_type.can_credential_expire(
)

def instance_type_exists(self, instance_type):
return service_catalog.instance_type_exists(instance_type, clouds='aws')

Expand Down
4 changes: 4 additions & 0 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,10 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
"""
raise NotImplementedError

def can_credential_expire(self) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
def can_credential_expire(self) -> bool:
def can_credentials_expire(self) -> bool:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checks the active credential(only one), the original make sense I think.

"""Returns whether the cloud credential can expire."""
return False

@classmethod
def get_image_size(cls, image_id: str, region: Optional[str]) -> float:
"""Check the image size from the cloud.
Expand Down
8 changes: 8 additions & 0 deletions sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ class GCPIdentityType(enum.Enum):

SHARED_CREDENTIALS_FILE = ''

def can_credential_expire(self) -> bool:
return self == GCPIdentityType.SHARED_CREDENTIALS_FILE


@clouds.CLOUD_REGISTRY.register
class GCP(clouds.Cloud):
Expand Down Expand Up @@ -863,6 +866,11 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
pass
return credentials

def can_credential_expire(self) -> bool:
identity_type = self._get_identity_type()
return identity_type is not None and identity_type.can_credential_expire(
)

@classmethod
def _get_identity_type(cls) -> Optional[GCPIdentityType]:
try:
Expand Down
Loading