diff --git a/litestar/_kwargs/extractors.py b/litestar/_kwargs/extractors.py index 20f6a46cee..e95add74ae 100644 --- a/litestar/_kwargs/extractors.py +++ b/litestar/_kwargs/extractors.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from litestar._kwargs import KwargsModel from litestar._kwargs.parameter_definition import ParameterDefinition + from litestar._kwargs.types import Extractor from litestar.connection import ASGIConnection, Request from litestar.dto import AbstractDTO from litestar.typing import FieldDefinition @@ -83,7 +84,7 @@ def create_connection_value_extractor( connection_key: str, expected_params: set[ParameterDefinition], parser: Callable[[ASGIConnection, KwargsModel], Mapping[str, Any]] | None = None, -) -> Callable[[dict[str, Any], ASGIConnection], None]: +) -> Extractor: """Create a kwargs extractor function. Args: @@ -98,7 +99,7 @@ def create_connection_value_extractor( alias_and_key_tuples, alias_defaults, alias_to_params = _create_param_mappings(expected_params) - def extractor(values: dict[str, Any], connection: ASGIConnection) -> None: + async def extractor(values: dict[str, Any], connection: ASGIConnection) -> None: data = parser(connection, kwargs_model) if parser else getattr(connection, connection_key, {}) try: @@ -178,7 +179,7 @@ def parse_connection_headers(connection: ASGIConnection, _: KwargsModel) -> Head return Headers.from_scope(connection.scope) -def state_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: +async def state_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: """Extract the app state from the connection and insert it to the kwargs injected to the handler. Args: @@ -191,7 +192,7 @@ def state_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: values["state"] = connection.app.state._state -def headers_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: +async def headers_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: """Extract the headers from the connection and insert them to the kwargs injected to the handler. Args: @@ -206,7 +207,7 @@ def headers_extractor(values: dict[str, Any], connection: ASGIConnection) -> Non values["headers"] = dict(connection.headers.items()) -def cookies_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: +async def cookies_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: """Extract the cookies from the connection and insert them to the kwargs injected to the handler. Args: @@ -219,7 +220,7 @@ def cookies_extractor(values: dict[str, Any], connection: ASGIConnection) -> Non values["cookies"] = connection.cookies -def query_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: +async def query_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: """Extract the query params from the connection and insert them to the kwargs injected to the handler. Args: @@ -232,7 +233,7 @@ def query_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: values["query"] = connection.query_params -def scope_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: +async def scope_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: """Extract the scope from the connection and insert it into the kwargs injected to the handler. Args: @@ -245,7 +246,7 @@ def scope_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: values["scope"] = connection.scope -def request_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: +async def request_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: """Set the connection instance as the 'request' value in the kwargs injected to the handler. Args: @@ -258,7 +259,7 @@ def request_extractor(values: dict[str, Any], connection: ASGIConnection) -> Non values["request"] = connection -def socket_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: +async def socket_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: """Set the connection instance as the 'socket' value in the kwargs injected to the handler. Args: @@ -271,7 +272,7 @@ def socket_extractor(values: dict[str, Any], connection: ASGIConnection) -> None values["socket"] = connection -def body_extractor( +async def body_extractor( values: dict[str, Any], connection: Request[Any, Any, Any], ) -> None: @@ -287,7 +288,7 @@ def body_extractor( Returns: The Body value. """ - values["body"] = connection.body() + values["body"] = await connection.body() async def json_extractor(connection: Request[Any, Any, Any]) -> Any: @@ -441,7 +442,7 @@ async def extract_url_encoded_extractor( ) -def create_data_extractor(kwargs_model: KwargsModel) -> Callable[[dict[str, Any], ASGIConnection], None]: +def create_data_extractor(kwargs_model: KwargsModel) -> Extractor: """Create an extractor for a request's body. Args: @@ -476,11 +477,11 @@ def create_data_extractor(kwargs_model: KwargsModel) -> Callable[[dict[str, Any] "Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", json_extractor ) - def extractor( + async def extractor( values: dict[str, Any], connection: ASGIConnection[Any, Any, Any, Any], ) -> None: - values["data"] = data_extractor(connection) + values["data"] = await data_extractor(connection) return extractor diff --git a/litestar/_kwargs/kwargs_model.py b/litestar/_kwargs/kwargs_model.py index 01ed2e5aef..9d757e92f7 100644 --- a/litestar/_kwargs/kwargs_model.py +++ b/litestar/_kwargs/kwargs_model.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any from anyio import create_task_group @@ -40,6 +40,7 @@ if TYPE_CHECKING: + from litestar._kwargs.types import Extractor from litestar._signature import SignatureModel from litestar.connection import ASGIConnection from litestar.di import Provide @@ -124,11 +125,11 @@ def __init__( ) self.is_data_optional = is_data_optional - self.extractors = self._create_extractors() + self.extractors: list[Extractor] = self._create_extractors() self.dependency_batches = create_dependency_batches(expected_dependencies) - def _create_extractors(self) -> list[Callable[[dict[str, Any], ASGIConnection], None]]: - reserved_kwargs_extractors: dict[str, Callable[[dict[str, Any], ASGIConnection], None]] = { + def _create_extractors(self) -> list[Extractor]: + reserved_kwargs_extractors: dict[str, Extractor] = { "data": create_data_extractor(self), "state": state_extractor, "scope": scope_extractor, @@ -140,7 +141,7 @@ def _create_extractors(self) -> list[Callable[[dict[str, Any], ASGIConnection], "body": body_extractor, # type: ignore[dict-item] } - extractors: list[Callable[[dict[str, Any], ASGIConnection], None]] = [ + extractors: list[Extractor] = [ reserved_kwargs_extractors[reserved_kwarg] for reserved_kwarg in self.expected_reserved_kwargs ] @@ -362,7 +363,7 @@ def create_for_signature_model( sequence_query_parameter_names=sequence_query_parameter_names, ) - def to_kwargs(self, connection: ASGIConnection) -> dict[str, Any]: + async def to_kwargs(self, connection: ASGIConnection) -> dict[str, Any]: """Return a dictionary of kwargs. Async values, i.e. CoRoutines, are not resolved to ensure this function is sync. @@ -376,7 +377,7 @@ def to_kwargs(self, connection: ASGIConnection) -> dict[str, Any]: output: dict[str, Any] = {} for extractor in self.extractors: - extractor(output, connection) + await extractor(output, connection) return output diff --git a/litestar/_kwargs/types.py b/litestar/_kwargs/types.py new file mode 100644 index 0000000000..5d9f343aa9 --- /dev/null +++ b/litestar/_kwargs/types.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from typing import Any, Awaitable, Callable, Dict + +from typing_extensions import TypeAlias + +from litestar.connection import ASGIConnection + +Extractor: TypeAlias = Callable[[Dict[str, Any], ASGIConnection], Awaitable[None]] diff --git a/litestar/routes/http.py b/litestar/routes/http.py index 99ef4afe78..a8e47f1b99 100644 --- a/litestar/routes/http.py +++ b/litestar/routes/http.py @@ -169,21 +169,13 @@ async def _get_response_data( cleanup_group: DependencyCleanupGroup | None = None if parameter_model.has_kwargs and route_handler.signature_model: - kwargs = parameter_model.to_kwargs(connection=request) + try: + kwargs = await parameter_model.to_kwargs(connection=request) + except SerializationException as e: + raise ClientException(str(e)) from e - if "data" in kwargs: - try: - data = await kwargs["data"] - except SerializationException as e: - raise ClientException(str(e)) from e - - if data is Empty: - del kwargs["data"] - else: - kwargs["data"] = data - - if "body" in kwargs: - kwargs["body"] = await kwargs["body"] + if kwargs.get("data") is Empty: + del kwargs["data"] if parameter_model.dependency_batches: cleanup_group = await parameter_model.resolve_dependencies(request, kwargs) diff --git a/litestar/routes/websocket.py b/litestar/routes/websocket.py index 3248e2a83c..6346e2e0e5 100644 --- a/litestar/routes/websocket.py +++ b/litestar/routes/websocket.py @@ -69,7 +69,7 @@ async def handle(self, scope: WebSocketScope, receive: Receive, send: Send) -> N cleanup_group: DependencyCleanupGroup | None = None if self.handler_parameter_model.has_kwargs and self.route_handler.signature_model: - parsed_kwargs = self.handler_parameter_model.to_kwargs(connection=websocket) + parsed_kwargs = await self.handler_parameter_model.to_kwargs(connection=websocket) if self.handler_parameter_model.dependency_batches: cleanup_group = await self.handler_parameter_model.resolve_dependencies(websocket, parsed_kwargs) diff --git a/pyproject.toml b/pyproject.toml index 223e941780..93b11561c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -201,7 +201,7 @@ skip = 'pdm.lock,docs/examples/contrib/sqlalchemy/us_state_lookup.json' [tool.coverage.run] concurrency = ["multiprocessing", "thread"] -omit = ["*/tests/*", "*/litestar/plugins/sqlalchemy.py"] +omit = ["*/tests/*", "*/litestar/plugins/sqlalchemy.py", "*/litestar/_kwargs/types.py"] parallel = true plugins = ["covdefaults"] source = ["litestar"] diff --git a/tests/unit/test_kwargs/test_cookie_params.py b/tests/unit/test_kwargs/test_cookie_params.py index 0a23c6eafa..7111fe0f70 100644 --- a/tests/unit/test_kwargs/test_cookie_params.py +++ b/tests/unit/test_kwargs/test_cookie_params.py @@ -1,8 +1,9 @@ from typing import Optional, Type import pytest +from typing_extensions import Annotated -from litestar import get +from litestar import get, post from litestar.params import Parameter, ParameterKwarg from litestar.status_codes import HTTP_200_OK, HTTP_400_BAD_REQUEST from litestar.testing import create_test_client @@ -36,3 +37,13 @@ def test_method(special_cookie: t_type = param) -> None: # type: ignore[valid-t client.cookies = param_dict # type: ignore[assignment] response = client.get(test_path) assert response.status_code == expected_code, response.json() + + +def test_cookie_param_with_post() -> None: + # https://github.com/litestar-org/litestar/issues/3734 + @post() + async def handler(data: str, secret: Annotated[str, Parameter(cookie="x-secret")]) -> None: + return None + + with create_test_client([handler], raise_server_exceptions=True) as client: + assert client.post("/", json={}).status_code == 400 diff --git a/tests/unit/test_kwargs/test_header_params.py b/tests/unit/test_kwargs/test_header_params.py index 8a654ba779..f281647bcc 100644 --- a/tests/unit/test_kwargs/test_header_params.py +++ b/tests/unit/test_kwargs/test_header_params.py @@ -1,8 +1,9 @@ from typing import Dict, Optional, Union import pytest +from typing_extensions import Annotated -from litestar import get +from litestar import get, post from litestar.params import Parameter, ParameterKwarg from litestar.status_codes import HTTP_200_OK, HTTP_400_BAD_REQUEST from litestar.testing import create_test_client @@ -37,3 +38,13 @@ def test_method(special_header: t_type = param) -> None: # type: ignore[valid-t assert response.status_code == HTTP_400_BAD_REQUEST, response.json() else: assert response.status_code == HTTP_200_OK, response.json() + + +def test_header_param_with_post() -> None: + # https://github.com/litestar-org/litestar/issues/3734 + @post() + async def handler(data: str, secret: Annotated[str, Parameter(header="x-secret")]) -> None: + return None + + with create_test_client([handler], raise_server_exceptions=True) as client: + assert client.post("/", json={}).status_code == 400 diff --git a/tests/unit/test_kwargs/test_query_params.py b/tests/unit/test_kwargs/test_query_params.py index a8541c844b..a2b60c4cb7 100644 --- a/tests/unit/test_kwargs/test_query_params.py +++ b/tests/unit/test_kwargs/test_query_params.py @@ -10,8 +10,9 @@ from urllib.parse import urlencode import pytest +from typing_extensions import Annotated -from litestar import MediaType, Request, get +from litestar import MediaType, Request, get, post from litestar.datastructures import MultiDict from litestar.di import Provide from litestar.params import Parameter @@ -221,3 +222,13 @@ def handler(page_size_dep: int) -> str: response = client.get("/?pageSize=1") assert response.status_code == HTTP_200_OK, response.text assert response.text == "1" + + +def test_query_params_with_post() -> None: + # https://github.com/litestar-org/litestar/issues/3734 + @post() + async def handler(data: str, secret: Annotated[str, Parameter(query="x-secret")]) -> None: + return None + + with create_test_client([handler], raise_server_exceptions=True) as client: + assert client.post("/", json={}).status_code == 400