diff --git a/litestar/data_extractors.py b/litestar/data_extractors.py index 6d4b182133..5a6f6607f0 100644 --- a/litestar/data_extractors.py +++ b/litestar/data_extractors.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Literal, TypedDict, cast +import inspect +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Iterable, Literal, TypedDict, cast from litestar._parsers import parse_cookie_string from litestar.connection.request import Request @@ -70,6 +71,7 @@ class ConnectionDataExtractor: "parse_query", "obfuscate_headers", "obfuscate_cookies", + "skip_parse_malformed_body", ) def __init__( @@ -88,6 +90,7 @@ def __init__( obfuscate_headers: set[str] | None = None, parse_body: bool = False, parse_query: bool = False, + skip_parse_malformed_body: bool = False, ) -> None: """Initialize ``ConnectionDataExtractor`` @@ -106,9 +109,11 @@ def __init__( obfuscate_cookies: cookie keys to obfuscate. Obfuscated values are replaced with '*****'. parse_body: Whether to parse the body value or return the raw byte string, (for requests only). parse_query: Whether to parse query parameters or return the raw byte string. + skip_parse_malformed_body: Whether to skip parsing the body if it is malformed """ self.parse_body = parse_body self.parse_query = parse_query + self.skip_parse_malformed_body = skip_parse_malformed_body self.obfuscate_headers = {h.lower() for h in (obfuscate_headers or set())} self.obfuscate_cookies = {c.lower() for c in (obfuscate_cookies or set())} self.connection_extractors: dict[str, Callable[[ASGIConnection[Any, Any, Any, Any]], Any]] = {} @@ -153,6 +158,25 @@ def __call__(self, connection: ASGIConnection[Any, Any, Any, Any]) -> ExtractedR ) return cast("ExtractedRequestData", {key: extractor(connection) for key, extractor in extractors.items()}) + async def extract( + self, connection: ASGIConnection[Any, Any, Any, Any], fields: Iterable[str] + ) -> ExtractedRequestData: + extractors = ( + {**self.connection_extractors, **self.request_extractors} # type: ignore + if isinstance(connection, Request) + else self.connection_extractors + ) + data = {} + for key, extractor in extractors.items(): + if key not in fields: + continue + if inspect.iscoroutinefunction(extractor): + value = await extractor(connection) + else: + value = extractor(connection) + data[key] = value + return cast("ExtractedRequestData", data) + @staticmethod def extract_scheme(connection: ASGIConnection[Any, Any, Any, Any]) -> str: """Extract the scheme from an ``ASGIConnection`` @@ -272,13 +296,20 @@ async def extract_body(self, request: Request[Any, Any, Any]) -> Any: return None if not self.parse_body: return await request.body() - request_encoding_type = request.content_type[0] - if request_encoding_type == RequestEncodingType.JSON: - return await request.json() - form_data = await request.form() - if request_encoding_type == RequestEncodingType.URL_ENCODED: - return dict(form_data) - return {key: repr(value) if isinstance(value, UploadFile) else value for key, value in form_data.multi_items()} + try: + request_encoding_type = request.content_type[0] + if request_encoding_type == RequestEncodingType.JSON: + return await request.json() + form_data = await request.form() + if request_encoding_type == RequestEncodingType.URL_ENCODED: + return dict(form_data) + return { + key: repr(value) if isinstance(value, UploadFile) else value for key, value in form_data.multi_items() + } + except Exception as exc: + if self.skip_parse_malformed_body: + return await request.body() + raise exc class ExtractedResponseData(TypedDict, total=False): diff --git a/litestar/middleware/logging.py b/litestar/middleware/logging.py index d52963b60e..dc827e303e 100644 --- a/litestar/middleware/logging.py +++ b/litestar/middleware/logging.py @@ -1,7 +1,6 @@ from __future__ import annotations from dataclasses import dataclass, field -from inspect import isawaitable from typing import TYPE_CHECKING, Any, Iterable from litestar.constants import ( @@ -81,6 +80,7 @@ def __init__(self, app: ASGIApp, config: LoggingMiddlewareConfig) -> None: obfuscate_headers=self.config.request_headers_to_obfuscate, parse_body=self.is_struct_logger, parse_query=self.is_struct_logger, + skip_parse_malformed_body=True, ) self.response_extractor = ResponseDataExtractor( extract_body="body" in self.config.response_log_fields, @@ -172,12 +172,11 @@ async def extract_request_data(self, request: Request) -> dict[str, Any]: data: dict[str, Any] = {"message": self.config.request_log_message} serializer = get_serializer_from_scope(request.scope) - extracted_data = self.request_extractor(connection=request) + + extracted_data = await self.request_extractor.extract(connection=request, fields=self.config.request_log_fields) + for key in self.config.request_log_fields: - value = extracted_data.get(key) - if isawaitable(value): - value = await value - data[key] = self._serialize_value(serializer, value) + data[key] = self._serialize_value(serializer, extracted_data.get(key)) return data def extract_response_data(self, scope: Scope) -> dict[str, Any]: diff --git a/tests/unit/test_data_extractors.py b/tests/unit/test_data_extractors.py index 31c6103ee8..b204707bd1 100644 --- a/tests/unit/test_data_extractors.py +++ b/tests/unit/test_data_extractors.py @@ -1,6 +1,8 @@ from typing import Any, List +from unittest.mock import AsyncMock import pytest +from pytest_mock import MockFixture from litestar import Request from litestar.connection.base import empty_receive @@ -108,3 +110,18 @@ async def send(message: "Any") -> None: assert extracted_data.get("body") == b'{"hello":"world"}' assert extracted_data.get("headers") == {**headers, "content-length": "17"} assert extracted_data.get("cookies") == {"Path": "/", "SameSite": "lax", "auth": "", "regular": ""} + + +async def test_request_data_extractor_skip_keys() -> None: + req = factory.get() + extractor = ConnectionDataExtractor() + assert (await extractor.extract(req, {"body"})).keys() == {"body"} + + +async def test_skip_parse_malformed_body_false_raises(mocker: MockFixture) -> None: + mocker.patch("litestar.testing.request_factory.Request.json", new=AsyncMock(side_effect=ValueError())) + req = factory.post(headers={"Content-Type": "application/json"}) + extractor = ConnectionDataExtractor(parse_body=True, skip_parse_malformed_body=False) + + with pytest.raises(ValueError): + await extractor.extract(req, {"body"}) diff --git a/tests/unit/test_middleware/test_logging_middleware.py b/tests/unit/test_middleware/test_logging_middleware.py index 98b1bee3e1..761a088c62 100644 --- a/tests/unit/test_middleware/test_logging_middleware.py +++ b/tests/unit/test_middleware/test_logging_middleware.py @@ -1,5 +1,5 @@ from logging import INFO -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Any, Dict import pytest from structlog.testing import capture_logs @@ -286,3 +286,17 @@ async def get_session() -> None: assert response.status_code == HTTP_200_OK assert "session" in client.cookies assert client.cookies["session"] == session_id + + +def test_structlog_invalid_request_body_handled() -> None: + # https://github.com/litestar-org/litestar/issues/3063 + @post("/") + async def hello_world(data: Dict[str, Any]) -> Dict[str, Any]: + return data + + with create_test_client( + route_handlers=[hello_world], + logging_config=StructLoggingConfig(log_exceptions="always"), + middleware=[LoggingMiddlewareConfig().middleware], + ) as client: + assert client.post("/", headers={"Content-Type": "application/json"}, content=b'{"a": "b",}').status_code == 400