From 406f39e82d583c0e4914d5129d0ad87021484f6f Mon Sep 17 00:00:00 2001 From: Aylei Date: Wed, 25 Dec 2024 22:03:26 +0800 Subject: [PATCH 1/3] [aws] cache user identity by 'aws configure list' Signed-off-by: Aylei --- sky/clouds/aws.py | 63 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 53 insertions(+), 10 deletions(-) diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index cafc789c5be..159d6305ed2 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -2,6 +2,8 @@ import enum import fnmatch import functools +import hashlib +import json import os import re import subprocess @@ -16,6 +18,7 @@ from sky import skypilot_config from sky.adaptors import aws from sky.clouds import service_catalog +from sky.clouds.service_catalog import common as catalog_common from sky.clouds.utils import aws_utils from sky.skylet import constants from sky.utils import common_utils @@ -624,14 +627,10 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]: @classmethod def _current_identity_type(cls) -> Optional[AWSIdentityType]: - proc = subprocess.run('aws configure list', - shell=True, - check=False, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - if proc.returncode != 0: + stdout = cls._aws_configure_list() + if stdout is None: return None - stdout = proc.stdout.decode() + output = stdout.decode() # We determine the identity type by looking at the output of # `aws configure list`. The output looks like: @@ -646,10 +645,10 @@ def _current_identity_type(cls) -> Optional[AWSIdentityType]: def _is_access_key_of_type(type_str: str) -> bool: # The dot (.) does not match line separators. - results = re.findall(fr'access_key.*{type_str}', stdout) + results = re.findall(fr'access_key.*{type_str}', output) if len(results) > 1: raise RuntimeError( - f'Unexpected `aws configure list` output:\n{stdout}') + f'Unexpected `aws configure list` output:\n{output}') return len(results) == 1 if _is_access_key_of_type(AWSIdentityType.SSO.value): @@ -663,9 +662,21 @@ def _is_access_key_of_type(type_str: str) -> bool: else: return AWSIdentityType.SHARED_CREDENTIALS_FILE + @classmethod + @functools.lru_cache(maxsize=1) + def _aws_configure_list(cls) -> Optional[bytes]: + proc = subprocess.run('aws configure list', + shell=True, + check=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + if proc.returncode != 0: + return None + return proc.stdout + @classmethod @functools.lru_cache(maxsize=1) # Cache since getting identity is slow. - def get_user_identities(cls) -> Optional[List[List[str]]]: + def _sts_get_caller_identity(cls) -> Optional[List[List[str]]]: """Returns a [UserId, Account] list that uniquely identifies the user. These fields come from `aws sts get-caller-identity`. We permit the same @@ -773,6 +784,38 @@ def get_user_identities(cls) -> Optional[List[List[str]]]: # automatic switching for AWS. Currently we only support one identity. return [user_ids] + @classmethod + @functools.lru_cache(maxsize=1) # Cache since getting identity is slow. + def get_user_identities(cls) -> Optional[List[List[str]]]: + stdout = cls._aws_configure_list() + if stdout is None: + # `aws configure list` is not available, possible reasons: + # - awscli is not installed but credentials are valid, e.g. run from + # an EC2 instance with IAM role + # - aws credentials are not set, proceed anyway to get unified error + # message for users + return cls._sts_get_caller_identity() + config_hash = hashlib.md5(stdout).hexdigest()[:8] + # Getting aws identity cost ~1s, so we cache the result with the output of + # `aws configure list` as cache key. Different `aws configure list` output + # can have same aws identity, our assumption is the output would be stable + # in real world, so the number of cache files would be limited. + # TODO(aylei): consider using a more stable cache key and evalute eviction. + cache_path = catalog_common.get_catalog_path( + f'aws/user-identity-{config_hash}.txt') + if os.path.exists(cache_path): + try: + with open(cache_path, 'r', encoding='utf-8') as f: + return json.loads(f.read()) + except json.JSONDecodeError: + # cache is invalid, ignore it and fetch identity again + pass + + result = cls._sts_get_caller_identity() + with open(cache_path, 'w', encoding='utf-8') as f: + f.write(json.dumps(result)) + return result + @classmethod def get_active_user_identity_str(cls) -> Optional[str]: user_identity = cls.get_active_user_identity() From c4b52305b5aad9e4dbbf9c5fe77e7d3d22d8d3c6 Mon Sep 17 00:00:00 2001 From: Aylei Date: Thu, 2 Jan 2025 20:18:08 +0800 Subject: [PATCH 2/3] refine get_user_identities docstring Signed-off-by: Aylei --- sky/clouds/aws.py | 65 ++++++++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index 159d6305ed2..37da24e771d 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -677,35 +677,6 @@ def _aws_configure_list(cls) -> Optional[bytes]: @classmethod @functools.lru_cache(maxsize=1) # Cache since getting identity is slow. def _sts_get_caller_identity(cls) -> Optional[List[List[str]]]: - """Returns a [UserId, Account] list that uniquely identifies the user. - - These fields come from `aws sts get-caller-identity`. We permit the same - actual user to: - - - switch between different root accounts (after which both elements - of the list will be different) and have their clusters owned by - each account be protected; or - - - within the same root account, switch between different IAM - users, and treat [user_id=1234, account=A] and - [user_id=4567, account=A] to be the *same*. Namely, switching - between these IAM roles within the same root account will cause - the first element of the returned list to differ, and will allow - the same actual user to continue to interact with their clusters. - Note: this is not 100% safe, since the IAM users can have very - specific permissions, that disallow them to access the clusters - but it is a reasonable compromise as that could be rare. - - Returns: - A list of strings that uniquely identifies the user on this cloud. - For identity check, we will fallback through the list of strings - until we find a match, and print a warning if we fail for the - first string. - - Raises: - exceptions.CloudUserIdentityError: if the user identity cannot be - retrieved. - """ try: sts = aws.client('sts') # The caller identity contains 3 fields: UserId, Account, Arn. @@ -787,6 +758,42 @@ def _sts_get_caller_identity(cls) -> Optional[List[List[str]]]: @classmethod @functools.lru_cache(maxsize=1) # Cache since getting identity is slow. def get_user_identities(cls) -> Optional[List[List[str]]]: + """Returns a [UserId, Account] list that uniquely identifies current AWS + principal (user, role or federated identity) whose credentials are used + to run current `sky` process. These identities are assumed to be stable + for the duration of the `sky` process. Modifying the credentials while + the `sky` process is running will not affect the identity returned by + this function. + + These fields come from `aws sts get-caller-identity` and are cached + locally by `aws configure list` output. + + We permit the same actual user to: + + - switch between different root accounts (after which both elements + of the list will be different) and have their clusters owned by + each account be protected; or + + - within the same root account, switch between different IAM + users, and treat [user_id=1234, account=A] and + [user_id=4567, account=A] to be the *same*. Namely, switching + between these IAM roles within the same root account will cause + the first element of the returned list to differ, and will allow + the same actual user to continue to interact with their clusters. + Note: this is not 100% safe, since the IAM users can have very + specific permissions, that disallow them to access the clusters + but it is a reasonable compromise as that could be rare. + + Returns: + A list of strings that uniquely identifies the user on this cloud. + For identity check, we will fallback through the list of strings + until we find a match, and print a warning if we fail for the + first string. + + Raises: + exceptions.CloudUserIdentityError: if the user identity cannot be + retrieved. + """ stdout = cls._aws_configure_list() if stdout is None: # `aws configure list` is not available, possible reasons: From 6e4b49f0a308ad3a39537f9d13709ec1496508b1 Mon Sep 17 00:00:00 2001 From: Aylei Date: Fri, 3 Jan 2025 22:56:06 +0800 Subject: [PATCH 3/3] address review comments Signed-off-by: Aylei --- sky/clouds/aws.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index 37da24e771d..c665263e22e 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -758,15 +758,13 @@ def _sts_get_caller_identity(cls) -> Optional[List[List[str]]]: @classmethod @functools.lru_cache(maxsize=1) # Cache since getting identity is slow. def get_user_identities(cls) -> Optional[List[List[str]]]: - """Returns a [UserId, Account] list that uniquely identifies current AWS - principal (user, role or federated identity) whose credentials are used - to run current `sky` process. These identities are assumed to be stable - for the duration of the `sky` process. Modifying the credentials while - the `sky` process is running will not affect the identity returned by - this function. + """Returns a [UserId, Account] list that uniquely identifies the user. These fields come from `aws sts get-caller-identity` and are cached - locally by `aws configure list` output. + locally by `aws configure list` output. The identities are assumed to + be stable for the duration of the `sky` process. Modifying the + credentials while the `sky` process is running will not affect the + identity returned by this function. We permit the same actual user to: @@ -809,7 +807,7 @@ def get_user_identities(cls) -> Optional[List[List[str]]]: # in real world, so the number of cache files would be limited. # TODO(aylei): consider using a more stable cache key and evalute eviction. cache_path = catalog_common.get_catalog_path( - f'aws/user-identity-{config_hash}.txt') + f'aws/.cache/user-identity-{config_hash}.txt') if os.path.exists(cache_path): try: with open(cache_path, 'r', encoding='utf-8') as f: