Skip to content

Commit

Permalink
refactor: Remove special casing of controllers during registration (#…
Browse files Browse the repository at this point in the history
…3527)

* Convert controllers to routers on registration
  • Loading branch information
provinzkraut authored May 26, 2024
1 parent b0c8f02 commit 057f813
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 62 deletions.
44 changes: 40 additions & 4 deletions litestar/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from litestar.handlers.http_handlers import HTTPRouteHandler
from litestar.handlers.websocket_handlers import WebsocketRouteHandler
from litestar.types.empty import Empty
from litestar.utils import ensure_async_callable, normalize_path
from litestar.utils import normalize_path
from litestar.utils.signature import add_types_to_signature_namespace

__all__ = ("Controller",)
Expand Down Expand Up @@ -51,6 +51,7 @@ class Controller:
"after_request",
"after_response",
"before_request",
"cache_control",
"dependencies",
"dto",
"etag",
Expand All @@ -69,6 +70,7 @@ class Controller:
"return_dto",
"security",
"signature_namespace",
"signature_types",
"tags",
"type_encoders",
"type_decoders",
Expand Down Expand Up @@ -174,12 +176,11 @@ def __init__(self, owner: Router) -> None:
Args:
owner: An instance of :class:`Router <.router.Router>`
"""
# Since functions set on classes are bound, we need replace the bound instance with the class version and wrap
# it to ensure it does not get bound.
# Since functions set on classes are bound, we need replace the bound instance with the class version
for key in ("after_request", "after_response", "before_request"):
cls_value = getattr(type(self), key, None)
if callable(cls_value):
setattr(self, key, ensure_async_callable(cls_value))
setattr(self, key, cls_value)

if not hasattr(self, "dto"):
self.dto = Empty
Expand All @@ -203,6 +204,41 @@ def __init__(self, owner: Router) -> None:
self.path = normalize_path(self.path or "/")
self.owner = owner

def as_router(self) -> Router:
from litestar.router import Router

router = Router(
path=self.path,
route_handlers=self.get_route_handlers(),
after_request=self.after_request,
after_response=self.after_response,
before_request=self.before_request,
cache_control=self.cache_control,
dependencies=self.dependencies,
dto=self.dto,
etag=self.etag,
exception_handlers=self.exception_handlers,
guards=self.guards,
include_in_schema=self.include_in_schema,
middleware=self.middleware,
opt=self.opt,
parameters=self.parameters,
request_class=self.request_class,
response_class=self.response_class,
response_cookies=self.response_cookies,
response_headers=self.response_headers,
return_dto=self.return_dto,
security=self.security,
signature_types=self.signature_types,
signature_namespace=self.signature_namespace,
tags=self.tags,
type_encoders=self.type_encoders,
type_decoders=self.type_decoders,
websocket_class=self.websocket_class,
)
router.owner = self.owner
return router

def get_route_handlers(self) -> list[BaseRouteHandler]:
"""Get a controller's route handlers and set the controller as the handlers' owner.
Expand Down
35 changes: 10 additions & 25 deletions litestar/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,40 +280,25 @@ def route_handler_method_map(self) -> dict[str, RouteHandlerMapItem]:
@classmethod
def get_route_handler_map(
cls,
value: Controller | RouteHandlerType | Router,
value: RouteHandlerType | Router,
) -> dict[str, RouteHandlerMapItem]:
"""Map route handlers to HTTP methods."""
if isinstance(value, Router):
return value.route_handler_method_map

if isinstance(value, (HTTPRouteHandler, ASGIRouteHandler, WebsocketRouteHandler)):
copied_value = copy(value)
if isinstance(value, HTTPRouteHandler):
return {path: {http_method: copied_value for http_method in value.http_methods} for path in value.paths}

return {
path: {"websocket" if isinstance(value, WebsocketRouteHandler) else "asgi": copied_value}
for path in value.paths
}

handlers_map: defaultdict[str, RouteHandlerMapItem] = defaultdict(dict)
for route_handler in value.get_route_handlers():
for handler_path in route_handler.paths:
path = join_paths([value.path, handler_path]) if handler_path else value.path
if isinstance(route_handler, HTTPRouteHandler):
for http_method in route_handler.http_methods:
handlers_map[path][http_method] = route_handler
else:
handlers_map[path]["websocket" if isinstance(route_handler, WebsocketRouteHandler) else "asgi"] = (
cast("WebsocketRouteHandler | ASGIRouteHandler", route_handler)
)
copied_value = copy(value)
if isinstance(value, HTTPRouteHandler):
return {path: {http_method: copied_value for http_method in value.http_methods} for path in value.paths}

return handlers_map
return {
path: {"websocket" if isinstance(value, WebsocketRouteHandler) else "asgi": copied_value}
for path in value.paths
}

def _validate_registration_value(self, value: ControllerRouterHandler) -> Controller | RouteHandlerType | Router:
def _validate_registration_value(self, value: ControllerRouterHandler) -> RouteHandlerType | Router:
"""Ensure values passed to the register method are supported."""
if is_class_and_subclass(value, Controller):
return value(owner=self)
return value(owner=self).as_router()

# this narrows down to an ABC, but we assume a non-abstract subclass of the ABC superclass
if is_class_and_subclass(value, WebsocketListener):
Expand Down
57 changes: 24 additions & 33 deletions tests/e2e/test_life_cycle_hooks/test_after_response.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,40 @@
from typing import TYPE_CHECKING, Dict
from unittest.mock import MagicMock, call

import pytest

from litestar import Controller, Request, Router, get
from litestar.status_codes import HTTP_200_OK
from litestar.testing import create_test_client

state: Dict[str, str] = {}

if TYPE_CHECKING:
from litestar.types import AfterResponseHookHandler
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize("layer", ["app", "router", "controller", "handler"])
def test_after_response_resolution(layer: str, sync: bool) -> None:
mock = MagicMock()

if sync:

def create_sync_test_handler(msg: str) -> "AfterResponseHookHandler":
def handler(_: Request) -> None:
state["msg"] = msg
def handler(_: Request) -> None: # pyright: ignore
mock(layer)

return handler
else:

async def handler(_: Request) -> None: # type: ignore[misc]
mock(layer)

def create_async_test_handler(msg: str) -> "AfterResponseHookHandler":
async def handler(_: Request) -> None:
state["msg"] = msg
class MyController(Controller):
path = "/controller"
after_response = handler if layer == "controller" else None

return handler
@get("/", after_response=handler if layer == "handler" else None)
def my_handler(self) -> None:
return None

router = Router(
path="/router", route_handlers=[MyController], after_response=handler if layer == "router" else None
)

@pytest.mark.parametrize("layer", ["app", "router", "controller", "handler"])
def test_after_response_resolution(layer: str) -> None:
for handler in (create_sync_test_handler(layer), create_async_test_handler(layer)):
state.pop("msg", None)

class MyController(Controller):
path = "/controller"
after_response = handler if layer == "controller" else None

@get("/", after_response=handler if layer == "handler" else None)
def my_handler(self) -> None:
return None

router = Router(
path="/router", route_handlers=[MyController], after_response=handler if layer == "router" else None
)

with create_test_client(route_handlers=[router], after_response=handler if layer == "app" else None) as client:
response = client.get("/router/controller/")
assert response.status_code == HTTP_200_OK
assert state["msg"] == layer
with create_test_client(route_handlers=[router], after_response=handler if layer == "app" else None) as client:
response = client.get("/router/controller/")
assert response.status_code == HTTP_200_OK
assert all(c == call(layer) for c in mock.call_args_list)

0 comments on commit 057f813

Please sign in to comment.