Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement api_gateway_infer_root_path option #312

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ handler = Mangum(
api_gateway_base_path=None,
custom_handlers=None,
text_mime_types=None,
api_gateway_infer_root_path=False
)
```

Expand Down
2 changes: 2 additions & 0 deletions mangum/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
custom_handlers: Optional[List[Type[LambdaHandler]]] = None,
text_mime_types: Optional[List[str]] = None,
exclude_headers: Optional[List[str]] = None,
api_gateway_infer_root_path: bool = False,
) -> None:
if lifespan not in ("auto", "on", "off"):
raise ConfigurationError(
Expand All @@ -57,6 +58,7 @@ def __init__(
exclude_headers = exclude_headers or []
self.config = LambdaConfig(
api_gateway_base_path=api_gateway_base_path or "/",
api_gateway_infer_root_path=api_gateway_infer_root_path,
text_mime_types=text_mime_types or [*DEFAULT_TEXT_MIME_TYPES],
exclude_headers=[header.lower() for header in exclude_headers],
)
Expand Down
17 changes: 16 additions & 1 deletion mangum/handlers/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@ def _combine_headers_v2(
return output_headers, cookies


def _infer_root_path(event: LambdaEvent) -> str:
# This is the full path, as received by API Gateway
request_path = event.get("requestContext", {}).get("path", "")
# This is the relative path of the resource within API Gateway
resource_path = event.get("path", "")

root_path = ""
if request_path.endswith(resource_path):
root_path = request_path[: -len(resource_path)]

return root_path


class APIGateway:
@classmethod
def infer(
Expand Down Expand Up @@ -98,7 +111,9 @@ def scope(self) -> Scope:
api_gateway_base_path=self.config["api_gateway_base_path"],
),
"raw_path": None,
"root_path": "",
"root_path": _infer_root_path(self.event)
if self.config.get("api_gateway_infer_root_path", False)
else "",
"scheme": headers.get("x-forwarded-proto", "https"),
"query_string": _encode_query_string_for_apigw(self.event),
"server": get_server_and_port(headers),
Expand Down
1 change: 1 addition & 0 deletions mangum/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class Response(TypedDict):

class LambdaConfig(TypedDict):
api_gateway_base_path: str
api_gateway_infer_root_path: bool
text_mime_types: List[str]
exclude_headers: List[str]

Expand Down
89 changes: 88 additions & 1 deletion tests/handlers/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from mangum import Mangum
from mangum.handlers.api_gateway import APIGateway
from mangum.handlers.api_gateway import APIGateway, _infer_root_path


def get_mock_aws_api_gateway_event(
Expand Down Expand Up @@ -36,6 +36,7 @@ def get_mock_aws_api_gateway_event(
"accountId": "123456789012",
"resourceId": "us4z18",
"stage": "Prod",
"path": f"/Prod{path}",
"requestId": "41b45ea3-70b5-11e6-b7bd-69b5aaebc7d9",
"identity": {
"cognitoIdentityPoolId": "",
Expand Down Expand Up @@ -429,3 +430,89 @@ async def app(scope, receive, send):
"multiValueHeaders": {},
"body": "Hello world",
}


@pytest.mark.parametrize(
"method,path,multi_value_query_parameters,req_body,body_base64_encoded,"
"query_string,scope_body",
[
("GET", "/test/hello", None, None, False, b"", None),
],
)
def test_aws_api_gateway_root_path(
method,
path,
multi_value_query_parameters,
req_body,
body_base64_encoded,
query_string,
scope_body,
):
event = get_mock_aws_api_gateway_event(
method, path, multi_value_query_parameters, req_body, body_base64_encoded
)

# Test with root path inferred
async def app(scope, receive, send):
assert scope["root_path"] == "/Prod"
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello world!"})

handler = Mangum(app, lifespan="off", api_gateway_infer_root_path=True)
response = handler(event, {})

assert response == {
"body": "Hello world!",
"headers": {"content-type": "text/plain"},
"multiValueHeaders": {},
"isBase64Encoded": False,
"statusCode": 200,
}

# Test without root path inferred
async def app(scope, receive, send):
assert scope["root_path"] == ""
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello world!"})

handler = Mangum(app, lifespan="off", api_gateway_infer_root_path=False)
response = handler(event, {})
assert response == {
"body": "Hello world!",
"headers": {"content-type": "text/plain"},
"multiValueHeaders": {},
"isBase64Encoded": False,
"statusCode": 200,
}


@pytest.mark.parametrize(
"path,requestPath,root_path",
[
("/", "/", ""),
("/", "/Prod", ""),
("/", "/Prod/", "/Prod"),
("/some/path", "/Prod/", ""),
("/some/path", "/Prod/some/path", "/Prod"),
("/baz", "/foo/bar/baz", "/foo/bar"),
("/foo", "/fooo", ""),
("/fooo", "/foo", ""),
],
)
def test_infer_root_path(path, requestPath, root_path):
assert (
_infer_root_path({"path": path, "requestContext": {"path": requestPath}})
== root_path
)