From 9fd80e23958ce82eca62d499f98248fdfae8eea0 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Wed, 8 May 2024 19:32:42 +1000 Subject: [PATCH] fix: prevent starting multiple responses (#3479) * fix: starting multiple responses Prevents the app's exception handler middleware from starting a response after one has already started. When something in the middleware stack raises an exception after a "http.response.start" message has already been sent, we end up with log exception chains that obfuscate the original exception, such as: ```python-traceback ERROR: Exception in ASGI application Traceback (most recent call last): File "/home/peter/PycharmProjects/litestar/litestar/response/streaming.py", line 134, in send_body await self._listen_for_disconnect(cancel_scope=task_group.cancel_scope, receive=receive) File "/home/peter/PycharmProjects/litestar/litestar/response/streaming.py", line 100, in _listen_for_disconnect await self._listen_for_disconnect(cancel_scope=cancel_scope, receive=receive) File "/home/peter/PycharmProjects/litestar/litestar/response/streaming.py", line 94, in _listen_for_disconnect message = await receive() File "/home/peter/.local/share/pdm/venvs/litestar-dj-FOhMr-3.8/lib/python3.8/site-packages/uvicorn/protocols/http/httptools_impl.py", line 568, in receive await self.message_event.wait() File "/home/peter/.pyenv/versions/3.8.18/lib/python3.8/asyncio/locks.py", line 309, in wait await fut asyncio.exceptions.CancelledError During handling of the above exception, another exception occurred: + Exception Group Traceback (most recent call last): | File "/home/peter/PycharmProjects/litestar/litestar/middleware/_internal/exceptions/middleware.py", line 157, in __call__ | await self.app(scope, receive, send) | File "/home/peter/PycharmProjects/litestar/litestar/routes/http.py", line 84, in handle | await response(scope, receive, send) | File "/home/peter/PycharmProjects/litestar/litestar/response/base.py", line 200, in __call__ | await self.send_body(send=send, receive=receive) | File "/home/peter/PycharmProjects/litestar/litestar/response/streaming.py", line 134, in send_body | await self._listen_for_disconnect(cancel_scope=task_group.cancel_scope, receive=receive) | File "/home/peter/.local/share/pdm/venvs/litestar-dj-FOhMr-3.8/lib/python3.8/site-packages/anyio/_backends/_asyncio.py", line 678, in __aexit__ | raise BaseExceptionGroup( | exceptiongroup.ExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception) +-+---------------- 1 ---------------- | Traceback (most recent call last): | File "/home/peter/PycharmProjects/litestar/litestar/response/streaming.py", line 117, in _stream | await send(stream_event) | File "/home/peter/PycharmProjects/litestar/litestar/middleware/logging.py", line 226, in send_wrapper | self.log_response(scope=scope) | File "/home/peter/PycharmProjects/litestar/litestar/middleware/logging.py", line 136, in log_response | extracted_data = self.extract_response_data(scope=scope) | File "/home/peter/PycharmProjects/litestar/litestar/middleware/logging.py", line 194, in extract_response_data | connection_state.log_context.pop(HTTP_RESPONSE_START), | KeyError: 'http.response.start' +------------------------------------ During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/home/peter/PycharmProjects/litestar/litestar/middleware/_internal/exceptions/middleware.py", line 157, in __call__ await self.app(scope, receive, send) File "/home/peter/PycharmProjects/litestar/litestar/_asgi/asgi_router.py", line 99, in __call__ await asgi_app(scope, receive, send) File "/home/peter/PycharmProjects/litestar/litestar/middleware/base.py", line 129, in wrapped_call await original__call__(self, scope, receive, send) # pyright: ignore File "/home/peter/PycharmProjects/litestar/litestar/middleware/logging.py", line 112, in __call__ await self.app(scope, receive, send) File "/home/peter/PycharmProjects/litestar/litestar/middleware/_internal/exceptions/middleware.py", line 174, in __call__ await self.handle_request_exception( File "/home/peter/PycharmProjects/litestar/litestar/middleware/_internal/exceptions/middleware.py", line 204, in handle_request_exception await response.to_asgi_response(app=None, request=request)(scope=scope, receive=receive, send=send) File "/home/peter/PycharmProjects/litestar/litestar/response/base.py", line 194, in __call__ await self.start_response(send=send) File "/home/peter/PycharmProjects/litestar/litestar/response/base.py", line 165, in start_response await send(event) File "/home/peter/PycharmProjects/litestar/litestar/middleware/logging.py", line 227, in send_wrapper await send(message) File "/home/peter/.local/share/pdm/venvs/litestar-dj-FOhMr-3.8/lib/python3.8/site-packages/uvicorn/protocols/http/httptools_impl.py", line 522, in send raise RuntimeError(msg % message_type) RuntimeError: Expected ASGI message 'http.response.body', but got 'http.response.start'. During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/home/peter/.local/share/pdm/venvs/litestar-dj-FOhMr-3.8/lib/python3.8/site-packages/uvicorn/protocols/http/httptools_impl.py", line 411, in run_asgi result = await app( # type: ignore[func-returns-value] File "/home/peter/.local/share/pdm/venvs/litestar-dj-FOhMr-3.8/lib/python3.8/site-packages/uvicorn/middleware/proxy_headers.py", line 69, in __call__ return await self.app(scope, receive, send) File "/home/peter/PycharmProjects/litestar/litestar/app.py", line 591, in __call__ await self.asgi_handler(scope, receive, self._wrap_send(send=send, scope=scope)) # type: ignore[arg-type] File "/home/peter/PycharmProjects/litestar/litestar/middleware/_internal/exceptions/middleware.py", line 174, in __call__ await self.handle_request_exception( File "/home/peter/PycharmProjects/litestar/litestar/middleware/_internal/exceptions/middleware.py", line 204, in handle_request_exception await response.to_asgi_response(app=None, request=request)(scope=scope, receive=receive, send=send) File "/home/peter/PycharmProjects/litestar/litestar/response/base.py", line 194, in __call__ await self.start_response(send=send) File "/home/peter/PycharmProjects/litestar/litestar/response/base.py", line 165, in start_response await send(event) File "/home/peter/.local/share/pdm/venvs/litestar-dj-FOhMr-3.8/lib/python3.8/site-packages/uvicorn/protocols/http/httptools_impl.py", line 522, in send raise RuntimeError(msg % message_type) RuntimeError: Expected ASGI message 'http.response.body', but got 'http.response.start'. ``` This PR tracks whether a response has started, and if so, we immediately raise the exception instead of sending it through the usual exception handling code path. * refactor: raise LitestarException Raise a LitestarException chained from the original exception when exception caught after response started. * test: add test --- .../_internal/exceptions/middleware.py | 13 +++++++- litestar/utils/scope/state.py | 3 ++ .../test_exception_handler_middleware.py | 30 +++++++++++++++++-- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/litestar/middleware/_internal/exceptions/middleware.py b/litestar/middleware/_internal/exceptions/middleware.py index 14db20a9f0..b4460b97f8 100644 --- a/litestar/middleware/_internal/exceptions/middleware.py +++ b/litestar/middleware/_internal/exceptions/middleware.py @@ -29,6 +29,7 @@ ExceptionHandler, ExceptionHandlersMap, Logger, + Message, Receive, Scope, Send, @@ -146,9 +147,19 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: Returns: None """ + scope_state = ScopeState.from_scope(scope) + + async def capture_response_started(event: Message) -> None: + if event["type"] == "http.response.start": + scope_state.response_started = True + await send(event) + try: - await self.app(scope, receive, send) + await self.app(scope, receive, capture_response_started) except Exception as e: # noqa: BLE001 + if scope_state.response_started: + raise LitestarException("Exception caught after response started") from e + litestar_app = scope["app"] if litestar_app.logging_config and (logger := litestar_app.logger): diff --git a/litestar/utils/scope/state.py b/litestar/utils/scope/state.py index 1eefbf9e9a..cc9fd31d5d 100644 --- a/litestar/utils/scope/state.py +++ b/litestar/utils/scope/state.py @@ -44,6 +44,7 @@ class ScopeState: "msgpack", "parsed_query", "response_compressed", + "response_started", "session_id", "url", "_compat_ns", @@ -68,6 +69,7 @@ def __init__(self) -> None: self.msgpack = Empty self.parsed_query = Empty self.response_compressed = Empty + self.response_started = False self.session_id = Empty self.url = Empty self._compat_ns: dict[str, Any] = {} @@ -90,6 +92,7 @@ def __init__(self) -> None: msgpack: Any | EmptyType parsed_query: tuple[tuple[str, str], ...] | EmptyType response_compressed: bool | EmptyType + response_started: bool session_id: str | None | EmptyType url: URL | EmptyType _compat_ns: dict[str, Any] diff --git a/tests/unit/test_middleware/test_exception_handler_middleware.py b/tests/unit/test_middleware/test_exception_handler_middleware.py index e25752a2b6..cfa10fd7e8 100644 --- a/tests/unit/test_middleware/test_exception_handler_middleware.py +++ b/tests/unit/test_middleware/test_exception_handler_middleware.py @@ -9,7 +9,7 @@ from structlog.testing import capture_logs from litestar import Litestar, MediaType, Request, Response, get -from litestar.exceptions import HTTPException, InternalServerException, ValidationException +from litestar.exceptions import HTTPException, InternalServerException, LitestarException, ValidationException from litestar.exceptions.responses._debug_response import get_symbol_name from litestar.logging.config import LoggingConfig, StructLoggingConfig from litestar.middleware._internal.exceptions.middleware import ( @@ -20,7 +20,7 @@ from litestar.status_codes import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR from litestar.testing import TestClient, create_test_client from litestar.types import ExceptionHandlersMap -from litestar.types.asgi_types import HTTPScope +from litestar.types.asgi_types import HTTPReceiveMessage, HTTPScope, Message, Receive, Scope, Send from litestar.utils.scope.state import ScopeState from tests.helpers import cleanup_logging_impl @@ -400,3 +400,29 @@ def handler() -> None: with create_test_client([handler], type_encoders={Foo: lambda f: f.value}) as client: res = client.get("/") assert res.json()["extra"] == {"foo": "bar"} + + +async def test_exception_handler_middleware_response_already_started(scope: HTTPScope) -> None: + assert not ScopeState.from_scope(scope).response_started + + async def mock_receive() -> HTTPReceiveMessage: # type: ignore[empty-body] + pass + + mock = MagicMock() + + async def mock_send(message: Message) -> None: + mock(message) + + start_message: Message = {"type": "http.response.start", "status": 200, "headers": []} + + async def asgi_app(scope: Scope, receive: Receive, send: Send) -> None: + await send(start_message) + raise RuntimeError("Test exception") + + mw = ExceptionHandlerMiddleware(asgi_app, None) + + with pytest.raises(LitestarException): + await mw(scope, mock_receive, mock_send) + + mock.assert_called_once_with(start_message) + assert ScopeState.from_scope(scope).response_started