From f67d98708f9bc59a927700ba7e1dae6509e45de8 Mon Sep 17 00:00:00 2001 From: hong Date: Wed, 18 Dec 2024 12:37:38 +0800 Subject: [PATCH] wip --- sky/cli.py | 26 ++++++++++++-------------- sky/clouds/aws.py | 8 ++++---- sky/clouds/gcp.py | 4 +++- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/sky/cli.py b/sky/cli.py index 4a612ace894..735eefab9e2 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -3738,15 +3738,14 @@ def jobs_launch( common_utils.check_cluster_name_is_valid(name) reauth_needed_clouds = [ - resource.cloud for task in dag.tasks for resource in task.resources - if resource.cloud.can_credential_expire() + resource.cloud for task in dag.tasks for resource in task.resources if + resource.cloud is not None and resource.cloud.can_credential_expire() ] - if reauth_needed_clouds: + if len(reauth_needed_clouds) > 0: prompt = ( - f'Launching jobs with cloud(s) {reauth_needed_clouds} may lead to jobs ' - 'being out of control. It is recommended to use credentials that never ' - 'expire or a service account. Proceed?' - ) + f'Launching jobs with cloud(s) {reauth_needed_clouds} may lead to ' + 'jobs being out of control. It is recommended to use credentials ' + 'that never expire or a service account. Proceed?') click.confirm(prompt, default=False, abort=True, show_default=True) managed_jobs.launch(dag, @@ -4236,15 +4235,14 @@ def serve_up( click.confirm(prompt, default=True, abort=True, show_default=True) reauth_needed_clouds = [ - resource.cloud for task in dag.tasks for resource in task.resources - if resource.cloud.can_credential_expire() + resource.cloud for task in dag.tasks for resource in task.resources if + resource.cloud is not None and resource.cloud.can_credential_expire() ] - if reauth_needed_clouds: + if len(reauth_needed_clouds) > 0: prompt = ( - f'Launching jobs with cloud(s) {reauth_needed_clouds} may lead to jobs ' - 'being out of control. It is recommended to use credentials that never ' - 'expire or a service account. Proceed?' - ) + f'Launching jobs with cloud(s) {reauth_needed_clouds} may lead to ' + 'jobs being out of control. It is recommended to use credentials ' + 'that never expire or a service account. Proceed?') click.confirm(prompt, default=False, abort=True, show_default=True) serve_lib.up(task, service_name) diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index 70eaa8fafb7..448ec808076 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -103,9 +103,7 @@ class AWSIdentityType(enum.Enum): def can_credential_expire(self) -> bool: """Check if the AWS identity type can expire.""" expirable_types = { - AWSIdentityType.SSO, - AWSIdentityType.ENV, - AWSIdentityType.IAM_ROLE, + AWSIdentityType.SSO, AWSIdentityType.ENV, AWSIdentityType.IAM_ROLE, AWSIdentityType.CONTAINER_ROLE } return self in expirable_types @@ -823,7 +821,9 @@ def get_credential_file_mounts(self) -> Dict[str, str]: } def can_credential_expire(self) -> bool: - return self._current_identity_type().can_credential_expire() + identity_type = self._current_identity_type() + return identity_type is not None and identity_type.can_credential_expire( + ) def instance_type_exists(self, instance_type): return service_catalog.instance_type_exists(instance_type, clouds='aws') diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 84264a1e7bc..d44a6579222 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -867,7 +867,9 @@ def get_credential_file_mounts(self) -> Dict[str, str]: return credentials def can_credential_expire(self) -> bool: - return self._get_identity_type().can_credential_expire() + identity_type = self._get_identity_type() + return identity_type is not None and identity_type.can_credential_expire( + ) @classmethod def _get_identity_type(cls) -> Optional[GCPIdentityType]: