Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix dangling coroutines in request extraction handling cleanup #3735

Merged
merged 2 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions litestar/_kwargs/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
if TYPE_CHECKING:
from litestar._kwargs import KwargsModel
from litestar._kwargs.parameter_definition import ParameterDefinition
from litestar._kwargs.types import Extractor
from litestar.connection import ASGIConnection, Request
from litestar.dto import AbstractDTO
from litestar.typing import FieldDefinition
Expand Down Expand Up @@ -83,7 +84,7 @@ def create_connection_value_extractor(
connection_key: str,
expected_params: set[ParameterDefinition],
parser: Callable[[ASGIConnection, KwargsModel], Mapping[str, Any]] | None = None,
) -> Callable[[dict[str, Any], ASGIConnection], None]:
) -> Extractor:
"""Create a kwargs extractor function.

Args:
Expand All @@ -98,7 +99,7 @@ def create_connection_value_extractor(

alias_and_key_tuples, alias_defaults, alias_to_params = _create_param_mappings(expected_params)

def extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
async def extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
data = parser(connection, kwargs_model) if parser else getattr(connection, connection_key, {})

try:
Expand Down Expand Up @@ -178,7 +179,7 @@ def parse_connection_headers(connection: ASGIConnection, _: KwargsModel) -> Head
return Headers.from_scope(connection.scope)


def state_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
async def state_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
"""Extract the app state from the connection and insert it to the kwargs injected to the handler.

Args:
Expand All @@ -191,7 +192,7 @@ def state_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
values["state"] = connection.app.state._state


def headers_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
async def headers_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
"""Extract the headers from the connection and insert them to the kwargs injected to the handler.

Args:
Expand All @@ -206,7 +207,7 @@ def headers_extractor(values: dict[str, Any], connection: ASGIConnection) -> Non
values["headers"] = dict(connection.headers.items())


def cookies_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
async def cookies_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
"""Extract the cookies from the connection and insert them to the kwargs injected to the handler.

Args:
Expand All @@ -219,7 +220,7 @@ def cookies_extractor(values: dict[str, Any], connection: ASGIConnection) -> Non
values["cookies"] = connection.cookies


def query_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
async def query_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
"""Extract the query params from the connection and insert them to the kwargs injected to the handler.

Args:
Expand All @@ -232,7 +233,7 @@ def query_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
values["query"] = connection.query_params


def scope_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
async def scope_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
"""Extract the scope from the connection and insert it into the kwargs injected to the handler.

Args:
Expand All @@ -245,7 +246,7 @@ def scope_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
values["scope"] = connection.scope


def request_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
async def request_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
"""Set the connection instance as the 'request' value in the kwargs injected to the handler.

Args:
Expand All @@ -258,7 +259,7 @@ def request_extractor(values: dict[str, Any], connection: ASGIConnection) -> Non
values["request"] = connection


def socket_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
async def socket_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
"""Set the connection instance as the 'socket' value in the kwargs injected to the handler.

Args:
Expand All @@ -271,7 +272,7 @@ def socket_extractor(values: dict[str, Any], connection: ASGIConnection) -> None
values["socket"] = connection


def body_extractor(
async def body_extractor(
values: dict[str, Any],
connection: Request[Any, Any, Any],
) -> None:
Expand All @@ -287,7 +288,7 @@ def body_extractor(
Returns:
The Body value.
"""
values["body"] = connection.body()
values["body"] = await connection.body()


async def json_extractor(connection: Request[Any, Any, Any]) -> Any:
Expand Down Expand Up @@ -441,7 +442,7 @@ async def extract_url_encoded_extractor(
)


def create_data_extractor(kwargs_model: KwargsModel) -> Callable[[dict[str, Any], ASGIConnection], None]:
def create_data_extractor(kwargs_model: KwargsModel) -> Extractor:
"""Create an extractor for a request's body.

Args:
Expand Down Expand Up @@ -476,11 +477,11 @@ def create_data_extractor(kwargs_model: KwargsModel) -> Callable[[dict[str, Any]
"Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", json_extractor
)

def extractor(
async def extractor(
values: dict[str, Any],
connection: ASGIConnection[Any, Any, Any, Any],
) -> None:
values["data"] = data_extractor(connection)
values["data"] = await data_extractor(connection)

return extractor

Expand Down
15 changes: 8 additions & 7 deletions litestar/_kwargs/kwargs_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any

from anyio import create_task_group

Expand Down Expand Up @@ -40,6 +40,7 @@


if TYPE_CHECKING:
from litestar._kwargs.types import Extractor
from litestar._signature import SignatureModel
from litestar.connection import ASGIConnection
from litestar.di import Provide
Expand Down Expand Up @@ -124,11 +125,11 @@ def __init__(
)

self.is_data_optional = is_data_optional
self.extractors = self._create_extractors()
self.extractors: list[Extractor] = self._create_extractors()
self.dependency_batches = create_dependency_batches(expected_dependencies)

def _create_extractors(self) -> list[Callable[[dict[str, Any], ASGIConnection], None]]:
reserved_kwargs_extractors: dict[str, Callable[[dict[str, Any], ASGIConnection], None]] = {
def _create_extractors(self) -> list[Extractor]:
reserved_kwargs_extractors: dict[str, Extractor] = {
"data": create_data_extractor(self),
"state": state_extractor,
"scope": scope_extractor,
Expand All @@ -140,7 +141,7 @@ def _create_extractors(self) -> list[Callable[[dict[str, Any], ASGIConnection],
"body": body_extractor, # type: ignore[dict-item]
}

extractors: list[Callable[[dict[str, Any], ASGIConnection], None]] = [
extractors: list[Extractor] = [
reserved_kwargs_extractors[reserved_kwarg] for reserved_kwarg in self.expected_reserved_kwargs
]

Expand Down Expand Up @@ -362,7 +363,7 @@ def create_for_signature_model(
sequence_query_parameter_names=sequence_query_parameter_names,
)

def to_kwargs(self, connection: ASGIConnection) -> dict[str, Any]:
async def to_kwargs(self, connection: ASGIConnection) -> dict[str, Any]:
"""Return a dictionary of kwargs. Async values, i.e. CoRoutines, are not resolved to ensure this function is
sync.

Expand All @@ -376,7 +377,7 @@ def to_kwargs(self, connection: ASGIConnection) -> dict[str, Any]:
output: dict[str, Any] = {}

for extractor in self.extractors:
extractor(output, connection)
await extractor(output, connection)

return output

Expand Down
9 changes: 9 additions & 0 deletions litestar/_kwargs/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from typing import Any, Awaitable, Callable, Dict

from typing_extensions import TypeAlias

from litestar.connection import ASGIConnection

Extractor: TypeAlias = Callable[[Dict[str, Any], ASGIConnection], Awaitable[None]]
20 changes: 6 additions & 14 deletions litestar/routes/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,21 +169,13 @@ async def _get_response_data(
cleanup_group: DependencyCleanupGroup | None = None

if parameter_model.has_kwargs and route_handler.signature_model:
kwargs = parameter_model.to_kwargs(connection=request)
try:
kwargs = await parameter_model.to_kwargs(connection=request)
except SerializationException as e:
raise ClientException(str(e)) from e

if "data" in kwargs:
try:
data = await kwargs["data"]
except SerializationException as e:
raise ClientException(str(e)) from e

if data is Empty:
del kwargs["data"]
else:
kwargs["data"] = data

if "body" in kwargs:
kwargs["body"] = await kwargs["body"]
if kwargs.get("data") is Empty:
del kwargs["data"]

if parameter_model.dependency_batches:
cleanup_group = await parameter_model.resolve_dependencies(request, kwargs)
Expand Down
2 changes: 1 addition & 1 deletion litestar/routes/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def handle(self, scope: WebSocketScope, receive: Receive, send: Send) -> N
cleanup_group: DependencyCleanupGroup | None = None

if self.handler_parameter_model.has_kwargs and self.route_handler.signature_model:
parsed_kwargs = self.handler_parameter_model.to_kwargs(connection=websocket)
parsed_kwargs = await self.handler_parameter_model.to_kwargs(connection=websocket)

if self.handler_parameter_model.dependency_batches:
cleanup_group = await self.handler_parameter_model.resolve_dependencies(websocket, parsed_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ skip = 'pdm.lock,docs/examples/contrib/sqlalchemy/us_state_lookup.json'

[tool.coverage.run]
concurrency = ["multiprocessing", "thread"]
omit = ["*/tests/*", "*/litestar/plugins/sqlalchemy.py"]
omit = ["*/tests/*", "*/litestar/plugins/sqlalchemy.py", "*/litestar/_kwargs/types.py"]
parallel = true
plugins = ["covdefaults"]
source = ["litestar"]
Expand Down
13 changes: 12 additions & 1 deletion tests/unit/test_kwargs/test_cookie_params.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Optional, Type

import pytest
from typing_extensions import Annotated

from litestar import get
from litestar import get, post
from litestar.params import Parameter, ParameterKwarg
from litestar.status_codes import HTTP_200_OK, HTTP_400_BAD_REQUEST
from litestar.testing import create_test_client
Expand Down Expand Up @@ -36,3 +37,13 @@ def test_method(special_cookie: t_type = param) -> None: # type: ignore[valid-t
client.cookies = param_dict # type: ignore[assignment]
response = client.get(test_path)
assert response.status_code == expected_code, response.json()


def test_cookie_param_with_post() -> None:
# https://github.com/litestar-org/litestar/issues/3734
@post()
async def handler(data: str, secret: Annotated[str, Parameter(cookie="x-secret")]) -> None:
return None

with create_test_client([handler], raise_server_exceptions=True) as client:
assert client.post("/", json={}).status_code == 400
13 changes: 12 additions & 1 deletion tests/unit/test_kwargs/test_header_params.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Dict, Optional, Union

import pytest
from typing_extensions import Annotated

from litestar import get
from litestar import get, post
from litestar.params import Parameter, ParameterKwarg
from litestar.status_codes import HTTP_200_OK, HTTP_400_BAD_REQUEST
from litestar.testing import create_test_client
Expand Down Expand Up @@ -37,3 +38,13 @@ def test_method(special_header: t_type = param) -> None: # type: ignore[valid-t
assert response.status_code == HTTP_400_BAD_REQUEST, response.json()
else:
assert response.status_code == HTTP_200_OK, response.json()


def test_header_param_with_post() -> None:
# https://github.com/litestar-org/litestar/issues/3734
@post()
async def handler(data: str, secret: Annotated[str, Parameter(header="x-secret")]) -> None:
return None

with create_test_client([handler], raise_server_exceptions=True) as client:
assert client.post("/", json={}).status_code == 400
13 changes: 12 additions & 1 deletion tests/unit/test_kwargs/test_query_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from urllib.parse import urlencode

import pytest
from typing_extensions import Annotated

from litestar import MediaType, Request, get
from litestar import MediaType, Request, get, post
from litestar.datastructures import MultiDict
from litestar.di import Provide
from litestar.params import Parameter
Expand Down Expand Up @@ -221,3 +222,13 @@ def handler(page_size_dep: int) -> str:
response = client.get("/?pageSize=1")
assert response.status_code == HTTP_200_OK, response.text
assert response.text == "1"


def test_query_params_with_post() -> None:
# https://github.com/litestar-org/litestar/issues/3734
@post()
async def handler(data: str, secret: Annotated[str, Parameter(query="x-secret")]) -> None:
return None

with create_test_client([handler], raise_server_exceptions=True) as client:
assert client.post("/", json={}).status_code == 400
Loading