From ebd2e558275bfd6bdb70af32c01e9d79acacacb9 Mon Sep 17 00:00:00 2001 From: Richard Nguyen Date: Thu, 7 Dec 2023 15:21:07 -0800 Subject: [PATCH] fix: fix cors problem on requesting API endpoints --- cursus/apis/__init__.py | 78 ++++++++++++++++++++++++++++++++------- cursus/apis/university.py | 5 --- 2 files changed, 64 insertions(+), 19 deletions(-) diff --git a/cursus/apis/__init__.py b/cursus/apis/__init__.py index e1ee4a7..7a03724 100644 --- a/cursus/apis/__init__.py +++ b/cursus/apis/__init__.py @@ -24,6 +24,42 @@ ) +def check_preflight_request(request: flask.Request) -> bool: + """Check if the request is a preflight request + + A preflight request is a CORS request that checks if the API endpoint is + allowed to be accessed outside of the domain. This function checks if the + request is a preflight request by checking if the request method is OPTIONS + and if the request headers contain the `Origin` header. + + Args: + request (flask.Request): The request object + + Returns: + bool: True if the request is a preflight request, False otherwise + """ + + if "Origin" not in request.headers: + return False + + if "Access-Control-Request-Method" not in request.headers: + return False + + if request.headers["Access-Control-Request-Method"] != "GET": + return False + + if "Access-Control-Request-Headers" not in request.headers: + return False + + if ( + "x-cursus-api-token" + != request.headers["Access-Control-Request-Headers"].lower() + ): + return False + + return True + + api_bp: Blueprint = Blueprint( name="api", import_name=__name__, url_prefix="/api/v1/" ) @@ -61,23 +97,27 @@ def swagger(): def before_request(): """Process actions all requests that are made to the API endpoints""" - # Missing API token - if "X-CURSUS-API-TOKEN" not in request.headers: - raise CursusException.BadRequestError( - "API endpoints require an authorized API token" + if request.method != "OPTIONS" and request.method != "GET": + raise CursusException.MethodNotAllowedError( + "Only GET requests are allowed" ) + # Prelight request to check if the API endpoint is allowed to be accessed + # outside of the domain if request.method == "OPTIONS": - response = flask.make_response("", 200) + if check_preflight_request(request): + # Returning a non-None value from a before_request handler will + # cause Flask to skip the normal request handling and continue to + # the after request handler. + return flask.make_response() - response.headers.add("Content-Type", "application/json") - response.headers.add("Access-Control-Allow-Origin", "*") - response.headers.add("Access-Control-Allow-Methods", "GET, OPTIONS") - response.headers.add( - "Access-Control-Allow-Headers", "origin, x-cursus-api-token" - ) + raise CursusException.BadRequestError("Invalid preflight request") - return response + # Missing API token + if "X-CURSUS-API-TOKEN" not in request.headers: + raise CursusException.BadRequestError( + "API endpoints require an authorized API token" + ) token = request.headers["X-CURSUS-API-TOKEN"] @@ -126,8 +166,6 @@ def before_request(): def after_request(response: flask.Response): """Perform actions after a request has been processed""" - response.headers.add("Access-Control-Allow-Origin", "*") - # A response that made it to the endpoint handler either succeeded (200) or # failed (404) to retrieve the requested resource. In both cases, we want # to increment the request count for the API token. @@ -136,6 +174,18 @@ def after_request(response: flask.Response): if response.status_code != 200 and response.status_code != 404: return response + response.headers.add("Access-Control-Allow-Origin", "*") + response.headers.add( + "Access-Control-Allow-Headers", + "X-CURSUS-API-TOKEN, Content-Type, Accept, Origin", + ) + response.headers.add("Access-Control-Allow-Methods", "GET, OPTIONS") + response.headers.add("Access-Control-Allow-Credentials", "true") + response.headers.add("Access-Control-Max-Age", "86400") + + if request.method == "OPTIONS": + return response + token = request.headers["X-CURSUS-API-TOKEN"] cache_item = cache.get(token) diff --git a/cursus/apis/university.py b/cursus/apis/university.py index a6222bb..5bea7b9 100644 --- a/cursus/apis/university.py +++ b/cursus/apis/university.py @@ -21,11 +21,6 @@ def university_index(): def university_find(): """University find endpoint with query string""" - if flask.request.method != "GET" or flask.request.method != "OPTIONS": - raise CursusException.MethodNotAllowedError( - "This endpoint only accepts GET requests" - ) - # Get request query string query_string = flask.request.query_string.decode("utf-8")