Skip to content

Commit

Permalink
Merge pull request #109 from bettersg/fix/108-fix-cors-issues
Browse files Browse the repository at this point in the history
Fix/108 fix cors issues
  • Loading branch information
longwind48 authored Dec 15, 2024
2 parents 34fd613 + 84a6697 commit 8ff083b
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 98 deletions.
23 changes: 8 additions & 15 deletions backend/functions/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from firebase_functions import https_fn, options
from loguru import logger
from ml_logic import Chatbot, dataframe_to_text
from utils.cors_config import get_cors_headers, handle_cors_preflight


# Remove default handler
Expand Down Expand Up @@ -51,27 +52,19 @@ def chat_message(req: https_fn.Request) -> https_fn.Response:
"""
logger.info("Request received")

# TODO remove for prod setup
# Set CORS headers for the preflight request
if req.method == "OPTIONS":
# Allows GET and POST requests from any origin with the Content-Type
# header and caches preflight response for an hour
headers = {
"Access-Control-Allow-Origin": "http://localhost:3000",
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}
return ("", 204, headers)

# Set CORS headers for the main request
headers = {"Access-Control-Allow-Origin": "http://localhost:3000"}
if not (req.method == "POST" or req.method == "GET"):
return handle_cors_preflight(req)

headers = get_cors_headers(req)

if not req.method == "POST":
return https_fn.Response(
response=json.dumps({"error": "Invalid request method; only POST or GET is supported"}),
status=405,
mimetype="application/json",
headers=headers,
)

chatbot = create_chatbot()

try:
Expand Down
48 changes: 17 additions & 31 deletions backend/functions/feedback/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
http://127.0.0.1:5001/schemessg-v3-dev/asia-southeast1/feedback
"""

import json
from datetime import datetime, timezone

from fb_manager.firebaseManager import FirebaseManager
from firebase_functions import https_fn, options
from datetime import datetime, timezone
import json
from utils.cors_config import get_cors_headers, handle_cors_preflight


# Firestore client
firebase_manager = FirebaseManager()
Expand All @@ -26,60 +29,45 @@ def feedback(req: https_fn.Request) -> https_fn.Response:
Returns:
https_fn.Response: response sent to client
"""
headers = {
"Access-Control-Allow-Origin": "http://localhost:3000",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}

if req.method == "OPTIONS":
return https_fn.Response(response="", status=204, headers=headers)
return handle_cors_preflight(req)

headers = get_cors_headers(req)

if req.method != "POST":
return https_fn.Response(
response=json.dumps(
{"success": False, "message": "Only POST requests are allowed"}
),
response=json.dumps({"error": "Method not allowed"}),
status=405,
mimetype="application/json",
headers=headers,
)

try:
# Parse the request data
request_json = req.get_json()
feedback_text = request_json.get("feedbackText")
userName = request_json.get("userName")
userEmail = request_json.get("userEmail")
timestamp = datetime.now(timezone.utc)
data = req.get_json()
feedback_text = data.get("feedbackText")
userName = data.get("userName", "Anonymous")
userEmail = data.get("userEmail", "Not provided")
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")

if not feedback_text:
return https_fn.Response(
response=json.dumps(
{"success": False, "message": "Missing required fields"}
),
response=json.dumps({"error": "Feedback text is required"}),
status=400,
mimetype="application/json",
headers=headers,
)

# Prepare the data for Firestore
feedback_data = {
"feedbackText": feedback_text,
"timestamp": timestamp,
"userName": userName,
"userEmail": userEmail,
}

# Add the data to Firestore
firebase_manager.firestore_client.collection("userFeedback").add(feedback_data)

# Return a success response
return https_fn.Response(
response=json.dumps(
{"success": True, "message": "Feedback successfully added"}
),
response=json.dumps({"success": True, "message": "Feedback successfully added"}),
status=200,
mimetype="application/json",
headers=headers,
Expand All @@ -88,9 +76,7 @@ def feedback(req: https_fn.Request) -> https_fn.Response:
except Exception as e:
print(f"Error: {e}")
return https_fn.Response(
response=json.dumps(
{"success": False, "message": "Failed to add feedback"}
),
response=json.dumps({"success": False, "message": "Failed to add feedback"}),
status=500,
mimetype="application/json",
headers=headers,
Expand Down
28 changes: 13 additions & 15 deletions backend/functions/schemes/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fb_manager.firebaseManager import FirebaseManager
from firebase_functions import https_fn, options
from loguru import logger
from utils.cors_config import get_cors_headers, handle_cors_preflight


def create_firebase_manager() -> FirebaseManager:
Expand All @@ -30,21 +31,12 @@ def schemes(req: https_fn.Request) -> https_fn.Response:
Returns:
https_fn.Response: response sent to client
"""
# TODO remove for prod setup
# Set CORS headers for the preflight request
# Handle CORS preflight request
if req.method == "OPTIONS":
# Allows GET and POST requests from any origin with the Content-Type
# header and caches preflight response for an hour
headers = {
"Access-Control-Allow-Origin": "http://localhost:3000",
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}
return ("", 204, headers)

# Set CORS headers for the main request
headers = {"Access-Control-Allow-Origin": "http://localhost:3000"}
return handle_cors_preflight(req)

# Get standard CORS headers for all other requests
headers = get_cors_headers(req)

firebase_manager = create_firebase_manager()

Expand All @@ -53,6 +45,7 @@ def schemes(req: https_fn.Request) -> https_fn.Response:
response=json.dumps({"error": "Invalid request method; only GET is supported"}),
status=405,
mimetype="application/json",
headers=headers,
)

splitted_path = req.path.split("/")
Expand Down Expand Up @@ -87,4 +80,9 @@ def schemes(req: https_fn.Request) -> https_fn.Response:
)

results = {"data": doc.to_dict()}
return https_fn.Response(response=json.dumps(results), status=200, mimetype="application/json", headers=headers)
return https_fn.Response(
response=json.dumps(results),
status=200,
mimetype="application/json",
headers=headers,
)
25 changes: 6 additions & 19 deletions backend/functions/schemes/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@
from firebase_functions import https_fn, options
from loguru import logger
from ml_logic import PredictParams, SearchModel
from utils.cors_config import get_cors_headers, handle_cors_preflight


def create_search_model() -> SearchModel:
"""Factory function to create a SearchModel instance."""

firebase_manager = FirebaseManager()
return SearchModel(firebase_manager)


@https_fn.on_request(
region="asia-southeast1",
memory=options.MemoryOption.GB_2, # Increases memory to 1GB
memory=options.MemoryOption.GB_2,
)
def schemes_search(req: https_fn.Request) -> https_fn.Response:
"""
Expand All @@ -32,27 +32,15 @@ def schemes_search(req: https_fn.Request) -> https_fn.Response:
Returns:
https_fn.Response: response sent to client
"""
# TODO remove for prod setup
# Set CORS headers for the preflight request
if req.method == "OPTIONS":
# Allows GET and POST requests from any origin with the Content-Type
# header and caches preflight response for an hour
headers = {
"Access-Control-Allow-Origin": "http://localhost:3000",
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}
return ("", 204, headers)

# Set CORS headers for the main request
headers = {"Access-Control-Allow-Origin": "http://localhost:3000"}
return handle_cors_preflight(req)

headers = get_cors_headers(req)
search_model = create_search_model()

if not (req.method == "POST" or req.method == "GET"):
if not req.method == "POST":
return https_fn.Response(
response=json.dumps({"error": "Invalid request method; only POST or GET is supported"}),
response=json.dumps({"error": "Invalid request method; only POST is supported"}),
status=405,
mimetype="application/json",
headers=headers,
Expand All @@ -63,7 +51,6 @@ def schemes_search(req: https_fn.Request) -> https_fn.Response:
query = body.get("query", None)
top_k = body.get("top_k", 20)
similarity_threshold = body.get("similarity_threshold", 0)
# print(query, top_k, similarity_threshold)
except Exception:
return https_fn.Response(
response=json.dumps({"error": "Invalid request body"}),
Expand Down
35 changes: 17 additions & 18 deletions backend/functions/update_scheme/update_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,22 @@
http://127.0.0.1:5001/schemessg-v3-dev/asia-southeast1/update_scheme
"""

import json
from datetime import datetime, timezone

from fb_manager.firebaseManager import FirebaseManager
from firebase_functions import https_fn, options
from datetime import datetime, timezone
import json
from utils.cors_config import get_cors_headers, handle_cors_preflight


# Firestore client
firebase_manager = FirebaseManager()


@https_fn.on_request(
region="asia-southeast1",
memory=options.MemoryOption.GB_1,
)
region="asia-southeast1",
memory=options.MemoryOption.GB_1,
)
def update_scheme(req: https_fn.Request) -> https_fn.Response:
"""
Handler for users seeking to add new schemes or request an edit on an existing scheme
Expand All @@ -25,22 +29,17 @@ def update_scheme(req: https_fn.Request) -> https_fn.Response:
Returns:
https_fn.Response: response sent to client
"""
headers = {
"Access-Control-Allow-Origin": "http://localhost:3000",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}

if req.method == "OPTIONS":
return https_fn.Response(response="", status=204, headers=headers)

return handle_cors_preflight(req)

headers = get_cors_headers(req)

if req.method != "POST":
return https_fn.Response(
response=json.dumps({"success": False, "message": "Only POST requests are allowed"}),
status=405,
mimetype="application/json",
headers=headers
headers=headers,
)

try:
Expand Down Expand Up @@ -77,7 +76,7 @@ def update_scheme(req: https_fn.Request) -> https_fn.Response:
"timestamp": timestamp,
"userName": userName,
"userEmail": userEmail,
"typeOfRequest": typeOfRequest
"typeOfRequest": typeOfRequest,
}

# Add the data to Firestore
Expand All @@ -88,7 +87,7 @@ def update_scheme(req: https_fn.Request) -> https_fn.Response:
response=json.dumps({"success": True, "message": "Request for scheme update successfully added"}),
status=200,
mimetype="application/json",
headers=headers
headers=headers,
)

except Exception as e:
Expand All @@ -97,5 +96,5 @@ def update_scheme(req: https_fn.Request) -> https_fn.Response:
response=json.dumps({"success": False, "message": "Failed to add request for scheme update"}),
status=500,
mimetype="application/json",
headers=headers
headers=headers,
)
58 changes: 58 additions & 0 deletions backend/functions/utils/cors_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Dict, Union

from firebase_functions import https_fn


# Define allowed origins
ALLOWED_ORIGINS = [
"http://localhost:3000", # Local development
"https://schemessg-v3-dev.web.app", # Staging frontend
]


def get_cors_headers(request: https_fn.Request) -> Dict[str, str]:
"""
Return CORS headers based on the request origin.
Only allows specific origins.
Args:
request: The incoming request containing origin information
"""
origin = request.headers.get("Origin", "")

# Only allow specified origins
if origin in ALLOWED_ORIGINS:
return {
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}

# If origin not allowed, return headers without Access-Control-Allow-Origin
return {
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}


def handle_cors_preflight(
request: https_fn.Request, allowed_methods: str = "GET, POST, OPTIONS"
) -> Union[https_fn.Response, tuple]:
"""
Handle CORS preflight requests with origin validation
Args:
request: The incoming request
allowed_methods: Comma-separated string of allowed HTTP methods
"""
headers = get_cors_headers(request)
headers["Access-Control-Allow-Methods"] = allowed_methods

# If origin wasn't in allowed list, get_cors_headers won't include Allow-Origin
# In that case, return 403 Forbidden
if "Access-Control-Allow-Origin" not in headers:
return https_fn.Response(response="Origin not allowed", status=403, headers=headers)

return https_fn.Response(response="", status=204, headers=headers)

0 comments on commit 8ff083b

Please sign in to comment.