diff --git a/packages/smithy-aws-core/src/smithy_aws_core/identity/container.py b/packages/smithy-aws-core/src/smithy_aws_core/identity/container.py new file mode 100644 index 00000000..30785c5e --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/identity/container.py @@ -0,0 +1,192 @@ +import asyncio +import ipaddress +import json +import os +from dataclasses import dataclass +from datetime import UTC, datetime +from urllib.parse import urlparse + +from smithy_core import URI +from smithy_core.aio.interfaces.identity import IdentityResolver +from smithy_core.exceptions import SmithyIdentityError +from smithy_http import Field, Fields +from smithy_http.aio import HTTPRequest +from smithy_http.aio.interfaces import HTTPClient, HTTPResponse + +from smithy_aws_core.identity import AWSCredentialsIdentity, AWSIdentityProperties + +_CONTAINER_METADATA_IP = "169.254.170.2" +_CONTAINER_METADATA_ALLOWED_HOSTS = { + _CONTAINER_METADATA_IP, + "169.254.170.23", + "fd00:ec2::23", + "localhost", +} +_DEFAULT_TIMEOUT = 2 +_DEFAULT_RETRIES = 3 +_SLEEP_SECONDS = 1 + + +@dataclass +class ContainerCredentialConfig: + """Configuration for container credential retrieval operations.""" + + timeout: int = _DEFAULT_TIMEOUT + retries: int = _DEFAULT_RETRIES + + +class ContainerMetadataClient: + """Client for remote credential retrieval in Container environments like ECS/EKS.""" + + def __init__(self, http_client: HTTPClient, config: ContainerCredentialConfig): + self._http_client = http_client + self._config = config + + def _validate_allowed_url(self, uri: URI) -> None: + if self._is_loopback(uri.host): + return + + if not self._is_allowed_container_metadata_host(uri.host): + raise SmithyIdentityError( + f"Unsupported host '{uri.host}'. " + f"Can only retrieve metadata from a loopback address or " + f"one of: {', '.join(_CONTAINER_METADATA_ALLOWED_HOSTS)}" + ) + + async def get_credentials(self, uri: URI, fields: Fields) -> dict[str, str]: + self._validate_allowed_url(uri) + fields.set_field(Field(name="Accept", values=["application/json"])) + + attempts = 0 + last_exc = None + while attempts < self._config.retries: + try: + request = HTTPRequest( + method="GET", + destination=uri, + fields=fields, + ) + response: HTTPResponse = await self._http_client.send(request) + body = await response.consume_body_async() + if response.status != 200: + raise SmithyIdentityError( + f"Container metadata service returned {response.status}: " + f"{body.decode('utf-8')}" + ) + try: + return json.loads(body.decode("utf-8")) + except Exception as e: + raise SmithyIdentityError( + f"Unable to parse JSON from container metadata: {body.decode('utf-8')}" + ) from e + except Exception as e: + last_exc = e + await asyncio.sleep(_SLEEP_SECONDS) + attempts += 1 + + raise SmithyIdentityError( + f"Failed to retrieve container metadata after {self._config.retries} attempt(s)" + ) from last_exc + + def _is_loopback(self, hostname: str) -> bool: + try: + return ipaddress.ip_address(hostname).is_loopback + except ValueError: + return False + + def _is_allowed_container_metadata_host(self, hostname: str) -> bool: + return hostname in _CONTAINER_METADATA_ALLOWED_HOSTS + + +class ContainerCredentialResolver( + IdentityResolver[AWSCredentialsIdentity, AWSIdentityProperties] +): + """Resolves AWS Credentials from container credential sources.""" + + ENV_VAR = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" + ENV_VAR_FULL = "AWS_CONTAINER_CREDENTIALS_FULL_URI" + ENV_VAR_AUTH_TOKEN = "AWS_CONTAINER_AUTHORIZATION_TOKEN" # noqa: S105 + ENV_VAR_AUTH_TOKEN_FILE = "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE" # noqa: S105 + + def __init__( + self, + http_client: HTTPClient, + config: ContainerCredentialConfig | None = None, + ): + self._http_client = http_client + self._config = config or ContainerCredentialConfig() + self._client = ContainerMetadataClient(http_client, self._config) + self._credentials = None + + async def _resolve_uri_from_env(self) -> URI: + if self.ENV_VAR in os.environ: + return URI( + scheme="http", + host=_CONTAINER_METADATA_IP, + path=os.environ[self.ENV_VAR], + ) + elif self.ENV_VAR_FULL in os.environ: + parsed = urlparse(os.environ[self.ENV_VAR_FULL]) + return URI( + scheme=parsed.scheme, + host=parsed.hostname or "", + port=parsed.port, + path=parsed.path, + ) + else: + raise SmithyIdentityError( + f"Neither {self.ENV_VAR} or {self.ENV_VAR_FULL} environment " + "variables are set. Unable to resolve credentials." + ) + + async def _resolve_fields_from_env(self) -> Fields: + fields = Fields() + if self.ENV_VAR_AUTH_TOKEN_FILE in os.environ: + try: + filename = os.environ[self.ENV_VAR_AUTH_TOKEN_FILE] + auth_token = await asyncio.to_thread(self._read_file, filename) + except (FileNotFoundError, PermissionError) as e: + raise SmithyIdentityError( + f"Unable to open {os.environ[self.ENV_VAR_AUTH_TOKEN_FILE]}." + ) from e + + fields.set_field(Field(name="Authorization", values=[auth_token])) + elif self.ENV_VAR_AUTH_TOKEN in os.environ: + auth_token = os.environ[self.ENV_VAR_AUTH_TOKEN] + fields.set_field(Field(name="Authorization", values=[auth_token])) + + return fields + + def _read_file(self, filename: str) -> str: + with open(filename) as f: + return f.read().strip() + + async def get_identity( + self, *, properties: AWSIdentityProperties + ) -> AWSCredentialsIdentity: + uri = await self._resolve_uri_from_env() + fields = await self._resolve_fields_from_env() + creds = await self._client.get_credentials(uri, fields) + + access_key_id = creds.get("AccessKeyId") + secret_access_key = creds.get("SecretAccessKey") + session_token = creds.get("Token") + expiration = creds.get("Expiration") + account_id = creds.get("AccountId", None) + + if isinstance(expiration, str): + expiration = datetime.fromisoformat(expiration).replace(tzinfo=UTC) + + if access_key_id is None or secret_access_key is None: + raise SmithyIdentityError( + "AccessKeyId and SecretAccessKey are required for container credentials" + ) + + self._credentials = AWSCredentialsIdentity( + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=session_token, + expiration=expiration, + account_id=account_id, + ) + return self._credentials diff --git a/packages/smithy-aws-core/tests/unit/identity/test_container.py b/packages/smithy-aws-core/tests/unit/identity/test_container.py new file mode 100644 index 00000000..d7941914 --- /dev/null +++ b/packages/smithy-aws-core/tests/unit/identity/test_container.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +import json +import os +import typing +from unittest.mock import AsyncMock, patch + +import pytest +from smithy_aws_core.identity import AWSCredentialsIdentity +from smithy_aws_core.identity.container import ( + ContainerCredentialConfig, + ContainerCredentialResolver, + ContainerMetadataClient, +) +from smithy_core import URI +from smithy_core.exceptions import SmithyIdentityError +from smithy_http import Fields + +if typing.TYPE_CHECKING: + import pathlib + +DEFAULT_RESPONSE_DATA = { + "AccessKeyId": "akid123", + "SecretAccessKey": "s3cr3t", + "Token": "session_token", +} + + +def test_config_custom_values(): + config = ContainerCredentialConfig(timeout=10, retries=5) + assert config.timeout == 10 + assert config.retries == 5 + + +def mock_http_client_response(status: int, body: bytes): + http_client = AsyncMock() + response = AsyncMock() + response.status = status + response.consume_body_async.return_value = body + http_client.send.return_value = response + return http_client + + +@pytest.mark.asyncio +async def test_metadata_client_valid_host(): + resp_body = json.dumps(DEFAULT_RESPONSE_DATA) + http_client = mock_http_client_response(200, resp_body.encode("utf-8")) + config = ContainerCredentialConfig() + client = ContainerMetadataClient(http_client, config) + + # Valid Host + uri = URI(scheme="http", host="169.254.170.2") + + creds = await client.get_credentials(uri, Fields()) + assert creds["AccessKeyId"] == "akid123" + assert creds["SecretAccessKey"] == "s3cr3t" + assert creds["Token"] == "session_token" + + +@pytest.mark.asyncio +async def test_metadata_client_invalid_host(): + resp_body = json.dumps(DEFAULT_RESPONSE_DATA) + http_client = mock_http_client_response(200, resp_body.encode("utf-8")) + config = ContainerCredentialConfig(retries=0) + client = ContainerMetadataClient(http_client, config) + + # Invalid Host + uri = URI(scheme="http", host="169.254.169.254") + + with pytest.raises(SmithyIdentityError): + await client.get_credentials(uri, Fields()) + + +@pytest.mark.asyncio +async def test_metadata_client_non_200_response(): + http_client = mock_http_client_response(404, b"not found") + config = ContainerCredentialConfig(retries=1) + client = ContainerMetadataClient(http_client, config) + + uri = URI(scheme="http", host="169.254.170.2") + with pytest.raises(SmithyIdentityError) as e: + await client.get_credentials(uri, Fields()) + + # Ensure both the received retry error and underlying error are what we expect. + assert "Container metadata service returned 404" in str(e.value.__cause__) + assert "Failed to retrieve container metadata after 1 attempt(s)" in str(e.value) + + +@pytest.mark.asyncio +async def test_metadata_client_invalid_json(): + http_client = mock_http_client_response( + 200, b"proxy" + ) + config = ContainerCredentialConfig(retries=1) + client = ContainerMetadataClient(http_client, config) + + uri = URI(scheme="http", host="169.254.170.2") + with pytest.raises(SmithyIdentityError): + await client.get_credentials(uri, Fields()) + + +def _assert_expected_identity(identity: AWSCredentialsIdentity) -> None: + assert identity.access_key_id == DEFAULT_RESPONSE_DATA["AccessKeyId"] + assert identity.secret_access_key == DEFAULT_RESPONSE_DATA["SecretAccessKey"] + assert identity.session_token == DEFAULT_RESPONSE_DATA["Token"] + + +@pytest.mark.asyncio +async def test_metadata_client_retries(): + http_client = AsyncMock() + config = ContainerCredentialConfig(retries=2) + client = ContainerMetadataClient(http_client, config) + uri = URI(scheme="http", host="169.254.170.2", path="/task") + http_client.send.side_effect = Exception() + + with pytest.raises(SmithyIdentityError): + await client.get_credentials(uri, Fields()) + assert http_client.send.call_count == 2 + + +@pytest.mark.asyncio +async def test_resolver_env_relative(): + resp_body = json.dumps(DEFAULT_RESPONSE_DATA) + http_client = mock_http_client_response(200, resp_body.encode("utf-8")) + + with patch.dict(os.environ, {ContainerCredentialResolver.ENV_VAR: "/test"}): + resolver = ContainerCredentialResolver(http_client) + identity = await resolver.get_identity(properties={}) + + # Ensure we derive the correct destination + expected_url = URI( + scheme="http", + host="169.254.170.2", + path="/test", + ) + http_request = http_client.send.call_args_list[0].args[0] + assert http_request.destination == expected_url + + _assert_expected_identity(identity) + + +@pytest.mark.asyncio +async def test_resolver_env_full(): + resp_body = json.dumps(DEFAULT_RESPONSE_DATA) + http_client = mock_http_client_response(200, resp_body.encode("utf-8")) + + with patch.dict( + os.environ, + {ContainerCredentialResolver.ENV_VAR_FULL: "http://169.254.170.23/full"}, + ): + resolver = ContainerCredentialResolver(http_client) + identity = await resolver.get_identity(properties={}) + + # Ensure we derive the correct destination + expected_url = URI( + scheme="http", + host="169.254.170.23", + path="/full", + ) + http_request = http_client.send.call_args_list[0].args[0] + assert http_request.destination == expected_url + + _assert_expected_identity(identity) + + +@pytest.mark.asyncio +async def test_resolver_env_token(): + resp_body = json.dumps(DEFAULT_RESPONSE_DATA) + http_client = mock_http_client_response(200, resp_body.encode("utf-8")) + + with patch.dict( + os.environ, + { + ContainerCredentialResolver.ENV_VAR_FULL: "http://169.254.170.23/full", + ContainerCredentialResolver.ENV_VAR_AUTH_TOKEN: "Bearer foobar", + }, + ): + resolver = ContainerCredentialResolver(http_client) + identity = await resolver.get_identity(properties={}) + + # Ensure we derive the correct destination and fields + expected_url = URI( + scheme="http", + host="169.254.170.23", + path="/full", + ) + http_request = http_client.send.call_args_list[0].args[0] + assert http_request.destination == expected_url + + assert "Authorization" in http_request.fields + auth_field = http_request.fields.get("Authorization") + assert auth_field.as_string() == "Bearer foobar" + + _assert_expected_identity(identity) + + +@pytest.mark.asyncio +async def test_resolver_env_token_file(tmp_path: pathlib.Path): + resp_body = json.dumps(DEFAULT_RESPONSE_DATA) + http_client = mock_http_client_response(200, resp_body.encode("utf-8")) + + token_file = tmp_path / "token_file" + token_file.write_text("Bearer barfoo") + + with patch.dict( + os.environ, + { + ContainerCredentialResolver.ENV_VAR_FULL: "http://169.254.170.23/full", + ContainerCredentialResolver.ENV_VAR_AUTH_TOKEN_FILE: str(token_file), + }, + ): + resolver = ContainerCredentialResolver(http_client) + identity = await resolver.get_identity(properties={}) + + # Ensure we derive the correct destination and fields + expected_url = URI( + scheme="http", + host="169.254.170.23", + path="/full", + ) + http_request = http_client.send.call_args_list[0].args[0] + assert http_request.destination == expected_url + + assert "Authorization" in http_request.fields + auth_field = http_request.fields.get("Authorization") + assert auth_field.as_string() == "Bearer barfoo" + + _assert_expected_identity(identity) + + +@pytest.mark.asyncio +async def test_resolver_env_token_file_precedence(tmp_path: pathlib.Path): + """Validate the token file is used over the explicit value if both are set.""" + resp_body = json.dumps(DEFAULT_RESPONSE_DATA) + http_client = mock_http_client_response(200, resp_body.encode("utf-8")) + + token_file = tmp_path / "token_file" + token_file.write_text("Bearer barfoo") + + with patch.dict( + os.environ, + { + ContainerCredentialResolver.ENV_VAR_FULL: "http://169.254.170.23/full", + ContainerCredentialResolver.ENV_VAR_AUTH_TOKEN_FILE: str(token_file), + ContainerCredentialResolver.ENV_VAR_AUTH_TOKEN: "Bearer foobar", + }, + ): + resolver = ContainerCredentialResolver(http_client) + identity = await resolver.get_identity(properties={}) + + # Ensure we derive the correct destination and fields + expected_url = URI( + scheme="http", + host="169.254.170.23", + path="/full", + ) + http_request = http_client.send.call_args_list[0].args[0] + assert http_request.destination == expected_url + + assert "Authorization" in http_request.fields + auth_field = http_request.fields.get("Authorization") + assert auth_field.as_string() == "Bearer barfoo" + + _assert_expected_identity(identity) + + +@pytest.mark.asyncio +async def test_resolver_missing_env(): + resp_body = json.dumps(DEFAULT_RESPONSE_DATA) + http_client = mock_http_client_response(200, resp_body.encode("utf-8")) + + with patch.dict(os.environ, {}): + resolver = ContainerCredentialResolver(http_client) + with pytest.raises(SmithyIdentityError): + await resolver.get_identity(properties={})