Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow before_request and after_request handlers to accept a parent argument #3748

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions litestar/handlers/http_handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import functools
import inspect
from enum import Enum
from typing import TYPE_CHECKING, AnyStr, Mapping, Sequence, TypedDict, cast

Expand Down Expand Up @@ -62,6 +64,15 @@
__all__ = ("HTTPRouteHandler", "route")


def _wrap_layered_hooks(hooks: list[AsyncAnyCallable]) -> AsyncAnyCallable | None:
"""Given a list of callables, starting from the end, set the parent= keyword argument of each to default to the preceding hook should any preceding hook exist and should that argument be accepted."""
if not hooks:
return None
if "parent" in inspect.signature(hooks[-1]).parameters:
return functools.partial(hooks[-1], parent=_wrap_layered_hooks(hooks[:-1]))
return hooks[-1]


class ResponseHandlerMap(TypedDict):
default_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType
response_type_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType
Expand Down Expand Up @@ -260,9 +271,9 @@ def __init__(
)

self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore
self.after_response = ensure_async_callable(after_response) if after_response else None
self.after_response = ensure_async_callable(after_response) if after_response else None # pyright: ignore
self.background = background
self.before_request = ensure_async_callable(before_request) if before_request else None
self.before_request = ensure_async_callable(before_request) if before_request else None # pyright: ignore
self.cache = cache
self.cache_control = cache_control
self.cache_key_builder = cache_key_builder
Expand Down Expand Up @@ -400,7 +411,7 @@ def resolve_before_request(self) -> AsyncAnyCallable | None:
"""
if self._resolved_before_request is Empty:
before_request_handlers = [layer.before_request for layer in self.ownership_layers if layer.before_request]
self._resolved_before_request = before_request_handlers[-1] if before_request_handlers else None
self._resolved_before_request = _wrap_layered_hooks(before_request_handlers)
return cast("AsyncAnyCallable | None", self._resolved_before_request)

def resolve_after_response(self) -> AsyncAnyCallable | None:
Expand All @@ -418,7 +429,7 @@ def resolve_after_response(self) -> AsyncAnyCallable | None:
for layer in self.ownership_layers
if layer.after_response
]
self._resolved_after_response = after_response_handlers[-1] if after_response_handlers else None
self._resolved_after_response = _wrap_layered_hooks(after_response_handlers)

return cast("AsyncAnyCallable | None", self._resolved_after_response)

Expand Down
5 changes: 3 additions & 2 deletions litestar/handlers/http_handlers/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ def __init__(
:class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished.
Defaults to ``None``.
before_request: A sync or async function called immediately before calling the route handler. Receives
the :class:`.connection.Request` instance and any non-``None`` return value is used for the response,
bypassing the route handler.
the :class:`.connection.Request` instance (and, if it accepts a keyword argument named `parent`, the
outer scope's before_request handler if any exists). Any non-``None`` return value is used for the
response, bypassing the route handler.
cache: Enables response caching if configured on the application level. Valid values are ``True`` or a number
of seconds (e.g. ``120``) to cache the response.
cache_control: A ``cache-control`` header of type
Expand Down
4 changes: 2 additions & 2 deletions litestar/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def __init__(
"""

self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore
self.after_response = ensure_async_callable(after_response) if after_response else None
self.before_request = ensure_async_callable(before_request) if before_request else None
self.after_response = ensure_async_callable(after_response) if after_response else None # pyright: ignore
self.before_request = ensure_async_callable(before_request) if before_request else None # pyright: ignore
self.cache_control = cache_control
self.dto = dto
self.etag = etag
Expand Down
23 changes: 20 additions & 3 deletions litestar/types/callable_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, TypeVar
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, Protocol, TypeVar

if TYPE_CHECKING:
from typing_extensions import TypeAlias
Expand All @@ -23,12 +23,29 @@
AfterRequestHookHandler: TypeAlias = (
"Callable[[ASGIApp], SyncOrAsyncUnion[ASGIApp]] | Callable[[Response], SyncOrAsyncUnion[Response]]"
)
AfterResponseHookHandler: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]"

AfterResponseHookHandlerSimple: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]"


class AfterResponseHookHandlerWithParent(Protocol):
async def __call__(self, request: Request, /, *, parent: AfterResponseHookHandler | None = None) -> None: ...


AfterResponseHookHandler: TypeAlias = "AfterResponseHookHandlerSimple | AfterResponseHookHandlerWithParent"

AsyncAnyCallable: TypeAlias = Callable[..., Awaitable[Any]]
AnyCallable: TypeAlias = Callable[..., Any]
AnyGenerator: TypeAlias = "Generator[Any, Any, Any] | AsyncGenerator[Any, Any]"
BeforeMessageSendHookHandler: TypeAlias = "Callable[[Message, Scope], SyncOrAsyncUnion[None]]"
BeforeRequestHookHandler: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]"


class BeforeRequestHookHandlerWithParent(Protocol):
async def __call__(self, request: Request, /, *, parent: BeforeRequestHookHandler | None = None) -> Any: ...


BeforeRequestHookHandlerSimple: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]"
BeforeRequestHookHandler: TypeAlias = "BeforeRequestHookHandlerSimple | BeforeRequestHookHandlerWithParent"

CacheKeyBuilder: TypeAlias = "Callable[[Request], str]"
ExceptionHandler: TypeAlias = "Callable[[Request, ExceptionT], Response]"
ExceptionLoggingHandler: TypeAlias = "Callable[[Logger, Scope, list[str]], None]"
Expand Down
36 changes: 35 additions & 1 deletion tests/e2e/test_life_cycle_hooks/test_before_request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, Optional
import logging
from typing import Any, Dict, Optional, Union

import pytest

Expand All @@ -7,6 +8,18 @@
from litestar.testing import create_test_client
from litestar.types import AnyCallable, BeforeRequestHookHandler

logger = logging.getLogger(__name__)


async def async_before_request_handler_with_parent(
request: Request[Any, Any, State], /, *, parent: Optional[BeforeRequestHookHandler] = None
) -> Optional[Dict[str, Union[str, int]]]:
assert isinstance(request, Request)
retval: Dict[str, Union[str, int]] = (None if parent is None else await parent(request)) or {}
retval.setdefault("amended_count", 0)
retval["amended_count"] += 1 # type: ignore[operator]
return retval


def sync_before_request_handler_with_return_value(request: Request[Any, Any, State]) -> Dict[str, str]:
assert isinstance(request, Request)
Expand Down Expand Up @@ -88,6 +101,27 @@ def handler() -> Dict[str, str]:
{"hello": "world"},
],
[None, None, None, async_before_request_handler_without_return_value, {"hello": "world"}],
[
sync_before_request_handler_with_return_value,
None,
None,
async_before_request_handler_with_parent,
{"hello": "moon", "amended_count": 1},
],
[
sync_before_request_handler_with_return_value,
None,
async_before_request_handler_with_parent,
async_before_request_handler_with_parent,
{"hello": "moon", "amended_count": 2},
],
[
sync_before_request_handler_with_return_value,
async_before_request_handler_with_parent,
async_before_request_handler_with_parent,
async_before_request_handler_with_parent,
{"hello": "moon", "amended_count": 3},
],
],
)
def test_before_request_handler_resolution(
Expand Down
Loading