Skip to content

Commit

Permalink
Update cluster when config file changes (#205)
Browse files Browse the repository at this point in the history
Move the code that expands the config file template to a separate lambda and
custom resource so we can get its sha512 hash and use it to determine when
the cluster needs to be updated.

Resolves #202
  • Loading branch information
cartalla authored Feb 26, 2024
1 parent 9d3c16f commit a8b6555
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 48 deletions.
92 changes: 79 additions & 13 deletions source/cdk/cdk_slurm_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,11 +1049,59 @@ def create_parallel_cluster_lambdas(self):
)
)

createParallelClusterConfigLambdaAsset = s3_assets.Asset(self, "CreateParallelClusterConfigAsset", path="resources/lambdas/CreateParallelClusterConfig")
self.create_parallel_cluster_config_lambda = aws_lambda.Function(
self, "CreateParallelClusterConfigLambda",
function_name=f"{self.stack_name}-CreateParallelClusterConfig",
description="Create ParallelCluster config",
memory_size=2048,
runtime=aws_lambda.Runtime.PYTHON_3_9,
architecture=aws_lambda.Architecture.X86_64,
timeout=Duration.minutes(15),
log_retention=logs.RetentionDays.INFINITE,
handler="CreateParallelClusterConfig.lambda_handler",
code=aws_lambda.Code.from_bucket(createParallelClusterConfigLambdaAsset.bucket, createParallelClusterConfigLambdaAsset.s3_object_key),
layers=[self.parallel_cluster_lambda_layer],
environment = {
'ClusterName': self.config['slurm']['ClusterName'],
'ErrorSnsTopicArn': self.config.get('ErrorSnsTopicArn', ''),
'ParallelClusterConfigS3Bucket': self.assets_bucket,
'ParallelClusterConfigYamlTemplateS3Key': self.parallel_cluster_config_template_yaml_s3_key,
'ParallelClusterConfigYamlS3Key': self.parallel_cluster_config_yaml_s3_key,
'Region': self.cluster_region
}
)
self.create_parallel_cluster_config_lambda.add_to_role_policy(
statement=iam.PolicyStatement(
effect=iam.Effect.ALLOW,
actions=[
's3:DeleteObject',
's3:GetObject',
's3:PutObject'
],
resources=[
f"arn:{Aws.PARTITION}:s3:::{self.assets_bucket}/{self.config['slurm']['ClusterName']}/*",
f"arn:{Aws.PARTITION}:s3:::{self.assets_bucket}/{self.config['slurm']['ClusterName']}/{self.parallel_cluster_config_template_yaml_s3_key}",
f"arn:{Aws.PARTITION}:s3:::{self.assets_bucket}/{self.config['slurm']['ClusterName']}/{self.parallel_cluster_config_yaml_s3_key}"
]
)
)
if 'ErrorSnsTopicArn' in self.config:
self.create_parallel_cluster_config_lambda.add_to_role_policy(
statement=iam.PolicyStatement(
effect=iam.Effect.ALLOW,
actions=[
'sns:Publish'
],
resources=[self.config['ErrorSnsTopicArn']]
)
)

