diff --git a/config/stacks/clientWebsocketApi.ts b/config/stacks/clientWebsocketApi.ts index 4585b9464..1ac34c193 100644 --- a/config/stacks/clientWebsocketApi.ts +++ b/config/stacks/clientWebsocketApi.ts @@ -1,5 +1,5 @@ import { WebSocketApiStackProps } from '../../lib/workload/stateless/stacks/client-websocket-conn/deploy'; -import { AppStage, vpcProps } from '../constants'; +import { AppStage, vpcProps, region, cognitoUserPoolIdParameterName } from '../constants'; export const getWebSocketApiStackProps = (stage: AppStage): WebSocketApiStackProps => { return { @@ -9,5 +9,7 @@ export const getWebSocketApiStackProps = (stage: AppStage): WebSocketApiStackPro vpcProps: vpcProps, websocketApiEndpointParameterName: `/orcabus/client-websocket-api-endpoint`, websocketStageName: stage, + cognitoRegion: region, + cognitoUserPoolIdParameterName: cognitoUserPoolIdParameterName, }; }; diff --git a/lib/workload/stateless/stacks/client-websocket-conn/deploy/index.ts b/lib/workload/stateless/stacks/client-websocket-conn/deploy/index.ts index b46b3d629..cd079b65b 100644 --- a/lib/workload/stateless/stacks/client-websocket-conn/deploy/index.ts +++ b/lib/workload/stateless/stacks/client-websocket-conn/deploy/index.ts @@ -9,6 +9,11 @@ import { Runtime, Architecture } from 'aws-cdk-lib/aws-lambda'; import { Construct } from 'constructs'; import * as path from 'path'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; +// import { +// WebSocketIamAuthorizer, +// WebSocketLambdaAuthorizer, +// WebSocketLambdaAuthorizerProps, +// } from 'aws-cdk-lib/aws-apigatewayv2-authorizers'; export interface WebSocketApiStackProps extends StackProps { connectionTableName: string; @@ -19,6 +24,10 @@ export interface WebSocketApiStackProps extends StackProps { websocketApiEndpointParameterName: string; websocketStageName: string; + + // Cognito configuration for the authorizer + cognitoRegion: string; + cognitoUserPoolIdParameterName: string; } export class WebSocketApiStack extends Stack { @@ -81,6 +90,23 @@ export class WebSocketApiStack extends Stack { timeout: Duration.minutes(2), }); + const userPoolId = StringParameter.fromStringParameterName( + this, + 'CognitoUserPoolIdParameter', + props.cognitoUserPoolIdParameterName + ).stringValue; + + // authorizer function to check the client token based on the JWT token + const connectAuthorizer = this.createPythonFunction('connectAuthorizer', { + index: 'auth.py', + handler: 'lambda_handler', + timeout: Duration.minutes(2), + environment: { + COGNITO_REGION: props.cognitoRegion, + COGNITO_USER_POOL_ID: userPoolId, + }, + }); + // Grant permissions to Lambda functions connectionTable.grantReadWriteData(connectHandler); connectionTable.grantReadWriteData(disconnectHandler); @@ -93,6 +119,10 @@ export class WebSocketApiStack extends Stack { apiName: props.websocketApigatewayName, connectRouteOptions: { integration: new WebSocketLambdaIntegration('ConnectIntegration', connectHandler), + // authorizer: new WebSocketLambdaAuthorizer('ConnectAuthorizer', connectAuthorizer, { + // authorizerName: 'ConnectAuthorizer', + // identitySource: ['route.request.header.Authorization'], + // }), }, disconnectRouteOptions: { integration: new WebSocketLambdaIntegration('DisconnectIntegration', disconnectHandler), @@ -102,6 +132,7 @@ export class WebSocketApiStack extends Stack { }, }); + // Add a route for sending messages for sending messages to the client api.addRoute('sendMessage', { integration: new WebSocketLambdaIntegration('SendMessageIntegration', messageHandler), }); diff --git a/lib/workload/stateless/stacks/client-websocket-conn/deps/requirements.txt b/lib/workload/stateless/stacks/client-websocket-conn/deps/requirements.txt new file mode 100644 index 000000000..f93f1f67b --- /dev/null +++ b/lib/workload/stateless/stacks/client-websocket-conn/deps/requirements.txt @@ -0,0 +1,2 @@ +PyJWT==2.8.0 +requests==2.31.0 \ No newline at end of file diff --git a/lib/workload/stateless/stacks/client-websocket-conn/lambda/auth.py b/lib/workload/stateless/stacks/client-websocket-conn/lambda/auth.py new file mode 100644 index 000000000..76dc43757 --- /dev/null +++ b/lib/workload/stateless/stacks/client-websocket-conn/lambda/auth.py @@ -0,0 +1,74 @@ +import os +import logging +import jwt +import requests +from typing import Dict, Any + +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +# Get environment variables +COGNITO_USER_POOL_ID = os.environ['COGNITO_USER_POOL_ID'] +COGNITO_REGION = os.environ.get('COGNITO_REGION', 'ap-southeast-2') + +def get_public_key(): + """Get Cognito public key for JWT verification""" + url = f'https://cognito-idp.{COGNITO_REGION}.amazonaws.com/{COGNITO_USER_POOL_ID}/.well-known/jwks.json' + try: + response = requests.get(url) + return response.json()['keys'][0] # Get the first key + except Exception as e: + logger.error(f"Error getting public key: {str(e)}") + raise + +def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """Simple Lambda authorizer for WebSocket""" + logger.info("WebSocket authorization request") + + try: + # Get token from headers + token = event.get('headers', {}).get('Authorization', '').replace('Bearer ', '') + if not token: + raise Exception('No token provided') + + # Get public key + public_key = get_public_key() + + # Verify token + decoded = jwt.decode( + token, + public_key, + algorithms=['RS256'], + issuer=f'https://cognito-idp.{COGNITO_REGION}.amazonaws.com/{COGNITO_USER_POOL_ID}' + ) + + # Generate allow policy + return { + 'principalId': decoded['sub'], + 'policyDocument': { + 'Version': '2012-10-17', + 'Statement': [{ + 'Action': 'execute-api:Invoke', + 'Effect': 'Allow', + 'Resource': event['methodArn'] + }] + }, + 'context': { + 'userId': decoded['sub'] + } + } + + except Exception as e: + logger.error(f"Authorization failed: {str(e)}") + # Return deny policy + return { + 'principalId': 'unauthorized', + 'policyDocument': { + 'Version': '2012-10-17', + 'Statement': [{ + 'Action': 'execute-api:Invoke', + 'Effect': 'Deny', + 'Resource': event['methodArn'] + }] + } + } \ No newline at end of file