diff --git a/litestar/handlers/websocket_handlers/listener.py b/litestar/handlers/websocket_handlers/listener.py index 8e702ea1aa..e4a2df7825 100644 --- a/litestar/handlers/websocket_handlers/listener.py +++ b/litestar/handlers/websocket_handlers/listener.py @@ -335,10 +335,6 @@ class WebsocketListener(ABC): """A sequence of :class:`Guard <.types.Guard>` callables.""" middleware: list[Middleware] | None = None """A sequence of :class:`Middleware <.types.Middleware>`.""" - on_accept: AnyCallable | None = None - """Called after a :class:`WebSocket <.connection.WebSocket>` connection has been accepted. Can receive any dependencies""" - on_disconnect: AnyCallable | None = None - """Called after a :class:`WebSocket <.connection.WebSocket>` connection has been disconnected. Can receive any dependencies""" receive_mode: WebSocketMode = "text" """:class:`WebSocket <.connection.WebSocket>` mode to receive data in, either ``text`` or ``binary``.""" send_mode: WebSocketMode = "text" @@ -380,6 +376,9 @@ def __init__(self, owner: Router) -> None: self._owner = owner def to_handler(self) -> WebsocketListenerRouteHandler: + on_accept = self.on_accept if self.on_accept != WebsocketListener.on_accept else None + on_disconnect = self.on_disconnect if self.on_disconnect != WebsocketListener.on_disconnect else None + handler = WebsocketListenerRouteHandler( dependencies=self.dependencies, dto=self.dto, @@ -389,8 +388,8 @@ def to_handler(self) -> WebsocketListenerRouteHandler: send_mode=self.send_mode, receive_mode=self.receive_mode, name=self.name, - on_accept=self.on_accept, - on_disconnect=self.on_disconnect, + on_accept=on_accept, + on_disconnect=on_disconnect, opt=self.opt, path=self.path, return_dto=self.return_dto, @@ -402,6 +401,16 @@ def to_handler(self) -> WebsocketListenerRouteHandler: handler.owner = self._owner return handler + def on_accept(self, *args: Any, **kwargs: Any) -> Any: + """Called after a :class:`WebSocket <.connection.WebSocket>` connection + has been accepted. Can receive any dependencies + """ + + def on_disconnect(self, *args: Any, **kwargs: Any) -> Any: + """Called after a :class:`WebSocket <.connection.WebSocket>` connection + has been disconnected. Can receive any dependencies + """ + @abstractmethod def on_receive(self, *args: Any, **kwargs: Any) -> Any: """Called after data has been received from the WebSocket. diff --git a/pyproject.toml b/pyproject.toml index fa0149f0e1..27eb2f0df8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -460,6 +460,7 @@ known-first-party = ["litestar", "tests", "examples"] "litestar/_openapi/schema_generation/schema.py" = ["C901"] "litestar/exceptions/*.*" = ["N818"] "litestar/handlers/**/*.*" = ["N801"] +"litestar/handlers/websocket_handlers/listener.py" = ["B027"] "litestar/params.py" = ["N802"] "test_apps/**/*.*" = ["D", "TRY", "EM", "S", "PTH"] "tests/**/*.*" = [ diff --git a/tests/unit/test_handlers/test_websocket_handlers/test_listeners.py b/tests/unit/test_handlers/test_websocket_handlers/test_listeners.py index 08c74690d4..f6afec0a2f 100644 --- a/tests/unit/test_handlers/test_websocket_handlers/test_listeners.py +++ b/tests/unit/test_handlers/test_websocket_handlers/test_listeners.py @@ -394,10 +394,10 @@ def some_dependency() -> str: class Listener(WebsocketListener): path = "/{name: str}" - def on_accept(self, name: str, state: State, query: dict, some: str) -> None: # type: ignore[override] + def on_accept(self, name: str, state: State, query: dict, some: str) -> None: # pyright: ignore on_accept_mock(name=name, state=state, query=query, some=some) - def on_disconnect(self, name: str, state: State, query: dict, some: str) -> None: # type: ignore[override] + def on_disconnect(self, name: str, state: State, query: dict, some: str) -> None: # pyright: ignore on_disconnect_mock(name=name, state=state, query=query, some=some) def on_receive(self, data: bytes) -> None: # pyright: ignore