diff --git a/moto/stepfunctions/models.py b/moto/stepfunctions/models.py index f8008e0cf6de..f4726fed1d76 100644 --- a/moto/stepfunctions/models.py +++ b/moto/stepfunctions/models.py @@ -34,6 +34,9 @@ def __init__( definition: str, roleArn: str, tags: Optional[List[Dict[str, str]]] = None, + encryptionConfiguration: Optional[Dict[str, Any]] = None, + loggingConfiguration: Optional[Dict[str, Any]] = None, + tracingConfiguration: Optional[Dict[str, Any]] = None, ): self.creation_date = iso_8601_datetime_with_milliseconds() self.update_date = self.creation_date @@ -46,6 +49,12 @@ def __init__( if tags: self.add_tags(tags) self.version = 0 + self.type = "STANDARD" + self.encryptionConfiguration = encryptionConfiguration or { + "type": "AWS_OWNED_KEY" + } + self.loggingConfiguration = loggingConfiguration or {"level": "OFF"} + self.tracingConfiguration = tracingConfiguration or {"enabled": False} def start_execution( self, @@ -523,6 +532,9 @@ def create_state_machine( roleArn: str, tags: Optional[List[Dict[str, str]]] = None, publish: Optional[bool] = None, + loggingConfiguration: Optional[Dict[str, Any]] = None, + tracingConfiguration: Optional[Dict[str, Any]] = None, + encryptionConfiguration: Optional[Dict[str, Any]] = None, ) -> StateMachine: self._validate_name(name) self._validate_role_arn(roleArn) @@ -530,7 +542,16 @@ def create_state_machine( try: return self.describe_state_machine(arn) except StateMachineDoesNotExist: - state_machine = StateMachine(arn, name, definition, roleArn, tags) + state_machine = StateMachine( + arn, + name, + definition, + roleArn, + tags, + encryptionConfiguration, + loggingConfiguration, + tracingConfiguration, + ) if publish: state_machine.version += 1 self.state_machines.append(state_machine) @@ -562,13 +583,20 @@ def update_state_machine( role_arn: Optional[str] = None, logging_configuration: Optional[Dict[str, bool]] = None, tracing_configuration: Optional[Dict[str, bool]] = None, + encryption_configuration: Optional[Dict[str, Any]] = None, publish: Optional[bool] = None, ) -> StateMachine: sm = self.describe_state_machine(arn) - updates = { + updates: Dict[str, Any] = { "definition": definition, "roleArn": role_arn, } + if encryption_configuration: + updates["encryptionConfiguration"] = encryption_configuration + if logging_configuration: + updates["loggingConfiguration"] = logging_configuration + if tracing_configuration: + updates["tracingConfiguration"] = tracing_configuration sm.update(**updates) if publish: sm.version += 1 diff --git a/moto/stepfunctions/parser/api.py b/moto/stepfunctions/parser/api.py index 173867900fff..579b0b395e5c 100644 --- a/moto/stepfunctions/parser/api.py +++ b/moto/stepfunctions/parser/api.py @@ -445,6 +445,12 @@ class CreateStateMachineAliasOutput(TypedDict, total=False): creationDate: Timestamp +class EncryptionConfiguration(TypedDict, total=False): + type: str + kmsDataKeyReusePeriodSeconds: Optional[int] + kmsKeyId: Optional[str] + + class TracingConfiguration(TypedDict, total=False): enabled: Optional[Enabled] @@ -472,6 +478,7 @@ class LoggingConfiguration(TypedDict, total=False): "loggingConfiguration": Optional[LoggingConfiguration], "tags": Optional[TagList], "tracingConfiguration": Optional[TracingConfiguration], + "encryptionConfiguration": Optional[EncryptionConfiguration], "publish": Optional[Publish], "versionDescription": Optional[VersionDescription], }, @@ -595,6 +602,7 @@ class DescribeStateMachineForExecutionOutput(TypedDict, total=False): updateDate: Timestamp loggingConfiguration: Optional[LoggingConfiguration] tracingConfiguration: Optional[TracingConfiguration] + encryptionConfiguration: Optional[EncryptionConfiguration] mapRunArn: Optional[str] label: Optional[MapRunLabel] revisionId: Optional[RevisionId] @@ -612,6 +620,7 @@ class DescribeStateMachineForExecutionOutput(TypedDict, total=False): "creationDate": Timestamp, "loggingConfiguration": Optional[LoggingConfiguration], "tracingConfiguration": Optional[TracingConfiguration], + "encryptionConfiguration": Optional[EncryptionConfiguration], "label": Optional[MapRunLabel], "revisionId": Optional[RevisionId], "description": Optional[VersionDescription], diff --git a/moto/stepfunctions/parser/backend/state_machine.py b/moto/stepfunctions/parser/backend/state_machine.py index d04494bcb276..fe6f54865143 100644 --- a/moto/stepfunctions/parser/backend/state_machine.py +++ b/moto/stepfunctions/parser/backend/state_machine.py @@ -10,6 +10,7 @@ from moto.stepfunctions.parser.api import ( Definition, DescribeStateMachineOutput, + EncryptionConfiguration, LoggingConfiguration, Name, RevisionId, @@ -36,6 +37,7 @@ class StateMachineInstance: logging_config: Optional[LoggingConfiguration] tags: Optional[TagList] tracing_config: Optional[TracingConfiguration] + encryption_config: Optional[EncryptionConfiguration] def __init__( self, @@ -48,6 +50,7 @@ def __init__( logging_config: Optional[LoggingConfiguration] = None, tags: Optional[TagList] = None, tracing_config: Optional[TracingConfiguration] = None, + encryption_config: Optional[EncryptionConfiguration] = None, ): self.name = name self.arn = arn @@ -61,6 +64,7 @@ def __init__( self.logging_config = logging_config self.tags = tags self.tracing_config = tracing_config + self.encryption_config = encryption_config def describe(self) -> DescribeStateMachineOutput: describe_output = DescribeStateMachineOutput( @@ -72,6 +76,7 @@ def describe(self) -> DescribeStateMachineOutput: type=self.sm_type, creationDate=self.create_date, loggingConfiguration=self.logging_config, + encryptionConfiguration=self.encryption_config, ) if self.revision_id: describe_output["revisionId"] = self.revision_id @@ -219,6 +224,7 @@ def __init__( logging_config=state_machine_revision.logging_config, tags=state_machine_revision.tags, tracing_config=state_machine_revision.tracing_config, + encryption_config=state_machine_revision.encryption_config, ) self.source_arn = state_machine_revision.arn self.revision_id = state_machine_revision.revision_id diff --git a/moto/stepfunctions/parser/models.py b/moto/stepfunctions/parser/models.py index ca75eff7f919..cf5fb0951cd5 100644 --- a/moto/stepfunctions/parser/models.py +++ b/moto/stepfunctions/parser/models.py @@ -6,6 +6,7 @@ from moto.stepfunctions.models import StateMachine, StepFunctionBackend from moto.stepfunctions.parser.api import ( Definition, + EncryptionConfiguration, ExecutionStatus, GetExecutionHistoryOutput, InvalidDefinition, @@ -81,6 +82,9 @@ def create_state_machine( roleArn: str, tags: Optional[List[Dict[str, str]]] = None, publish: Optional[bool] = None, + loggingConfiguration: Optional[LoggingConfiguration] = None, + tracingConfiguration: Optional[TracingConfiguration] = None, + encryptionConfiguration: Optional[EncryptionConfiguration] = None, ) -> StateMachine: StepFunctionsParserBackend._validate_definition(definition=definition) @@ -90,6 +94,9 @@ def create_state_machine( roleArn=roleArn, tags=tags, publish=publish, + loggingConfiguration=loggingConfiguration, + tracingConfiguration=tracingConfiguration, + encryptionConfiguration=encryptionConfiguration, ) def send_task_heartbeat(self, task_token: TaskToken) -> SendTaskHeartbeatOutput: @@ -192,14 +199,21 @@ def update_state_machine( role_arn: str = None, logging_configuration: LoggingConfiguration = None, tracing_configuration: TracingConfiguration = None, + encryption_configuration: EncryptionConfiguration = None, publish: Optional[bool] = None, version_description: VersionDescription = None, ) -> StateMachine: if not any( - [definition, role_arn, logging_configuration, tracing_configuration] + [ + definition, + role_arn, + logging_configuration, + tracing_configuration, + encryption_configuration, + ] ): raise MissingRequiredParameter( - "Either the definition, the role ARN, the LoggingConfiguration, or the TracingConfiguration must be specified" + "Either the definition, the role ARN, the LoggingConfiguration, the EncryptionConfiguration or the TracingConfiguration must be specified" ) if definition is not None: @@ -211,6 +225,7 @@ def update_state_machine( role_arn, logging_configuration=logging_configuration, tracing_configuration=tracing_configuration, + encryption_configuration=encryption_configuration, publish=publish, ) diff --git a/moto/stepfunctions/parser/resource_providers/aws_stepfunctions_statemachine.py b/moto/stepfunctions/parser/resource_providers/aws_stepfunctions_statemachine.py index 51a8bbc11a71..a4f33f9e1adb 100644 --- a/moto/stepfunctions/parser/resource_providers/aws_stepfunctions_statemachine.py +++ b/moto/stepfunctions/parser/resource_providers/aws_stepfunctions_statemachine.py @@ -31,6 +31,7 @@ class StepFunctionsStateMachineProperties(TypedDict): StateMachineType: Optional[str] Tags: Optional[List[TagsEntry]] TracingConfiguration: Optional[TracingConfiguration] + EncryptionConfiguration: Optional[EncryptionConfiguration] class CloudWatchLogsLogGroup(TypedDict): @@ -51,6 +52,12 @@ class TracingConfiguration(TypedDict): Enabled: Optional[bool] +class EncryptionConfiguration(TypedDict): + Type: Optional[str] + KmsKeyID: Optional[str] + KmsDataKeyReusePeriodSeconds: Optional[int] + + class S3Location(TypedDict): Bucket: Optional[str] Key: Optional[str] diff --git a/moto/stepfunctions/parser/resource_providers/aws_stepfunctions_statemachine.schema.json b/moto/stepfunctions/parser/resource_providers/aws_stepfunctions_statemachine.schema.json index 607e1a9bccda..ecc15e42f672 100644 --- a/moto/stepfunctions/parser/resource_providers/aws_stepfunctions_statemachine.schema.json +++ b/moto/stepfunctions/parser/resource_providers/aws_stepfunctions_statemachine.schema.json @@ -78,6 +78,20 @@ } } }, + "EncryptionConfiguration": { + "type": "object", + "additionalProperties": false, + "properties": { + "Type": { + "type": "string" + }, + "KmsKeyId": { + "type": "string" + }, + "KmsDataKeyReusePeriodSeconds": { + "type": "integer" + } + }, "S3Location": { "type": "object", "additionalProperties": false, @@ -166,6 +180,9 @@ "TracingConfiguration": { "$ref": "#/definitions/TracingConfiguration" }, + "EncryptionConfiguration": { + "$ref": "#/definitions/EncryptionConfiguration" + }, "DefinitionS3Location": { "$ref": "#/definitions/S3Location" }, diff --git a/moto/stepfunctions/responses.py b/moto/stepfunctions/responses.py index b258b460c299..3191bd9998ca 100644 --- a/moto/stepfunctions/responses.py +++ b/moto/stepfunctions/responses.py @@ -29,12 +29,18 @@ def create_state_machine(self) -> TYPE_RESPONSE: roleArn = self._get_param("roleArn") tags = self._get_param("tags") publish = self._get_param("publish") + encryptionConfiguration = self._get_param("encryptionConfiguration") + loggingConfiguration = self._get_param("loggingConfiguration") + tracingConfiguration = self._get_param("tracingConfiguration") state_machine = self.stepfunction_backend.create_state_machine( name=name, definition=definition, roleArn=roleArn, tags=tags, publish=publish, + loggingConfiguration=loggingConfiguration, + tracingConfiguration=tracingConfiguration, + encryptionConfiguration=encryptionConfiguration, ) response = { "creationDate": state_machine.creation_date, @@ -80,6 +86,10 @@ def _describe_state_machine(self, state_machine_arn: str) -> TYPE_RESPONSE: "name": state_machine.name, "roleArn": state_machine.roleArn, "status": "ACTIVE", + "type": state_machine.type, + "encryptionConfiguration": state_machine.encryptionConfiguration, + "tracingConfiguration": state_machine.tracingConfiguration, + "loggingConfiguration": state_machine.loggingConfiguration, } return 200, {}, json.dumps(response) @@ -93,12 +103,16 @@ def update_state_machine(self) -> TYPE_RESPONSE: definition = self._get_param("definition") role_arn = self._get_param("roleArn") tracing_config = self._get_param("tracingConfiguration") + encryption_config = self._get_param("encryptionConfiguration") + logging_config = self._get_param("loggingConfiguration") publish = self._get_param("publish") state_machine = self.stepfunction_backend.update_state_machine( arn=arn, definition=definition, role_arn=role_arn, tracing_configuration=tracing_config, + encryption_configuration=encryption_config, + logging_configuration=logging_config, publish=publish, ) response = { diff --git a/tests/test_stepfunctions/parser/test_stepfunctions.py b/tests/test_stepfunctions/parser/test_stepfunctions.py index 96c8d2402620..91f97776502a 100644 --- a/tests/test_stepfunctions/parser/test_stepfunctions.py +++ b/tests/test_stepfunctions/parser/test_stepfunctions.py @@ -110,14 +110,27 @@ def test_version_is_only_available_when_published(): PolicyName="allowLambdaInvoke", RoleName=role_name, ) + kms_key_id = boto3.client("kms", region_name="us-east-1").create_key()[ + "KeyMetadata" + ]["KeyId"] sleep(10 if allow_aws_request() else 0) client = boto3.client("stepfunctions", region_name="us-east-1") try: name1 = f"sfn_name_{str(uuid4())[0:6]}" + encryption_config = { + "kmsDataKeyReusePeriodSeconds": 60, + "kmsKeyId": kms_key_id, + "type": "CUSTOMER_MANAGED_CMK", + } response = client.create_state_machine( - name=name1, definition=simple_definition, roleArn=sfn_role + name=name1, + definition=simple_definition, + roleArn=sfn_role, + tracingConfiguration={"enabled": False}, + loggingConfiguration={"level": "OFF"}, + encryptionConfiguration=encryption_config, ) assert "stateMachineVersionArn" not in response arn1 = response["stateMachineArn"] @@ -140,7 +153,11 @@ def test_version_is_only_available_when_published(): assert response["stateMachineVersionArn"] == f"{arn2}:1" resp = client.update_state_machine( - stateMachineArn=arn2, publish=True, tracingConfiguration={"enabled": True} + stateMachineArn=arn2, + publish=True, + tracingConfiguration={"enabled": True}, + loggingConfiguration={"level": "OFF"}, + encryptionConfiguration=encryption_config, ) assert resp["stateMachineVersionArn"] == f"{arn2}:2" finally: diff --git a/tests/test_stepfunctions/test_stepfunctions.py b/tests/test_stepfunctions/test_stepfunctions.py index 1f7688ff764e..dccfa013eeaf 100644 --- a/tests/test_stepfunctions/test_stepfunctions.py +++ b/tests/test_stepfunctions/test_stepfunctions.py @@ -28,11 +28,9 @@ def test_state_machine_creation_succeeds(): client = boto3.client("stepfunctions", region_name=region) name = "example_step_function" - # response = client.create_state_machine( name=name, definition=str(simple_definition), roleArn=_get_default_role() ) - # assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 assert isinstance(response["creationDate"], datetime) assert response["stateMachineArn"] == ( @@ -40,6 +38,30 @@ def test_state_machine_creation_succeeds(): ) +@mock_aws +def test_state_machine_with_cmk(): + client = boto3.client("stepfunctions", region_name=region) + kms_key_id = boto3.client("kms", region_name=region).create_key()["KeyMetadata"][ + "KeyId" + ] + name = "example_step_function_cmk" + encryption_config = { + "kmsDataKeyReusePeriodSeconds": 60, + "kmsKeyId": kms_key_id, + "type": "CUSTOMER_MANAGED_CMK", + } + + state_machine_arn = client.create_state_machine( + name=name, + definition=str(simple_definition), + roleArn=_get_default_role(), + encryptionConfiguration=encryption_config, + )["stateMachineArn"] + + desc = client.describe_state_machine(stateMachineArn=state_machine_arn) + assert desc["encryptionConfiguration"] == encryption_config + + @mock_aws def test_state_machine_creation_fails_with_invalid_names(): client = boto3.client("stepfunctions", region_name=region) @@ -170,10 +192,32 @@ def test_update_state_machine(): updated_definition = str(simple_definition).replace( "DefaultState", "DefaultStateUpdated" ) + kms_key_id = boto3.client("kms", region_name=region).create_key()["KeyMetadata"][ + "KeyId" + ] + encryption_config = { + "kmsDataKeyReusePeriodSeconds": 60, + "kmsKeyId": kms_key_id, + "type": "CUSTOMER_MANAGED_CMK", + } + updated_logging_config = { + "level": "ALL", + "destinations": [ + { + "cloudWatchLogsLogGroup": { + "logGroupArn": "arn:aws:logs:us-east-1:123456789012:log-group:my-log-group" + } + } + ], + } + updated_tracing_config = {"enabled": True} resp = client.update_state_machine( stateMachineArn=state_machine_arn, definition=updated_definition, roleArn=updated_role, + encryptionConfiguration=encryption_config, + loggingConfiguration=updated_logging_config, + tracingConfiguration=updated_tracing_config, ) assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 assert isinstance(resp["updateDate"], datetime) @@ -181,6 +225,9 @@ def test_update_state_machine(): desc = client.describe_state_machine(stateMachineArn=state_machine_arn) assert desc["definition"] == updated_definition assert desc["roleArn"] == updated_role + assert desc["encryptionConfiguration"] == encryption_config + assert desc["loggingConfiguration"] == updated_logging_config + assert desc["tracingConfiguration"] == updated_tracing_config @mock_aws @@ -279,6 +326,10 @@ def test_state_machine_creation_can_be_described(): assert desc["roleArn"] == _get_default_role() assert desc["stateMachineArn"] == sm["stateMachineArn"] assert desc["status"] == "ACTIVE" + assert desc["type"] == "STANDARD" + assert desc["encryptionConfiguration"] == {"type": "AWS_OWNED_KEY"} + assert desc["loggingConfiguration"] == {"level": "OFF"} + assert desc["tracingConfiguration"] == {"enabled": False} @mock_aws @@ -362,10 +413,21 @@ def test_state_machine_untagging_non_existent_resource_fails(): @mock_aws def test_state_machine_tagging(): client = boto3.client("stepfunctions", region_name=region) + # Test tags are added on resource creation tags = [ {"key": "tag_key1", "value": "tag_value1"}, {"key": "tag_key2", "value": "tag_value2"}, ] + machine = client.create_state_machine( + name="test-with-tags", + definition=str(simple_definition), + roleArn=_get_default_role(), + tags=tags, + ) + resp = client.list_tags_for_resource(resourceArn=machine["stateMachineArn"]) + assert resp["tags"] == tags + + # Test tags are added after creation with tag_resource machine = client.create_state_machine( name="test", definition=str(simple_definition), roleArn=_get_default_role() ) diff --git a/tests/test_stepfunctions/test_stepfunctions_cloudformation.py b/tests/test_stepfunctions/test_stepfunctions_cloudformation.py index a1a69d0fb913..ca6836f40ca6 100644 --- a/tests/test_stepfunctions/test_stepfunctions_cloudformation.py +++ b/tests/test_stepfunctions/test_stepfunctions_cloudformation.py @@ -31,6 +31,9 @@ def test_state_machine_cloudformation(): "StateMachineType": "STANDARD", "DefinitionString": definition, "RoleArn": role_arn, + "TracingConfiguration": {"Enabled": False}, + "EncryptionConfiguration": {"Type": "AWS_OWNED_KEY"}, + "LoggingConfiguration": {"Level": "OFF"}, "Tags": [ {"Key": "key1", "Value": "value1"}, {"Key": "key2", "Value": "value2"}, @@ -51,6 +54,9 @@ def test_state_machine_cloudformation(): assert state_machine["name"] == output["StateMachineName"] assert state_machine["roleArn"] == role_arn assert state_machine["definition"] == definition + assert state_machine["tracingConfiguration"] == {"enabled": False} + assert state_machine["encryptionConfiguration"]["type"] == "AWS_OWNED_KEY" + assert state_machine["loggingConfiguration"]["level"] == "OFF" tags = sf.list_tags_for_resource(resourceArn=output["StateMachineArn"]).get("tags") for i, tag in enumerate(tags, 1): assert tag["key"] == f"key{i}"