diff --git a/litestar/connection/websocket.py b/litestar/connection/websocket.py index b849a56b06..0c7bc04404 100644 --- a/litestar/connection/websocket.py +++ b/litestar/connection/websocket.py @@ -341,7 +341,3 @@ async def send_msgpack( None """ await self.send_data(data=encode_msgpack(data, serializer), mode="binary", encoding=encoding) - - async def send_stream(self, stream: AsyncGenerator[str | bytes, Any], mode: WebSocketMode) -> None: - async for event in stream: - await self.send_data(event, mode=mode) diff --git a/litestar/handlers/websocket_handlers/stream.py b/litestar/handlers/websocket_handlers/stream.py index 1e99a90a35..fe3fb82997 100644 --- a/litestar/handlers/websocket_handlers/stream.py +++ b/litestar/handlers/websocket_handlers/stream.py @@ -1,18 +1,25 @@ from __future__ import annotations +import dataclasses import functools -import inspect import warnings -from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Mapping +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Mapping, cast import anyio +from msgspec.json import Encoder as JsonEncoder +from typing_extensions import Self -from litestar.exceptions import LitestarWarning, WebSocketDisconnect +from litestar.exceptions import ImproperlyConfiguredException, LitestarWarning, WebSocketDisconnect from litestar.handlers.websocket_handlers.route_handler import WebsocketRouteHandler +from litestar.types import Empty +from litestar.types.builtin_types import NoneType +from litestar.typing import FieldDefinition +from litestar.utils.signature import ParsedSignature if TYPE_CHECKING: - from litestar import WebSocket - from litestar.types import Dependencies, ExceptionHandler, Guard, Middleware + from litestar import Litestar, WebSocket + from litestar.dto import AbstractDTO + from litestar.types import Dependencies, EmptyType, ExceptionHandler, Guard, Middleware, TypeEncodersMap from litestar.types.asgi_types import WebSocketMode @@ -21,11 +28,11 @@ async def send_websocket_stream( stream: AsyncGenerator[str | bytes, Any], close: bool = True, mode: WebSocketMode = "text", + send_handler: Callable[[WebSocket, Any], Awaitable[Any]] | None = None, listen_for_disconnect: bool = False, warn_on_data_discard: bool = True, ) -> None: - """ - Stream data to the ``socket`` from an asynchronous generator. + """Stream data to the ``socket`` from an asynchronous generator. Example: Sending the current time to the connected client every 0.5 seconds: @@ -52,7 +59,8 @@ async def time_handler(socket: WebSocket) -> None: socket: The :class:`~litestar.connection.WebSocket` to send to stream: An asynchronous generator yielding data to send close: If ``True``, close the socket after the generator is exhausted - mode: WebSocket mode to use for sending + mode: WebSocket mode to use for sending when no ``send_handler`` is specified + send_handler: Callable to handle the send process. If ``None``, defaults to ``type(socket).send_data`` listen_for_disconnect: If ``True``, listen for client disconnects in the background. If a client disconnects, stop the generator and cancel sending data. Should always be ``True`` unless disconnects are handled elsewhere, for example by reading data from the socket concurrently. Should never be set to ``True`` when @@ -60,11 +68,18 @@ async def time_handler(socket: WebSocket) -> None: warn_on_data_discard: If ``True`` and ``listen_for_disconnect=True``, warn if during listening for client disconnects, data is received from the socket """ + if send_handler is None: + send_handler = functools.partial(type(socket).send_data, mode=mode) + + async def send_stream() -> None: + async for event in stream: + await send_handler(socket, event) + if listen_for_disconnect: # wrap 'send_stream' and disconnect listener, so they'll cancel the other once # one of the finishes async def wrapped_stream() -> None: - await socket.send_stream(stream, mode=mode) + await send_stream() # stream exhausted, we can stop listening for a disconnect tg.cancel_scope.cancel() @@ -96,7 +111,7 @@ async def disconnect_listener() -> None: tg.start_soon(disconnect_listener) else: - await socket.send_stream(stream=stream, mode=mode) + await send_stream() if close and socket.connection_state != "disconnect": await socket.close() @@ -114,12 +129,13 @@ def websocket_stream( signature_namespace: Mapping[str, Any] | None = None, websocket_class: type[WebSocket] | None = None, mode: WebSocketMode = "text", + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + type_encoders: TypeEncodersMap | None = None, listen_for_disconnect: bool = True, warn_on_data_discard: bool = True, **kwargs: Any, -) -> Callable[[Callable[..., AsyncGenerator[str | bytes, Any]]], WebsocketRouteHandler]: - """ - Create a WebSocket handler that accepts a connection and sends data to it from an +) -> Callable[[Callable[..., AsyncGenerator[Any, Any]]], WebsocketRouteHandler]: + """Create a WebSocket handler that accepts a connection and sends data to it from an async generator. Example: @@ -147,7 +163,10 @@ async def send_time() -> AsyncGenerator[str, None]: websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's default websocket class. mode: WebSocket mode used for sending - listen_for_disconnect: If ``True``, listen for client disconnects in the background. If a client disconnects, + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing outbound response data. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + listen_for_disconnect: If ``True``, listen for client + disconnects in the background. If a client disconnects, stop the generator and cancel sending data. Should always be ``True`` unless disconnects are handled elsewhere, for example by reading data from the socket concurrently. Should never be set to ``True`` when reading data from socket concurrently, as it can lead to data loss @@ -156,53 +175,132 @@ async def send_time() -> AsyncGenerator[str, None]: **kwargs: Any additional kwarg - will be set in the opt dictionary. """ - def decorator(fn: Callable[..., AsyncGenerator[str | bytes, Any]]) -> WebsocketRouteHandler: - signature = inspect.signature(fn) - generator_receives_socket = "socket" in signature.parameters + def decorator(fn: Callable[..., AsyncGenerator[Any, Any]]) -> WebsocketRouteHandler: + return WebSocketStreamHandler( + path=path, + dependencies=dependencies, + exception_handlers=exception_handlers, + guard=guards, + middleware=middleware, + name=name, + opt=opt, + signature_namespace=signature_namespace, + websocket_class=websocket_class, + return_dto=return_dto, + type_encoders=type_encoders, + **kwargs, + )( + _WebSocketStreamOptions( + generator_fn=fn, + send_mode=mode, + listen_for_disconnect=listen_for_disconnect, + warn_on_data_discard=warn_on_data_discard, + ) + ) + + return decorator + + +class WebSocketStreamHandler(WebsocketRouteHandler): + __slots__ = ("_ws_stream_options",) + _ws_stream_options: _WebSocketStreamOptions + + def __call__(self, fn: _WebSocketStreamOptions) -> Self: # type: ignore[override] + self._ws_stream_options = fn + self._fn = self._ws_stream_options.generator_fn # type: ignore[assignment] + return self + + def on_registration(self, app: Litestar) -> None: + parsed_handler_signature = parsed_stream_fn_signature = ParsedSignature.from_fn( + self.fn, self.resolve_signature_namespace() + ) + + if not parsed_stream_fn_signature.return_type.is_subclass_of(AsyncGenerator): + raise ImproperlyConfiguredException( + f"Route handler {self}: 'websocket_stream' handlers must return an " + f"'AsyncGenerator', not {type(parsed_stream_fn_signature.return_type.raw)!r}" + ) + + # important not to use 'self._ws_stream_options.generator_fn' here; This would + # break in cases the decorator has been used inside a controller, as it would + # be a reference to the unbound method. The bound method is patched in later + # after the controller has been initialized. This is a workaround that should + # go away with v3.0's static handlers + stream_fn = cast(Callable[..., AsyncGenerator[Any, Any]], self.fn) + + # construct a fake signature for the kwargs modelling, using the generator + # function passed to the handler as a base, to include all the dependencies, + # params, injection kwargs, etc. + 'socket', so DI works properly, but the + # signature looks to kwargs/signature modelling like a plain '@websocket' + # handler that returns 'None' + parsed_handler_signature = dataclasses.replace( + parsed_handler_signature, return_type=FieldDefinition.from_annotation(NoneType) + ) + receives_socket_parameter = "socket" in parsed_stream_fn_signature.parameters + + if not receives_socket_parameter: + parsed_handler_signature = dataclasses.replace( + parsed_handler_signature, + parameters={ + **parsed_handler_signature.parameters, + "socket": FieldDefinition.from_annotation("WebSocket", name="socket"), + }, + ) + + self._parsed_fn_signature = parsed_handler_signature + self._parsed_return_field = parsed_stream_fn_signature.return_type.inner_types[0] - @functools.wraps(fn) - async def handler(*args: Any, socket: WebSocket, **kw: Any) -> None: - if generator_receives_socket: + json_encoder = JsonEncoder(enc_hook=self.default_serializer) + return_dto = self.resolve_return_dto() + + # make sure the closure doesn't capture self._ws_stream / self + send_mode = self._ws_stream_options.send_mode + listen_for_disconnect = self._ws_stream_options.listen_for_disconnect + warn_on_data_discard = self._ws_stream_options.warn_on_data_discard + + async def send_handler(socket: WebSocket, data: Any) -> None: + if isinstance(data, (str, bytes)): + await socket.send_data(data=data, mode=send_mode) + return + + if return_dto: + encoded_data = return_dto(socket).data_to_encodable_type(data) + data = json_encoder.encode(encoded_data) + await socket.send_data(data=data, mode=send_mode) + return + + data = json_encoder.encode(data) + await socket.send_data(data=data, mode=send_mode) + + @functools.wraps(stream_fn) + async def handler_fn(*args: Any, socket: WebSocket, **kw: Any) -> None: + if receives_socket_parameter: kw["socket"] = socket await send_websocket_stream( socket=socket, - stream=fn(*args, **kw), - mode=mode, + stream=stream_fn(*args, **kw), + mode=send_mode, close=True, listen_for_disconnect=listen_for_disconnect, warn_on_data_discard=warn_on_data_discard, + send_handler=send_handler, ) - handler.__annotations__ = fn.__annotations__ - handler.__annotations__["return"] = None - - if not generator_receives_socket: - handler.__annotations__["socket"] = "WebSocket" + self._fn = handler_fn - new_signature = signature - if not generator_receives_socket: - new_signature = new_signature.replace( - parameters=[ - *signature.parameters.values(), - inspect.Parameter(name="socket", annotation="WebSocket", kind=inspect.Parameter.KEYWORD_ONLY), - ], - return_annotation=None, - ) - - handler.__signature__ = new_signature # type: ignore[attr-defined] + super().on_registration(app) - return WebsocketRouteHandler( - path=path, - dependencies=dependencies, - exception_handlers=exception_handlers, - guard=guards, - middleware=middleware, - name=name, - oprt=opt, - signature_namespace=signature_namespace, - websocket_class=websocket_class, - **kwargs, - )(handler) - return decorator +class _WebSocketStreamOptions: + def __init__( + self, + generator_fn: Callable[..., AsyncGenerator[Any, Any]], + listen_for_disconnect: bool, + warn_on_data_discard: bool, + send_mode: WebSocketMode, + ) -> None: + self.generator_fn = generator_fn + self.listen_for_disconnect = listen_for_disconnect + self.warn_on_data_discard = warn_on_data_discard + self.send_mode = send_mode diff --git a/tests/unit/test_handlers/test_websocket_handlers/test_stream.py b/tests/unit/test_handlers/test_websocket_handlers/test_stream.py index b8220b53ce..060964e9a5 100644 --- a/tests/unit/test_handlers/test_websocket_handlers/test_stream.py +++ b/tests/unit/test_handlers/test_websocket_handlers/test_stream.py @@ -1,8 +1,15 @@ +from __future__ import annotations + import asyncio -from typing import AsyncGenerator +import dataclasses +from typing import AsyncGenerator, Generator from unittest.mock import MagicMock -from litestar import WebSocket, Controller +import pytest + +from litestar import Controller, Litestar, WebSocket +from litestar.dto import DataclassDTO, dto_field +from litestar.exceptions import ImproperlyConfiguredException from litestar.handlers.websocket_handlers import websocket_stream from litestar.testing import create_test_client @@ -100,3 +107,46 @@ async def handler() -> AsyncGenerator[str, None]: # ensure we still disconnect even after receiving some data ws.send_text("") assert ws.receive_text(timeout=0.1) == "foo" + + +def test_websocket_stream_send_json() -> None: + @websocket_stream("/") + async def handler() -> AsyncGenerator[dict[str, str], None]: + yield {"hello": "there"} + yield {"and": "goodbye"} + + with create_test_client([handler]) as client, client.websocket_connect("/") as ws: + assert ws.receive_json(timeout=0.1) == {"hello": "there"} + assert ws.receive_json(timeout=0.1) == {"and": "goodbye"} + + +def test_websocket_stream_send_json_with_dto() -> None: + @dataclasses.dataclass + class Event: + id: int = dataclasses.field(metadata=dto_field("private")) + content: str + + @websocket_stream("/", return_dto=DataclassDTO[Event]) + async def handler() -> AsyncGenerator[Event, None]: + yield Event(id=1, content="hello") + + with create_test_client([handler], signature_types=[Event]) as client, client.websocket_connect("/") as ws: + assert ws.receive_json(timeout=0.1) == {"content": "hello"} + + +def test_raises_if_stream_fn_does_not_return_async_generator() -> None: + with pytest.raises(ImproperlyConfiguredException): + + @websocket_stream("/") # type: ignore[arg-type] + def foo() -> Generator[bytes, None, None]: + yield b"" + + Litestar([foo]) + + with pytest.raises(ImproperlyConfiguredException): + + @websocket_stream("/") # type: ignore[arg-type] + def foo() -> bytes: + return b"" + + Litestar([foo])