diff --git a/docs/conf.py b/docs/conf.py index be6d8017ac..bcfa4aadd7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -245,6 +245,7 @@ "litestar.concurrency.set_asyncio_executor": {"ThreadPoolExecutor"}, "litestar.concurrency.get_asyncio_executor": {"ThreadPoolExecutor"}, re.compile(r"litestar\.channels\.backends\.asyncpg.*"): {"asyncpg.connection.Connection", "asyncpg.Connection"}, + re.compile(r"litestar\.handlers\.websocket_handlers\.stream.*"): {"WebSocketMode"}, } # Do not warn about broken links to the following: diff --git a/docs/examples/websockets/stream_and_receive_listener.py b/docs/examples/websockets/stream_and_receive_listener.py new file mode 100644 index 0000000000..d01158e12d --- /dev/null +++ b/docs/examples/websockets/stream_and_receive_listener.py @@ -0,0 +1,26 @@ +import asyncio +import time +from typing import Any, AsyncGenerator + +from litestar import Litestar, WebSocket, websocket_listener +from litestar.handlers import send_websocket_stream + + +async def listener_lifespan(socket: WebSocket) -> None: + async def handle_stream() -> AsyncGenerator[dict[str, float], None]: + while True: + yield {"time": time.time()} + await asyncio.sleep(0.5) + + task = asyncio.create_task(send_websocket_stream(socket=socket, stream=handle_stream())) + yield + task.cancel() + await task + + +@websocket_listener("/", connection_lifespan=listener_lifespan) +def handler(socket: WebSocket, data: Any) -> None: + print(f"{socket.client}: {data}") + + +app = Litestar([handler]) diff --git a/docs/examples/websockets/stream_and_receive_raw.py b/docs/examples/websockets/stream_and_receive_raw.py new file mode 100644 index 0000000000..6e6966b0e7 --- /dev/null +++ b/docs/examples/websockets/stream_and_receive_raw.py @@ -0,0 +1,27 @@ +import asyncio +import time +from typing import AsyncGenerator + +from litestar import Litestar, WebSocket, websocket +from litestar.handlers import send_websocket_stream + + +@websocket("/") +async def handler(socket: WebSocket) -> None: + await socket.accept() + + async def handle_stream() -> AsyncGenerator[dict[str, float], None]: + while True: + yield {"time": time.time()} + await asyncio.sleep(0.5) + + async def handle_receive() -> None: + async for event in socket.iter_json(): + print(f"{socket.client}: {event}") + + async with asyncio.TaskGroup() as tg: + tg.create_task(send_websocket_stream(socket=socket, stream=handle_stream())) + tg.create_task(handle_receive()) + + +app = Litestar([handler]) diff --git a/docs/examples/websockets/stream_basic.py b/docs/examples/websockets/stream_basic.py new file mode 100644 index 0000000000..7ed6517331 --- /dev/null +++ b/docs/examples/websockets/stream_basic.py @@ -0,0 +1,15 @@ +import asyncio +import time +from typing import AsyncGenerator + +from litestar import Litestar, websocket_stream + + +@websocket_stream("/") +async def ping() -> AsyncGenerator[float, None]: + while True: + yield time.time() + await asyncio.sleep(0.5) + + +app = Litestar([ping]) diff --git a/docs/examples/websockets/stream_di_hog.py b/docs/examples/websockets/stream_di_hog.py new file mode 100644 index 0000000000..dca552bbc7 --- /dev/null +++ b/docs/examples/websockets/stream_di_hog.py @@ -0,0 +1,23 @@ +import asyncio +from typing import AsyncGenerator + +from app.lib import ping_external_resource +from litestar import Litestar, websocket_stream + +RESOURCE_LOCK = asyncio.Lock() + + +async def acquire_lock() -> AsyncGenerator[None, None]: + async with RESOURCE_LOCK: + yield + + +@websocket_stream("/") +async def ping(lock: asyncio.Lock) -> AsyncGenerator[float, None]: + while True: + alive = await ping_external_resource() + yield alive + await asyncio.sleep(1) + + +app = Litestar([ping], dependencies={"lock": acquire_lock}) diff --git a/docs/examples/websockets/stream_di_hog_fix.py b/docs/examples/websockets/stream_di_hog_fix.py new file mode 100644 index 0000000000..38716f3ef9 --- /dev/null +++ b/docs/examples/websockets/stream_di_hog_fix.py @@ -0,0 +1,19 @@ +import asyncio +from typing import AsyncGenerator + +from app.lib import ping_external_resource +from litestar import Litestar, websocket_stream + +RESOURCE_LOCK = asyncio.Lock() + + +@websocket_stream("/") +async def ping() -> AsyncGenerator[float, None]: + while True: + async with RESOURCE_LOCK: + alive = await ping_external_resource() + yield alive + await asyncio.sleep(1) + + +app = Litestar([ping]) diff --git a/docs/examples/websockets/stream_socket_access.py b/docs/examples/websockets/stream_socket_access.py new file mode 100644 index 0000000000..63d2c99d06 --- /dev/null +++ b/docs/examples/websockets/stream_socket_access.py @@ -0,0 +1,15 @@ +import asyncio +import time +from typing import Any, AsyncGenerator + +from litestar import Litestar, WebSocket, websocket_stream + + +@websocket_stream("/") +async def ping(socket: WebSocket) -> AsyncGenerator[dict[str, Any], None]: + while True: + yield {"time": time.time(), "client": socket.client} + await asyncio.sleep(0.5) + + +app = Litestar([ping]) diff --git a/docs/usage/websockets.rst b/docs/usage/websockets.rst index d1bb34aaa6..dc9dc7eeb5 100644 --- a/docs/usage/websockets.rst +++ b/docs/usage/websockets.rst @@ -1,20 +1,31 @@ WebSockets ========== +There are three ways to handle WebSockets in Litestar: -Handling WebSockets in an application often involves dealing with low level constructs -such as the socket itself, setting up a loop and listening for incoming data, handling -exceptions, and parsing incoming and serializing outgoing data. In addition to the -low-level :class:`WebSocket route handler <.handlers.websocket>`, Litestar offers two -high level interfaces: +1. The low-level :class:`~litestar.handlers.websocket` route handler, providing basic + abstractions over the ASGI WebSocket interface +2. :class:`~litestar.handlers.websocket_listener` and :class:`~litestar.handlers.WebsocketListener`\ : + Reactive, event-driven WebSockets with full serialization and DTO support and support + for a synchronous interface +3. :class:`~litestar.handlers.websocket_stream` and :func:`~litestar.handlers.send_websocket_stream`\ : + Proactive, stream oriented WebSockets with full serialization and DTO support -- :class:`websocket_listener <.handlers.websocket_listener>` -- :class:`WebSocketListener <.handlers.WebsocketListener>` +The main difference between the low and high level interfaces is that, dealing with low +level interface requires, setting up a loop and listening for incoming data, handling +exceptions, client disconnects, and parsing incoming and serializing outgoing data. -These treat a WebSocket handler like any other route handler: as a callable that takes -in incoming data in an already pre-processed form and returns data to be serialized and -sent over the connection. The low level details will be handled behind the curtains. + + +WebSocket Listeners +-------------------- + +WebSocket Listeners can be used to interact with a WebSocket in an event-driven manner, +using a callback style interface. They treat a WebSocket handler like any other route +handler: A callable that takes in incoming data in an already pre-processed form and +returns data to be serialized and sent over the connection. The low level details will +be handled behind the curtains. .. code-block:: python @@ -44,7 +55,7 @@ type of data which should be received, and it will be converted accordingly. Receiving data --------------- +++++++++++++++ Data can be received in the listener via the ``data`` parameter. The data passed to this will be converted / parsed according to the given type annotation and supports @@ -78,7 +89,7 @@ form of JSON. Sending data ------------- ++++++++++++++ Sending data is done by simply returning the value to be sent from the handler function. Similar to receiving data, type annotations configure how the data is being handled. @@ -86,7 +97,8 @@ Values that are not :class:`str` or :class:`bytes` are assumed to be JSON encoda will be serialized accordingly before being sent. This serialization is available for all data types currently supported by Litestar ( :doc:`dataclasses `\ , :class:`TypedDict `, -:class:`NamedTuple `, :class:`msgspec.Struct`, etc.). +:class:`NamedTuple `, :class:`msgspec.Struct`, etc.), including +DTOs. .. tab-set:: @@ -113,25 +125,12 @@ all data types currently supported by Litestar ( :language: python -Transport modes ---------------- - -WebSockets have two transport modes: Text and binary. These can be specified -individually for receiving and sending data. - -.. note:: - It may seem intuitive that ``text`` and ``binary`` should map to :class:`str` and - :class:`bytes` respectively, but this is not the case. Listeners can receive and - send data in any format, independently of the mode. The mode only affects how - data is encoded during transport (i.e. on the protocol level). In most cases the - default mode - ``text`` - is all that's needed. Binary transport is usually employed - when sending binary blobs that don't have a meaningful string representation, such - as images. - +Setting transport modes ++++++++++++++++++++++++ -Setting the receive mode -++++++++++++++++++++++++ +Receive mode +~~~~~~~~~~~~ .. tab-set:: @@ -156,8 +155,8 @@ Setting the receive mode it will not respond to WebSocket events sending data in the text channel. -Setting the send mode -++++++++++++++++++++++ +Send mode +~~~~~~~~~ .. tab-set:: @@ -179,10 +178,10 @@ Setting the send mode Dependency injection --------------------- +++++++++++++++++++++ -:doc:`dependency-injection` is available as well and generally works the same as with -regular route handlers: +:doc:`dependency-injection` is available and generally works the same as in regular +route handlers: .. literalinclude:: /examples/websockets/dependency_injection_simple.py :language: python @@ -203,7 +202,7 @@ the ``yield`` will only be executed after the connection has been closed. Interacting with the WebSocket directly ---------------------------------------- ++++++++++++++++++++++++++++++++++++++++ Sometimes access to the socket instance is needed, in which case the :class:`WebSocket <.connection.WebSocket>` instance can be injected into the handler @@ -220,7 +219,7 @@ function via the ``socket`` argument: Customising connection acceptance ---------------------------------- ++++++++++++++++++++++++++++++++++ By default, Litestar will accept all incoming connections by awaiting ``WebSocket.accept()`` without arguments. This behavior can be customized by passing a custom ``connection_accept_handler`` function. Litestar will await this @@ -231,7 +230,7 @@ function to accept the connection. Class based WebSocket handling ------------------------------- +++++++++++++++++++++++++++++++ In addition to using a simple function as in the examples above, a class based approach is made possible by extending the @@ -254,7 +253,7 @@ encapsulate more complex logic. Custom WebSocket ----------------- +++++++++++++++++ .. versionadded:: 2.7.0 @@ -273,3 +272,118 @@ The example below illustrates how to implement a custom WebSocket class for the class on multiple layers, the layer closest to the route handler will take precedence. You can read more about this in the :ref:`usage/applications:layered architecture` section + + +WebSocket Streams +----------------- + +WebSocket streams can be used to proactively push data to a client, using an +asynchronous generator function. Data will be sent via the socket every time the +generator ``yield``\ s, until it is either exhausted or the client disconnects. + +.. literalinclude:: /examples/websockets/stream_basic.py + :language: python + :caption: Streaming the current time in 0.5 second intervals + + +Serialization ++++++++++++++ + +Just like with route handlers, type annotations configure how the data is being handled. +:class:`str` or :class:`bytes` will be sent as-is, while everything else will be encoded +as JSON before being sent. This serialization is available for all data types currently +supported by Litestar (:doc:`dataclasses `, +:class:`TypedDict `, :class:`NamedTuple `, +:class:`msgspec.Struct`, etc.), including DTOs. + + +Dependency Injection +++++++++++++++++++++ + +Dependency injection is available and works analogous to regular route handlers. + +.. important:: + One thing to keep in mind, especially for long-lived streams, is that dependencies + are scoped to the lifetime of the handler. This means that if for example a + database connection is acquired in a dependency, it will be held until the generator + stops. This may not be desirable in all cases, and acquiring resources ad-hoc inside + the generator itself preferable + + .. literalinclude:: /examples/websockets/stream_di_hog.py + :language: python + :caption: Bad: The lock will be held until the client disconnects + + + .. literalinclude:: /examples/websockets/stream_di_hog_fix.py + :language: python + :caption: Good: The lock will only be acquired when it's needed + + +Interacting with the WebSocket directly ++++++++++++++++++++++++++++++++++++++++ + +To interact with the :class:`WebSocket <.connection.WebSocket>` directly, it can be +injected into the generator function via the ``socket`` argument: + +.. literalinclude:: /examples/websockets/stream_socket_access.py + :language: python + + +Receiving data while streaming +++++++++++++++++++++++++++++++ + +By default, a stream will listen for a client disconnect in the background, and stop +the generator once received. Since this requires receiving data from the socket, it can +lead to data loss if the application is attempting to read from the same socket +simultaneously. + +.. tip:: + To prevent data loss, by default, ``websocket_stream`` will raise an + exception if it receives any data while listening for client disconnects. If + incoming data should be ignored, ``allow_data_discard`` should be set to ``True`` + +If receiving data while streaming is desired, +:func:`~litestar.handlers.send_websocket_stream` can be configured to not listen for +disconnects by setting ``listen_for_disconnect=False``. + +.. important:: + When using ``listen_for_disconnect=False``, the application needs to ensure the + disconnect event is received elsewhere, otherwise the stream will only terminate + when the generator is exhausted + + +Combining streaming and receiving data +--------------------------------------- + +To stream and receive data concurrently, the stream can be set up manually using +:func:`~litestar.handlers.send_websocket_stream` in combination with either a regular +:class:`~litestar.handlers.websocket` handler or a WebSocket listener. + +.. tab-set:: + + .. tab-item:: websocket_listener + + .. literalinclude:: /examples/websockets/stream_and_receive_listener.py + :language: python + + .. tab-item:: websocket handler + + .. literalinclude:: /examples/websockets/stream_and_receive_raw.py + :language: python + + +Transport modes +--------------- + +WebSockets have two transport modes: ``text`` and ``binary``. They dictate how bytes are +transferred over the wire and can be set independently from another, i.e. a socket can +send ``binary`` and receive ``text`` + + +It may seem intuitive that ``text`` and ``binary`` should map to :class:`str` and +:class:`bytes` respectively, but this is not the case. WebSockets can receive and +send data in any format, independently of the mode. The mode only affects how the +bytes are handled during transport (i.e. on the protocol level). In most cases the +default mode - ``text`` - is all that's needed. Binary transport is usually employed +when sending binary blobs that don't have a meaningful string representation, such +as images. diff --git a/litestar/__init__.py b/litestar/__init__.py index 3235113fef..1303d932aa 100644 --- a/litestar/__init__.py +++ b/litestar/__init__.py @@ -2,7 +2,19 @@ from litestar.connection import Request, WebSocket from litestar.controller import Controller from litestar.enums import HttpMethod, MediaType -from litestar.handlers import asgi, delete, get, head, patch, post, put, route, websocket, websocket_listener +from litestar.handlers import ( + asgi, + delete, + get, + head, + patch, + post, + put, + route, + websocket, + websocket_listener, + websocket_stream, +) from litestar.response import Response from litestar.router import Router from litestar.utils.version import get_version @@ -30,4 +42,5 @@ "route", "websocket", "websocket_listener", + "websocket_stream", ) diff --git a/litestar/handlers/__init__.py b/litestar/handlers/__init__.py index 0abf8ee105..8cb1731f57 100644 --- a/litestar/handlers/__init__.py +++ b/litestar/handlers/__init__.py @@ -5,8 +5,10 @@ WebsocketListener, WebsocketListenerRouteHandler, WebsocketRouteHandler, + send_websocket_stream, websocket, websocket_listener, + websocket_stream, ) __all__ = ( @@ -24,6 +26,8 @@ "post", "put", "route", + "send_websocket_stream", "websocket", "websocket_listener", + "websocket_stream", ) diff --git a/litestar/handlers/websocket_handlers/__init__.py b/litestar/handlers/websocket_handlers/__init__.py index 5b24734948..a156c6c455 100644 --- a/litestar/handlers/websocket_handlers/__init__.py +++ b/litestar/handlers/websocket_handlers/__init__.py @@ -6,11 +6,14 @@ websocket_listener, ) from litestar.handlers.websocket_handlers.route_handler import WebsocketRouteHandler, websocket +from litestar.handlers.websocket_handlers.stream import send_websocket_stream, websocket_stream __all__ = ( "WebsocketListener", "WebsocketListenerRouteHandler", "WebsocketRouteHandler", + "send_websocket_stream", "websocket", "websocket_listener", + "websocket_stream", ) diff --git a/litestar/handlers/websocket_handlers/route_handler.py b/litestar/handlers/websocket_handlers/route_handler.py index 3b3b8f03bf..4356d618fc 100644 --- a/litestar/handlers/websocket_handlers/route_handler.py +++ b/litestar/handlers/websocket_handlers/route_handler.py @@ -49,9 +49,9 @@ def __init__( :class:`ASGI Scope <.types.Scope>`. signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. type_encoders: A mapping of types to callables that transform them into types supported for serialization. - **kwargs: Any additional kwarg - will be set in the opt dictionary. websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's default websocket class. + **kwargs: Any additional kwarg - will be set in the opt dictionary. """ self.websocket_class = websocket_class diff --git a/litestar/handlers/websocket_handlers/stream.py b/litestar/handlers/websocket_handlers/stream.py new file mode 100644 index 0000000000..b52a3eb8d2 --- /dev/null +++ b/litestar/handlers/websocket_handlers/stream.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import dataclasses +import functools +import warnings +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Mapping, cast + +import anyio +from msgspec.json import Encoder as JsonEncoder +from typing_extensions import Self + +from litestar.exceptions import ImproperlyConfiguredException, LitestarWarning, WebSocketDisconnect +from litestar.handlers.websocket_handlers.route_handler import WebsocketRouteHandler +from litestar.types import Empty +from litestar.types.builtin_types import NoneType +from litestar.typing import FieldDefinition +from litestar.utils.signature import ParsedSignature + +if TYPE_CHECKING: + from litestar import Litestar, WebSocket + from litestar.dto import AbstractDTO + from litestar.types import Dependencies, EmptyType, ExceptionHandler, Guard, Middleware, TypeEncodersMap + from litestar.types.asgi_types import WebSocketMode + + +async def send_websocket_stream( + socket: WebSocket, + stream: AsyncGenerator[Any, Any], + *, + close: bool = True, + mode: WebSocketMode = "text", + send_handler: Callable[[WebSocket, Any], Awaitable[Any]] | None = None, + listen_for_disconnect: bool = False, + warn_on_data_discard: bool = True, +) -> None: + """Stream data to the ``socket`` from an asynchronous generator. + + Example: + Sending the current time to the connected client every 0.5 seconds: + + .. code-block:: python + + async def stream_current_time() -> AsyncGenerator[str, None]: + while True: + yield str(time.time()) + await asyncio.sleep(0.5) + + + @websocket("/time") + async def time_handler(socket: WebSocket) -> None: + await socket.accept() + await send_websocket_stream( + socket, + stream_current_time(), + listen_for_disconnect=True, + ) + + + Args: + socket: The :class:`~litestar.connection.WebSocket` to send to + stream: An asynchronous generator yielding data to send + close: If ``True``, close the socket after the generator is exhausted + mode: WebSocket mode to use for sending when no ``send_handler`` is specified + send_handler: Callable to handle the send process. If ``None``, defaults to ``type(socket).send_data`` + listen_for_disconnect: If ``True``, listen for client disconnects in the background. If a client disconnects, + stop the generator and cancel sending data. Should always be ``True`` unless disconnects are handled + elsewhere, for example by reading data from the socket concurrently. Should never be set to ``True`` when + reading data from socket concurrently, as it can lead to data loss + warn_on_data_discard: If ``True`` and ``listen_for_disconnect=True``, warn if during listening for client + disconnects, data is received from the socket + """ + if send_handler is None: + send_handler = functools.partial(type(socket).send_data, mode=mode) + + async def send_stream() -> None: + try: + # client might have disconnected elsewhere, so we stop sending + while socket.connection_state != "disconnect": + await send_handler(socket, await stream.__anext__()) + except StopAsyncIteration: + pass + + if listen_for_disconnect: + # wrap 'send_stream' and disconnect listener, so they'll cancel the other once + # one of the finishes + async def wrapped_stream() -> None: + await send_stream() + # stream exhausted, we can stop listening for a disconnect + tg.cancel_scope.cancel() + + async def disconnect_listener() -> None: + try: + # run this in a loop - we might receive other data than disconnects. + # listen_for_disconnect is explicitly not safe when consuming WS data + # in other places, so discarding that data here is fine + while True: + await socket.receive_data("text") + if warn_on_data_discard: + warnings.warn( + "received data from websocket while listening for client " + "disconnect in a websocket_stream. listen_for_disconnect " + "is not safe to use when attempting to receive data from " + "the same socket concurrently with a websocket_stream. set " + "listen_for_disconnect=False if you're attempting to " + "receive data from this socket or set " + "warn_on_data_discard=False to disable this warning", + stacklevel=2, + category=LitestarWarning, + ) + + except WebSocketDisconnect: + # client disconnected, we can stop streaming + tg.cancel_scope.cancel() + + async with anyio.create_task_group() as tg: + tg.start_soon(wrapped_stream) + tg.start_soon(disconnect_listener) + + else: + await send_stream() + + if close and socket.connection_state != "disconnect": + await socket.close() + + +def websocket_stream( + path: str | list[str] | None = None, + *, + dependencies: Dependencies | None = None, + exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None, + guards: list[Guard] | None = None, + middleware: list[Middleware] | None = None, + name: str | None = None, + opt: dict[str, Any] | None = None, + signature_namespace: Mapping[str, Any] | None = None, + websocket_class: type[WebSocket] | None = None, + mode: WebSocketMode = "text", + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + type_encoders: TypeEncodersMap | None = None, + listen_for_disconnect: bool = True, + warn_on_data_discard: bool = True, + **kwargs: Any, +) -> Callable[[Callable[..., AsyncGenerator[Any, Any]]], WebsocketRouteHandler]: + """Create a WebSocket handler that accepts a connection and sends data to it from an + async generator. + + Example: + Sending the current time to the connected client every 0.5 seconds: + + .. code-block:: python + + @websocket_stream("/time") + async def send_time() -> AsyncGenerator[str, None]: + while True: + yield str(time.time()) + await asyncio.sleep(0.5) + + Args: + path: A path fragment for the route handler function or a sequence of path fragments. If not given defaults + to ``/`` + dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + name: A string identifying the route handler. + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or + :class:`ASGI Scope <.types.Scope>`. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's + default websocket class. + mode: WebSocket mode used for sending + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing outbound response data. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + listen_for_disconnect: If ``True``, listen for client disconnects in the background. If a client disconnects, + stop the generator and cancel sending data. Should always be ``True`` unless disconnects are handled + elsewhere, for example by reading data from the socket concurrently. Should never be set to ``True`` when + reading data from socket concurrently, as it can lead to data loss + warn_on_data_discard: If ``True`` and ``listen_for_disconnect=True``, warn if during listening for client + disconnects, data is received from the socket + **kwargs: Any additional kwarg - will be set in the opt dictionary. + """ + + def decorator(fn: Callable[..., AsyncGenerator[Any, Any]]) -> WebsocketRouteHandler: + return WebSocketStreamHandler( + path=path, + dependencies=dependencies, + exception_handlers=exception_handlers, + guard=guards, + middleware=middleware, + name=name, + opt=opt, + signature_namespace=signature_namespace, + websocket_class=websocket_class, + return_dto=return_dto, + type_encoders=type_encoders, + **kwargs, + )( + _WebSocketStreamOptions( + generator_fn=fn, + send_mode=mode, + listen_for_disconnect=listen_for_disconnect, + warn_on_data_discard=warn_on_data_discard, + ) + ) + + return decorator + + +class WebSocketStreamHandler(WebsocketRouteHandler): + __slots__ = ("_ws_stream_options",) + _ws_stream_options: _WebSocketStreamOptions + + def __call__(self, fn: _WebSocketStreamOptions) -> Self: # type: ignore[override] + self._ws_stream_options = fn + self._fn = self._ws_stream_options.generator_fn # type: ignore[assignment] + return self + + def on_registration(self, app: Litestar) -> None: + parsed_handler_signature = parsed_stream_fn_signature = ParsedSignature.from_fn( + self.fn, self.resolve_signature_namespace() + ) + + if not parsed_stream_fn_signature.return_type.is_subclass_of(AsyncGenerator): + raise ImproperlyConfiguredException( + f"Route handler {self}: 'websocket_stream' handlers must return an " + f"'AsyncGenerator', not {type(parsed_stream_fn_signature.return_type.raw)!r}" + ) + + # important not to use 'self._ws_stream_options.generator_fn' here; This would + # break in cases the decorator has been used inside a controller, as it would + # be a reference to the unbound method. The bound method is patched in later + # after the controller has been initialized. This is a workaround that should + # go away with v3.0's static handlers + stream_fn = cast(Callable[..., AsyncGenerator[Any, Any]], self.fn) + + # construct a fake signature for the kwargs modelling, using the generator + # function passed to the handler as a base, to include all the dependencies, + # params, injection kwargs, etc. + 'socket', so DI works properly, but the + # signature looks to kwargs/signature modelling like a plain '@websocket' + # handler that returns 'None' + parsed_handler_signature = dataclasses.replace( + parsed_handler_signature, return_type=FieldDefinition.from_annotation(NoneType) + ) + receives_socket_parameter = "socket" in parsed_stream_fn_signature.parameters + + if not receives_socket_parameter: + parsed_handler_signature = dataclasses.replace( + parsed_handler_signature, + parameters={ + **parsed_handler_signature.parameters, + "socket": FieldDefinition.from_annotation("WebSocket", name="socket"), + }, + ) + + self._parsed_fn_signature = parsed_handler_signature + self._parsed_return_field = parsed_stream_fn_signature.return_type.inner_types[0] + + json_encoder = JsonEncoder(enc_hook=self.default_serializer) + return_dto = self.resolve_return_dto() + + # make sure the closure doesn't capture self._ws_stream / self + send_mode: WebSocketMode = self._ws_stream_options.send_mode # pyright: ignore + listen_for_disconnect = self._ws_stream_options.listen_for_disconnect + warn_on_data_discard = self._ws_stream_options.warn_on_data_discard + + async def send_handler(socket: WebSocket, data: Any) -> None: + if isinstance(data, (str, bytes)): + await socket.send_data(data=data, mode=send_mode) + return + + if return_dto: + encoded_data = return_dto(socket).data_to_encodable_type(data) + data = json_encoder.encode(encoded_data) + await socket.send_data(data=data, mode=send_mode) + return + + data = json_encoder.encode(data) + await socket.send_data(data=data, mode=send_mode) + + @functools.wraps(stream_fn) + async def handler_fn(*args: Any, socket: WebSocket, **kw: Any) -> None: + if receives_socket_parameter: + kw["socket"] = socket + + await send_websocket_stream( + socket=socket, + stream=stream_fn(*args, **kw), + mode=send_mode, + close=True, + listen_for_disconnect=listen_for_disconnect, + warn_on_data_discard=warn_on_data_discard, + send_handler=send_handler, + ) + + self._fn = handler_fn + + super().on_registration(app) + + +class _WebSocketStreamOptions: + def __init__( + self, + generator_fn: Callable[..., AsyncGenerator[Any, Any]], + listen_for_disconnect: bool, + warn_on_data_discard: bool, + send_mode: WebSocketMode, + ) -> None: + self.generator_fn = generator_fn + self.listen_for_disconnect = listen_for_disconnect + self.warn_on_data_discard = warn_on_data_discard + self.send_mode = send_mode diff --git a/tests/unit/test_handlers/test_websocket_handlers/test_stream.py b/tests/unit/test_handlers/test_websocket_handlers/test_stream.py new file mode 100644 index 0000000000..aeb6222c8c --- /dev/null +++ b/tests/unit/test_handlers/test_websocket_handlers/test_stream.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import asyncio +import dataclasses +from typing import AsyncGenerator, Dict, Generator +from unittest.mock import MagicMock + +import pytest + +from litestar import Controller, Litestar, WebSocket +from litestar.dto import DataclassDTO, dto_field +from litestar.exceptions import ImproperlyConfiguredException +from litestar.handlers.websocket_handlers import websocket_stream +from litestar.testing import create_test_client + + +def test_websocket_stream() -> None: + @websocket_stream("/") + async def handler(socket: WebSocket) -> AsyncGenerator[str, None]: + yield "foo" + yield "bar" + + with create_test_client([handler]) as client, client.websocket_connect("/") as ws: + assert ws.receive_text(timeout=0.1) == "foo" + assert ws.receive_text(timeout=0.1) == "bar" + + +def test_websocket_stream_in_controller() -> None: + class MyController(Controller): + @websocket_stream("/") + async def handler(self, socket: WebSocket) -> AsyncGenerator[str, None]: + yield "foo" + + with create_test_client([MyController]) as client, client.websocket_connect("/") as ws: + assert ws.receive_text(timeout=0.1) == "foo" + + +def test_websocket_stream_without_socket() -> None: + @websocket_stream("/") + async def handler() -> AsyncGenerator[str, None]: + yield "foo" + + with create_test_client([handler]) as client, client.websocket_connect("/") as ws: + assert ws.receive_text(timeout=0.1) == "foo" + + +def test_websocket_stream_dependency_injection() -> None: + async def provide_hello() -> str: + return "hello" + + # ensure we can inject dependencies + @websocket_stream("/1", dependencies={"greeting": provide_hello}) + async def handler_one(greeting: str) -> AsyncGenerator[str, None]: + yield greeting + + # ensure dependency injection also works with 'socket' present + @websocket_stream("/2", dependencies={"greeting": provide_hello}) + async def handler_two(socket: WebSocket, greeting: str) -> AsyncGenerator[str, None]: + yield greeting + + with create_test_client([handler_one, handler_two]) as client: + with client.websocket_connect("/1") as ws: + assert ws.receive_text(timeout=0.1) == "hello" + + with client.websocket_connect("/2") as ws: + assert ws.receive_text(timeout=0.1) == "hello" + + +def test_websocket_stream_dependencies_cleaned_up_after_stream_close() -> None: + mock = MagicMock() + + async def dep() -> AsyncGenerator[str, None]: + yield "foo" + mock() + + @websocket_stream( + "/", + dependencies={"message": dep}, + listen_for_disconnect=False, + ) + async def handler(socket: WebSocket, message: str) -> AsyncGenerator[str, None]: + yield "one" + await socket.receive_text() + yield message + + with create_test_client([handler]) as client, client.websocket_connect("/") as ws: + assert ws.receive_text(timeout=0.1) == "one" + assert mock.call_count == 0 + ws.send_text("") + assert ws.receive_text(timeout=0.1) == "foo" + + assert mock.call_count == 1 + + +def test_websocket_stream_handle_disconnect() -> None: + @websocket_stream("/") + async def handler() -> AsyncGenerator[str, None]: + while True: + yield "foo" + # sleep for longer than our read-timeout to ensure we're disconnecting prematurely + await asyncio.sleep(1) + + with create_test_client([handler]) as client, client.websocket_connect("/") as ws: + assert ws.receive_text(timeout=0.1) == "foo" + + with create_test_client([handler]) as client, client.websocket_connect("/") as ws: + # ensure we still disconnect even after receiving some data + ws.send_text("") + assert ws.receive_text(timeout=0.1) == "foo" + + +def test_websocket_stream_send_json() -> None: + @websocket_stream("/") + async def handler() -> AsyncGenerator[Dict[str, str], None]: # noqa: UP006 + yield {"hello": "there"} + yield {"and": "goodbye"} + + with create_test_client([handler]) as client, client.websocket_connect("/") as ws: + assert ws.receive_json(timeout=0.1) == {"hello": "there"} + assert ws.receive_json(timeout=0.1) == {"and": "goodbye"} + + +def test_websocket_stream_send_json_with_dto() -> None: + @dataclasses.dataclass + class Event: + id: int = dataclasses.field(metadata=dto_field("private")) + content: str + + @websocket_stream("/", return_dto=DataclassDTO[Event]) + async def handler() -> AsyncGenerator[Event, None]: + yield Event(id=1, content="hello") + + with create_test_client([handler], signature_types=[Event]) as client, client.websocket_connect("/") as ws: + assert ws.receive_json(timeout=0.1) == {"content": "hello"} + + +def test_raises_if_stream_fn_does_not_return_async_generator() -> None: + with pytest.raises(ImproperlyConfiguredException): + + @websocket_stream("/") # type: ignore[arg-type] + def foo() -> Generator[bytes, None, None]: + yield b"" + + Litestar([foo]) + + with pytest.raises(ImproperlyConfiguredException): + + @websocket_stream("/") # type: ignore[arg-type] + def foo() -> bytes: + return b"" + + Litestar([foo])