Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aws] cache user identity by 'aws configure list' #4507

Merged
merged 3 commits into from
Jan 3, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 53 additions & 10 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import enum
import fnmatch
import functools
import hashlib
import json
import os
import re
import subprocess
Expand All @@ -16,6 +18,7 @@
from sky import skypilot_config
from sky.adaptors import aws
from sky.clouds import service_catalog
from sky.clouds.service_catalog import common as catalog_common
from sky.clouds.utils import aws_utils
from sky.skylet import constants
from sky.utils import common_utils
Expand Down Expand Up @@ -624,14 +627,10 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:

@classmethod
def _current_identity_type(cls) -> Optional[AWSIdentityType]:
proc = subprocess.run('aws configure list',
shell=True,
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
if proc.returncode != 0:
stdout = cls._aws_configure_list()
if stdout is None:
return None
stdout = proc.stdout.decode()
output = stdout.decode()

# We determine the identity type by looking at the output of
# `aws configure list`. The output looks like:
Expand All @@ -646,10 +645,10 @@ def _current_identity_type(cls) -> Optional[AWSIdentityType]:

def _is_access_key_of_type(type_str: str) -> bool:
# The dot (.) does not match line separators.
results = re.findall(fr'access_key.*{type_str}', stdout)
results = re.findall(fr'access_key.*{type_str}', output)
if len(results) > 1:
raise RuntimeError(
f'Unexpected `aws configure list` output:\n{stdout}')
f'Unexpected `aws configure list` output:\n{output}')
return len(results) == 1

if _is_access_key_of_type(AWSIdentityType.SSO.value):
Expand All @@ -663,9 +662,21 @@ def _is_access_key_of_type(type_str: str) -> bool:
else:
return AWSIdentityType.SHARED_CREDENTIALS_FILE

@classmethod
@functools.lru_cache(maxsize=1)
def _aws_configure_list(cls) -> Optional[bytes]:
proc = subprocess.run('aws configure list',
shell=True,
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
if proc.returncode != 0:
return None
return proc.stdout

@classmethod
@functools.lru_cache(maxsize=1) # Cache since getting identity is slow.
def get_user_identities(cls) -> Optional[List[List[str]]]:
def _sts_get_caller_identity(cls) -> Optional[List[List[str]]]:
"""Returns a [UserId, Account] list that uniquely identifies the user.

These fields come from `aws sts get-caller-identity`. We permit the same
Expand Down Expand Up @@ -773,6 +784,38 @@ def get_user_identities(cls) -> Optional[List[List[str]]]:
# automatic switching for AWS. Currently we only support one identity.
return [user_ids]

@classmethod
@functools.lru_cache(maxsize=1) # Cache since getting identity is slow.
def get_user_identities(cls) -> Optional[List[List[str]]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a docstr for this function for the behavior of caching and returning identity. It might also be good to move the docstr from _sts_get_caller_identity to here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! just moved the docstring here and added the description of caching behavior

stdout = cls._aws_configure_list()
if stdout is None:
# `aws configure list` is not available, possible reasons:
# - awscli is not installed but credentials are valid, e.g. run from
# an EC2 instance with IAM role
# - aws credentials are not set, proceed anyway to get unified error
# message for users
return cls._sts_get_caller_identity()
config_hash = hashlib.md5(stdout).hexdigest()[:8]
# Getting aws identity cost ~1s, so we cache the result with the output of
# `aws configure list` as cache key. Different `aws configure list` output
# can have same aws identity, our assumption is the output would be stable
# in real world, so the number of cache files would be limited.
# TODO(aylei): consider using a more stable cache key and evalute eviction.
cache_path = catalog_common.get_catalog_path(
f'aws/user-identity-{config_hash}.txt')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we move it to aws/.cache/user-identity-{config_hash}.txt

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

if os.path.exists(cache_path):
try:
with open(cache_path, 'r', encoding='utf-8') as f:
return json.loads(f.read())
except json.JSONDecodeError:
# cache is invalid, ignore it and fetch identity again
pass

result = cls._sts_get_caller_identity()
with open(cache_path, 'w', encoding='utf-8') as f:
f.write(json.dumps(result))
return result

@classmethod
def get_active_user_identity_str(cls) -> Optional[str]:
user_identity = cls.get_active_user_identity()
Expand Down
Loading