From be9da2899e95d2d82eeee52753da613e28686199 Mon Sep 17 00:00:00 2001 From: Bryan Ayala Date: Thu, 18 May 2023 15:03:47 -0700 Subject: [PATCH] Remove dependency on AWS Encryption SDK for HOOK types (#260) * Remove dependency on AWS Encryption SDK for HOOK types --- src/cloudformation_cli_python_lib/cipher.py | 111 --------------- .../exceptions.py | 4 - src/cloudformation_cli_python_lib/hook.py | 30 +--- src/cloudformation_cli_python_lib/utils.py | 45 +++++- src/setup.py | 1 - tests/lib/cipher_test.py | 134 ------------------ tests/lib/hook_test.py | 56 +------- tests/lib/utils_test.py | 57 +++++++- 8 files changed, 101 insertions(+), 337 deletions(-) delete mode 100644 src/cloudformation_cli_python_lib/cipher.py delete mode 100644 tests/lib/cipher_test.py diff --git a/src/cloudformation_cli_python_lib/cipher.py b/src/cloudformation_cli_python_lib/cipher.py deleted file mode 100644 index 2bb071c..0000000 --- a/src/cloudformation_cli_python_lib/cipher.py +++ /dev/null @@ -1,111 +0,0 @@ -# boto3, botocore, aws_encryption_sdk don't have stub files -import boto3 # type: ignore - -import aws_encryption_sdk # type: ignore -import base64 -import json -import uuid -from aws_encryption_sdk.exceptions import AWSEncryptionSDKClientError # type: ignore -from aws_encryption_sdk.identifiers import CommitmentPolicy # type: ignore -from botocore.client import BaseClient # type: ignore -from botocore.config import Config # type: ignore -from botocore.credentials import ( # type: ignore - DeferredRefreshableCredentials, - create_assume_role_refresher, -) -from botocore.session import Session, get_session # type: ignore -from typing import Optional - -from .exceptions import _EncryptionError -from .utils import Credentials - - -class Cipher: - def decrypt_credentials( - self, encrypted_credentials: str - ) -> Optional[Credentials]: # pragma: no cover - raise NotImplementedError() - - -class KmsCipher(Cipher): - """ - This class is decrypted the encrypted credentials sent by CFN in the request: - * Credentials are encrypted service side - * The encrypted credentials along with an IAM role - * Credentials are decrypetd using the AWS Encryption - SDK while assuming the encryption role - """ - - def __init__( - self, encryption_key_arn: Optional[str], encryption_key_role: Optional[str] - ) -> None: - self._crypto_client = aws_encryption_sdk.EncryptionSDKClient( - commitment_policy=CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT - ) - self._key_provider = None - if encryption_key_arn and encryption_key_role: - self._key_provider = aws_encryption_sdk.StrictAwsKmsMasterKeyProvider( - key_ids=[encryption_key_arn], - botocore_session=self._get_assume_role_session( - encryption_key_role, self._create_client() - ), - ) - - def decrypt_credentials( - self, encrypted_credentials: Optional[str] - ) -> Optional[Credentials]: - if not encrypted_credentials: - return None - - # If no kms key and role arn provided - # Attempt to deserialize unencrypted credentials - # This happens during contract tests - if not self._key_provider: - try: - credentials_data = json.loads(encrypted_credentials) - return Credentials(**credentials_data) - except (json.JSONDecodeError, TypeError, ValueError): - return None - - try: - decrypted_credentials, _decryptor_header = self._crypto_client.decrypt( - source=base64.b64decode(encrypted_credentials), - key_provider=self._key_provider, - ) - credentials_data = json.loads(decrypted_credentials.decode("UTF-8")) - if credentials_data is None: - raise _EncryptionError( - "Failed to decrypt credentials. Decrypted credentials are 'null'." - ) - - return Credentials(**credentials_data) - except ( - AWSEncryptionSDKClientError, - json.JSONDecodeError, - TypeError, - ValueError, - ) as e: - raise _EncryptionError("Failed to decrypt credentials.") from e - - @staticmethod - def _get_assume_role_session( - encryption_key_role: str, client: BaseClient - ) -> Session: - params = {"RoleArn": encryption_key_role, "RoleSessionName": str(uuid.uuid4())} - - session = get_session() - # pylint: disable=protected-access - session._credentials = DeferredRefreshableCredentials( - refresh_using=create_assume_role_refresher(client, params), - method="sts-assume-role", - ) - return session - - @staticmethod - def _create_client() -> BaseClient: - return boto3.client( - "sts", - config=Config( - connect_timeout=10, read_timeout=60, retries={"max_attempts": 3} - ), - ) diff --git a/src/cloudformation_cli_python_lib/exceptions.py b/src/cloudformation_cli_python_lib/exceptions.py index e891f84..7ab9766 100644 --- a/src/cloudformation_cli_python_lib/exceptions.py +++ b/src/cloudformation_cli_python_lib/exceptions.py @@ -106,7 +106,3 @@ def __init__(self, hook_type_name: str, target_type_name: str): class Unknown(_HandlerError): pass - - -class _EncryptionError(Exception): - pass diff --git a/src/cloudformation_cli_python_lib/hook.py b/src/cloudformation_cli_python_lib/hook.py index 5504d6a..04dadb0 100644 --- a/src/cloudformation_cli_python_lib/hook.py +++ b/src/cloudformation_cli_python_lib/hook.py @@ -6,14 +6,7 @@ from typing import Any, Callable, MutableMapping, Optional, Tuple, Type, Union from .boto3_proxy import SessionProxy, _get_boto_session -from .cipher import Cipher, KmsCipher -from .exceptions import ( - AccessDenied, - InternalFailure, - InvalidRequest, - _EncryptionError, - _HandlerError, -) +from .exceptions import InternalFailure, InvalidRequest, _HandlerError from .interface import ( BaseHookHandlerRequest, HandlerErrorCode, @@ -173,31 +166,16 @@ def _parse_request( ]: try: event = HookInvocationRequest.deserialize(event_data) - cipher: Cipher = KmsCipher( - event.requestData.hookEncryptionKeyArn, - event.requestData.hookEncryptionKeyRole, - ) - - caller_credentials = cipher.decrypt_credentials( - event.requestData.callerCredentials - ) - provider_credentials = cipher.decrypt_credentials( - event.requestData.providerCredentials - ) - - caller_sess = _get_boto_session(caller_credentials) - provider_sess = _get_boto_session(provider_credentials) + caller_sess = _get_boto_session(event.requestData.callerCredentials) + provider_sess = _get_boto_session(event.requestData.providerCredentials) # credentials are used when rescheduling, so can't zero them out (for now) invocation_point = HookInvocationPoint[event.actionInvocationPoint] callback_context = event.requestContext.callbackContext or {} - except _EncryptionError as e: - LOG.exception("Failed to decrypt credentials") - raise AccessDenied(f"{e} ({type(e).__name__})") from e except Exception as e: LOG.exception("Invalid request") raise InvalidRequest(f"{e} ({type(e).__name__})") from e - return ((caller_sess, provider_sess)), invocation_point, callback_context, event + return (caller_sess, provider_sess), invocation_point, callback_context, event def _cast_hook_request( self, request: HookInvocationRequest diff --git a/src/cloudformation_cli_python_lib/utils.py b/src/cloudformation_cli_python_lib/utils.py index 0ef80de..f610487 100644 --- a/src/cloudformation_cli_python_lib/utils.py +++ b/src/cloudformation_cli_python_lib/utils.py @@ -204,6 +204,9 @@ def deserialize(cls, json_data: MutableMapping[str, Any]) -> "HookRequestContext return HookRequestContext() return HookRequestContext(**json_data) + def serialize(self) -> Mapping[str, Any]: + return {key: value for key, value in self.__dict__.items() if value is not None} + @dataclass class HookRequestData: @@ -211,15 +214,38 @@ class HookRequestData: targetType: str targetLogicalId: str targetModel: Mapping[str, Any] - callerCredentials: Optional[str] = None - providerCredentials: Optional[str] = None + callerCredentials: Optional[Credentials] = None + providerCredentials: Optional[Credentials] = None providerLogGroupName: Optional[str] = None - hookEncryptionKeyArn: Optional[str] = None - hookEncryptionKeyRole: Optional[str] = None + + def __init__(self, **kwargs: Any) -> None: + dataclass_fields = {f.name for f in fields(self)} + for k, v in kwargs.items(): + if k in dataclass_fields: + setattr(self, k, v) @classmethod def deserialize(cls, json_data: MutableMapping[str, Any]) -> "HookRequestData": - return HookRequestData(**json_data) + req_data = HookRequestData(**json_data) + for key in json_data: + if not key.endswith("Credentials"): + continue + creds = json_data.get(key) + if creds: + cred_data = json.loads(creds) + setattr(req_data, key, Credentials(**cred_data)) + return req_data + + def serialize(self) -> Mapping[str, Any]: + return { + key: {k: v for k, v in value.items() if v is not None} + if key == "targetModel" + else value.__dict__.copy() + if key.endswith("Credentials") + else value + for key, value in self.__dict__.items() + if value is not None + } @dataclass @@ -252,6 +278,15 @@ def deserialize(cls, json_data: MutableMapping[str, Any]) -> Any: ) return event + def serialize(self) -> Mapping[str, Any]: + return { + key: value.serialize() + if key in ("requestData", "requestContext") + else value + for key, value in self.__dict__.items() + if value is not None + } + @dataclass class UnmodelledHookRequest: diff --git a/src/setup.py b/src/setup.py index 7d925c2..4f14dfb 100644 --- a/src/setup.py +++ b/src/setup.py @@ -15,7 +15,6 @@ python_requires=">=3.6", install_requires=[ "boto3>=1.10.20", - "aws-encryption-sdk==3.1.1", 'dataclasses;python_version<"3.7"', ], license="Apache License 2.0", diff --git a/tests/lib/cipher_test.py b/tests/lib/cipher_test.py deleted file mode 100644 index ee451cc..0000000 --- a/tests/lib/cipher_test.py +++ /dev/null @@ -1,134 +0,0 @@ -# pylint: disable=wrong-import-order,line-too-long -import pytest -from cloudformation_cli_python_lib.cipher import KmsCipher -from cloudformation_cli_python_lib.exceptions import _EncryptionError -from cloudformation_cli_python_lib.utils import Credentials - -from aws_encryption_sdk.exceptions import AWSEncryptionSDKClientError -from unittest.mock import Mock, patch - - -def mock_session(): - return Mock(spec_set=["client"]) - - -def test_create_kms_cipher(): - with patch( - "cloudformation_cli_python_lib.cipher.aws_encryption_sdk.StrictAwsKmsMasterKeyProvider", - autospec=True, - ), patch("boto3.client", autospec=True): - cipher = KmsCipher("encryptionKeyArn", "encryptionKeyRole") - assert cipher - - -def test_decrypt_credentials_success(): - expected_credentials = { - "accessKeyId": "IASAYK835GAIFHAHEI23", - "secretAccessKey": "66iOGPN5LnpZorcLr8Kh25u8AbjHVllv5poh2O0", - "sessionToken": "lameHS2vQOknSHWhdFYTxm2eJc1JMn9YBNI4nV4mXue945KPL6DH" - "fW8EsUQT5zwssYEC1NvYP9yD6Y5s5lKR3chflOHPFsIe6eqg", - } - - with patch( - "cloudformation_cli_python_lib.cipher.aws_encryption_sdk.StrictAwsKmsMasterKeyProvider", - autospec=True, - ), patch( - "cloudformation_cli_python_lib.cipher.aws_encryption_sdk.EncryptionSDKClient.decrypt" - ) as mock_decrypt, patch( - "boto3.client", autospec=True - ): - mock_decrypt.return_value = ( - b'{"accessKeyId": "IASAYK835GAIFHAHEI23", "secretAccessKey": "66iOGPN5LnpZorcLr8Kh25u8AbjHVllv5poh2O0", "sessionToken": "lameHS2vQOknSHWhdFYTxm2eJc1JMn9YBNI4nV4mXue945KPL6DHfW8EsUQT5zwssYEC1NvYP9yD6Y5s5lKR3chflOHPFsIe6eqg"}', # noqa: B950 - Mock(), - ) - cipher = KmsCipher("encryptionKeyArn", "encryptionKeyRole") - - credentials = cipher.decrypt_credentials( - "ewogICAgICAgICAgICAiYWNjZXNzS2V5SWQiOiAiSUFTQVlLODM1R0FJRkhBSEVJMjMiLAogICAg" - ) - assert Credentials(**expected_credentials) == credentials - - -def test_decrypt_credentials_fail(): - with patch( - "cloudformation_cli_python_lib.cipher.aws_encryption_sdk.StrictAwsKmsMasterKeyProvider", - autospec=True, - ), patch( - "cloudformation_cli_python_lib.cipher.aws_encryption_sdk.EncryptionSDKClient.decrypt" - ) as mock_decrypt, pytest.raises( - _EncryptionError - ) as excinfo, patch( - "boto3.client", autospec=True - ): - mock_decrypt.side_effect = AWSEncryptionSDKClientError() - cipher = KmsCipher("encryptionKeyArn", "encryptionKeyRole") - cipher.decrypt_credentials( - "ewogICAgICAgICAgICAiYWNjZXNzS2V5SWQiOiAiSUFTQVlLODM1R0FJRkhBSEVJMjMiLAogICAg" - ) - assert str(excinfo.value) == "Failed to decrypt credentials." - - -def test_decrypt_credentials_returns_null_fail(): - with patch( - "cloudformation_cli_python_lib.cipher.aws_encryption_sdk.StrictAwsKmsMasterKeyProvider", - autospec=True, - ), patch( - "cloudformation_cli_python_lib.cipher.aws_encryption_sdk.EncryptionSDKClient.decrypt" - ) as mock_decrypt, pytest.raises( - _EncryptionError - ) as excinfo, patch( - "boto3.client", autospec=True - ): - mock_decrypt.return_value = ( - b"null", - Mock(), - ) - cipher = KmsCipher("encryptionKeyArn", "encryptionKeyRole") - cipher.decrypt_credentials( - "ewogICAgICAgICAgICAiYWNjZXNzS2V5SWQiOiAiSUFTQVlLODM1R0FJRkhBSEVJMjMiLAogICAg" - ) - assert ( - str(excinfo.value) - == "Failed to decrypt credentials. Decrypted credentials are 'null'." - ) - - -@pytest.mark.parametrize( - "encryption_key_arn,encryption_key_role", - [ - (None, "encryptionKeyRole"), - ("encryptionKeyArn", None), - (None, None), - ], -) -def test_decrypt_unencrypted_credentials_success( - encryption_key_arn, encryption_key_role -): - expected_credentials = { - "accessKeyId": "IASAYK835GAIFHAHEI23", - "secretAccessKey": "66iOGPN5LnpZorcLr8Kh25u8AbjHVllv5poh2O0", - "sessionToken": "lameHS2vQOknSHWhdFYTxm2eJc1JMn9YBNI4nV4mXue945KPL6DH" - "fW8EsUQT5zwssYEC1NvYP9yD6Y5s5lKR3chflOHPFsIe6eqg", - } - - cipher = KmsCipher(encryption_key_arn, encryption_key_role) - - credentials = cipher.decrypt_credentials( - '{"accessKeyId": "IASAYK835GAIFHAHEI23", "secretAccessKey": "66iOGPN5LnpZorcLr8Kh25u8AbjHVllv5poh2O0", "sessionToken": "lameHS2vQOknSHWhdFYTxm2eJc1JMn9YBNI4nV4mXue945KPL6DHfW8EsUQT5zwssYEC1NvYP9yD6Y5s5lKR3chflOHPFsIe6eqg"}' # noqa: B950 - ) - assert Credentials(**expected_credentials) == credentials - - -@pytest.mark.parametrize( - "encryption_key_arn,encryption_key_role", - [ - (None, "encryptionKeyRole"), - ("encryptionKeyArn", None), - (None, None), - ], -) -def test_decrypt_unencrypted_credentials_fail(encryption_key_arn, encryption_key_role): - cipher = KmsCipher(encryption_key_arn, encryption_key_role) - - credentials = cipher.decrypt_credentials("{ Not JSON") # noqa: B950 - assert not credentials diff --git a/tests/lib/hook_test.py b/tests/lib/hook_test.py index 41ccd85..65f993a 100644 --- a/tests/lib/hook_test.py +++ b/tests/lib/hook_test.py @@ -3,11 +3,7 @@ import pytest from cloudformation_cli_python_lib import Hook -from cloudformation_cli_python_lib.exceptions import ( - InternalFailure, - InvalidRequest, - _EncryptionError, -) +from cloudformation_cli_python_lib.exceptions import InternalFailure, InvalidRequest from cloudformation_cli_python_lib.hook import _ensure_serialize from cloudformation_cli_python_lib.interface import ( BaseModel, @@ -88,11 +84,8 @@ def test_entrypoint_success(): with patch( "cloudformation_cli_python_lib.hook.HookProviderLogHandler.setup" ) as mock_log_delivery, patch( - "cloudformation_cli_python_lib.hook.KmsCipher.decrypt_credentials" - ) as mock_cipher, patch( "cloudformation_cli_python_lib.hook._get_boto_session", autospec=True ): - mock_cipher.side_effect = lambda c: Credentials(**json.loads(c)) event = hook.__call__.__wrapped__( # pylint: disable=no-member hook, ENTRYPOINT_PAYLOAD, None ) @@ -125,11 +118,8 @@ def _deserialize(cls, json_data): "cloudformation_cli_python_lib.hook.MetricsPublisherProxy" ) as mock_metrics, patch( "cloudformation_cli_python_lib.hook.Hook._invoke_handler" - ) as mock__invoke_handler, patch( - "cloudformation_cli_python_lib.hook.KmsCipher.decrypt_credentials" - ) as mock_cipher: + ) as mock__invoke_handler: mock__invoke_handler.side_effect = InvalidRequest("handler failed") - mock_cipher.side_effect = lambda c: Credentials(**json.loads(c)) event = hook.__call__.__wrapped__( # pylint: disable=no-member hook, ENTRYPOINT_PAYLOAD, None ) @@ -158,12 +148,7 @@ def test_entrypoint_with_context(): with patch( "cloudformation_cli_python_lib.hook.HookProviderLogHandler.setup" - ), patch( - "cloudformation_cli_python_lib.hook.KmsCipher.decrypt_credentials" - ) as mock_cipher, patch( - "cloudformation_cli_python_lib.hook._get_boto_session", autospec=True - ): - mock_cipher.side_effect = lambda c: Credentials(**json.loads(c)) + ), patch("cloudformation_cli_python_lib.hook._get_boto_session", autospec=True): hook.__call__.__wrapped__(hook, payload, None) # pylint: disable=no-member mock_handler.assert_called_once() @@ -203,34 +188,6 @@ def test_entrypoint_success_without_caller_provider_creds(): assert event == expected -def test_entrypoint_encryption_error_raises_access_denied(): - @dataclass - class TypeConfigurationModel(BaseModel): - a_string: str - - @classmethod - def _deserialize(cls, json_data): - return cls("test") - - hook = Hook(Mock(), TypeConfigurationModel) - - with patch( - "cloudformation_cli_python_lib.hook.HookProviderLogHandler.setup" - ), patch("cloudformation_cli_python_lib.hook.MetricsPublisherProxy"), patch( - "cloudformation_cli_python_lib.hook.KmsCipher.decrypt_credentials" - ) as mock_cipher: - mock_cipher.side_effect = _EncryptionError("Failed to decrypt credentials.") - event = hook.__call__.__wrapped__( # pylint: disable=no-member - hook, ENTRYPOINT_PAYLOAD, None - ) - - assert event["errorCode"] == "AccessDenied" - assert event["hookStatus"] == "FAILED" - assert event["callbackDelaySeconds"] == 0 - assert event["clientRequestToken"] == "4b90a7e4-b790-456b-a937-0cfdfa211dfe" - assert "Failed to decrypt credentials" in event["message"] - - def test_cast_hook_request_invalid_request(hook): request = HookInvocationRequest.deserialize(ENTRYPOINT_PAYLOAD) request.requestData = None @@ -248,12 +205,7 @@ def test__parse_request_valid_request_and__cast_hook_request(): hook = Hook(TYPE_NAME, mock_type_configuration_model) - with patch( - "cloudformation_cli_python_lib.hook._get_boto_session" - ) as mock_session, patch( - "cloudformation_cli_python_lib.hook.KmsCipher.decrypt_credentials" - ) as mock_cipher: - mock_cipher.side_effect = lambda c: Credentials(**json.loads(c)) + with patch("cloudformation_cli_python_lib.hook._get_boto_session") as mock_session: ret = hook._parse_request(ENTRYPOINT_PAYLOAD) sessions, invocation_point, callback_context, request = ret diff --git a/tests/lib/utils_test.py b/tests/lib/utils_test.py index cf81cb9..041c080 100644 --- a/tests/lib/utils_test.py +++ b/tests/lib/utils_test.py @@ -1,9 +1,10 @@ -# pylint: disable=protected-access +# pylint: disable=protected-access,line-too-long import pytest from cloudformation_cli_python_lib.exceptions import InvalidRequest from cloudformation_cli_python_lib.interface import BaseModel from cloudformation_cli_python_lib.utils import ( HandlerRequest, + HookInvocationRequest, KitchenSinkEncoder, UnmodelledRequest, deserialize_list, @@ -97,6 +98,57 @@ def test_handler_request_serde_roundtrip(): for k, v in payload.items() if v is not None and k not in undesired } + assert ser == expected + + +def test_hook_handler_request_serde_roundtrip(): + payload = { + "awsAccountId": "123456789012", + "clientRequestToken": "4b90a7e4-b790-456b-a937-0cfdfa211dfe", + "actionInvocationPoint": "CREATE_PRE_PROVISION", + "hookTypeName": "AWS::Test::TestHook", + "hookTypeVersion": "1.0", + "requestContext": { + "invocation": 1, + "callbackContext": {}, + }, + "requestData": { + "callerCredentials": '{"accessKeyId": "IASAYK835GAIFHAHEI23", "secretAccessKey": "66iOGPN5LnpZorcLr8Kh25u8AbjHVllv5poh2O0", "sessionToken": "lameHS2vQOknSHWhdFYTxm2eJc1JMn9YBNI4nV4mXue945KPL6DHfW8EsUQT5zwssYEC1NvYP9yD6Y5s5lKR3chflOHPFsIe6eqg"}', # noqa: B950 + "providerCredentials": None, + "providerLogGroupName": "providerLoggingGroupName", + "targetName": "AWS::Test::Resource", + "targetType": "RESOURCE", + "targetLogicalId": "myResource", + "hookEncryptionKeyArn": None, + "hookEncryptionKeyRole": None, + "targetModel": { + "resourceProperties": {}, + "previousResourceProperties": None, + }, + "undesiredField": "value", + }, + "stackId": "arn:aws:cloudformation:us-east-1:123456789012:stack/SampleStack/e" + "722ae60-fe62-11e8-9a0e-0ae8cc519968", + "hookModel": {}, + } + undesired = "undesiredField" + ser = HookInvocationRequest.deserialize(payload).serialize() + # remove None values from payload + expected = { + k: { + k: {k: v for k, v in v.items() if v is not None} + if k == "targetModel" + else json.loads(v) + if k.endswith("Credentials") + else v + for k, v in payload[k].items() + if v is not None and k not in undesired + } + if k in ("requestData", "requestContext") + else v + for k, v in payload.items() + if v is not None and k not in undesired + } assert ser == expected @@ -144,6 +196,3 @@ def test_deserialize_list_empty(): def test_deserialize_list_invalid(): with pytest.raises(InvalidRequest): deserialize_list([(1, 2)], BaseModel) - - -# test_hook_response