Skip to content

Commit

Permalink
addd basic stack cdk code
Browse files Browse the repository at this point in the history
  • Loading branch information
raylrui committed Nov 21, 2024
1 parent 6c88075 commit 993c411
Show file tree
Hide file tree
Showing 6 changed files with 294 additions and 0 deletions.
31 changes: 31 additions & 0 deletions lib/workload/stateless/stacks/client-websocket-conn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# WebSocket API Stack

A serverless WebSocket API implementation using AWS CDK, API Gateway WebSocket APIs, Lambda, and DynamoDB for real-time communication.

## Architecture

![Architecture Diagram](./websocket-api-arch.png)


### Components

- **API Gateway WebSocket API**: Handles WebSocket connections
- **Lambda Functions**: Process WebSocket events
- **DynamoDB**: Stores connection information

## Features

- Real-time bidirectional communication
- Connection management
- Message broadcasting
- Secure VPC deployment
- Automatic scaling
- Connection cleanup

## Prerequisites

- AWS CDK CLI
- Node.js & npm
- Python 3.12
- AWS Account and configured credentials
- VPC with private subnets
153 changes: 153 additions & 0 deletions lib/workload/stateless/stacks/client-websocket-conn/deploy/stack.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import { Stack, RemovalPolicy, StackProps, Duration } from 'aws-cdk-lib';
import { Table, AttributeType } from 'aws-cdk-lib/aws-dynamodb';
import { Vpc, SecurityGroup, VpcLookupOptions, IVpc, ISecurityGroup } from 'aws-cdk-lib/aws-ec2';
import { WebSocketApi, WebSocketStage } from 'aws-cdk-lib/aws-apigatewayv2';
import { WebSocketLambdaIntegration } from 'aws-cdk-lib/aws-apigatewayv2-integrations';
import { PolicyStatement } from 'aws-cdk-lib/aws-iam';
import { PythonFunction } from '@aws-cdk/aws-lambda-python-alpha';
import { Runtime, Architecture } from 'aws-cdk-lib/aws-lambda';
import { Construct } from 'constructs';
import * as path from 'path';

export interface WebSocketApiStackProps extends StackProps {
connectionTableName: string;
websocketApigatewayName: string;
connectionFunctionName: string;
disconnectFunctionName: string;
messageFunctionName: string;

lambdaSecurityGroupName: string;
vpcProps: VpcLookupOptions;
}

