From b9da728680132b4a96228833a770917265d0b0ac Mon Sep 17 00:00:00 2001 From: Tobias Laundal Date: Mon, 4 Dec 2023 10:58:11 +0100 Subject: [PATCH] Implement api_gateway_infer_root_path The `api_gateway_infer_root_path` option instructs Mangum to infer the `root_path` ASGI scope property based on the AWS API Gateway event object. This enables applications to know what subpath they are being served from, without explicit configuration. Relates to #147. --- docs/adapter.md | 1 + mangum/adapter.py | 2 + mangum/handlers/api_gateway.py | 17 +++++- mangum/types.py | 1 + tests/handlers/test_api_gateway.py | 89 +++++++++++++++++++++++++++++- 5 files changed, 108 insertions(+), 2 deletions(-) diff --git a/docs/adapter.md b/docs/adapter.md index 04b179e..7af98ae 100644 --- a/docs/adapter.md +++ b/docs/adapter.md @@ -9,6 +9,7 @@ handler = Mangum( api_gateway_base_path=None, custom_handlers=None, text_mime_types=None, + api_gateway_infer_root_path=False ) ``` diff --git a/mangum/adapter.py b/mangum/adapter.py index bb99cfb..7c7bc0f 100644 --- a/mangum/adapter.py +++ b/mangum/adapter.py @@ -45,6 +45,7 @@ def __init__( custom_handlers: Optional[List[Type[LambdaHandler]]] = None, text_mime_types: Optional[List[str]] = None, exclude_headers: Optional[List[str]] = None, + api_gateway_infer_root_path: bool = False, ) -> None: if lifespan not in ("auto", "on", "off"): raise ConfigurationError( @@ -57,6 +58,7 @@ def __init__( exclude_headers = exclude_headers or [] self.config = LambdaConfig( api_gateway_base_path=api_gateway_base_path or "/", + api_gateway_infer_root_path=api_gateway_infer_root_path, text_mime_types=text_mime_types or [*DEFAULT_TEXT_MIME_TYPES], exclude_headers=[header.lower() for header in exclude_headers], ) diff --git a/mangum/handlers/api_gateway.py b/mangum/handlers/api_gateway.py index d9b30c0..8e7c766 100644 --- a/mangum/handlers/api_gateway.py +++ b/mangum/handlers/api_gateway.py @@ -64,6 +64,19 @@ def _combine_headers_v2( return output_headers, cookies +def _infer_root_path(event: LambdaEvent) -> str: + # This is the full path, as received by API Gateway + request_path = event.get("requestContext", {}).get("path", "") + # This is the relative path of the resource within API Gateway + resource_path = event.get("path", "") + + root_path = "" + if request_path.endswith(resource_path): + root_path = request_path[: -len(resource_path)] + + return root_path + + class APIGateway: @classmethod def infer( @@ -98,7 +111,9 @@ def scope(self) -> Scope: api_gateway_base_path=self.config["api_gateway_base_path"], ), "raw_path": None, - "root_path": "", + "root_path": _infer_root_path(self.event) + if self.config.get("api_gateway_infer_root_path", False) + else "", "scheme": headers.get("x-forwarded-proto", "https"), "query_string": _encode_query_string_for_apigw(self.event), "server": get_server_and_port(headers), diff --git a/mangum/types.py b/mangum/types.py index 0ff436c..4f1e710 100644 --- a/mangum/types.py +++ b/mangum/types.py @@ -116,6 +116,7 @@ class Response(TypedDict): class LambdaConfig(TypedDict): api_gateway_base_path: str + api_gateway_infer_root_path: bool text_mime_types: List[str] exclude_headers: List[str] diff --git a/tests/handlers/test_api_gateway.py b/tests/handlers/test_api_gateway.py index e2458c2..39eea97 100644 --- a/tests/handlers/test_api_gateway.py +++ b/tests/handlers/test_api_gateway.py @@ -3,7 +3,7 @@ import pytest from mangum import Mangum -from mangum.handlers.api_gateway import APIGateway +from mangum.handlers.api_gateway import APIGateway, _infer_root_path def get_mock_aws_api_gateway_event( @@ -36,6 +36,7 @@ def get_mock_aws_api_gateway_event( "accountId": "123456789012", "resourceId": "us4z18", "stage": "Prod", + "path": f"/Prod{path}", "requestId": "41b45ea3-70b5-11e6-b7bd-69b5aaebc7d9", "identity": { "cognitoIdentityPoolId": "", @@ -429,3 +430,89 @@ async def app(scope, receive, send): "multiValueHeaders": {}, "body": "Hello world", } + + +@pytest.mark.parametrize( + "method,path,multi_value_query_parameters,req_body,body_base64_encoded," + "query_string,scope_body", + [ + ("GET", "/test/hello", None, None, False, b"", None), + ], +) +def test_aws_api_gateway_root_path( + method, + path, + multi_value_query_parameters, + req_body, + body_base64_encoded, + query_string, + scope_body, +): + event = get_mock_aws_api_gateway_event( + method, path, multi_value_query_parameters, req_body, body_base64_encoded + ) + + # Test with root path inferred + async def app(scope, receive, send): + assert scope["root_path"] == "/Prod" + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello world!"}) + + handler = Mangum(app, lifespan="off", api_gateway_infer_root_path=True) + response = handler(event, {}) + + assert response == { + "body": "Hello world!", + "headers": {"content-type": "text/plain"}, + "multiValueHeaders": {}, + "isBase64Encoded": False, + "statusCode": 200, + } + + # Test without root path inferred + async def app(scope, receive, send): + assert scope["root_path"] == "" + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello world!"}) + + handler = Mangum(app, lifespan="off", api_gateway_infer_root_path=False) + response = handler(event, {}) + assert response == { + "body": "Hello world!", + "headers": {"content-type": "text/plain"}, + "multiValueHeaders": {}, + "isBase64Encoded": False, + "statusCode": 200, + } + + +@pytest.mark.parametrize( + "path,requestPath,root_path", + [ + ("/", "/", ""), + ("/", "/Prod", ""), + ("/", "/Prod/", "/Prod"), + ("/some/path", "/Prod/", ""), + ("/some/path", "/Prod/some/path", "/Prod"), + ("/baz", "/foo/bar/baz", "/foo/bar"), + ("/foo", "/fooo", ""), + ("/fooo", "/foo", ""), + ], +) +def test_infer_root_path(path, requestPath, root_path): + assert ( + _infer_root_path({"path": path, "requestContext": {"path": requestPath}}) + == root_path + )