Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aws] cache user identity by 'aws configure list' #4507

Merged
merged 3 commits into from
Jan 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 87 additions & 39 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import enum
import fnmatch
import functools
import hashlib
import json
import os
import re
import subprocess
Expand All @@ -16,6 +18,7 @@
from sky import skypilot_config
from sky.adaptors import aws
from sky.clouds import service_catalog
from sky.clouds.service_catalog import common as catalog_common
from sky.clouds.utils import aws_utils
from sky.skylet import constants
from sky.utils import common_utils
Expand Down Expand Up @@ -624,14 +627,10 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:

@classmethod
def _current_identity_type(cls) -> Optional[AWSIdentityType]:
proc = subprocess.run('aws configure list',
shell=True,
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
if proc.returncode != 0:
stdout = cls._aws_configure_list()
if stdout is None:
return None
stdout = proc.stdout.decode()
output = stdout.decode()

# We determine the identity type by looking at the output of
# `aws configure list`. The output looks like:
Expand All @@ -646,10 +645,10 @@ def _current_identity_type(cls) -> Optional[AWSIdentityType]:

def _is_access_key_of_type(type_str: str) -> bool:
# The dot (.) does not match line separators.
results = re.findall(fr'access_key.*{type_str}', stdout)
results = re.findall(fr'access_key.*{type_str}', output)
if len(results) > 1:
raise RuntimeError(
f'Unexpected `aws configure list` output:\n{stdout}')
f'Unexpected `aws configure list` output:\n{output}')
return len(results) == 1

if _is_access_key_of_type(AWSIdentityType.SSO.value):
Expand All @@ -664,37 +663,20 @@ def _is_access_key_of_type(type_str: str) -> bool:
return AWSIdentityType.SHARED_CREDENTIALS_FILE

@classmethod
@functools.lru_cache(maxsize=1) # Cache since getting identity is slow.
def get_user_identities(cls) -> Optional[List[List[str]]]:
"""Returns a [UserId, Account] list that uniquely identifies the user.

These fields come from `aws sts get-caller-identity`. We permit the same
actual user to:

- switch between different root accounts (after which both elements
of the list will be different) and have their clusters owned by
each account be protected; or

- within the same root account, switch between different IAM
users, and treat [user_id=1234, account=A] and
[user_id=4567, account=A] to be the *same*. Namely, switching
between these IAM roles within the same root account will cause
the first element of the returned list to differ, and will allow
the same actual user to continue to interact with their clusters.
Note: this is not 100% safe, since the IAM users can have very
specific permissions, that disallow them to access the clusters
but it is a reasonable compromise as that could be rare.

Returns:
A list of strings that uniquely identifies the user on this cloud.
For identity check, we will fallback through the list of strings
until we find a match, and print a warning if we fail for the
first string.
@functools.lru_cache(maxsize=1)
def _aws_configure_list(cls) -> Optional[bytes]:
proc = subprocess.run('aws configure list',
shell=True,
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
if proc.returncode != 0:
return None
return proc.stdout

Raises:
exceptions.CloudUserIdentityError: if the user identity cannot be
retrieved.
"""
@classmethod
@functools.lru_cache(maxsize=1) # Cache since getting identity is slow.
def _sts_get_caller_identity(cls) -> Optional[List[List[str]]]:
try:
sts = aws.client('sts')
# The caller identity contains 3 fields: UserId, Account, Arn.
Expand Down Expand Up @@ -773,6 +755,72 @@ def get_user_identities(cls) -> Optional[List[List[str]]]:
# automatic switching for AWS. Currently we only support one identity.
return [user_ids]

@classmethod
@functools.lru_cache(maxsize=1) # Cache since getting identity is slow.
def get_user_identities(cls) -> Optional[List[List[str]]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a docstr for this function for the behavior of caching and returning identity. It might also be good to move the docstr from _sts_get_caller_identity to here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! just moved the docstring here and added the description of caching behavior

"""Returns a [UserId, Account] list that uniquely identifies the user.

These fields come from `aws sts get-caller-identity` and are cached
locally by `aws configure list` output. The identities are assumed to
be stable for the duration of the `sky` process. Modifying the
credentials while the `sky` process is running will not affect the
identity returned by this function.

We permit the same actual user to:

- switch between different root accounts (after which both elements
of the list will be different) and have their clusters owned by
each account be protected; or

- within the same root account, switch between different IAM
users, and treat [user_id=1234, account=A] and
[user_id=4567, account=A] to be the *same*. Namely, switching
between these IAM roles within the same root account will cause
the first element of the returned list to differ, and will allow
the same actual user to continue to interact with their clusters.
Note: this is not 100% safe, since the IAM users can have very
specific permissions, that disallow them to access the clusters
but it is a reasonable compromise as that could be rare.

Returns:
A list of strings that uniquely identifies the user on this cloud.
For identity check, we will fallback through the list of strings
until we find a match, and print a warning if we fail for the
first string.

Raises:
exceptions.CloudUserIdentityError: if the user identity cannot be
retrieved.
"""
stdout = cls._aws_configure_list()
if stdout is None:
# `aws configure list` is not available, possible reasons:
# - awscli is not installed but credentials are valid, e.g. run from
# an EC2 instance with IAM role
# - aws credentials are not set, proceed anyway to get unified error
# message for users
return cls._sts_get_caller_identity()
config_hash = hashlib.md5(stdout).hexdigest()[:8]
# Getting aws identity cost ~1s, so we cache the result with the output of
# `aws configure list` as cache key. Different `aws configure list` output
# can have same aws identity, our assumption is the output would be stable
# in real world, so the number of cache files would be limited.
# TODO(aylei): consider using a more stable cache key and evalute eviction.
cache_path = catalog_common.get_catalog_path(
f'aws/.cache/user-identity-{config_hash}.txt')
if os.path.exists(cache_path):
try:
with open(cache_path, 'r', encoding='utf-8') as f:
return json.loads(f.read())
except json.JSONDecodeError:
# cache is invalid, ignore it and fetch identity again
pass

result = cls._sts_get_caller_identity()
with open(cache_path, 'w', encoding='utf-8') as f:
f.write(json.dumps(result))
return result

@classmethod
def get_active_user_identity_str(cls) -> Optional[str]:
user_identity = cls.get_active_user_identity()
Expand Down
Loading