export class WebSocketApiStack extends Stack {
private readonly lambdaRuntimePythonVersion = Runtime.PYTHON_3_12;
private readonly props: WebSocketApiStackProps;
private vpc: IVpc;
private lambdaSG: ISecurityGroup;

constructor(scope: Construct, id: string, props: WebSocketApiStackProps) {
super(scope, id, props);

this.props = props;

this.vpc = Vpc.fromLookup(this, 'MainVpc', props.vpcProps);
this.lambdaSG = SecurityGroup.fromLookupByName(
this,
'LambdaSecurityGroup',
props.lambdaSecurityGroupName,
this.vpc
);

// DynamoDB Table for storing connection IDs
const connectionTable = new Table(this, 'WebSocketConnections', {
tableName: props.connectionTableName,
partitionKey: {
name: 'ConnectionId',
type: AttributeType.STRING,
},
removalPolicy: RemovalPolicy.DESTROY, // For demo purposes, not recommended for production
});

// DynamoDB Table for message history
// const messageHistoryTable = new Table(this, "WebSocketMessageHistory", {
// partitionKey: {
// name: "messageId",
// type: AttributeType.STRING,
// },
// timeToLiveAttribute: "ttl", // Enable TTL
// removalPolicy: RemovalPolicy.DESTROY,
// });

// Lambda function for $connect
const connectHandler = this.createPythonFunction(props.connectionFunctionName, {
index: 'connect.py',
handler: 'lambda_handler',
timeout: Duration.minutes(2),
});

// Lambda function for $disconnect
const disconnectHandler = this.createPythonFunction(props.disconnectFunctionName, {
index: 'disconnect.py',
handler: 'lambda_handler',
timeout: Duration.minutes(2),
});

// Lambda function for $default (broadcast messages)
const messageHandler = this.createPythonFunction(props.messageFunctionName, {
index: 'message.py',
handler: 'lambda_handler',
timeout: Duration.minutes(2),
});

// Grant permissions to Lambda functions
connectionTable.grantReadWriteData(connectHandler);
connectionTable.grantReadWriteData(disconnectHandler);
connectionTable.grantReadWriteData(messageHandler);
// messageHistoryTable.grantReadData(connectHandler);
// messageHistoryTable.grantReadWriteData(messageHandler);

// WebSocket API
const api = new WebSocketApi(this, props.websocketApigatewayName, {
apiName: props.websocketApigatewayName,
connectRouteOptions: {
integration: new WebSocketLambdaIntegration('ConnectIntegration', connectHandler),
},
disconnectRouteOptions: {
integration: new WebSocketLambdaIntegration('DisconnectIntegration', disconnectHandler),
},
defaultRouteOptions: {
integration: new WebSocketLambdaIntegration('DefaultIntegration', messageHandler),
},
});

api.addRoute('sendMessage', {
integration: new WebSocketLambdaIntegration('SendMessageIntegration', messageHandler),
});

// Deploy WebSocket API to a stage
const stage = new WebSocketStage(this, 'WebSocketStage', {
webSocketApi: api,
stageName: 'dev',
autoDeploy: true,
});

// Create the WebSocket API endpoint URL
const webSocketApiEndpoint = `${api.apiEndpoint}/${stage.stageName}`;

const commonEnvironment = {
CONNECTION_TABLE: connectionTable.tableName,
// MESSAGE_HISTORY_TABLE: messageHistoryTable.tableName,
WEBSOCKET_API_ENDPOINT: webSocketApiEndpoint,
};

// Add environment variables individually
for (const [key, value] of Object.entries(commonEnvironment)) {
connectHandler.addEnvironment(key, value);
disconnectHandler.addEnvironment(key, value);
messageHandler.addEnvironment(key, value);
}

// Grant permissions to the message handler
messageHandler.addToRolePolicy(
new PolicyStatement({
actions: ['execute-api:ManageConnections'],
resources: [
`arn:aws:execute-api:${this.region}:${this.account}:${api.apiId}/dev/POST/@connections/*`,
],
})
);
}

private createPythonFunction(name: string, props: object): PythonFunction {
return new PythonFunction(this, name, {
entry: path.join(__dirname, '../lambda'),
runtime: this.lambdaRuntimePythonVersion,
securityGroups: [this.lambdaSG],
vpc: this.vpc,
vpcSubnets: { subnets: this.vpc.privateSubnets },
architecture: Architecture.ARM_64,
...props,
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import boto3
import os

def lambda_handler(event, context):
# Get table names from environment variables
connections_table_name = os.environ['CONNECTION_TABLE']

dynamodb = boto3.resource('dynamodb')
connections_table = dynamodb.Table(connections_table_name)

connection_id = event['requestContext']['connectionId']

try:
# Store connection
connections_table.put_item(
Item={'ConnectionId': connection_id}
)
except Exception as e:
return {'statusCode': 500, 'body': str(e)}

return {'statusCode': 200}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import boto3
import os

def lambda_handler(event, context):
# Get table name from environment variable
connections_table_name = os.environ['CONNECTION_TABLE']

dynamodb = boto3.resource('dynamodb')
table = dynamodb.Table(connections_table_name)

connection_id = event['requestContext']['connectionId']

try:
table.delete_item(Key={'ConnectionId': connection_id})
return {'statusCode': 200}
except Exception as e:
return {'statusCode': 500, 'body': str(e)}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import boto3
import json
import os

def lambda_handler(event, context):

assert os.environ['CONNECTION_TABLE'] is not None, "CONNECTION_TABLE environment variable is not set"
assert os.environ['WEBSOCKET_API_ENDPOINT'] is not None, "WEBSOCKET_API_ENDPOINT environment variable is not set"

# Get environment variables
connections_table_name = os.environ['CONNECTION_TABLE']

# connections URL with replace wss:// header to https
websocket_endpoint = os.environ['WEBSOCKET_API_ENDPOINT'].replace('wss://', 'https://')

dynamodb = boto3.resource('dynamodb')
connections_table = dynamodb.Table(connections_table_name)

# Initialize API Gateway client
apigw_client = boto3.client('apigatewaymanagementapi',
endpoint_url=websocket_endpoint)

print(f"Received event: {event}, websocket endpoint: {websocket_endpoint}")

try:
# Initialize response data
data = event
response_data = {
'type': data.get('type', ''),
'message': data.get('message', '')
}

# Broadcast to all connections
connections = connections_table.scan()['Items']

for connection in connections:
connection_id = connection['ConnectionId']
try:
apigw_client.post_to_connection(
ConnectionId=connection_id,
Data=json.dumps(response_data)
)
except apigw_client.exceptions.GoneException:
# Remove stale connection
connections_table.delete_item(Key={'connectionId': connection_id})
except Exception as e:
print(f"Failed to post message to {connection_id}: {e}")

return {'statusCode': 200}

except json.JSONDecodeError:
return {
'statusCode': 400,
'body': json.dumps({'error': 'Invalid JSON in request body'})
}
except KeyError as e:
return {
'statusCode': 400,
'body': json.dumps({'error': f'Missing required field: {str(e)}'})
}
except Exception as e:
print(f"Error: {e}")
return {
'statusCode': 500,
'body': json.dumps({'error': 'Internal server error'})
}


# test case
# curl -X POST https://<api-id>.execute-api.<region>.amazonaws.com/Prod/message -H "Content-Type: application/json" -d '{"type": "test", "message": "Hello, world!"}'
# invoke lambda function from aws console, cmd: aws lambda invoke --function-name <function-name> --payload '{"type": "test", "message": "Hello, world!"}' response.json
# check cloudwatch logs for response
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 993c411

Please sign in to comment.