createParallelClusterLambdaAsset = s3_assets.Asset(self, "CreateParallelClusterAsset", path="resources/lambdas/CreateParallelCluster")
self.create_parallel_cluster_lambda = aws_lambda.Function(
self, "CreateParallelClusterLambda",
function_name=f"{self.stack_name}-CreateParallelCluster",
description="Create ParallelCluster from json string",
description="Create ParallelCluster",
memory_size=2048,
runtime=aws_lambda.Runtime.PYTHON_3_9,
architecture=aws_lambda.Architecture.X86_64,
Expand Down Expand Up @@ -2380,7 +2428,7 @@ def create_parallel_cluster_config(self):
index = 0
for extra_mount_sg_name, extra_mount_sg in self.extra_mount_security_groups[fs_type].items():
template_var = f"ExtraMountSecurityGroupId{index}"
self.create_parallel_cluster_lambda.add_environment(
self.create_parallel_cluster_config_lambda.add_environment(
key = template_var,
value = extra_mount_sg.security_group_id
)
Expand Down Expand Up @@ -2838,50 +2886,64 @@ def create_parallel_cluster_config(self):
self.parallel_cluster_config['SharedStorage'].append(parallel_cluster_storage_dict)

# Save the config template to s3.
self.parallel_cluster_config_template_yaml = yaml.dump(self.parallel_cluster_config)
self.parallel_cluster_config_template_yaml_hash = sha512()
self.parallel_cluster_config_template_yaml_hash.update(bytes(self.parallel_cluster_config_template_yaml, 'utf-8'))
self.assets_hash.update(bytes(self.parallel_cluster_config_template_yaml, 'utf-8'))
self.s3_client.put_object(
Bucket = self.assets_bucket,
Key = self.parallel_cluster_config_template_yaml_s3_key,
Body = yaml.dump(self.parallel_cluster_config)
Body = self.parallel_cluster_config_template_yaml
)

self.build_config_files = CustomResource(
self, "BuildConfigFiles",
service_token = self.create_build_files_lambda.function_arn
)

self.create_parallel_cluster_lambda.add_environment(
self.create_parallel_cluster_config_lambda.add_environment(
key = 'ParallelClusterAssetReadPolicyArn',
value = self.parallel_cluster_asset_read_policy.managed_policy_arn
)
self.create_parallel_cluster_lambda.add_environment(
self.create_parallel_cluster_config_lambda.add_environment(
key = 'ParallelClusterJwtWritePolicyArn',
value = self.parallel_cluster_jwt_write_policy.managed_policy_arn
)
self.create_parallel_cluster_lambda.add_environment(
self.create_parallel_cluster_config_lambda.add_environment(
key = 'ParallelClusterMungeKeyWritePolicyArn',
value = self.parallel_cluster_munge_key_write_policy.managed_policy_arn
)
self.create_parallel_cluster_lambda.add_environment(
self.create_parallel_cluster_config_lambda.add_environment(
key = 'ParallelClusterSnsPublishPolicyArn',
value = self.parallel_cluster_sns_publish_policy.managed_policy_arn
)
self.create_parallel_cluster_lambda.add_environment(
self.create_parallel_cluster_config_lambda.add_environment(
key = 'SlurmCtlSecurityGroupId',
value = self.slurmctl_sg.security_group_id
)
self.create_parallel_cluster_lambda.add_environment(
self.create_parallel_cluster_config_lambda.add_environment(
key = 'SlurmNodeSecurityGroupId',
value = self.slurmnode_sg.security_group_id
)
self.parallel_cluster_config = CustomResource(
self, "ParallelClusterConfig",
service_token = self.create_parallel_cluster_config_lambda.function_arn,
properties = {
'ParallelClusterConfigTemplateYamlHash': self.parallel_cluster_config_template_yaml_hash.hexdigest()
}
)
self.parallel_cluster_config_template_yaml_s3_url = self.parallel_cluster_config.get_att_string('ConfigTemplateYamlS3Url')
self.parallel_cluster_config_yaml_s3_url = self.parallel_cluster_config.get_att_string('ConfigYamlS3Url')
self.parallel_cluster_config_yaml_hash = self.parallel_cluster_config.get_att_string('ConfigYamlHash')
self.assets_hash.update(bytes(self.parallel_cluster_config_yaml_hash, 'utf-8'))

self.parallel_cluster = CustomResource(
self, "ParallelCluster",
service_token = self.create_parallel_cluster_lambda.function_arn,
properties = {
'ParallelClusterConfigHash': self.assets_hash.hexdigest()
'ParallelClusterConfigHash': self.parallel_cluster_config_yaml_hash
}
)
self.parallel_cluster_config_template_yaml_s3_url = self.parallel_cluster.get_att_string('ConfigTemplateYamlS3Url')
self.parallel_cluster_config_yaml_s3_url = self.parallel_cluster.get_att_string('ConfigYamlS3Url')
# The lambda to create an A record for the head node must be built before the parallel cluster.
self.parallel_cluster.node.add_dependency(self.create_head_node_a_record_lambda)
self.parallel_cluster.node.add_dependency(self.update_head_node_lambda)
Expand All @@ -2891,6 +2953,7 @@ def create_parallel_cluster_config(self):
self.parallel_cluster.node.add_dependency(self.configure_res_submitters_lambda)
# Build config files need to be created before cluster so that they can be downloaded as part of on_head_node_configures
self.parallel_cluster.node.add_dependency(self.build_config_files)
self.parallel_cluster.node.add_dependency(self.parallel_cluster_config)

self.call_slurm_rest_api_lambda.node.add_dependency(self.parallel_cluster)

Expand All @@ -2899,7 +2962,7 @@ def create_parallel_cluster_config(self):
self, "UpdateHeadNode",
service_token = self.update_head_node_lambda.function_arn,
properties = {
'ParallelClusterConfigHash': self.assets_hash.hexdigest(),
'ParallelClusterConfigHash': self.parallel_cluster_config_yaml_hash,
}
)
self.update_head_node.node.add_dependency(self.parallel_cluster)
Expand Down Expand Up @@ -2929,6 +2992,9 @@ def create_parallel_cluster_config(self):
CfnOutput(self, "ParallelClusterConfigYamlS3Url",
value = self.parallel_cluster_config_yaml_s3_url
)
CfnOutput(self, "ParallelClusterConfigHash",
value = self.parallel_cluster_config_yaml_hash
)
CfnOutput(self, "PlaybookS3Url",
value = self.playbooks_asset.s3_object_url
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,49 +76,16 @@ def lambda_handler(event, context):
else:
raise KeyError(error_message)

s3_resource = boto3.resource('s3')

yaml_template_key = environ['ParallelClusterConfigYamlTemplateS3Key']
yaml_template_s3_url = f"s3://{environ['ParallelClusterConfigS3Bucket']}/{yaml_template_key}"
yaml_template_config_object = s3_resource.Object(
bucket_name = environ['ParallelClusterConfigS3Bucket'],
key = yaml_template_key
)
parallel_cluster_config_yaml_template = Template(yaml_template_config_object.get()['Body'].read().decode('utf-8'))

template_vars = {}
for template_var in environ:
template_vars[template_var] = environ[template_var]
parallel_cluster_config_yaml = parallel_cluster_config_yaml_template.render(**template_vars)
parallel_cluster_config = yaml.load(parallel_cluster_config_yaml, Loader=yaml.FullLoader)
logger.info(f"HeadNode config:\n{json.dumps(parallel_cluster_config['HeadNode'], indent=4)}")

yaml_key = f"{environ['ParallelClusterConfigYamlS3Key']}"
yaml_s3_url = f"s3://{environ['ParallelClusterConfigS3Bucket']}/{yaml_key}"
yaml_config_object = s3_resource.Object(
bucket_name = environ['ParallelClusterConfigS3Bucket'],
key = yaml_key
)
if requestType == 'Delete':
logging.info(f"Deleting Parallel Cluster yaml config in {yaml_s3_url}")
try:
yaml_config_object.delete()
except:
pass
else:
logging.info(f"Saving Parallel Cluster yaml config in {yaml_s3_url}")
yaml_config_object.put(Body=yaml.dump(parallel_cluster_config, sort_keys=False))

cluster_name = environ['ClusterName']
cluster_region = environ['Region']

logger.info(f"{requestType} request for {cluster_name} in {cluster_region}")

cluster_status = get_cluster_status(cluster_name, cluster_region)
if cluster_status:
valid_statuses = ['CREATE_COMPLETE', 'UPDATE_COMPLETE', 'UPDATE_ROLLBACK_COMPLETE']
invalid_statuses = ['CREATE_IN_PROGRESS', 'UPDATE_IN_PROGRESS', 'DELETE_IN_PROGRESS']
if cluster_status in invalid_statuses:
logger.error(f"{cluster_name} has invalid status: {cluster_status}")
cfnresponse.send(event, context, cfnresponse.FAILED, {'error': f"{cluster_name} in {cluster_status} state."}, physicalResourceId=cluster_name)
return
if requestType == 'Create':
Expand All @@ -135,6 +102,19 @@ def lambda_handler(event, context):
else:
logger.info(f"{cluster_name} doesn't exist.")

yaml_key = f"{environ['ParallelClusterConfigYamlS3Key']}"
yaml_s3_url = f"s3://{environ['ParallelClusterConfigS3Bucket']}/{yaml_key}"

logger.info(f"Getting Parallel Cluster yaml config from {yaml_s3_url}")
s3_client = boto3.client('s3')
parallel_cluster_config_yaml = s3_client.get_object(
Bucket = environ['ParallelClusterConfigS3Bucket'],
Key = yaml_key
)['Body'].read().decode('utf-8')

parallel_cluster_config = yaml.load(parallel_cluster_config_yaml, Loader=yaml.FullLoader)
logger.info(f"HeadNode config:\n{json.dumps(parallel_cluster_config['HeadNode'], indent=4)}")

if requestType == "Create":
logger.info(f"Creating {cluster_name}")
try:
Expand Down Expand Up @@ -277,4 +257,4 @@ def lambda_handler(event, context):
logger.info(f"Published error to {environ['ErrorSnsTopicArn']}")
raise

cfnresponse.send(event, context, cfnresponse.SUCCESS, {'ConfigTemplateYamlS3Url': yaml_template_s3_url, 'ConfigYamlS3Url': yaml_s3_url}, physicalResourceId=cluster_name)
cfnresponse.send(event, context, cfnresponse.SUCCESS, {}, physicalResourceId=cluster_name)
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: MIT-0
Permission is hereby granted, free of charge, to any person obtaining a copy of this
software and associated documentation files (the "Software"), to deal in the Software
without restriction, including without limitation the rights to use, copy, modify,
merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""

'''
Create/update/delete ParallelCluster cluster config file and save to S3 as json and yaml.
'''
import boto3
import cfnresponse
from hashlib import sha512
from jinja2 import Template as Template
import json
import logging
from os import environ as environ
import pcluster.lib as pc
from pcluster.api.errors import BadRequestException, UpdateClusterBadRequestException
from time import sleep
import yaml

logger=logging.getLogger(__file__)
logger_formatter = logging.Formatter('%(levelname)s: %(message)s')
logger_streamHandler = logging.StreamHandler()
logger_streamHandler.setFormatter(logger_formatter)
logger.addHandler(logger_streamHandler)
logger.setLevel(logging.INFO)
logger.propagate = False

def lambda_handler(event, context):
try:
logger.info(f"event:\n{json.dumps(event, indent=4)}")
cluster_name = None
requestType = event['RequestType']
properties = event['ResourceProperties']
required_properties = [
'ParallelClusterConfigTemplateYamlHash'
]
error_message = ""
for property in required_properties:
try:
value = properties[property]
except:
error_message += f"Missing {property} property. "
if error_message:
logger.info(error_message)
if requestType == 'Delete':
cfnresponse.send(event, context, cfnresponse.SUCCESS, {}, physicalResourceId=cluster_name)
return
else:
raise KeyError(error_message)

s3_client = boto3.client('s3')

yaml_template_key = environ['ParallelClusterConfigYamlTemplateS3Key']
yaml_template_s3_url = f"s3://{environ['ParallelClusterConfigS3Bucket']}/{yaml_template_key}"

yaml_key = f"{environ['ParallelClusterConfigYamlS3Key']}"
yaml_s3_url = f"s3://{environ['ParallelClusterConfigS3Bucket']}/{yaml_key}"

parallel_cluster_config_hash = sha512()

if requestType == 'Delete':
logger.info(f"Deleting Parallel Cluster yaml config template in {yaml_template_s3_url}")
try:
s3_client.delete_object(
Bucket = environ['ParallelClusterConfigS3Bucket'],
Key = yaml_template_key
)
except:
pass

logger.info(f"Deleting Parallel Cluster yaml config in {yaml_s3_url}")
try:
s3_client.delete_object(
Bucket = environ['ParallelClusterConfigS3Bucket'],
Key = yaml_key
)
except:
pass
else: # Create or Update
parallel_cluster_config_yaml_template = Template(
s3_client.get_object(
Bucket = environ['ParallelClusterConfigS3Bucket'],
Key = yaml_template_key
)['Body'].read().decode('utf-8'))

template_vars = {}
for template_var in environ:
template_vars[template_var] = environ[template_var]
logger.info(f"template_vars:\n{json.dumps(template_vars, indent=4, sort_keys=True)}")
parallel_cluster_config_yaml = parallel_cluster_config_yaml_template.render(**template_vars)

parallel_cluster_config_hash.update(bytes(parallel_cluster_config_yaml, 'utf-8'))
logger.info(f"Config hash: {parallel_cluster_config_hash.hexdigest()}")

parallel_cluster_config = yaml.load(parallel_cluster_config_yaml, Loader=yaml.FullLoader)
logger.info(f"HeadNode config:\n{json.dumps(parallel_cluster_config['HeadNode'], indent=4)}")

logger.info(f"Saving Parallel Cluster yaml config in {yaml_s3_url}")
s3_client.put_object(
Bucket = environ['ParallelClusterConfigS3Bucket'],
Key = yaml_key,
Body = parallel_cluster_config_yaml
)

except Exception as e:
logger.exception(str(e))
cfnresponse.send(event, context, cfnresponse.FAILED, {'error': str(e)}, physicalResourceId=cluster_name)
sns_client = boto3.client('sns')
sns_client.publish(
TopicArn = environ['ErrorSnsTopicArn'],
Subject = f"{cluster_name} CreateParallelClusterConfig failed",
Message = str(e)
)
logger.info(f"Published error to {environ['ErrorSnsTopicArn']}")
raise

cfnresponse.send(event, context, cfnresponse.SUCCESS, {'ConfigTemplateYamlS3Url': yaml_template_s3_url, 'ConfigYamlS3Url': yaml_s3_url, 'ConfigYamlHash': parallel_cluster_config_hash.hexdigest()}, physicalResourceId=cluster_name)

0 comments on commit a8b6555

Please sign in to comment.