Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
hong authored and hong committed Dec 18, 2024
1 parent f0ebf13 commit 5c5ebc8
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 0 deletions.
24 changes: 24 additions & 0 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3737,6 +3737,18 @@ def jobs_launch(

common_utils.check_cluster_name_is_valid(name)

reauth_needed_clouds = [
resource.cloud for task in dag.tasks for resource in task.resources
if resource.cloud.can_credential_expire()
]
if reauth_needed_clouds:
prompt = (
f'Launching jobs with cloud(s) {reauth_needed_clouds} may lead to jobs '
'being out of control. It is recommended to use credentials that never '
'expire or a service account. Proceed?'
)
click.confirm(prompt, default=False, abort=True, show_default=True)

managed_jobs.launch(dag,
name,
detach_run=detach_run,
Expand Down Expand Up @@ -4223,6 +4235,18 @@ def serve_up(
if prompt is not None:
click.confirm(prompt, default=True, abort=True, show_default=True)

reauth_needed_clouds = [
resource.cloud for task in dag.tasks for resource in task.resources
if resource.cloud.can_credential_expire()
]
if reauth_needed_clouds:
prompt = (
f'Launching jobs with cloud(s) {reauth_needed_clouds} may lead to jobs '
'being out of control. It is recommended to use credentials that never '
'expire or a service account. Proceed?'
)
click.confirm(prompt, default=False, abort=True, show_default=True)

serve_lib.up(task, service_name)


Expand Down
13 changes: 13 additions & 0 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ class AWSIdentityType(enum.Enum):
# region us-east-1 config-file ~/.aws/config
SHARED_CREDENTIALS_FILE = 'shared-credentials-file'

def can_credential_expire(self) -> bool:
"""Check if the AWS identity type can expire."""
expirable_types = {
AWSIdentityType.SSO,
AWSIdentityType.ENV,
AWSIdentityType.IAM_ROLE,
AWSIdentityType.CONTAINER_ROLE
}
return self in expirable_types


@clouds.CLOUD_REGISTRY.register
class AWS(clouds.Cloud):
Expand Down Expand Up @@ -812,6 +822,9 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
if os.path.exists(os.path.expanduser(f'~/.aws/{filename}'))
}

def can_credential_expire(self) -> bool:
return self._current_identity_type().can_credential_expire()

def instance_type_exists(self, instance_type):
return service_catalog.instance_type_exists(instance_type, clouds='aws')

Expand Down
4 changes: 4 additions & 0 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,10 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
"""
raise NotImplementedError

def can_credential_expire(self) -> bool:
"""Returns whether the cloud credential can expire."""
return False

@classmethod
def get_image_size(cls, image_id: str, region: Optional[str]) -> float:
"""Check the image size from the cloud.
Expand Down
6 changes: 6 additions & 0 deletions sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ class GCPIdentityType(enum.Enum):

SHARED_CREDENTIALS_FILE = ''

def can_credential_expire(self) -> bool:
return self == GCPIdentityType.SHARED_CREDENTIALS_FILE


@clouds.CLOUD_REGISTRY.register
class GCP(clouds.Cloud):
Expand Down Expand Up @@ -863,6 +866,9 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
pass
return credentials

def can_credential_expire(self) -> bool:
return self._get_identity_type().can_credential_expire()

@classmethod
def _get_identity_type(cls) -> Optional[GCPIdentityType]:
try:
Expand Down

0 comments on commit 5c5ebc8

Please sign in to comment.