Skip to content

Commit

Permalink
tests for aws credentials data
Browse files Browse the repository at this point in the history
  • Loading branch information
kalaspuff committed Feb 5, 2024
1 parent dec4bf2 commit 96e46ae
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 16 deletions.
170 changes: 170 additions & 0 deletions tests/test_aws_credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from typing import cast

import pytest
from botocore.credentials import Credentials as BotocoreCredentials

from tomodachi.helpers.aws_credentials import Credentials, CredentialsDict

TEST_AWS_ACCESS_KEY_ID = "AKIAXXXXXXXXXXXXXXXX" # example, not real access key
TEST_AWS_SECRET_ACCESS_KEY = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" # example, not real secret key
TEST_AWS_SESSION_TOKEN = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" # example, not real session token


def test_aws_credentials_object() -> None:
credentials = Credentials(
region_name="eu-west-1",
aws_access_key_id=TEST_AWS_ACCESS_KEY_ID,
aws_secret_access_key=TEST_AWS_SECRET_ACCESS_KEY,
)
assert credentials.region_name == "eu-west-1"
assert credentials.aws_access_key_id == TEST_AWS_ACCESS_KEY_ID
assert credentials.aws_secret_access_key == TEST_AWS_SECRET_ACCESS_KEY

assert credentials.dict() == {
"region_name": "eu-west-1",
"aws_access_key_id": TEST_AWS_ACCESS_KEY_ID,
"aws_secret_access_key": TEST_AWS_SECRET_ACCESS_KEY,
"aws_session_token": None,
"endpoint_url": None,
}

assert list(credentials.keys()) == [
"region_name",
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"endpoint_url",
]

assert list(credentials.values()) == [
"eu-west-1",
TEST_AWS_ACCESS_KEY_ID,
TEST_AWS_SECRET_ACCESS_KEY,
None,
None,
]

assert list(credentials.items()) == [
("region_name", "eu-west-1"),
("aws_access_key_id", TEST_AWS_ACCESS_KEY_ID),
("aws_secret_access_key", TEST_AWS_SECRET_ACCESS_KEY),
("aws_session_token", None),
("endpoint_url", None),
]

assert "region_name" in credentials
assert "aws_session_token" in credentials
assert "not_a_valid_key" not in credentials

assert credentials["region_name"] == "eu-west-1"
assert credentials.get("region_name") == "eu-west-1"

assert credentials.endpoint_url is None
assert credentials["endpoint_url"] is None
assert credentials.get("endpoint_url") is None
assert credentials.get("endpoint_url", "a default value") == "a default value"
assert credentials.get("not_a_valid_key") is None # type: ignore[call-overload]
assert credentials.get("not_a_valid_key", "a default value") == "a default value"

keys = []
values = []

for key in credentials:
keys.append(key)
values.append(credentials[key])
assert credentials[key] == getattr(credentials, key)
assert credentials[key] == credentials.get(key)

assert len(keys) == len(values)
assert len(keys) == 5
assert set(keys) == credentials.keys()
assert values == list(credentials.values())

assert credentials.dict() == credentials.dict()
assert credentials.dict() == dict(credentials)
assert credentials.dict() == Credentials(credentials).dict()
assert credentials.dict() == Credentials(credentials.dict()).dict()

_credentials: CredentialsDict = cast(CredentialsDict, credentials)
assert credentials.dict() == Credentials(**_credentials).dict()


def test_aws_credentials_object_invalid_accessor() -> None:
credentials = Credentials(
region_name="eu-west-1",
aws_access_key_id=TEST_AWS_ACCESS_KEY_ID,
aws_secret_access_key=TEST_AWS_SECRET_ACCESS_KEY,
)

with pytest.raises(AttributeError):
_ = credentials["not_a_valid_key"] # type: ignore[index]


def test_aws_credentials_object_invalid_attribute() -> None:
credentials = Credentials(
region_name="eu-west-1",
aws_access_key_id=TEST_AWS_ACCESS_KEY_ID,
aws_secret_access_key=TEST_AWS_SECRET_ACCESS_KEY,
)

with pytest.raises(AttributeError):
_ = credentials.not_a_valid_key # type: ignore[attr-defined]


def test_aws_credentials_object_invalid_argument() -> None:
with pytest.raises(TypeError):
Credentials(
region_name="eu-west-1",
aws_access_key_id=TEST_AWS_ACCESS_KEY_ID,
aws_secret_access_key=TEST_AWS_SECRET_ACCESS_KEY,
not_a_valid_argument="invalid", # type: ignore[call-overload]
)


def test_dict_credentials() -> None:
credentials = Credentials({"aws_session_token": TEST_AWS_SESSION_TOKEN})
assert credentials.aws_session_token == TEST_AWS_SESSION_TOKEN
assert credentials.endpoint_url is None

assert credentials.dict() == {
"region_name": None,
"aws_access_key_id": None,
"aws_secret_access_key": None,
"aws_session_token": TEST_AWS_SESSION_TOKEN,
"endpoint_url": None,
}


def test_botocore_credentials() -> None:
credentials = Credentials(BotocoreCredentials(TEST_AWS_ACCESS_KEY_ID, TEST_AWS_SECRET_ACCESS_KEY))
assert credentials.aws_access_key_id == TEST_AWS_ACCESS_KEY_ID
assert credentials.aws_secret_access_key == TEST_AWS_SECRET_ACCESS_KEY

