From f4f8dc6d267ad5f0ccefbe57ef131420e0311908 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= <25355197+provinzkraut@users.noreply.github.com> Date: Sat, 14 Dec 2024 17:23:08 +0100 Subject: [PATCH] remove resolve_websocket_class --- .../handlers/websocket_handlers/listener.py | 3 +-- .../websocket_handlers/route_handler.py | 19 ++++++++++--------- litestar/routes/websocket.py | 2 +- tests/unit/test_websocket_class_resolution.py | 4 ++-- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/litestar/handlers/websocket_handlers/listener.py b/litestar/handlers/websocket_handlers/listener.py index 6f71aebafa..19749f40ba 100644 --- a/litestar/handlers/websocket_handlers/listener.py +++ b/litestar/handlers/websocket_handlers/listener.py @@ -202,7 +202,6 @@ def __init__( self.on_disconnect = ensure_async_callable(on_disconnect) if on_disconnect else None self.type_decoders = type_decoders self.type_encoders = type_encoders - self.websocket_class = websocket_class listener_dependencies = dict(dependencies or {}) @@ -251,7 +250,7 @@ def merge(self, other: Controller | Router) -> WebsocketListenerRouteHandler: signature_types=getattr(other, "signature_types", None), type_decoders=(*(other.type_decoders or ()), *self.type_decoders), type_encoders={**(other.type_encoders or {}), **self.type_encoders}, - websocket_class=self.websocket_class or other.websocket_class, + websocket_class=self._websocket_class or other.websocket_class, parameters={**other.parameters, **self.parameters}, receive_mode=self._receive_mode, send_mode=self._send_mode, diff --git a/litestar/handlers/websocket_handlers/route_handler.py b/litestar/handlers/websocket_handlers/route_handler.py index 0664f8a259..506fa3da44 100644 --- a/litestar/handlers/websocket_handlers/route_handler.py +++ b/litestar/handlers/websocket_handlers/route_handler.py @@ -5,12 +5,11 @@ from litestar.connection import WebSocket from litestar.exceptions import ImproperlyConfiguredException from litestar.handlers import BaseRouteHandler -from litestar.plugins import PluginRegistry from litestar.types import AsyncAnyCallable from litestar.types import Empty from litestar.types import ParametersMap from litestar.types.builtin_types import NoneType -from litestar.utils import join_paths +from litestar.utils import join_paths, deprecated from litestar.utils.empty import value_or_default from litestar.utils.predicates import is_async_callable from litestar.utils.signature import merge_signature_namespaces @@ -24,7 +23,7 @@ class WebsocketRouteHandler(BaseRouteHandler): - __slots__ = ("_kwargs_model", "websocket_class") + __slots__ = ("_kwargs_model", "_websocket_class") def __init__( self, @@ -64,7 +63,7 @@ def __init__( default websocket class. **kwargs: Any additional kwarg - will be set in the opt dictionary. """ - self.websocket_class = websocket_class + self._websocket_class = websocket_class self._kwargs_model: KwargsModel | EmptyType = Empty super().__init__( @@ -97,10 +96,11 @@ def merge(self, other: Controller | Router) -> WebsocketRouteHandler: signature_types=getattr(other, "signature_types", None), type_decoders=(*(other.type_decoders or ()), *self.type_decoders), type_encoders={**(other.type_encoders or {}), **self.type_encoders}, - websocket_class=self.websocket_class or other.websocket_class, + websocket_class=self._websocket_class or other.websocket_class, parameters={**other.parameters, **self.parameters}, ) + @deprecated("3.0", removal_in="4.0", alternative=".websocket_class property") def resolve_websocket_class(self) -> type[WebSocket]: """Return the closest custom WebSocket class in the owner graph or the default Websocket class. @@ -109,10 +109,11 @@ def resolve_websocket_class(self) -> type[WebSocket]: Returns: The default :class:`WebSocket <.connection.WebSocket>` class for the route handler. """ - return next( - (layer.websocket_class for layer in reversed(self._ownership_layers) if layer.websocket_class is not None), - WebSocket, - ) + return self.websocket_class + + @property + def websocket_class(self) -> type[WebSocket]: + return self._websocket_class or WebSocket def _validate_handler_function(self, app: Litestar | None = None) -> None: """Validate the route handler function once it's set by inspecting its return annotations.""" diff --git a/litestar/routes/websocket.py b/litestar/routes/websocket.py index 477c1da208..5bca9a4674 100644 --- a/litestar/routes/websocket.py +++ b/litestar/routes/websocket.py @@ -42,5 +42,5 @@ async def handle(self, scope: WebSocketScope, receive: Receive, send: Send) -> N Returns: None """ - socket = self.route_handler.resolve_websocket_class()(scope=scope, receive=receive, send=send) + socket = self.route_handler.websocket_class(scope=scope, receive=receive, send=send) await self.route_handler.handle(connection=socket) diff --git a/tests/unit/test_websocket_class_resolution.py b/tests/unit/test_websocket_class_resolution.py index 98d565c43b..887ce7494b 100644 --- a/tests/unit/test_websocket_class_resolution.py +++ b/tests/unit/test_websocket_class_resolution.py @@ -54,7 +54,7 @@ def handler(self, data: str) -> None: route_handler = app.routes[0].route_handler # type: ignore[union-attr] - websocket_class = route_handler.resolve_websocket_class() # type: ignore[union-attr] + websocket_class = route_handler.websocket_class # type: ignore[union-attr] assert websocket_class is expected @@ -93,5 +93,5 @@ def on_receive(self, data: str) -> str: # pyright: ignore app = Litestar(route_handlers=[router], websocket_class=app_websocket_class) route_handler = app.routes[0].route_handler # type: ignore[union-attr] - websocket_class = route_handler.resolve_websocket_class() # type: ignore[union-attr] + websocket_class = route_handler.websocket_class # type: ignore[union-attr] assert websocket_class is expected