Skip to content

Commit

Permalink
add auth func
Browse files Browse the repository at this point in the history
  • Loading branch information
raylrui committed Nov 21, 2024
1 parent f89503a commit 5d1e48f
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 1 deletion.
4 changes: 3 additions & 1 deletion config/stacks/clientWebsocketApi.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { WebSocketApiStackProps } from '../../lib/workload/stateless/stacks/client-websocket-conn/deploy';
import { AppStage, vpcProps } from '../constants';
import { AppStage, vpcProps, region, cognitoUserPoolIdParameterName } from '../constants';

export const getWebSocketApiStackProps = (stage: AppStage): WebSocketApiStackProps => {
return {
Expand All @@ -9,5 +9,7 @@ export const getWebSocketApiStackProps = (stage: AppStage): WebSocketApiStackPro
vpcProps: vpcProps,
websocketApiEndpointParameterName: `/orcabus/client-websocket-api-endpoint`,
websocketStageName: stage,
cognitoRegion: region,
cognitoUserPoolIdParameterName: cognitoUserPoolIdParameterName,
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import { Runtime, Architecture } from 'aws-cdk-lib/aws-lambda';
import { Construct } from 'constructs';
import * as path from 'path';
import { StringParameter } from 'aws-cdk-lib/aws-ssm';
// import {
// WebSocketIamAuthorizer,
// WebSocketLambdaAuthorizer,
// WebSocketLambdaAuthorizerProps,
// } from 'aws-cdk-lib/aws-apigatewayv2-authorizers';

export interface WebSocketApiStackProps extends StackProps {
connectionTableName: string;
Expand All @@ -19,6 +24,10 @@ export interface WebSocketApiStackProps extends StackProps {

websocketApiEndpointParameterName: string;
websocketStageName: string;

// Cognito configuration for the authorizer
cognitoRegion: string;
cognitoUserPoolIdParameterName: string;
}

export class WebSocketApiStack extends Stack {
Expand Down Expand Up @@ -81,6 +90,23 @@ export class WebSocketApiStack extends Stack {
timeout: Duration.minutes(2),
});

const userPoolId = StringParameter.fromStringParameterName(
this,
'CognitoUserPoolIdParameter',
props.cognitoUserPoolIdParameterName
).stringValue;

// authorizer function to check the client token based on the JWT token
const connectAuthorizer = this.createPythonFunction('connectAuthorizer', {
index: 'auth.py',
handler: 'lambda_handler',
timeout: Duration.minutes(2),
environment: {
COGNITO_REGION: props.cognitoRegion,
COGNITO_USER_POOL_ID: userPoolId,
},
});

// Grant permissions to Lambda functions
connectionTable.grantReadWriteData(connectHandler);
connectionTable.grantReadWriteData(disconnectHandler);
Expand All @@ -93,6 +119,10 @@ export class WebSocketApiStack extends Stack {
apiName: props.websocketApigatewayName,
connectRouteOptions: {
integration: new WebSocketLambdaIntegration('ConnectIntegration', connectHandler),
// authorizer: new WebSocketLambdaAuthorizer('ConnectAuthorizer', connectAuthorizer, {
// authorizerName: 'ConnectAuthorizer',
// identitySource: ['route.request.header.Authorization'],
// }),
},
disconnectRouteOptions: {
integration: new WebSocketLambdaIntegration('DisconnectIntegration', disconnectHandler),
Expand All @@ -102,6 +132,7 @@ export class WebSocketApiStack extends Stack {
},
});

// Add a route for sending messages for sending messages to the client
api.addRoute('sendMessage', {
integration: new WebSocketLambdaIntegration('SendMessageIntegration', messageHandler),
});
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
PyJWT==2.8.0
requests==2.31.0
74 changes: 74 additions & 0 deletions lib/workload/stateless/stacks/client-websocket-conn/lambda/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os
import logging
import jwt
import requests
from typing import Dict, Any

logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Get environment variables
COGNITO_USER_POOL_ID = os.environ['COGNITO_USER_POOL_ID']
COGNITO_REGION = os.environ.get('COGNITO_REGION', 'ap-southeast-2')

def get_public_key():
"""Get Cognito public key for JWT verification"""
url = f'https://cognito-idp.{COGNITO_REGION}.amazonaws.com/{COGNITO_USER_POOL_ID}/.well-known/jwks.json'
try:
response = requests.get(url)
return response.json()['keys'][0] # Get the first key
except Exception as e:
logger.error(f"Error getting public key: {str(e)}")
raise

def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
"""Simple Lambda authorizer for WebSocket"""
logger.info("WebSocket authorization request")

try:
# Get token from headers
token = event.get('headers', {}).get('Authorization', '').replace('Bearer ', '')
if not token:
raise Exception('No token provided')

# Get public key
public_key = get_public_key()

# Verify token
decoded = jwt.decode(
token,
public_key,
algorithms=['RS256'],
issuer=f'https://cognito-idp.{COGNITO_REGION}.amazonaws.com/{COGNITO_USER_POOL_ID}'
)

# Generate allow policy
return {
'principalId': decoded['sub'],
'policyDocument': {
'Version': '2012-10-17',
'Statement': [{
'Action': 'execute-api:Invoke',
'Effect': 'Allow',
'Resource': event['methodArn']
}]
},
'context': {
'userId': decoded['sub']
}
}

except Exception as e:
logger.error(f"Authorization failed: {str(e)}")
# Return deny policy
return {
'principalId': 'unauthorized',
'policyDocument': {
'Version': '2012-10-17',
'Statement': [{
'Action': 'execute-api:Invoke',
'Effect': 'Deny',
'Resource': event['methodArn']
}]
}
}

0 comments on commit 5d1e48f

Please sign in to comment.