diff --git a/litestar/middleware/_internal/exceptions/middleware.py b/litestar/middleware/_internal/exceptions/middleware.py index 14db20a9f0..b4460b97f8 100644 --- a/litestar/middleware/_internal/exceptions/middleware.py +++ b/litestar/middleware/_internal/exceptions/middleware.py @@ -29,6 +29,7 @@ ExceptionHandler, ExceptionHandlersMap, Logger, + Message, Receive, Scope, Send, @@ -146,9 +147,19 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: Returns: None """ + scope_state = ScopeState.from_scope(scope) + + async def capture_response_started(event: Message) -> None: + if event["type"] == "http.response.start": + scope_state.response_started = True + await send(event) + try: - await self.app(scope, receive, send) + await self.app(scope, receive, capture_response_started) except Exception as e: # noqa: BLE001 + if scope_state.response_started: + raise LitestarException("Exception caught after response started") from e + litestar_app = scope["app"] if litestar_app.logging_config and (logger := litestar_app.logger): diff --git a/litestar/utils/scope/state.py b/litestar/utils/scope/state.py index 1eefbf9e9a..cc9fd31d5d 100644 --- a/litestar/utils/scope/state.py +++ b/litestar/utils/scope/state.py @@ -44,6 +44,7 @@ class ScopeState: "msgpack", "parsed_query", "response_compressed", + "response_started", "session_id", "url", "_compat_ns", @@ -68,6 +69,7 @@ def __init__(self) -> None: self.msgpack = Empty self.parsed_query = Empty self.response_compressed = Empty + self.response_started = False self.session_id = Empty self.url = Empty self._compat_ns: dict[str, Any] = {} @@ -90,6 +92,7 @@ def __init__(self) -> None: msgpack: Any | EmptyType parsed_query: tuple[tuple[str, str], ...] | EmptyType response_compressed: bool | EmptyType + response_started: bool session_id: str | None | EmptyType url: URL | EmptyType _compat_ns: dict[str, Any] diff --git a/tests/unit/test_middleware/test_exception_handler_middleware.py b/tests/unit/test_middleware/test_exception_handler_middleware.py index e25752a2b6..cfa10fd7e8 100644 --- a/tests/unit/test_middleware/test_exception_handler_middleware.py +++ b/tests/unit/test_middleware/test_exception_handler_middleware.py @@ -9,7 +9,7 @@ from structlog.testing import capture_logs from litestar import Litestar, MediaType, Request, Response, get -from litestar.exceptions import HTTPException, InternalServerException, ValidationException +from litestar.exceptions import HTTPException, InternalServerException, LitestarException, ValidationException from litestar.exceptions.responses._debug_response import get_symbol_name from litestar.logging.config import LoggingConfig, StructLoggingConfig from litestar.middleware._internal.exceptions.middleware import ( @@ -20,7 +20,7 @@ from litestar.status_codes import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR from litestar.testing import TestClient, create_test_client from litestar.types import ExceptionHandlersMap -from litestar.types.asgi_types import HTTPScope +from litestar.types.asgi_types import HTTPReceiveMessage, HTTPScope, Message, Receive, Scope, Send from litestar.utils.scope.state import ScopeState from tests.helpers import cleanup_logging_impl @@ -400,3 +400,29 @@ def handler() -> None: with create_test_client([handler], type_encoders={Foo: lambda f: f.value}) as client: res = client.get("/") assert res.json()["extra"] == {"foo": "bar"} + + +async def test_exception_handler_middleware_response_already_started(scope: HTTPScope) -> None: + assert not ScopeState.from_scope(scope).response_started + + async def mock_receive() -> HTTPReceiveMessage: # type: ignore[empty-body] + pass + + mock = MagicMock() + + async def mock_send(message: Message) -> None: + mock(message) + + start_message: Message = {"type": "http.response.start", "status": 200, "headers": []} + + async def asgi_app(scope: Scope, receive: Receive, send: Send) -> None: + await send(start_message) + raise RuntimeError("Test exception") + + mw = ExceptionHandlerMiddleware(asgi_app, None) + + with pytest.raises(LitestarException): + await mw(scope, mock_receive, mock_send) + + mock.assert_called_once_with(start_message) + assert ScopeState.from_scope(scope).response_started