Skip to content

Commit

Permalink
refactor: exception handler middleware optimizations (#3389)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterschutt authored Apr 26, 2024
1 parent 7814d59 commit 5f4d627
Show file tree
Hide file tree
Showing 43 changed files with 1,087 additions and 581 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
- id: check-case-conflict
- id: check-toml
- id: debug-statements
exclude: "litestar/middleware/exceptions/middleware.py"
exclude: "litestar/middleware/_internal/exceptions/middleware.py"
- id: end-of-file-fixer
- id: mixed-line-ending
- id: trailing-whitespace
Expand Down
3 changes: 3 additions & 0 deletions docs/reference/exceptions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ exceptions

.. automodule:: litestar.exceptions
:members:

.. automodule:: litestar.exceptions.responses
:members:
22 changes: 16 additions & 6 deletions litestar/_asgi/asgi_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
from litestar._asgi.utils import get_route_handlers
from litestar.exceptions import ImproperlyConfiguredException
from litestar.utils import normalize_path

__all__ = ("ASGIRouter",)

from litestar.utils.scope.state import ScopeState

if TYPE_CHECKING:
from litestar._asgi.routing_trie.types import RouteTrieNode
Expand All @@ -24,6 +22,7 @@
from litestar.routes.base import BaseRoute
from litestar.types import (
ASGIApp,
ExceptionHandlersMap,
LifeSpanReceive,
LifeSpanSend,
LifeSpanShutdownCompleteEvent,
Expand All @@ -37,6 +36,8 @@
Send,
)

__all__ = ("ASGIRouter",)


class ASGIRouter:
"""Litestar ASGI router.
Expand All @@ -45,6 +46,7 @@ class ASGIRouter:
"""

__slots__ = (
"_app_exception_handlers",
"_mount_paths_regex",
"_mount_routes",
"_plain_routes",
Expand All @@ -62,6 +64,7 @@ def __init__(self, app: Litestar) -> None:
Args:
app: The Litestar app instance
"""
self._app_exception_handlers: ExceptionHandlersMap = app.exception_handlers
self._mount_paths_regex: Pattern | None = None
self._mount_routes: dict[str, RouteTrieNode] = {}
self._plain_routes: set[str] = set()
Expand All @@ -83,9 +86,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
path = path.split(root_path, maxsplit=1)[-1]
normalized_path = normalize_path(path)

asgi_app, scope["route_handler"], scope["path"], scope["path_params"] = self.handle_routing(
path=normalized_path, method=scope.get("method")
)
try:
asgi_app, route_handler, scope["path"], scope["path_params"] = self.handle_routing(
path=normalized_path, method=scope.get("method")
)
except Exception:
ScopeState.from_scope(scope).exception_handlers = self._app_exception_handlers
raise
else:
ScopeState.from_scope(scope).exception_handlers = route_handler.resolve_exception_handlers()
scope["route_handler"] = route_handler
await asgi_app(scope, receive, send)

@lru_cache(1024) # noqa: B019
Expand Down
49 changes: 26 additions & 23 deletions litestar/_asgi/routing_trie/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,33 +189,36 @@ def build_route_middleware_stack(
from litestar.middleware.response_cache import ResponseCacheMiddleware
from litestar.routes import HTTPRoute

# we wrap the route.handle method in the ExceptionHandlerMiddleware
asgi_handler = wrap_in_exception_handler(
app=route.handle, # type: ignore[arg-type]
exception_handlers=route_handler.resolve_exception_handlers(),
asgi_handler: ASGIApp = route.handle # type: ignore[assignment]
handler_middleware = route_handler.resolve_middleware()
has_cached_route = isinstance(route, HTTPRoute) and any(r.cache for r in route.route_handlers)
has_middleware = (
app.csrf_config or app.compression_config or has_cached_route or app.allowed_hosts or handler_middleware
)

if app.csrf_config:
asgi_handler = CSRFMiddleware(app=asgi_handler, config=app.csrf_config)
if has_middleware:
# If there is an exception raised from the handler, the first ExceptionHandlerMiddleware that catches the
# exception will create the response and call send(). As middleware may wrap the send() callable, we need there
# to be an instance of ExceptionHandlerMiddleware in between the handler and the middleware so that any send
# wrappers instated by middleware are called. If there is no middleware, we can skip this step.
asgi_handler = wrap_in_exception_handler(app=asgi_handler)

if app.compression_config:
asgi_handler = CompressionMiddleware(app=asgi_handler, config=app.compression_config)
if app.csrf_config:
asgi_handler = CSRFMiddleware(app=asgi_handler, config=app.csrf_config)

if isinstance(route, HTTPRoute) and any(r.cache for r in route.route_handlers):
asgi_handler = ResponseCacheMiddleware(app=asgi_handler, config=app.response_cache_config)
if app.compression_config:
asgi_handler = CompressionMiddleware(app=asgi_handler, config=app.compression_config)

if app.allowed_hosts:
asgi_handler = AllowedHostsMiddleware(app=asgi_handler, config=app.allowed_hosts)
if has_cached_route:
asgi_handler = ResponseCacheMiddleware(app=asgi_handler, config=app.response_cache_config)

for middleware in route_handler.resolve_middleware():
if hasattr(middleware, "__iter__"):
handler, kwargs = cast("tuple[Any, dict[str, Any]]", middleware)
asgi_handler = handler(app=asgi_handler, **kwargs)
else:
asgi_handler = middleware(app=asgi_handler) # type: ignore[call-arg]
if app.allowed_hosts:
asgi_handler = AllowedHostsMiddleware(app=asgi_handler, config=app.allowed_hosts)

# we wrap the entire stack again in ExceptionHandlerMiddleware
return wrap_in_exception_handler(
app=cast("ASGIApp", asgi_handler),
exception_handlers=route_handler.resolve_exception_handlers(),
) # pyright: ignore
for middleware in handler_middleware:
if hasattr(middleware, "__iter__"):
handler, kwargs = cast("tuple[Any, dict[str, Any]]", middleware)
asgi_handler = handler(app=asgi_handler, **kwargs)
else:
asgi_handler = middleware(app=asgi_handler) # type: ignore[call-arg]
return asgi_handler
14 changes: 6 additions & 8 deletions litestar/_asgi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,26 @@

from typing import TYPE_CHECKING, cast

__all__ = ("get_route_handlers", "wrap_in_exception_handler")


if TYPE_CHECKING:
from litestar.routes import ASGIRoute, HTTPRoute, WebSocketRoute
from litestar.routes.base import BaseRoute
from litestar.types import ASGIApp, ExceptionHandlersMap, RouteHandlerType
from litestar.types import ASGIApp, RouteHandlerType

__all__ = ("get_route_handlers", "wrap_in_exception_handler")


def wrap_in_exception_handler(app: ASGIApp, exception_handlers: ExceptionHandlersMap) -> ASGIApp:
def wrap_in_exception_handler(app: ASGIApp) -> ASGIApp:
"""Wrap the given ASGIApp in an instance of ExceptionHandlerMiddleware.
Args:
app: The ASGI app that is being wrapped.
exception_handlers: A mapping of exceptions to handler functions.
Returns:
A wrapped ASGIApp.
"""
from litestar.middleware.exceptions import ExceptionHandlerMiddleware
from litestar.middleware._internal.exceptions import ExceptionHandlerMiddleware

return ExceptionHandlerMiddleware(app=app, exception_handlers=exception_handlers, debug=None)
return ExceptionHandlerMiddleware(app=app, debug=None)


def get_route_handlers(route: BaseRoute) -> list[RouteHandlerType]:
Expand Down
12 changes: 5 additions & 7 deletions litestar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
NoRouteMatchFoundException,
)
from litestar.logging.config import LoggingConfig, get_logger_placeholder
from litestar.middleware._internal import CORSMiddleware
from litestar.middleware._internal.cors import CORSMiddleware
from litestar.openapi.config import OpenAPIConfig
from litestar.plugins import (
CLIPluginProtocol,
Expand Down Expand Up @@ -412,7 +412,6 @@ def __init__(
self.get_logger: GetLogger = get_logger_placeholder
self.logger: Logger | None = None
self.routes: list[HTTPRoute | ASGIRoute | WebSocketRoute] = []
self.asgi_router = ASGIRouter(app=self)

self.after_exception = [ensure_async_callable(h) for h in config.after_exception]
self.allowed_hosts = cast("AllowedHostsConfig | None", config.allowed_hosts)
Expand Down Expand Up @@ -442,7 +441,7 @@ def __init__(
try:
from starlette.exceptions import HTTPException as StarletteHTTPException

from litestar.middleware.exceptions.middleware import _starlette_exception_handler
from litestar.middleware._internal.exceptions.middleware import _starlette_exception_handler

config.exception_handlers.setdefault(StarletteHTTPException, _starlette_exception_handler)
except ImportError:
Expand Down Expand Up @@ -479,6 +478,8 @@ def __init__(
websocket_class=self.websocket_class,
)

self.asgi_router = ASGIRouter(app=self)

for route_handler in config.route_handlers:
self.register(route_handler)

Expand Down Expand Up @@ -839,10 +840,7 @@ def _create_asgi_handler(self) -> ASGIApp:
If CORS or TrustedHost configs are provided to the constructor, they will wrap the router as well.
"""
asgi_handler = wrap_in_exception_handler(
app=self.asgi_router,
exception_handlers=self.exception_handlers or {}, # pyright: ignore
)
asgi_handler = wrap_in_exception_handler(app=self.asgi_router)

if self.cors_config:
return CORSMiddleware(app=asgi_handler, config=self.cors_config)
Expand Down
6 changes: 4 additions & 2 deletions litestar/datastructures/state.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

from copy import copy, deepcopy
from copy import deepcopy
from threading import RLock
from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Iterator, Mapping, MutableMapping

from litestar.utils.scope.state import CONNECTION_STATE_KEY

if TYPE_CHECKING:
from typing_extensions import Self

Expand Down Expand Up @@ -143,7 +145,7 @@ def dict(self) -> dict[str, Any]:
Returns:
A dict
"""
return copy(self._state)
return {k: v for k, v in self._state.items() if k != CONNECTION_STATE_KEY}

@classmethod
def __get_validators__(
Expand Down
115 changes: 115 additions & 0 deletions litestar/exceptions/responses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations

from dataclasses import asdict, dataclass, field
from typing import Any

from litestar import MediaType, Request, Response
from litestar.exceptions import HTTPException, LitestarException
from litestar.exceptions.responses import _debug_response
from litestar.serialization import encode_json, get_serializer
from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR

__all__ = (
"ExceptionResponseContent",
"create_exception_response",
"create_debug_response",
)


@dataclass
class ExceptionResponseContent:
"""Represent the contents of an exception-response."""

status_code: int
"""Exception status code."""
detail: str
"""Exception details or message."""
media_type: MediaType | str
"""Media type of the response."""
headers: dict[str, str] | None = field(default=None)
"""Headers to attach to the response."""
extra: dict[str, Any] | list[Any] | None = field(default=None)
"""An extra mapping to attach to the exception."""

def to_response(self, request: Request | None = None) -> Response:
"""Create a response from the model attributes.
Returns:
A response instance.
"""
from litestar.response import Response

content: Any = {k: v for k, v in asdict(self).items() if k not in ("headers", "media_type") and v is not None}
type_encoders = _debug_response._get_type_encoders_for_request(request) if request is not None else None

if self.media_type != MediaType.JSON:
content = encode_json(content, get_serializer(type_encoders))

return Response(
content=content,
headers=self.headers,
status_code=self.status_code,
media_type=self.media_type,
type_encoders=type_encoders,
)


def create_exception_response(request: Request[Any, Any, Any], exc: Exception) -> Response:
"""Construct a response from an exception.
Notes:
- For instances of :class:`HTTPException <litestar.exceptions.HTTPException>` or other exception classes that have a
``status_code`` attribute (e.g. Starlette exceptions), the status code is drawn from the exception, otherwise
response status is ``HTTP_500_INTERNAL_SERVER_ERROR``.
Args:
request: The request that triggered the exception.
exc: An exception.
Returns:
Response: HTTP response constructed from exception details.
"""
headers: dict[str, Any] | None
extra: dict[str, Any] | list | None

if isinstance(exc, HTTPException):
status_code = exc.status_code
headers = exc.headers
extra = exc.extra
else:
status_code = HTTP_500_INTERNAL_SERVER_ERROR
headers = None
extra = None

detail = (
exc.detail
if isinstance(exc, LitestarException) and status_code != HTTP_500_INTERNAL_SERVER_ERROR
else "Internal Server Error"
)

try:
media_type = request.route_handler.media_type
except (KeyError, AttributeError):
media_type = MediaType.JSON

content = ExceptionResponseContent(
status_code=status_code,
detail=detail,
headers=headers,
extra=extra,
media_type=media_type,
)
return content.to_response(request=request)


def create_debug_response(request: Request, exc: Exception) -> Response:
"""Create a debug response from an exception.
Args:
request: The request that triggered the exception.
exc: An exception.
Returns:
Response: Debug response constructed from exception details.
"""
return _debug_response.create_debug_response(request, exc)
Loading

0 comments on commit 5f4d627

Please sign in to comment.