Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update linting and type checker setup #336

Merged
merged 6 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ jobs:
- name: Install dependencies
run: uv sync --python ${{ matrix.python-version }} --frozen

- name: Run linters
run: scripts/check

- name: Run tests
run: scripts/test

- name: Run linters
run: scripts/lint

# https://github.com/marketplace/actions/alls-green#why used for branch protection checks
check:
if: always()
Expand Down
38 changes: 13 additions & 25 deletions mangum/adapter.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,20 @@
from __future__ import annotations

import logging
from itertools import chain
from contextlib import ExitStack
from typing import List, Optional, Type
from itertools import chain
from typing import Any

from mangum.protocols import HTTPCycle, LifespanCycle
from mangum.handlers import ALB, HTTPGateway, APIGateway, LambdaAtEdge
from mangum.exceptions import ConfigurationError
from mangum.types import (
ASGI,
LifespanMode,
LambdaConfig,
LambdaEvent,
LambdaContext,
LambdaHandler,
)

from mangum.handlers import ALB, APIGateway, HTTPGateway, LambdaAtEdge
from mangum.protocols import HTTPCycle, LifespanCycle
from mangum.types import ASGI, LambdaConfig, LambdaContext, LambdaEvent, LambdaHandler, LifespanMode

logger = logging.getLogger("mangum")

HANDLERS: list[type[LambdaHandler]] = [ALB, HTTPGateway, APIGateway, LambdaAtEdge]

HANDLERS: List[Type[LambdaHandler]] = [
ALB,
HTTPGateway,
APIGateway,
LambdaAtEdge,
]

DEFAULT_TEXT_MIME_TYPES: List[str] = [
DEFAULT_TEXT_MIME_TYPES: list[str] = [
"text/",
"application/json",
"application/javascript",
Expand All @@ -42,9 +30,9 @@ def __init__(
app: ASGI,
lifespan: LifespanMode = "auto",
api_gateway_base_path: str = "/",
custom_handlers: Optional[List[Type[LambdaHandler]]] = None,
text_mime_types: Optional[List[str]] = None,
exclude_headers: Optional[List[str]] = None,
custom_handlers: list[type[LambdaHandler]] | None = None,
text_mime_types: list[str] | None = None,
exclude_headers: list[str] | None = None,
) -> None:
if lifespan not in ("auto", "on", "off"):
raise ConfigurationError("Invalid argument supplied for `lifespan`. Choices are: auto|on|off")
Expand All @@ -70,7 +58,7 @@ def infer(self, event: LambdaEvent, context: LambdaContext) -> LambdaHandler:
"supported handler.)"
)

def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict:
def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict[str, Any]:
handler = self.infer(event, context)
with ExitStack() as stack:
if self.lifespan in ("auto", "on"):
Expand Down
3 changes: 1 addition & 2 deletions mangum/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from mangum.handlers.api_gateway import APIGateway, HTTPGateway
from mangum.handlers.alb import ALB
from mangum.handlers.api_gateway import APIGateway, HTTPGateway
from mangum.handlers.lambda_at_edge import LambdaAtEdge


__all__ = ["APIGateway", "HTTPGateway", "ALB", "LambdaAtEdge"]
24 changes: 13 additions & 11 deletions mangum/handlers/alb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from itertools import islice
from typing import Dict, Generator, List, Tuple
from urllib.parse import urlencode, unquote, unquote_plus
from typing import Any, Generator
from urllib.parse import unquote, unquote_plus, urlencode

from mangum.handlers.utils import (
get_server_and_port,
Expand All @@ -9,12 +11,12 @@
maybe_encode_body,
)
from mangum.types import (
Response,
Scope,
LambdaConfig,
LambdaEvent,
LambdaContext,
LambdaEvent,
QueryParams,
Response,
Scope,
)


Expand All @@ -37,9 +39,9 @@ def all_casings(input_string: str) -> Generator[str, None, None]:
yield first.upper() + sub_casing


def case_mutated_headers(multi_value_headers: Dict[str, List[str]]) -> Dict[str, str]:
def case_mutated_headers(multi_value_headers: dict[str, list[str]]) -> dict[str, str]:
"""Create str/str key/value headers, with duplicate keys case mutated."""
headers: Dict[str, str] = {}
headers: dict[str, str] = {}
for key, values in multi_value_headers.items():
if len(values) > 0:
casings = list(islice(all_casings(key), len(values)))
Expand Down Expand Up @@ -68,8 +70,8 @@ def encode_query_string_for_alb(params: QueryParams) -> bytes:
return query_string


def transform_headers(event: LambdaEvent) -> List[Tuple[bytes, bytes]]:
headers: List[Tuple[bytes, bytes]] = []
def transform_headers(event: LambdaEvent) -> list[tuple[bytes, bytes]]:
headers: list[tuple[bytes, bytes]] = []
if "multiValueHeaders" in event:
for k, v in event["multiValueHeaders"].items():
for inner_v in v:
Expand Down Expand Up @@ -139,8 +141,8 @@ def scope(self) -> Scope:

return scope

def __call__(self, response: Response) -> dict:
multi_value_headers: Dict[str, List[str]] = {}
def __call__(self, response: Response) -> dict[str, Any]:
multi_value_headers: dict[str, list[str]] = {}
for key, value in response["headers"]:
lower_key = key.decode().lower()
if lower_key not in multi_value_headers:
Expand Down
22 changes: 12 additions & 10 deletions mangum/handlers/api_gateway.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Dict, List, Tuple
from __future__ import annotations

from typing import Any
from urllib.parse import urlencode

from mangum.handlers.utils import (
Expand All @@ -10,12 +12,12 @@
strip_api_gateway_path,
)
from mangum.types import (
Response,
LambdaConfig,
Headers,
LambdaEvent,
LambdaConfig,
LambdaContext,
LambdaEvent,
QueryParams,
Response,
Scope,
)

Expand All @@ -30,7 +32,7 @@ def _encode_query_string_for_apigw(event: LambdaEvent) -> bytes:
return urlencode(params, doseq=True).encode()


def _handle_multi_value_headers_for_request(event: LambdaEvent) -> Dict[str, str]:
def _handle_multi_value_headers_for_request(event: LambdaEvent) -> dict[str, str]:
headers = event.get("headers", {}) or {}
headers = {k.lower(): v for k, v in headers.items()}
if event.get("multiValueHeaders"):
Expand All @@ -46,9 +48,9 @@ def _handle_multi_value_headers_for_request(event: LambdaEvent) -> Dict[str, str

def _combine_headers_v2(
input_headers: Headers,
) -> Tuple[Dict[str, str], List[str]]:
output_headers: Dict[str, str] = {}
cookies: List[str] = []
) -> tuple[dict[str, str], list[str]]:
output_headers: dict[str, str] = {}
cookies: list[str] = []
for key, value in input_headers:
normalized_key: str = key.decode().lower()
normalized_value: str = value.decode()
Expand Down Expand Up @@ -105,7 +107,7 @@ def scope(self) -> Scope:
"aws.context": self.context,
}

def __call__(self, response: Response) -> dict:
def __call__(self, response: Response) -> dict[str, Any]:
finalized_headers, multi_value_headers = handle_multi_value_headers(response["headers"])
finalized_body, is_base64_encoded = handle_base64_response_body(
response["body"], finalized_headers, self.config["text_mime_types"]
Expand Down Expand Up @@ -185,7 +187,7 @@ def scope(self) -> Scope:
"aws.context": self.context,
}

def __call__(self, response: Response) -> dict:
def __call__(self, response: Response) -> dict[str, Any]:
if self.scope["aws.event"]["version"] == "2.0":
finalized_headers, cookies = _combine_headers_v2(response["headers"])

Expand Down
10 changes: 6 additions & 4 deletions mangum/handlers/lambda_at_edge.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Dict, List
from __future__ import annotations

from typing import Any

from mangum.handlers.utils import (
handle_base64_response_body,
handle_exclude_headers,
handle_multi_value_headers,
maybe_encode_body,
)
from mangum.types import Scope, Response, LambdaConfig, LambdaEvent, LambdaContext
from mangum.types import LambdaConfig, LambdaContext, LambdaEvent, Response, Scope


class LambdaAtEdge:
Expand Down Expand Up @@ -66,12 +68,12 @@ def scope(self) -> Scope:
"aws.context": self.context,
}

def __call__(self, response: Response) -> dict:
def __call__(self, response: Response) -> dict[str, Any]:
multi_value_headers, _ = handle_multi_value_headers(response["headers"])
response_body, is_base64_encoded = handle_base64_response_body(
response["body"], multi_value_headers, self.config["text_mime_types"]
)
finalized_headers: Dict[str, List[Dict[str, str]]] = {
finalized_headers: dict[str, list[dict[str, str]]] = {
key.decode().lower(): [{"key": key.decode().lower(), "value": val.decode()}]
for key, val in response["headers"]
}
Expand Down
22 changes: 12 additions & 10 deletions mangum/handlers/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import base64
from typing import Any, Dict, List, Tuple, Union
from typing import Any
from urllib.parse import unquote

from mangum.types import Headers, LambdaConfig


def maybe_encode_body(body: Union[str, bytes], *, is_base64: bool) -> bytes:
def maybe_encode_body(body: str | bytes, *, is_base64: bool) -> bytes:
body = body or b""
if is_base64:
body = base64.b64decode(body)
Expand All @@ -15,7 +17,7 @@ def maybe_encode_body(body: Union[str, bytes], *, is_base64: bool) -> bytes:
return body


def get_server_and_port(headers: dict) -> Tuple[str, int]:
def get_server_and_port(headers: dict[str, Any]) -> tuple[str, int]:
server_name = headers.get("host", "mangum")
if ":" not in server_name:
server_port = headers.get("x-forwarded-port", 80)
Expand All @@ -41,9 +43,9 @@ def strip_api_gateway_path(path: str, *, api_gateway_base_path: str) -> str:

def handle_multi_value_headers(
response_headers: Headers,
) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
headers: Dict[str, str] = {}
multi_value_headers: Dict[str, List[str]] = {}
) -> tuple[dict[str, str], dict[str, list[str]]]:
headers: dict[str, str] = {}
multi_value_headers: dict[str, list[str]] = {}
for key, value in response_headers:
lower_key = key.decode().lower()
if lower_key in multi_value_headers:
Expand All @@ -62,9 +64,9 @@ def handle_multi_value_headers(

def handle_base64_response_body(
body: bytes,
headers: Dict[str, str],
text_mime_types: List[str],
) -> Tuple[str, bool]:
headers: dict[str, str],
text_mime_types: list[str],
) -> tuple[str, bool]:
is_base64_encoded = False
output_body = ""
if body != b"":
Expand All @@ -83,7 +85,7 @@ def handle_base64_response_body(
return output_body, is_base64_encoded


def handle_exclude_headers(headers: Dict[str, Any], config: LambdaConfig) -> Dict[str, Any]:
def handle_exclude_headers(headers: dict[str, Any], config: LambdaConfig) -> dict[str, Any]:
finalized_headers = {}
for header_key, header_value in headers.items():
if header_key in config["exclude_headers"]:
Expand Down
2 changes: 1 addition & 1 deletion mangum/protocols/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .http import HTTPCycle
from .lifespan import LifespanCycleState, LifespanCycle
from .lifespan import LifespanCycle, LifespanCycleState

__all__ = ["HTTPCycle", "LifespanCycleState", "LifespanCycle"]
2 changes: 1 addition & 1 deletion mangum/protocols/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import logging
from io import BytesIO

from mangum.types import ASGI, Message, Scope, Response
from mangum.exceptions import UnexpectedMessage
from mangum.types import ASGI, Message, Response, Scope


class HTTPCycleState(enum.Enum):
Expand Down
15 changes: 8 additions & 7 deletions mangum/protocols/lifespan.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import asyncio
import enum
import logging
from types import TracebackType
from typing import Optional, Type

from mangum.exceptions import LifespanFailure, LifespanUnsupported, UnexpectedMessage
from mangum.types import ASGI, LifespanMode, Message
from mangum.exceptions import LifespanUnsupported, LifespanFailure, UnexpectedMessage


class LifespanCycleState(enum.Enum):
Expand All @@ -21,7 +22,7 @@ class LifespanCycleState(enum.Enum):
* **FAILED** - A lifespan failure has been detected, and the connection will be
closed with an error.
* **UNSUPPORTED** - An application attempted to send a message before receiving
the lifepan startup event. If the lifespan argument is "on", then the connection
the lifespan startup event. If the lifespan argument is "on", then the connection
will be closed with an error.
"""

Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(self, app: ASGI, lifespan: LifespanMode) -> None:
self.app = app
self.lifespan = lifespan
self.state: LifespanCycleState = LifespanCycleState.CONNECTING
self.exception: Optional[BaseException] = None
self.exception: BaseException | None = None
self.loop = asyncio.get_event_loop()
self.app_queue: asyncio.Queue[Message] = asyncio.Queue()
self.startup_event: asyncio.Event = asyncio.Event()
Expand All @@ -70,9 +71,9 @@ def __enter__(self) -> None:

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
"""Runs the event loop for application shutdown."""
self.loop.run_until_complete(self.shutdown())
Expand Down
Loading