From b5ba4b6b45ea822b0b6fe2bbf032e67f1ab5fad7 Mon Sep 17 00:00:00 2001 From: liamschn Date: Thu, 19 Dec 2024 17:17:11 -0700 Subject: [PATCH] fix multiple accounts for eval job --- .../sra_bedrock_check_eval_job_bucket/app.py | 24 +++++++++++++++---- .../templates/sra-bedrock-org-main.yaml | 4 ++-- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/aws_sra_examples/solutions/genai/bedrock_org/lambda/rules/sra_bedrock_check_eval_job_bucket/app.py b/aws_sra_examples/solutions/genai/bedrock_org/lambda/rules/sra_bedrock_check_eval_job_bucket/app.py index ab64cfad..db2d687e 100644 --- a/aws_sra_examples/solutions/genai/bedrock_org/lambda/rules/sra_bedrock_check_eval_job_bucket/app.py +++ b/aws_sra_examples/solutions/genai/bedrock_org/lambda/rules/sra_bedrock_check_eval_job_bucket/app.py @@ -29,6 +29,7 @@ # Define the AWS Config rule parameters RULE_NAME = "sra-bedrock-check-eval-job-bucket" SERVICE_NAME = "bedrock.amazonaws.com" +BUCKET_NAME = "" def evaluate_compliance(event: dict, context: Any) -> tuple[str, str]: # noqa: U100, CCR001, C901 @@ -41,14 +42,23 @@ def evaluate_compliance(event: dict, context: Any) -> tuple[str, str]: # noqa: Returns: tuple[str, str]: The compliance status and annotation """ + global BUCKET_NAME LOGGER.info(f"Evaluate Compliance Event: {event}") # Initialize AWS clients s3 = boto3.client("s3") - + sts = boto3.client("sts") + account = sts.get_caller_identity().get("Account") # Get rule parameters params = ast.literal_eval(event["ruleParameters"]) LOGGER.info(f"Parameters: {params}") - bucket_name = params.get("BucketName", "") + LOGGER.info(f"Account: {account}") + buckets = params.get("Buckets", {account: ""}) + LOGGER.info(f"Buckets: {buckets}") + buckets = ast.literal_eval(buckets) + bucket_name = buckets.get(account, "") + LOGGER.info(f"Bucket Name: {bucket_name}") + BUCKET_NAME = bucket_name + check_retention = params.get("CheckRetention", "true").lower() != "false" check_encryption = params.get("CheckEncryption", "true").lower() != "false" check_logging = params.get("CheckLogging", "true").lower() != "false" @@ -56,6 +66,8 @@ def evaluate_compliance(event: dict, context: Any) -> tuple[str, str]: # noqa: check_versioning = params.get("CheckVersioning", "true").lower() != "false" # Check if the bucket exists + if bucket_name == "": + return build_evaluation("NOT_APPLICABLE", "No bucket name provided") if not check_bucket_exists(bucket_name): return build_evaluation("NOT_APPLICABLE", f"Bucket {bucket_name} does not exist or is not accessible") @@ -64,6 +76,7 @@ def evaluate_compliance(event: dict, context: Any) -> tuple[str, str]: # noqa: # Check retention if check_retention: + LOGGER.info(f"Checking retention policy for bucket {bucket_name}") try: retention = s3.get_bucket_lifecycle_configuration(Bucket=bucket_name) if not any(rule.get("Expiration") for rule in retention.get("Rules", [])): @@ -75,6 +88,7 @@ def evaluate_compliance(event: dict, context: Any) -> tuple[str, str]: # noqa: # Check encryption if check_encryption: + LOGGER.info(f"Checking encryption for bucket {bucket_name}") try: encryption = s3.get_bucket_encryption(Bucket=bucket_name) if "ServerSideEncryptionConfiguration" not in encryption: @@ -86,6 +100,7 @@ def evaluate_compliance(event: dict, context: Any) -> tuple[str, str]: # noqa: # Check logging if check_logging: + LOGGER.info(f"Checking logging for bucket {bucket_name}") logging = s3.get_bucket_logging(Bucket=bucket_name) if "LoggingEnabled" not in logging: compliance_type = "NON_COMPLIANT" @@ -93,6 +108,7 @@ def evaluate_compliance(event: dict, context: Any) -> tuple[str, str]: # noqa: # Check object locking if check_object_locking: + LOGGER.info(f"Checking object locking for bucket {bucket_name}") try: object_locking = s3.get_object_lock_configuration(Bucket=bucket_name) if "ObjectLockConfiguration" not in object_locking: @@ -104,6 +120,7 @@ def evaluate_compliance(event: dict, context: Any) -> tuple[str, str]: # noqa: # Check versioning if check_versioning: + LOGGER.info(f"Checking versioning for bucket {bucket_name}") versioning = s3.get_bucket_versioning(Bucket=bucket_name) if versioning.get("Status") != "Enabled": compliance_type = "NON_COMPLIANT" @@ -157,12 +174,11 @@ def lambda_handler(event: dict, context: Any) -> None: LOGGER.info(f"Lambda Handler Event: {event}") evaluation = evaluate_compliance(event, context) config = boto3.client("config") - params = ast.literal_eval(event["ruleParameters"]) config.put_evaluations( Evaluations=[ { "ComplianceResourceType": "AWS::S3::Bucket", - "ComplianceResourceId": params.get("BucketName"), + "ComplianceResourceId": BUCKET_NAME, "ComplianceType": evaluation["ComplianceType"], # type: ignore "Annotation": evaluation["Annotation"], # type: ignore "OrderingTimestamp": evaluation["OrderingTimestamp"], # type: ignore diff --git a/aws_sra_examples/solutions/genai/bedrock_org/templates/sra-bedrock-org-main.yaml b/aws_sra_examples/solutions/genai/bedrock_org/templates/sra-bedrock-org-main.yaml index 0090f925..0e9855f7 100644 --- a/aws_sra_examples/solutions/genai/bedrock_org/templates/sra-bedrock-org-main.yaml +++ b/aws_sra_examples/solutions/genai/bedrock_org/templates/sra-bedrock-org-main.yaml @@ -96,9 +96,9 @@ Parameters: pBedrockModelEvalBucketRuleParams: Type: String - Default: '{"deploy": "true", "accounts": ["444455556666"], "regions": ["us-west-2"], "input_params": {"BucketName": "model-invocation-log-bucket-444455556666-us-west-2"}}' + Default: '{"deploy": "true", "accounts": ["444455556666"], "regions": ["us-west-2"], "input_params": {"Buckets": {"444455556666": "model-invocation-log-bucket-444455556666"},"CheckRetention": "true", "CheckEncryption": "true", "CheckLogging": "true", "CheckObjectLocking": "true", "CheckVersioning": "true"}}' Description: Bedrock Model Evaluation Job Config Rule Parameters - AllowedPattern: ^\{"deploy"\s*:\s*"(true|false)",\s*"accounts"\s*:\s*\[((?:"[0-9]+"(?:\s*,\s*)?)*)\],\s*"regions"\s*:\s*\[((?:"[a-z0-9-]+"(?:\s*,\s*)?)*)\],\s*"input_params"\s*:\s*(\{\s*(?:"BucketName"\s*:\s*"([a-zA-Z0-9-]*)"\s*)?})\}$ + AllowedPattern: ^\{"deploy"\s*:\s*"(true|false)",\s*"accounts"\s*:\s*\[((?:"[0-9]+"(?:\s*,\s*)?)*)\],\s*"regions"\s*:\s*\[((?:"[a-z0-9-]+"(?:\s*,\s*)?)*)\],\s*"input_params"\s*:\s*(\{\s*(?:"Buckets"\s*:\s*(\{\s*"[0-9]+"\s*:\s*"[a-zA-Z0-9-]*"\s*)?},\s*"CheckRetention"\s*:\s*"(true|false)",\s*"CheckEncryption"\s*:\s*"(true|false)",\s*"CheckLogging"\s*:\s*"(true|false)",\s*"CheckObjectLocking"\s*:\s*"(true|false)",\s*"CheckVersioning"\s*:\s*"(true|false)"\s*)}})$ ConstraintDescription: "Must be a valid JSON string containing: 'deploy' (true/false), 'accounts' (array of account numbers), 'regions' (array of region names), and 'input_params' object (can be empty or contain 'BucketName'). Arrays can be empty.