diff --git a/sky/cli.py b/sky/cli.py index 12f77e9f6c9..4a612ace894 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -3737,6 +3737,18 @@ def jobs_launch( common_utils.check_cluster_name_is_valid(name) + reauth_needed_clouds = [ + resource.cloud for task in dag.tasks for resource in task.resources + if resource.cloud.can_credential_expire() + ] + if reauth_needed_clouds: + prompt = ( + f'Launching jobs with cloud(s) {reauth_needed_clouds} may lead to jobs ' + 'being out of control. It is recommended to use credentials that never ' + 'expire or a service account. Proceed?' + ) + click.confirm(prompt, default=False, abort=True, show_default=True) + managed_jobs.launch(dag, name, detach_run=detach_run, @@ -4223,6 +4235,18 @@ def serve_up( if prompt is not None: click.confirm(prompt, default=True, abort=True, show_default=True) + reauth_needed_clouds = [ + resource.cloud for task in dag.tasks for resource in task.resources + if resource.cloud.can_credential_expire() + ] + if reauth_needed_clouds: + prompt = ( + f'Launching jobs with cloud(s) {reauth_needed_clouds} may lead to jobs ' + 'being out of control. It is recommended to use credentials that never ' + 'expire or a service account. Proceed?' + ) + click.confirm(prompt, default=False, abort=True, show_default=True) + serve_lib.up(task, service_name) diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index c42d67f8ba4..70eaa8fafb7 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -100,6 +100,16 @@ class AWSIdentityType(enum.Enum): # region us-east-1 config-file ~/.aws/config SHARED_CREDENTIALS_FILE = 'shared-credentials-file' + def can_credential_expire(self) -> bool: + """Check if the AWS identity type can expire.""" + expirable_types = { + AWSIdentityType.SSO, + AWSIdentityType.ENV, + AWSIdentityType.IAM_ROLE, + AWSIdentityType.CONTAINER_ROLE + } + return self in expirable_types + @clouds.CLOUD_REGISTRY.register class AWS(clouds.Cloud): @@ -812,6 +822,9 @@ def get_credential_file_mounts(self) -> Dict[str, str]: if os.path.exists(os.path.expanduser(f'~/.aws/{filename}')) } + def can_credential_expire(self) -> bool: + return self._current_identity_type().can_credential_expire() + def instance_type_exists(self, instance_type): return service_catalog.instance_type_exists(instance_type, clouds='aws') diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index 455baeaf5d9..2cb45ca14fc 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -536,6 +536,10 @@ def get_credential_file_mounts(self) -> Dict[str, str]: """ raise NotImplementedError + def can_credential_expire(self) -> bool: + """Returns whether the cloud credential can expire.""" + return False + @classmethod def get_image_size(cls, image_id: str, region: Optional[str]) -> float: """Check the image size from the cloud. diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 8a28a35505e..84264a1e7bc 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -132,6 +132,9 @@ class GCPIdentityType(enum.Enum): SHARED_CREDENTIALS_FILE = '' + def can_credential_expire(self) -> bool: + return self == GCPIdentityType.SHARED_CREDENTIALS_FILE + @clouds.CLOUD_REGISTRY.register class GCP(clouds.Cloud): @@ -863,6 +866,9 @@ def get_credential_file_mounts(self) -> Dict[str, str]: pass return credentials + def can_credential_expire(self) -> bool: + return self._get_identity_type().can_credential_expire() + @classmethod def _get_identity_type(cls) -> Optional[GCPIdentityType]: try: