Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/108 fix cors issues #109

Merged
merged 2 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 8 additions & 15 deletions backend/functions/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from firebase_functions import https_fn, options
from loguru import logger
from ml_logic import Chatbot, dataframe_to_text
from utils.cors_config import get_cors_headers, handle_cors_preflight


# Remove default handler
Expand Down Expand Up @@ -51,27 +52,19 @@ def chat_message(req: https_fn.Request) -> https_fn.Response:
"""
logger.info("Request received")

# TODO remove for prod setup
# Set CORS headers for the preflight request
if req.method == "OPTIONS":
# Allows GET and POST requests from any origin with the Content-Type
# header and caches preflight response for an hour
headers = {
"Access-Control-Allow-Origin": "http://localhost:3000",
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}
return ("", 204, headers)

# Set CORS headers for the main request
headers = {"Access-Control-Allow-Origin": "http://localhost:3000"}
if not (req.method == "POST" or req.method == "GET"):
return handle_cors_preflight(req)

headers = get_cors_headers(req)

if not req.method == "POST":
return https_fn.Response(
response=json.dumps({"error": "Invalid request method; only POST or GET is supported"}),
status=405,
mimetype="application/json",
headers=headers,
)

chatbot = create_chatbot()

try:
Expand Down
48 changes: 17 additions & 31 deletions backend/functions/feedback/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
http://127.0.0.1:5001/schemessg-v3-dev/asia-southeast1/feedback
"""

import json
from datetime import datetime, timezone

from fb_manager.firebaseManager import FirebaseManager
from firebase_functions import https_fn, options
from datetime import datetime, timezone
import json
from utils.cors_config import get_cors_headers, handle_cors_preflight


# Firestore client
firebase_manager = FirebaseManager()
Expand All @@ -26,60 +29,45 @@ def feedback(req: https_fn.Request) -> https_fn.Response:
Returns:
https_fn.Response: response sent to client
"""
headers = {
"Access-Control-Allow-Origin": "http://localhost:3000",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}

if req.method == "OPTIONS":
return https_fn.Response(response="", status=204, headers=headers)
return handle_cors_preflight(req)

headers = get_cors_headers(req)

if req.method != "POST":
return https_fn.Response(
response=json.dumps(
{"success": False, "message": "Only POST requests are allowed"}
),
response=json.dumps({"error": "Method not allowed"}),
status=405,
mimetype="application/json",
headers=headers,
)

try:
# Parse the request data
request_json = req.get_json()
feedback_text = request_json.get("feedbackText")
userName = request_json.get("userName")
userEmail = request_json.get("userEmail")
timestamp = datetime.now(timezone.utc)
data = req.get_json()
feedback_text = data.get("feedbackText")
userName = data.get("userName", "Anonymous")
userEmail = data.get("userEmail", "Not provided")
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")

if not feedback_text:
return https_fn.Response(
response=json.dumps(
{"success": False, "message": "Missing required fields"}
),
response=json.dumps({"error": "Feedback text is required"}),
status=400,
mimetype="application/json",
headers=headers,
)

# Prepare the data for Firestore
feedback_data = {
"feedbackText": feedback_text,
"timestamp": timestamp,
"userName": userName,
"userEmail": userEmail,
}

# Add the data to Firestore
firebase_manager.firestore_client.collection("userFeedback").add(feedback_data)

# Return a success response
return https_fn.Response(
response=json.dumps(
{"success": True, "message": "Feedback successfully added"}
),
response=json.dumps({"success": True, "message": "Feedback successfully added"}),
status=200,
mimetype="application/json",
headers=headers,
Expand All @@ -88,9 +76,7 @@ def feedback(req: https_fn.Request) -> https_fn.Response:
except Exception as e:
print(f"Error: {e}")
return https_fn.Response(
response=json.dumps(
{"success": False, "message": "Failed to add feedback"}
),
response=json.dumps({"success": False, "message": "Failed to add feedback"}),
status=500,
mimetype="application/json",
headers=headers,
Expand Down
28 changes: 13 additions & 15 deletions backend/functions/schemes/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fb_manager.firebaseManager import FirebaseManager
from firebase_functions import https_fn, options
from loguru import logger
from utils.cors_config import get_cors_headers, handle_cors_preflight


def create_firebase_manager() -> FirebaseManager:
Expand All @@ -30,21 +31,12 @@ def schemes(req: https_fn.Request) -> https_fn.Response:
Returns:
https_fn.Response: response sent to client
"""
# TODO remove for prod setup
# Set CORS headers for the preflight request
# Handle CORS preflight request
if req.method == "OPTIONS":
# Allows GET and POST requests from any origin with the Content-Type
# header and caches preflight response for an hour
headers = {
"Access-Control-Allow-Origin": "http://localhost:3000",
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}
return ("", 204, headers)

# Set CORS headers for the main request
headers = {"Access-Control-Allow-Origin": "http://localhost:3000"}
return handle_cors_preflight(req)

# Get standard CORS headers for all other requests
headers = get_cors_headers(req)

firebase_manager = create_firebase_manager()

Expand All @@ -53,6 +45,7 @@ def schemes(req: https_fn.Request) -> https_fn.Response:
response=json.dumps({"error": "Invalid request method; only GET is supported"}),
status=405,
mimetype="application/json",
headers=headers,
)

splitted_path = req.path.split("/")
Expand Down Expand Up @@ -87,4 +80,9 @@ def schemes(req: https_fn.Request) -> https_fn.Response:
)

results = {"data": doc.to_dict()}
return https_fn.Response(response=json.dumps(results), status=200, mimetype="application/json", headers=headers)
return https_fn.Response(
response=json.dumps(results),
status=200,
mimetype="application/json",
headers=headers,
)
25 changes: 6 additions & 19 deletions backend/functions/schemes/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@
from firebase_functions import https_fn, options
from loguru import logger
from ml_logic import PredictParams, SearchModel
from utils.cors_config import get_cors_headers, handle_cors_preflight


def create_search_model() -> SearchModel:
"""Factory function to create a SearchModel instance."""

firebase_manager = FirebaseManager()
return SearchModel(firebase_manager)


@https_fn.on_request(
region="asia-southeast1",
memory=options.MemoryOption.GB_2, # Increases memory to 1GB
memory=options.MemoryOption.GB_2,
)
def schemes_search(req: https_fn.Request) -> https_fn.Response:
"""
Expand All @@ -32,27 +32,15 @@ def schemes_search(req: https_fn.Request) -> https_fn.Response:
Returns:
https_fn.Response: response sent to client
"""
# TODO remove for prod setup
# Set CORS headers for the preflight request
if req.method == "OPTIONS":
# Allows GET and POST requests from any origin with the Content-Type
# header and caches preflight response for an hour
headers = {
"Access-Control-Allow-Origin": "http://localhost:3000",
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}
return ("", 204, headers)

# Set CORS headers for the main request
headers = {"Access-Control-Allow-Origin": "http://localhost:3000"}
return handle_cors_preflight(req)

headers = get_cors_headers(req)
search_model = create_search_model()

if not (req.method == "POST" or req.method == "GET"):
if not req.method == "POST":
return https_fn.Response(
response=json.dumps({"error": "Invalid request method; only POST or GET is supported"}),
response=json.dumps({"error": "Invalid request method; only POST is supported"}),
status=405,
mimetype="application/json",
headers=headers,
Expand All @@ -63,7 +51,6 @@ def schemes_search(req: https_fn.Request) -> https_fn.Response:
query = body.get("query", None)
top_k = body.get("top_k", 20)
similarity_threshold = body.get("similarity_threshold", 0)
# print(query, top_k, similarity_threshold)
except Exception:
return https_fn.Response(
response=json.dumps({"error": "Invalid request body"}),
Expand Down
35 changes: 17 additions & 18 deletions backend/functions/update_scheme/update_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,22 @@
http://127.0.0.1:5001/schemessg-v3-dev/asia-southeast1/update_scheme
"""

import json
from datetime import datetime, timezone

from fb_manager.firebaseManager import FirebaseManager
from firebase_functions import https_fn, options
from datetime import datetime, timezone
import json
from utils.cors_config import get_cors_headers, handle_cors_preflight


# Firestore client
firebase_manager = FirebaseManager()


@https_fn.on_request(
region="asia-southeast1",
memory=options.MemoryOption.GB_1,
)
region="asia-southeast1",
memory=options.MemoryOption.GB_1,
)
def update_scheme(req: https_fn.Request) -> https_fn.Response:
"""
Handler for users seeking to add new schemes or request an edit on an existing scheme
Expand All @@ -25,22 +29,17 @@ def update_scheme(req: https_fn.Request) -> https_fn.Response:
Returns:
https_fn.Response: response sent to client
"""
headers = {
"Access-Control-Allow-Origin": "http://localhost:3000",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}

if req.method == "OPTIONS":
return https_fn.Response(response="", status=204, headers=headers)

return handle_cors_preflight(req)

headers = get_cors_headers(req)

if req.method != "POST":
return https_fn.Response(
response=json.dumps({"success": False, "message": "Only POST requests are allowed"}),
status=405,
mimetype="application/json",
headers=headers
headers=headers,
)

try:
Expand Down Expand Up @@ -77,7 +76,7 @@ def update_scheme(req: https_fn.Request) -> https_fn.Response:
"timestamp": timestamp,
"userName": userName,
"userEmail": userEmail,
"typeOfRequest": typeOfRequest
"typeOfRequest": typeOfRequest,
}

# Add the data to Firestore
Expand All @@ -88,7 +87,7 @@ def update_scheme(req: https_fn.Request) -> https_fn.Response:
response=json.dumps({"success": True, "message": "Request for scheme update successfully added"}),
status=200,
mimetype="application/json",
headers=headers
headers=headers,
)

except Exception as e:
Expand All @@ -97,5 +96,5 @@ def update_scheme(req: https_fn.Request) -> https_fn.Response:
response=json.dumps({"success": False, "message": "Failed to add request for scheme update"}),
status=500,
mimetype="application/json",
headers=headers
headers=headers,
)
58 changes: 58 additions & 0 deletions backend/functions/utils/cors_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Dict, Union

from firebase_functions import https_fn


# Define allowed origins
ALLOWED_ORIGINS = [
"http://localhost:3000", # Local development
"https://schemessg-v3-dev.web.app", # Staging frontend
]


def get_cors_headers(request: https_fn.Request) -> Dict[str, str]:
"""
Return CORS headers based on the request origin.
Only allows specific origins.

Args:
request: The incoming request containing origin information
"""
origin = request.headers.get("Origin", "")

# Only allow specified origins
if origin in ALLOWED_ORIGINS:
return {
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}

# If origin not allowed, return headers without Access-Control-Allow-Origin
return {
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}


def handle_cors_preflight(
request: https_fn.Request, allowed_methods: str = "GET, POST, OPTIONS"
) -> Union[https_fn.Response, tuple]:
"""
Handle CORS preflight requests with origin validation

Args:
request: The incoming request
allowed_methods: Comma-separated string of allowed HTTP methods
"""
headers = get_cors_headers(request)
headers["Access-Control-Allow-Methods"] = allowed_methods

# If origin wasn't in allowed list, get_cors_headers won't include Allow-Origin
# In that case, return 403 Forbidden
if "Access-Control-Allow-Origin" not in headers:
return https_fn.Response(response="Origin not allowed", status=403, headers=headers)

return https_fn.Response(response="", status=204, headers=headers)