diff --git a/litestar/controller.py b/litestar/controller.py index 967454b168..3893acdf98 100644 --- a/litestar/controller.py +++ b/litestar/controller.py @@ -12,7 +12,7 @@ from litestar.handlers.http_handlers import HTTPRouteHandler from litestar.handlers.websocket_handlers import WebsocketRouteHandler from litestar.types.empty import Empty -from litestar.utils import ensure_async_callable, normalize_path +from litestar.utils import normalize_path from litestar.utils.signature import add_types_to_signature_namespace __all__ = ("Controller",) @@ -51,6 +51,7 @@ class Controller: "after_request", "after_response", "before_request", + "cache_control", "dependencies", "dto", "etag", @@ -69,6 +70,7 @@ class Controller: "return_dto", "security", "signature_namespace", + "signature_types", "tags", "type_encoders", "type_decoders", @@ -174,12 +176,11 @@ def __init__(self, owner: Router) -> None: Args: owner: An instance of :class:`Router <.router.Router>` """ - # Since functions set on classes are bound, we need replace the bound instance with the class version and wrap - # it to ensure it does not get bound. + # Since functions set on classes are bound, we need replace the bound instance with the class version for key in ("after_request", "after_response", "before_request"): cls_value = getattr(type(self), key, None) if callable(cls_value): - setattr(self, key, ensure_async_callable(cls_value)) + setattr(self, key, cls_value) if not hasattr(self, "dto"): self.dto = Empty @@ -203,6 +204,41 @@ def __init__(self, owner: Router) -> None: self.path = normalize_path(self.path or "/") self.owner = owner + def as_router(self) -> Router: + from litestar.router import Router + + router = Router( + path=self.path, + route_handlers=self.get_route_handlers(), + after_request=self.after_request, + after_response=self.after_response, + before_request=self.before_request, + cache_control=self.cache_control, + dependencies=self.dependencies, + dto=self.dto, + etag=self.etag, + exception_handlers=self.exception_handlers, + guards=self.guards, + include_in_schema=self.include_in_schema, + middleware=self.middleware, + opt=self.opt, + parameters=self.parameters, + request_class=self.request_class, + response_class=self.response_class, + response_cookies=self.response_cookies, + response_headers=self.response_headers, + return_dto=self.return_dto, + security=self.security, + signature_types=self.signature_types, + signature_namespace=self.signature_namespace, + tags=self.tags, + type_encoders=self.type_encoders, + type_decoders=self.type_decoders, + websocket_class=self.websocket_class, + ) + router.owner = self.owner + return router + def get_route_handlers(self) -> list[BaseRouteHandler]: """Get a controller's route handlers and set the controller as the handlers' owner. diff --git a/litestar/router.py b/litestar/router.py index 85346d8c1d..88ac0fd567 100644 --- a/litestar/router.py +++ b/litestar/router.py @@ -280,40 +280,25 @@ def route_handler_method_map(self) -> dict[str, RouteHandlerMapItem]: @classmethod def get_route_handler_map( cls, - value: Controller | RouteHandlerType | Router, + value: RouteHandlerType | Router, ) -> dict[str, RouteHandlerMapItem]: """Map route handlers to HTTP methods.""" if isinstance(value, Router): return value.route_handler_method_map - if isinstance(value, (HTTPRouteHandler, ASGIRouteHandler, WebsocketRouteHandler)): - copied_value = copy(value) - if isinstance(value, HTTPRouteHandler): - return {path: {http_method: copied_value for http_method in value.http_methods} for path in value.paths} - - return { - path: {"websocket" if isinstance(value, WebsocketRouteHandler) else "asgi": copied_value} - for path in value.paths - } - - handlers_map: defaultdict[str, RouteHandlerMapItem] = defaultdict(dict) - for route_handler in value.get_route_handlers(): - for handler_path in route_handler.paths: - path = join_paths([value.path, handler_path]) if handler_path else value.path - if isinstance(route_handler, HTTPRouteHandler): - for http_method in route_handler.http_methods: - handlers_map[path][http_method] = route_handler - else: - handlers_map[path]["websocket" if isinstance(route_handler, WebsocketRouteHandler) else "asgi"] = ( - cast("WebsocketRouteHandler | ASGIRouteHandler", route_handler) - ) + copied_value = copy(value) + if isinstance(value, HTTPRouteHandler): + return {path: {http_method: copied_value for http_method in value.http_methods} for path in value.paths} - return handlers_map + return { + path: {"websocket" if isinstance(value, WebsocketRouteHandler) else "asgi": copied_value} + for path in value.paths + } - def _validate_registration_value(self, value: ControllerRouterHandler) -> Controller | RouteHandlerType | Router: + def _validate_registration_value(self, value: ControllerRouterHandler) -> RouteHandlerType | Router: """Ensure values passed to the register method are supported.""" if is_class_and_subclass(value, Controller): - return value(owner=self) + return value(owner=self).as_router() # this narrows down to an ABC, but we assume a non-abstract subclass of the ABC superclass if is_class_and_subclass(value, WebsocketListener): diff --git a/tests/e2e/test_life_cycle_hooks/test_after_response.py b/tests/e2e/test_life_cycle_hooks/test_after_response.py index 39400c0b27..c9a7554d82 100644 --- a/tests/e2e/test_life_cycle_hooks/test_after_response.py +++ b/tests/e2e/test_life_cycle_hooks/test_after_response.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Dict +from unittest.mock import MagicMock, call import pytest @@ -6,44 +6,35 @@ from litestar.status_codes import HTTP_200_OK from litestar.testing import create_test_client -state: Dict[str, str] = {} -if TYPE_CHECKING: - from litestar.types import AfterResponseHookHandler +@pytest.mark.parametrize("sync", [True, False]) +@pytest.mark.parametrize("layer", ["app", "router", "controller", "handler"]) +def test_after_response_resolution(layer: str, sync: bool) -> None: + mock = MagicMock() + if sync: -def create_sync_test_handler(msg: str) -> "AfterResponseHookHandler": - def handler(_: Request) -> None: - state["msg"] = msg + def handler(_: Request) -> None: # pyright: ignore + mock(layer) - return handler + else: + async def handler(_: Request) -> None: # type: ignore[misc] + mock(layer) -def create_async_test_handler(msg: str) -> "AfterResponseHookHandler": - async def handler(_: Request) -> None: - state["msg"] = msg + class MyController(Controller): + path = "/controller" + after_response = handler if layer == "controller" else None - return handler + @get("/", after_response=handler if layer == "handler" else None) + def my_handler(self) -> None: + return None + router = Router( + path="/router", route_handlers=[MyController], after_response=handler if layer == "router" else None + ) -@pytest.mark.parametrize("layer", ["app", "router", "controller", "handler"]) -def test_after_response_resolution(layer: str) -> None: - for handler in (create_sync_test_handler(layer), create_async_test_handler(layer)): - state.pop("msg", None) - - class MyController(Controller): - path = "/controller" - after_response = handler if layer == "controller" else None - - @get("/", after_response=handler if layer == "handler" else None) - def my_handler(self) -> None: - return None - - router = Router( - path="/router", route_handlers=[MyController], after_response=handler if layer == "router" else None - ) - - with create_test_client(route_handlers=[router], after_response=handler if layer == "app" else None) as client: - response = client.get("/router/controller/") - assert response.status_code == HTTP_200_OK - assert state["msg"] == layer + with create_test_client(route_handlers=[router], after_response=handler if layer == "app" else None) as client: + response = client.get("/router/controller/") + assert response.status_code == HTTP_200_OK + assert all(c == call(layer) for c in mock.call_args_list)