Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(3063): Logging middleware with structlog causes application to return HTTP 500 when request body is malformed #3109

Merged
merged 4 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions litestar/data_extractors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Coroutine, Literal, TypedDict, cast
import inspect
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Iterable, Literal, TypedDict, cast

from litestar._parsers import parse_cookie_string
from litestar.connection.request import Request
Expand Down Expand Up @@ -70,6 +71,7 @@ class ConnectionDataExtractor:
"parse_query",
"obfuscate_headers",
"obfuscate_cookies",
"skip_parse_malformed_body",
)

def __init__(
Expand All @@ -88,6 +90,7 @@ def __init__(
obfuscate_headers: set[str] | None = None,
parse_body: bool = False,
parse_query: bool = False,
skip_parse_malformed_body: bool = False,
) -> None:
"""Initialize ``ConnectionDataExtractor``

Expand All @@ -106,9 +109,11 @@ def __init__(
obfuscate_cookies: cookie keys to obfuscate. Obfuscated values are replaced with '*****'.
parse_body: Whether to parse the body value or return the raw byte string, (for requests only).
parse_query: Whether to parse query parameters or return the raw byte string.
skip_parse_malformed_body: Whether to skip parsing the body if it is malformed
"""
self.parse_body = parse_body
self.parse_query = parse_query
self.skip_parse_malformed_body = skip_parse_malformed_body
self.obfuscate_headers = {h.lower() for h in (obfuscate_headers or set())}
self.obfuscate_cookies = {c.lower() for c in (obfuscate_cookies or set())}
self.connection_extractors: dict[str, Callable[[ASGIConnection[Any, Any, Any, Any]], Any]] = {}
Expand Down Expand Up @@ -153,6 +158,25 @@ def __call__(self, connection: ASGIConnection[Any, Any, Any, Any]) -> ExtractedR
)
return cast("ExtractedRequestData", {key: extractor(connection) for key, extractor in extractors.items()})

async def extract(
self, connection: ASGIConnection[Any, Any, Any, Any], fields: Iterable[str]
) -> ExtractedRequestData:
extractors = (
{**self.connection_extractors, **self.request_extractors} # type: ignore
if isinstance(connection, Request)
else self.connection_extractors
)
data = {}
for key, extractor in extractors.items():
if key not in fields:
continue
if inspect.iscoroutinefunction(extractor):
value = await extractor(connection)
else:
value = extractor(connection)
data[key] = value
return data

@staticmethod
def extract_scheme(connection: ASGIConnection[Any, Any, Any, Any]) -> str:
"""Extract the scheme from an ``ASGIConnection``
Expand Down Expand Up @@ -272,13 +296,20 @@ async def extract_body(self, request: Request[Any, Any, Any]) -> Any:
return None
if not self.parse_body:
return await request.body()
request_encoding_type = request.content_type[0]
if request_encoding_type == RequestEncodingType.JSON:
return await request.json()
form_data = await request.form()
if request_encoding_type == RequestEncodingType.URL_ENCODED:
return dict(form_data)
return {key: repr(value) if isinstance(value, UploadFile) else value for key, value in form_data.multi_items()}
try:
request_encoding_type = request.content_type[0]
if request_encoding_type == RequestEncodingType.JSON:
return await request.json()
form_data = await request.form()
if request_encoding_type == RequestEncodingType.URL_ENCODED:
return dict(form_data)
return {
key: repr(value) if isinstance(value, UploadFile) else value for key, value in form_data.multi_items()
}
except Exception as exc:
if self.skip_parse_on_exception:
return await request.body()
raise exc


class ExtractedResponseData(TypedDict, total=False):
Expand Down
11 changes: 5 additions & 6 deletions litestar/middleware/logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass, field
from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Iterable

from litestar.constants import (
Expand Down Expand Up @@ -81,6 +80,7 @@ def __init__(self, app: ASGIApp, config: LoggingMiddlewareConfig) -> None:
obfuscate_headers=self.config.request_headers_to_obfuscate,
parse_body=self.is_struct_logger,
parse_query=self.is_struct_logger,
skip_parse_malformed_body=True,
)
self.response_extractor = ResponseDataExtractor(
extract_body="body" in self.config.response_log_fields,
Expand Down Expand Up @@ -172,12 +172,11 @@ async def extract_request_data(self, request: Request) -> dict[str, Any]:

data: dict[str, Any] = {"message": self.config.request_log_message}
serializer = get_serializer_from_scope(request.scope)
extracted_data = self.request_extractor(connection=request)

extracted_data = await self.request_extractor.extract(connection=request, fields=self.config.request_log_fields)

for key in self.config.request_log_fields:
value = extracted_data.get(key)
if isawaitable(value):
value = await value
data[key] = self._serialize_value(serializer, value)
data[key] = self._serialize_value(serializer, extracted_data.get(key))
return data

def extract_response_data(self, scope: Scope) -> dict[str, Any]:
Expand Down
16 changes: 15 additions & 1 deletion tests/unit/test_middleware/test_logging_middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from logging import INFO
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Any, Dict

import pytest
from structlog.testing import capture_logs
Expand Down Expand Up @@ -286,3 +286,17 @@ async def get_session() -> None:
assert response.status_code == HTTP_200_OK
assert "session" in client.cookies
assert client.cookies["session"] == session_id


def test_structlog_invalid_request_body_handled():
# https://github.com/litestar-org/litestar/issues/3063
@post("/")
async def hello_world(data: dict[str, Any]) -> dict[str, Any]:
return data

with create_test_client(
route_handlers=[hello_world],
logging_config=StructLoggingConfig(log_exceptions="always"),
middleware=[LoggingMiddlewareConfig().middleware],
) as client:
assert client.post("/", headers={"Content-Type": "application/json"}, content=b'{"a": "b",}').status_code == 400
Loading