Skip to content

Commit

Permalink
Step Functions - Updated State Machine to support logging, tracing & …
Browse files Browse the repository at this point in the history
…encryption config. (#8200)
  • Loading branch information
zkarpinski authored Oct 8, 2024
1 parent 21873d1 commit 5cbfacf
Show file tree
Hide file tree
Showing 10 changed files with 189 additions and 8 deletions.
32 changes: 30 additions & 2 deletions moto/stepfunctions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def __init__(
definition: str,
roleArn: str,
tags: Optional[List[Dict[str, str]]] = None,
encryptionConfiguration: Optional[Dict[str, Any]] = None,
loggingConfiguration: Optional[Dict[str, Any]] = None,
tracingConfiguration: Optional[Dict[str, Any]] = None,
):
self.creation_date = iso_8601_datetime_with_milliseconds()
self.update_date = self.creation_date
Expand All @@ -46,6 +49,12 @@ def __init__(
if tags:
self.add_tags(tags)
self.version = 0
self.type = "STANDARD"
self.encryptionConfiguration = encryptionConfiguration or {
"type": "AWS_OWNED_KEY"
}
self.loggingConfiguration = loggingConfiguration or {"level": "OFF"}
self.tracingConfiguration = tracingConfiguration or {"enabled": False}

def start_execution(
self,
Expand Down Expand Up @@ -523,14 +532,26 @@ def create_state_machine(
roleArn: str,
tags: Optional[List[Dict[str, str]]] = None,
publish: Optional[bool] = None,
loggingConfiguration: Optional[Dict[str, Any]] = None,
tracingConfiguration: Optional[Dict[str, Any]] = None,
encryptionConfiguration: Optional[Dict[str, Any]] = None,
) -> StateMachine:
self._validate_name(name)
self._validate_role_arn(roleArn)
arn = f"arn:{get_partition(self.region_name)}:states:{self.region_name}:{self.account_id}:stateMachine:{name}"
try:
return self.describe_state_machine(arn)
except StateMachineDoesNotExist:
state_machine = StateMachine(arn, name, definition, roleArn, tags)
state_machine = StateMachine(
arn,
name,
definition,
roleArn,
tags,
encryptionConfiguration,
loggingConfiguration,
tracingConfiguration,
)
if publish:
state_machine.version += 1
self.state_machines.append(state_machine)
Expand Down Expand Up @@ -562,13 +583,20 @@ def update_state_machine(
role_arn: Optional[str] = None,
logging_configuration: Optional[Dict[str, bool]] = None,
tracing_configuration: Optional[Dict[str, bool]] = None,
encryption_configuration: Optional[Dict[str, Any]] = None,
publish: Optional[bool] = None,
) -> StateMachine:
sm = self.describe_state_machine(arn)
updates = {
updates: Dict[str, Any] = {
"definition": definition,
"roleArn": role_arn,
}
if encryption_configuration:
updates["encryptionConfiguration"] = encryption_configuration
if logging_configuration:
updates["loggingConfiguration"] = logging_configuration
if tracing_configuration:
updates["tracingConfiguration"] = tracing_configuration
sm.update(**updates)
if publish:
sm.version += 1
Expand Down
9 changes: 9 additions & 0 deletions moto/stepfunctions/parser/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,12 @@ class CreateStateMachineAliasOutput(TypedDict, total=False):
creationDate: Timestamp


class EncryptionConfiguration(TypedDict, total=False):
type: str
kmsDataKeyReusePeriodSeconds: Optional[int]
kmsKeyId: Optional[str]


class TracingConfiguration(TypedDict, total=False):
enabled: Optional[Enabled]

Expand Down Expand Up @@ -472,6 +478,7 @@ class LoggingConfiguration(TypedDict, total=False):
"loggingConfiguration": Optional[LoggingConfiguration],
"tags": Optional[TagList],
"tracingConfiguration": Optional[TracingConfiguration],
"encryptionConfiguration": Optional[EncryptionConfiguration],
"publish": Optional[Publish],
"versionDescription": Optional[VersionDescription],
},
Expand Down Expand Up @@ -595,6 +602,7 @@ class DescribeStateMachineForExecutionOutput(TypedDict, total=False):
updateDate: Timestamp
loggingConfiguration: Optional[LoggingConfiguration]
tracingConfiguration: Optional[TracingConfiguration]
encryptionConfiguration: Optional[EncryptionConfiguration]
mapRunArn: Optional[str]
label: Optional[MapRunLabel]
revisionId: Optional[RevisionId]
Expand All @@ -612,6 +620,7 @@ class DescribeStateMachineForExecutionOutput(TypedDict, total=False):
"creationDate": Timestamp,
"loggingConfiguration": Optional[LoggingConfiguration],
"tracingConfiguration": Optional[TracingConfiguration],
"encryptionConfiguration": Optional[EncryptionConfiguration],
"label": Optional[MapRunLabel],
"revisionId": Optional[RevisionId],
"description": Optional[VersionDescription],
Expand Down
6 changes: 6 additions & 0 deletions moto/stepfunctions/parser/backend/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from moto.stepfunctions.parser.api import (
Definition,
DescribeStateMachineOutput,
EncryptionConfiguration,
LoggingConfiguration,
Name,
RevisionId,
Expand All @@ -36,6 +37,7 @@ class StateMachineInstance:
logging_config: Optional[LoggingConfiguration]
tags: Optional[TagList]
tracing_config: Optional[TracingConfiguration]
encryption_config: Optional[EncryptionConfiguration]

def __init__(
self,
Expand All @@ -48,6 +50,7 @@ def __init__(
logging_config: Optional[LoggingConfiguration] = None,
tags: Optional[TagList] = None,
tracing_config: Optional[TracingConfiguration] = None,
encryption_config: Optional[EncryptionConfiguration] = None,
):
self.name = name
self.arn = arn
Expand All @@ -61,6 +64,7 @@ def __init__(
self.logging_config = logging_config
self.tags = tags
self.tracing_config = tracing_config
self.encryption_config = encryption_config

def describe(self) -> DescribeStateMachineOutput:
describe_output = DescribeStateMachineOutput(
Expand All @@ -72,6 +76,7 @@ def describe(self) -> DescribeStateMachineOutput:
type=self.sm_type,
creationDate=self.create_date,
loggingConfiguration=self.logging_config,
encryptionConfiguration=self.encryption_config,
)
if self.revision_id:
describe_output["revisionId"] = self.revision_id
Expand Down Expand Up @@ -219,6 +224,7 @@ def __init__(
logging_config=state_machine_revision.logging_config,
tags=state_machine_revision.tags,
tracing_config=state_machine_revision.tracing_config,
encryption_config=state_machine_revision.encryption_config,
)
self.source_arn = state_machine_revision.arn
self.revision_id = state_machine_revision.revision_id
Expand Down
19 changes: 17 additions & 2 deletions moto/stepfunctions/parser/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from moto.stepfunctions.models import StateMachine, StepFunctionBackend
from moto.stepfunctions.parser.api import (
Definition,
EncryptionConfiguration,
ExecutionStatus,
GetExecutionHistoryOutput,
InvalidDefinition,
Expand Down Expand Up @@ -81,6 +82,9 @@ def create_state_machine(
roleArn: str,
tags: Optional[List[Dict[str, str]]] = None,
publish: Optional[bool] = None,
loggingConfiguration: Optional[LoggingConfiguration] = None,
tracingConfiguration: Optional[TracingConfiguration] = None,
encryptionConfiguration: Optional[EncryptionConfiguration] = None,
) -> StateMachine:
StepFunctionsParserBackend._validate_definition(definition=definition)

Expand All @@ -90,6 +94,9 @@ def create_state_machine(
roleArn=roleArn,
tags=tags,
publish=publish,
loggingConfiguration=loggingConfiguration,
tracingConfiguration=tracingConfiguration,
encryptionConfiguration=encryptionConfiguration,
)

def send_task_heartbeat(self, task_token: TaskToken) -> SendTaskHeartbeatOutput:
Expand Down Expand Up @@ -192,14 +199,21 @@ def update_state_machine(
role_arn: str = None,
logging_configuration: LoggingConfiguration = None,
tracing_configuration: TracingConfiguration = None,
encryption_configuration: EncryptionConfiguration = None,
publish: Optional[bool] = None,
version_description: VersionDescription = None,
) -> StateMachine:
if not any(
[definition, role_arn, logging_configuration, tracing_configuration]
[
definition,
role_arn,
logging_configuration,
tracing_configuration,
encryption_configuration,
]
):
raise MissingRequiredParameter(
"Either the definition, the role ARN, the LoggingConfiguration, or the TracingConfiguration must be specified"
"Either the definition, the role ARN, the LoggingConfiguration, the EncryptionConfiguration or the TracingConfiguration must be specified"
)

if definition is not None:
Expand All @@ -211,6 +225,7 @@ def update_state_machine(
role_arn,
logging_configuration=logging_configuration,
tracing_configuration=tracing_configuration,
encryption_configuration=encryption_configuration,
publish=publish,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class StepFunctionsStateMachineProperties(TypedDict):
StateMachineType: Optional[str]
Tags: Optional[List[TagsEntry]]
TracingConfiguration: Optional[TracingConfiguration]
EncryptionConfiguration: Optional[EncryptionConfiguration]


class CloudWatchLogsLogGroup(TypedDict):
Expand All @@ -51,6 +52,12 @@ class TracingConfiguration(TypedDict):
Enabled: Optional[bool]


class EncryptionConfiguration(TypedDict):
Type: Optional[str]
KmsKeyID: Optional[str]
KmsDataKeyReusePeriodSeconds: Optional[int]


class S3Location(TypedDict):
Bucket: Optional[str]
Key: Optional[str]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@
}
}
},
"EncryptionConfiguration": {
"type": "object",
"additionalProperties": false,
"properties": {
"Type": {
"type": "string"
},
"KmsKeyId": {
"type": "string"
},
"KmsDataKeyReusePeriodSeconds": {
"type": "integer"
}
},
"S3Location": {
"type": "object",
"additionalProperties": false,
Expand Down Expand Up @@ -166,6 +180,9 @@
"TracingConfiguration": {
"$ref": "#/definitions/TracingConfiguration"
},
"EncryptionConfiguration": {
"$ref": "#/definitions/EncryptionConfiguration"
},
"DefinitionS3Location": {
"$ref": "#/definitions/S3Location"
},
Expand Down
14 changes: 14 additions & 0 deletions moto/stepfunctions/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,18 @@ def create_state_machine(self) -> TYPE_RESPONSE:
roleArn = self._get_param("roleArn")
tags = self._get_param("tags")
publish = self._get_param("publish")
encryptionConfiguration = self._get_param("encryptionConfiguration")
loggingConfiguration = self._get_param("loggingConfiguration")
tracingConfiguration = self._get_param("tracingConfiguration")
state_machine = self.stepfunction_backend.create_state_machine(
name=name,
definition=definition,
roleArn=roleArn,
tags=tags,
publish=publish,
loggingConfiguration=loggingConfiguration,
tracingConfiguration=tracingConfiguration,
encryptionConfiguration=encryptionConfiguration,
)
response = {
"creationDate": state_machine.creation_date,
Expand Down Expand Up @@ -80,6 +86,10 @@ def _describe_state_machine(self, state_machine_arn: str) -> TYPE_RESPONSE:
"name": state_machine.name,
"roleArn": state_machine.roleArn,
"status": "ACTIVE",
"type": state_machine.type,
"encryptionConfiguration": state_machine.encryptionConfiguration,
"tracingConfiguration": state_machine.tracingConfiguration,
"loggingConfiguration": state_machine.loggingConfiguration,
}
return 200, {}, json.dumps(response)

Expand All @@ -93,12 +103,16 @@ def update_state_machine(self) -> TYPE_RESPONSE:
definition = self._get_param("definition")
role_arn = self._get_param("roleArn")
tracing_config = self._get_param("tracingConfiguration")
encryption_config = self._get_param("encryptionConfiguration")
logging_config = self._get_param("loggingConfiguration")
publish = self._get_param("publish")
state_machine = self.stepfunction_backend.update_state_machine(
arn=arn,
definition=definition,
role_arn=role_arn,
tracing_configuration=tracing_config,
encryption_configuration=encryption_config,
logging_configuration=logging_config,
publish=publish,
)
response = {
Expand Down
21 changes: 19 additions & 2 deletions tests/test_stepfunctions/parser/test_stepfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,27 @@ def test_version_is_only_available_when_published():
PolicyName="allowLambdaInvoke",
RoleName=role_name,
)
kms_key_id = boto3.client("kms", region_name="us-east-1").create_key()[
"KeyMetadata"
]["KeyId"]
sleep(10 if allow_aws_request() else 0)

client = boto3.client("stepfunctions", region_name="us-east-1")

try:
name1 = f"sfn_name_{str(uuid4())[0:6]}"
encryption_config = {
"kmsDataKeyReusePeriodSeconds": 60,
"kmsKeyId": kms_key_id,
"type": "CUSTOMER_MANAGED_CMK",
}
response = client.create_state_machine(
name=name1, definition=simple_definition, roleArn=sfn_role
name=name1,
definition=simple_definition,
roleArn=sfn_role,
tracingConfiguration={"enabled": False},
loggingConfiguration={"level": "OFF"},
encryptionConfiguration=encryption_config,
)
assert "stateMachineVersionArn" not in response
arn1 = response["stateMachineArn"]
Expand All @@ -140,7 +153,11 @@ def test_version_is_only_available_when_published():
assert response["stateMachineVersionArn"] == f"{arn2}:1"

resp = client.update_state_machine(
stateMachineArn=arn2, publish=True, tracingConfiguration={"enabled": True}
stateMachineArn=arn2,
publish=True,
tracingConfiguration={"enabled": True},
loggingConfiguration={"level": "OFF"},
encryptionConfiguration=encryption_config,
)
assert resp["stateMachineVersionArn"] == f"{arn2}:2"
finally:
Expand Down
Loading

0 comments on commit 5cbfacf

Please sign in to comment.