From 7252ce34a16833397c6dcd89087b06442eb76a17 Mon Sep 17 00:00:00 2001 From: longwind48 Date: Mon, 16 Dec 2024 00:21:06 +0800 Subject: [PATCH 1/2] refactor: add cors config utils --- backend/functions/utils/cors_config.py | 58 ++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 backend/functions/utils/cors_config.py diff --git a/backend/functions/utils/cors_config.py b/backend/functions/utils/cors_config.py new file mode 100644 index 0000000..8814a27 --- /dev/null +++ b/backend/functions/utils/cors_config.py @@ -0,0 +1,58 @@ +from typing import Dict, Union + +from firebase_functions import https_fn + + +# Define allowed origins +ALLOWED_ORIGINS = [ + "http://localhost:3000", # Local development + "https://schemessg-v3-dev.web.app", # Staging frontend +] + + +def get_cors_headers(request: https_fn.Request) -> Dict[str, str]: + """ + Return CORS headers based on the request origin. + Only allows specific origins. + + Args: + request: The incoming request containing origin information + """ + origin = request.headers.get("Origin", "") + + # Only allow specified origins + if origin in ALLOWED_ORIGINS: + return { + "Access-Control-Allow-Origin": origin, + "Access-Control-Allow-Methods": "GET, POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + "Access-Control-Max-Age": "3600", + } + + # If origin not allowed, return headers without Access-Control-Allow-Origin + return { + "Access-Control-Allow-Methods": "GET, POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + "Access-Control-Max-Age": "3600", + } + + +def handle_cors_preflight( + request: https_fn.Request, allowed_methods: str = "GET, POST, OPTIONS" +) -> Union[https_fn.Response, tuple]: + """ + Handle CORS preflight requests with origin validation + + Args: + request: The incoming request + allowed_methods: Comma-separated string of allowed HTTP methods + """ + headers = get_cors_headers(request) + headers["Access-Control-Allow-Methods"] = allowed_methods + + # If origin wasn't in allowed list, get_cors_headers won't include Allow-Origin + # In that case, return 403 Forbidden + if "Access-Control-Allow-Origin" not in headers: + return https_fn.Response(response="Origin not allowed", status=403, headers=headers) + + return https_fn.Response(response="", status=204, headers=headers) From 84a6697a922cd32d6434e8a61f193cf4f85a71af Mon Sep 17 00:00:00 2001 From: longwind48 Date: Mon, 16 Dec 2024 00:21:31 +0800 Subject: [PATCH 2/2] refactor: update cors config of all endpoints --- backend/functions/chat/chat.py | 23 ++++----- backend/functions/feedback/feedback.py | 48 +++++++------------ backend/functions/schemes/schemes.py | 28 +++++------ backend/functions/schemes/search.py | 25 +++------- .../functions/update_scheme/update_scheme.py | 35 +++++++------- 5 files changed, 61 insertions(+), 98 deletions(-) diff --git a/backend/functions/chat/chat.py b/backend/functions/chat/chat.py index 646ddc5..c6f2a89 100644 --- a/backend/functions/chat/chat.py +++ b/backend/functions/chat/chat.py @@ -12,6 +12,7 @@ from firebase_functions import https_fn, options from loguru import logger from ml_logic import Chatbot, dataframe_to_text +from utils.cors_config import get_cors_headers, handle_cors_preflight # Remove default handler @@ -51,27 +52,19 @@ def chat_message(req: https_fn.Request) -> https_fn.Response: """ logger.info("Request received") - # TODO remove for prod setup - # Set CORS headers for the preflight request if req.method == "OPTIONS": - # Allows GET and POST requests from any origin with the Content-Type - # header and caches preflight response for an hour - headers = { - "Access-Control-Allow-Origin": "http://localhost:3000", - "Access-Control-Allow-Methods": "POST", - "Access-Control-Allow-Headers": "Content-Type", - "Access-Control-Max-Age": "3600", - } - return ("", 204, headers) - - # Set CORS headers for the main request - headers = {"Access-Control-Allow-Origin": "http://localhost:3000"} - if not (req.method == "POST" or req.method == "GET"): + return handle_cors_preflight(req) + + headers = get_cors_headers(req) + + if not req.method == "POST": return https_fn.Response( response=json.dumps({"error": "Invalid request method; only POST or GET is supported"}), status=405, mimetype="application/json", + headers=headers, ) + chatbot = create_chatbot() try: diff --git a/backend/functions/feedback/feedback.py b/backend/functions/feedback/feedback.py index 51b4a6d..e00fb2e 100644 --- a/backend/functions/feedback/feedback.py +++ b/backend/functions/feedback/feedback.py @@ -3,10 +3,13 @@ http://127.0.0.1:5001/schemessg-v3-dev/asia-southeast1/feedback """ +import json +from datetime import datetime, timezone + from fb_manager.firebaseManager import FirebaseManager from firebase_functions import https_fn, options -from datetime import datetime, timezone -import json +from utils.cors_config import get_cors_headers, handle_cors_preflight + # Firestore client firebase_manager = FirebaseManager() @@ -26,45 +29,34 @@ def feedback(req: https_fn.Request) -> https_fn.Response: Returns: https_fn.Response: response sent to client """ - headers = { - "Access-Control-Allow-Origin": "http://localhost:3000", - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", - "Access-Control-Max-Age": "3600", - } - if req.method == "OPTIONS": - return https_fn.Response(response="", status=204, headers=headers) + return handle_cors_preflight(req) + + headers = get_cors_headers(req) if req.method != "POST": return https_fn.Response( - response=json.dumps( - {"success": False, "message": "Only POST requests are allowed"} - ), + response=json.dumps({"error": "Method not allowed"}), status=405, mimetype="application/json", headers=headers, ) try: - # Parse the request data - request_json = req.get_json() - feedback_text = request_json.get("feedbackText") - userName = request_json.get("userName") - userEmail = request_json.get("userEmail") - timestamp = datetime.now(timezone.utc) + data = req.get_json() + feedback_text = data.get("feedbackText") + userName = data.get("userName", "Anonymous") + userEmail = data.get("userEmail", "Not provided") + timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") if not feedback_text: return https_fn.Response( - response=json.dumps( - {"success": False, "message": "Missing required fields"} - ), + response=json.dumps({"error": "Feedback text is required"}), status=400, mimetype="application/json", headers=headers, ) - # Prepare the data for Firestore feedback_data = { "feedbackText": feedback_text, "timestamp": timestamp, @@ -72,14 +64,10 @@ def feedback(req: https_fn.Request) -> https_fn.Response: "userEmail": userEmail, } - # Add the data to Firestore firebase_manager.firestore_client.collection("userFeedback").add(feedback_data) - # Return a success response return https_fn.Response( - response=json.dumps( - {"success": True, "message": "Feedback successfully added"} - ), + response=json.dumps({"success": True, "message": "Feedback successfully added"}), status=200, mimetype="application/json", headers=headers, @@ -88,9 +76,7 @@ def feedback(req: https_fn.Request) -> https_fn.Response: except Exception as e: print(f"Error: {e}") return https_fn.Response( - response=json.dumps( - {"success": False, "message": "Failed to add feedback"} - ), + response=json.dumps({"success": False, "message": "Failed to add feedback"}), status=500, mimetype="application/json", headers=headers, diff --git a/backend/functions/schemes/schemes.py b/backend/functions/schemes/schemes.py index 82cee1e..07be2b6 100644 --- a/backend/functions/schemes/schemes.py +++ b/backend/functions/schemes/schemes.py @@ -8,6 +8,7 @@ from fb_manager.firebaseManager import FirebaseManager from firebase_functions import https_fn, options from loguru import logger +from utils.cors_config import get_cors_headers, handle_cors_preflight def create_firebase_manager() -> FirebaseManager: @@ -30,21 +31,12 @@ def schemes(req: https_fn.Request) -> https_fn.Response: Returns: https_fn.Response: response sent to client """ - # TODO remove for prod setup - # Set CORS headers for the preflight request + # Handle CORS preflight request if req.method == "OPTIONS": - # Allows GET and POST requests from any origin with the Content-Type - # header and caches preflight response for an hour - headers = { - "Access-Control-Allow-Origin": "http://localhost:3000", - "Access-Control-Allow-Methods": "POST", - "Access-Control-Allow-Headers": "Content-Type", - "Access-Control-Max-Age": "3600", - } - return ("", 204, headers) - - # Set CORS headers for the main request - headers = {"Access-Control-Allow-Origin": "http://localhost:3000"} + return handle_cors_preflight(req) + + # Get standard CORS headers for all other requests + headers = get_cors_headers(req) firebase_manager = create_firebase_manager() @@ -53,6 +45,7 @@ def schemes(req: https_fn.Request) -> https_fn.Response: response=json.dumps({"error": "Invalid request method; only GET is supported"}), status=405, mimetype="application/json", + headers=headers, ) splitted_path = req.path.split("/") @@ -87,4 +80,9 @@ def schemes(req: https_fn.Request) -> https_fn.Response: ) results = {"data": doc.to_dict()} - return https_fn.Response(response=json.dumps(results), status=200, mimetype="application/json", headers=headers) + return https_fn.Response( + response=json.dumps(results), + status=200, + mimetype="application/json", + headers=headers, + ) diff --git a/backend/functions/schemes/search.py b/backend/functions/schemes/search.py index 1526c10..6a4618e 100644 --- a/backend/functions/schemes/search.py +++ b/backend/functions/schemes/search.py @@ -9,18 +9,18 @@ from firebase_functions import https_fn, options from loguru import logger from ml_logic import PredictParams, SearchModel +from utils.cors_config import get_cors_headers, handle_cors_preflight def create_search_model() -> SearchModel: """Factory function to create a SearchModel instance.""" - firebase_manager = FirebaseManager() return SearchModel(firebase_manager) @https_fn.on_request( region="asia-southeast1", - memory=options.MemoryOption.GB_2, # Increases memory to 1GB + memory=options.MemoryOption.GB_2, ) def schemes_search(req: https_fn.Request) -> https_fn.Response: """ @@ -32,27 +32,15 @@ def schemes_search(req: https_fn.Request) -> https_fn.Response: Returns: https_fn.Response: response sent to client """ - # TODO remove for prod setup - # Set CORS headers for the preflight request if req.method == "OPTIONS": - # Allows GET and POST requests from any origin with the Content-Type - # header and caches preflight response for an hour - headers = { - "Access-Control-Allow-Origin": "http://localhost:3000", - "Access-Control-Allow-Methods": "POST", - "Access-Control-Allow-Headers": "Content-Type", - "Access-Control-Max-Age": "3600", - } - return ("", 204, headers) - - # Set CORS headers for the main request - headers = {"Access-Control-Allow-Origin": "http://localhost:3000"} + return handle_cors_preflight(req) + headers = get_cors_headers(req) search_model = create_search_model() - if not (req.method == "POST" or req.method == "GET"): + if not req.method == "POST": return https_fn.Response( - response=json.dumps({"error": "Invalid request method; only POST or GET is supported"}), + response=json.dumps({"error": "Invalid request method; only POST is supported"}), status=405, mimetype="application/json", headers=headers, @@ -63,7 +51,6 @@ def schemes_search(req: https_fn.Request) -> https_fn.Response: query = body.get("query", None) top_k = body.get("top_k", 20) similarity_threshold = body.get("similarity_threshold", 0) - # print(query, top_k, similarity_threshold) except Exception: return https_fn.Response( response=json.dumps({"error": "Invalid request body"}), diff --git a/backend/functions/update_scheme/update_scheme.py b/backend/functions/update_scheme/update_scheme.py index dbdb544..de92a86 100644 --- a/backend/functions/update_scheme/update_scheme.py +++ b/backend/functions/update_scheme/update_scheme.py @@ -3,18 +3,22 @@ http://127.0.0.1:5001/schemessg-v3-dev/asia-southeast1/update_scheme """ +import json +from datetime import datetime, timezone + from fb_manager.firebaseManager import FirebaseManager from firebase_functions import https_fn, options -from datetime import datetime, timezone -import json +from utils.cors_config import get_cors_headers, handle_cors_preflight + # Firestore client firebase_manager = FirebaseManager() + @https_fn.on_request( - region="asia-southeast1", - memory=options.MemoryOption.GB_1, - ) + region="asia-southeast1", + memory=options.MemoryOption.GB_1, +) def update_scheme(req: https_fn.Request) -> https_fn.Response: """ Handler for users seeking to add new schemes or request an edit on an existing scheme @@ -25,22 +29,17 @@ def update_scheme(req: https_fn.Request) -> https_fn.Response: Returns: https_fn.Response: response sent to client """ - headers = { - "Access-Control-Allow-Origin": "http://localhost:3000", - "Access-Control-Allow-Methods": "POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", - "Access-Control-Max-Age": "3600", - } - if req.method == "OPTIONS": - return https_fn.Response(response="", status=204, headers=headers) - + return handle_cors_preflight(req) + + headers = get_cors_headers(req) + if req.method != "POST": return https_fn.Response( response=json.dumps({"success": False, "message": "Only POST requests are allowed"}), status=405, mimetype="application/json", - headers=headers + headers=headers, ) try: @@ -77,7 +76,7 @@ def update_scheme(req: https_fn.Request) -> https_fn.Response: "timestamp": timestamp, "userName": userName, "userEmail": userEmail, - "typeOfRequest": typeOfRequest + "typeOfRequest": typeOfRequest, } # Add the data to Firestore @@ -88,7 +87,7 @@ def update_scheme(req: https_fn.Request) -> https_fn.Response: response=json.dumps({"success": True, "message": "Request for scheme update successfully added"}), status=200, mimetype="application/json", - headers=headers + headers=headers, ) except Exception as e: @@ -97,5 +96,5 @@ def update_scheme(req: https_fn.Request) -> https_fn.Response: response=json.dumps({"success": False, "message": "Failed to add request for scheme update"}), status=500, mimetype="application/json", - headers=headers + headers=headers, )