assert credentials.dict() == {
"region_name": None,
"aws_access_key_id": TEST_AWS_ACCESS_KEY_ID,
"aws_secret_access_key": TEST_AWS_SECRET_ACCESS_KEY,
"aws_session_token": None,
"endpoint_url": None,
}


def test_botocore_credentials_with_extra_input() -> None:
credentials = Credentials(
BotocoreCredentials(TEST_AWS_ACCESS_KEY_ID, TEST_AWS_SECRET_ACCESS_KEY, TEST_AWS_SESSION_TOKEN),
region_name="eu-west-1",
endpoint_url="http://localhost:4567",
)
assert credentials.region_name == "eu-west-1"
assert credentials.aws_access_key_id == TEST_AWS_ACCESS_KEY_ID
assert credentials.aws_secret_access_key == TEST_AWS_SECRET_ACCESS_KEY
assert credentials.aws_session_token == TEST_AWS_SESSION_TOKEN
assert credentials.endpoint_url == "http://localhost:4567"

assert credentials.dict() == {
"region_name": "eu-west-1",
"aws_access_key_id": TEST_AWS_ACCESS_KEY_ID,
"aws_secret_access_key": TEST_AWS_SECRET_ACCESS_KEY,
"aws_session_token": TEST_AWS_SESSION_TOKEN,
"endpoint_url": "http://localhost:4567",
}
63 changes: 47 additions & 16 deletions tomodachi/helpers/aws_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
KeysView,
Literal,
Optional,
Set,
TypedDict,
TypeVar,
Union,
Expand Down Expand Up @@ -47,6 +48,7 @@ class Credentials:
aws_secret_access_key: Optional[str]
aws_session_token: Optional[str]
endpoint_url: Optional[str]
__unset: Set[str]

def keys(
self,
Expand All @@ -59,7 +61,10 @@ def keys(
def dict(self) -> CredentialsDict:
result: CredentialsDict = {}
for key in self.keys():
result[key] = getattr(self, key, None)
value = getattr(self, key, None)
if value is Ellipsis:
value = None
result[key] = value
return result

def values(
Expand All @@ -70,12 +75,24 @@ def values(
def items(
self,
) -> ItemsView[CredentialsTypeKeys, Optional[str]]:
return cast(Dict[CredentialsTypeKeys, Optional[str]], self.items()).items()
return cast(Dict[CredentialsTypeKeys, Optional[str]], self.dict()).items()

@overload
def __init__(
self,
__map: Union[CredentialsDict, CredentialsTypeProtocol, BotocoreCredentials, BotocoreReadOnlyCredentials],
__map: Union[
CredentialsDict,
CredentialsTypeProtocol,
"Credentials",
BotocoreCredentials,
BotocoreReadOnlyCredentials,
Dict[
Literal[
"region_name", "aws_access_key_id", "aws_secret_access_key", "aws_session_token", "endpoint_url"
],
Optional[str],
],
],
/,
*,
region_name: Optional[str] = ...,
Expand All @@ -100,7 +117,19 @@ def __init__(
def __init__(
self,
__map: Optional[
Union[CredentialsDict, CredentialsTypeProtocol, BotocoreCredentials, BotocoreReadOnlyCredentials]
Union[
CredentialsDict,
CredentialsTypeProtocol,
"Credentials",
BotocoreCredentials,
BotocoreReadOnlyCredentials,
Dict[
Literal[
"region_name", "aws_access_key_id", "aws_secret_access_key", "aws_session_token", "endpoint_url"
],
Optional[str],
],
]
] = None,
/,
**kwargs: Any,
Expand All @@ -116,23 +145,25 @@ def __init__(
}
if not isinstance(__map, dict):
try:
__map = cast(CredentialsDict, dict(__map)) # type: ignore[call-overload]
__map = cast(CredentialsDict, dict(__map)) # type: ignore[arg-type]
except TypeError:
__map = cast(CredentialsDict, dict(__map.__dict__))
if __map and isinstance(__map, dict):
for key, value in __map.items():
setattr(self, key, value)
for key in kwargs:
if key not in (
"region_name",
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"endpoint_url",
):
if key not in self.keys():
raise TypeError(f"__init__() got an unexpected keyword argument '{key}'")
setattr(self, key, kwargs[key])

self.__unset = set()
for key in self.keys():
try:
self.get(key, default=Ellipsis)
except AttributeError:
setattr(self, key, None)
self.__unset.add(key)

def __iter__(
self,
) -> Iterator[CredentialsTypeKeys]:
Expand All @@ -142,7 +173,7 @@ def __getitem__(
self,
key: CredentialsTypeKeys,
) -> Optional[str]:
return self.get(key)
return self.get(key, default=Ellipsis) if key not in self.__unset else None

def __contains__(
self,
Expand All @@ -167,11 +198,11 @@ def get(
def get(
self,
key: str,
default: Any = ...,
default: Any = None,
) -> Any:
if default is Ellipsis:
return getattr(self, key)
return getattr(self, key, default)
return getattr(self, key) if key not in self.__unset else None
return getattr(self, key, default) if key not in self.__unset else default


CredentialsMapping = Union[
Expand Down

0 comments on commit 96e46ae

Please sign in to comment.