Skip to content

Commit

Permalink
fix: prevent starting multiple responses (#3479)
Browse files Browse the repository at this point in the history
* fix: starting multiple responses

Prevents the app's exception handler middleware from starting a response after one has already started.

When something in the middleware stack raises an exception after a "http.response.start" message has already been sent, we end up with log exception chains that obfuscate the original exception, such as:

```python-traceback
ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/home/peter/PycharmProjects/litestar/litestar/response/streaming.py", line 134, in send_body
    await self._listen_for_disconnect(cancel_scope=task_group.cancel_scope, receive=receive)
  File "/home/peter/PycharmProjects/litestar/litestar/response/streaming.py", line 100, in _listen_for_disconnect
    await self._listen_for_disconnect(cancel_scope=cancel_scope, receive=receive)
  File "/home/peter/PycharmProjects/litestar/litestar/response/streaming.py", line 94, in _listen_for_disconnect
    message = await receive()
  File "/home/peter/.local/share/pdm/venvs/litestar-dj-FOhMr-3.8/lib/python3.8/site-packages/uvicorn/protocols/http/httptools_impl.py", line 568, in receive
    await self.message_event.wait()
  File "/home/peter/.pyenv/versions/3.8.18/lib/python3.8/asyncio/locks.py", line 309, in wait
    await fut
asyncio.exceptions.CancelledError

During handling of the above exception, another exception occurred:

  + Exception Group Traceback (most recent call last):
  |   File "/home/peter/PycharmProjects/litestar/litestar/middleware/_internal/exceptions/middleware.py", line 157, in __call__
  |     await self.app(scope, receive, send)
  |   File "/home/peter/PycharmProjects/litestar/litestar/routes/http.py", line 84, in handle
  |     await response(scope, receive, send)
  |   File "/home/peter/PycharmProjects/litestar/litestar/response/base.py", line 200, in __call__
  |     await self.send_body(send=send, receive=receive)
  |   File "/home/peter/PycharmProjects/litestar/litestar/response/streaming.py", line 134, in send_body
  |     await self._listen_for_disconnect(cancel_scope=task_group.cancel_scope, receive=receive)
  |   File "/home/peter/.local/share/pdm/venvs/litestar-dj-FOhMr-3.8/lib/python3.8/site-packages/anyio/_backends/_asyncio.py", line 678, in __aexit__
  |     raise BaseExceptionGroup(
  | exceptiongroup.ExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception)
  +-+---------------- 1 ----------------
    | Traceback (most recent call last):
    |   File "/home/peter/PycharmProjects/litestar/litestar/response/streaming.py", line 117, in _stream
    |     await send(stream_event)
    |   File "/home/peter/PycharmProjects/litestar/litestar/middleware/logging.py", line 226, in send_wrapper
    |     self.log_response(scope=scope)
    |   File "/home/peter/PycharmProjects/litestar/litestar/middleware/logging.py", line 136, in log_response
    |     extracted_data = self.extract_response_data(scope=scope)
    |   File "/home/peter/PycharmProjects/litestar/litestar/middleware/logging.py", line 194, in extract_response_data
    |     connection_state.log_context.pop(HTTP_RESPONSE_START),
    | KeyError: 'http.response.start'
    +------------------------------------

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/peter/PycharmProjects/litestar/litestar/middleware/_internal/exceptions/middleware.py", line 157, in __call__
    await self.app(scope, receive, send)
  File "/home/peter/PycharmProjects/litestar/litestar/_asgi/asgi_router.py", line 99, in __call__
    await asgi_app(scope, receive, send)
  File "/home/peter/PycharmProjects/litestar/litestar/middleware/base.py", line 129, in wrapped_call
    await original__call__(self, scope, receive, send)  # pyright: ignore
  File "/home/peter/PycharmProjects/litestar/litestar/middleware/logging.py", line 112, in __call__
    await self.app(scope, receive, send)
  File "/home/peter/PycharmProjects/litestar/litestar/middleware/_internal/exceptions/middleware.py", line 174, in __call__
    await self.handle_request_exception(
  File "/home/peter/PycharmProjects/litestar/litestar/middleware/_internal/exceptions/middleware.py", line 204, in handle_request_exception
    await response.to_asgi_response(app=None, request=request)(scope=scope, receive=receive, send=send)
  File "/home/peter/PycharmProjects/litestar/litestar/response/base.py", line 194, in __call__
    await self.start_response(send=send)
  File "/home/peter/PycharmProjects/litestar/litestar/response/base.py", line 165, in start_response
    await send(event)
  File "/home/peter/PycharmProjects/litestar/litestar/middleware/logging.py", line 227, in send_wrapper
    await send(message)
  File "/home/peter/.local/share/pdm/venvs/litestar-dj-FOhMr-3.8/lib/python3.8/site-packages/uvicorn/protocols/http/httptools_impl.py", line 522, in send
    raise RuntimeError(msg % message_type)
RuntimeError: Expected ASGI message 'http.response.body', but got 'http.response.start'.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/peter/.local/share/pdm/venvs/litestar-dj-FOhMr-3.8/lib/python3.8/site-packages/uvicorn/protocols/http/httptools_impl.py", line 411, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
  File "/home/peter/.local/share/pdm/venvs/litestar-dj-FOhMr-3.8/lib/python3.8/site-packages/uvicorn/middleware/proxy_headers.py", line 69, in __call__
    return await self.app(scope, receive, send)
  File "/home/peter/PycharmProjects/litestar/litestar/app.py", line 591, in __call__
    await self.asgi_handler(scope, receive, self._wrap_send(send=send, scope=scope))  # type: ignore[arg-type]
  File "/home/peter/PycharmProjects/litestar/litestar/middleware/_internal/exceptions/middleware.py", line 174, in __call__
    await self.handle_request_exception(
  File "/home/peter/PycharmProjects/litestar/litestar/middleware/_internal/exceptions/middleware.py", line 204, in handle_request_exception
    await response.to_asgi_response(app=None, request=request)(scope=scope, receive=receive, send=send)
  File "/home/peter/PycharmProjects/litestar/litestar/response/base.py", line 194, in __call__
    await self.start_response(send=send)
  File "/home/peter/PycharmProjects/litestar/litestar/response/base.py", line 165, in start_response
    await send(event)
  File "/home/peter/.local/share/pdm/venvs/litestar-dj-FOhMr-3.8/lib/python3.8/site-packages/uvicorn/protocols/http/httptools_impl.py", line 522, in send
    raise RuntimeError(msg % message_type)
RuntimeError: Expected ASGI message 'http.response.body', but got 'http.response.start'.
```

This PR tracks whether a response has started, and if so, we immediately raise the exception instead of sending it through the usual exception handling code path.

* refactor: raise LitestarException

Raise a LitestarException chained from the original exception when exception caught after response started.

* test: add test
  • Loading branch information
peterschutt authored May 8, 2024
1 parent 3125235 commit 9fd80e2
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 3 deletions.
13 changes: 12 additions & 1 deletion litestar/middleware/_internal/exceptions/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ExceptionHandler,
ExceptionHandlersMap,
Logger,
Message,
Receive,
Scope,
Send,
Expand Down Expand Up @@ -146,9 +147,19 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
Returns:
None
"""
scope_state = ScopeState.from_scope(scope)

async def capture_response_started(event: Message) -> None:
if event["type"] == "http.response.start":
scope_state.response_started = True
await send(event)

try:
await self.app(scope, receive, send)
await self.app(scope, receive, capture_response_started)
except Exception as e: # noqa: BLE001
if scope_state.response_started:
raise LitestarException("Exception caught after response started") from e

litestar_app = scope["app"]

if litestar_app.logging_config and (logger := litestar_app.logger):
Expand Down
3 changes: 3 additions & 0 deletions litestar/utils/scope/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class ScopeState:
"msgpack",
"parsed_query",
"response_compressed",
"response_started",
"session_id",
"url",
"_compat_ns",
Expand All @@ -68,6 +69,7 @@ def __init__(self) -> None:
self.msgpack = Empty
self.parsed_query = Empty
self.response_compressed = Empty
self.response_started = False
self.session_id = Empty
self.url = Empty
self._compat_ns: dict[str, Any] = {}
Expand All @@ -90,6 +92,7 @@ def __init__(self) -> None:
msgpack: Any | EmptyType
parsed_query: tuple[tuple[str, str], ...] | EmptyType
response_compressed: bool | EmptyType
response_started: bool
session_id: str | None | EmptyType
url: URL | EmptyType
_compat_ns: dict[str, Any]
Expand Down
30 changes: 28 additions & 2 deletions tests/unit/test_middleware/test_exception_handler_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from structlog.testing import capture_logs

from litestar import Litestar, MediaType, Request, Response, get
from litestar.exceptions import HTTPException, InternalServerException, ValidationException
from litestar.exceptions import HTTPException, InternalServerException, LitestarException, ValidationException
from litestar.exceptions.responses._debug_response import get_symbol_name
from litestar.logging.config import LoggingConfig, StructLoggingConfig
from litestar.middleware._internal.exceptions.middleware import (
Expand All @@ -20,7 +20,7 @@
from litestar.status_codes import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
from litestar.testing import TestClient, create_test_client
from litestar.types import ExceptionHandlersMap
from litestar.types.asgi_types import HTTPScope
from litestar.types.asgi_types import HTTPReceiveMessage, HTTPScope, Message, Receive, Scope, Send
from litestar.utils.scope.state import ScopeState
from tests.helpers import cleanup_logging_impl

Expand Down Expand Up @@ -400,3 +400,29 @@ def handler() -> None:
with create_test_client([handler], type_encoders={Foo: lambda f: f.value}) as client:
res = client.get("/")
assert res.json()["extra"] == {"foo": "bar"}


async def test_exception_handler_middleware_response_already_started(scope: HTTPScope) -> None:
assert not ScopeState.from_scope(scope).response_started

async def mock_receive() -> HTTPReceiveMessage: # type: ignore[empty-body]
pass

mock = MagicMock()

async def mock_send(message: Message) -> None:
mock(message)

start_message: Message = {"type": "http.response.start", "status": 200, "headers": []}

async def asgi_app(scope: Scope, receive: Receive, send: Send) -> None:
await send(start_message)
raise RuntimeError("Test exception")

mw = ExceptionHandlerMiddleware(asgi_app, None)

with pytest.raises(LitestarException):
await mw(scope, mock_receive, mock_send)

mock.assert_called_once_with(start_message)
assert ScopeState.from_scope(scope).response_started

0 comments on commit 9fd80e2

Please sign in to comment.