diff --git a/docs/conf.py b/docs/conf.py index be6d8017ac..bcfa4aadd7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -245,6 +245,7 @@ "litestar.concurrency.set_asyncio_executor": {"ThreadPoolExecutor"}, "litestar.concurrency.get_asyncio_executor": {"ThreadPoolExecutor"}, re.compile(r"litestar\.channels\.backends\.asyncpg.*"): {"asyncpg.connection.Connection", "asyncpg.Connection"}, + re.compile(r"litestar\.handlers\.websocket_handlers\.stream.*"): {"WebSocketMode"}, } # Do not warn about broken links to the following: diff --git a/docs/examples/websockets/stream_and_receive_listener.py b/docs/examples/websockets/stream_and_receive_listener.py new file mode 100644 index 0000000000..d01158e12d --- /dev/null +++ b/docs/examples/websockets/stream_and_receive_listener.py @@ -0,0 +1,26 @@ +import asyncio +import time +from typing import Any, AsyncGenerator + +from litestar import Litestar, WebSocket, websocket_listener +from litestar.handlers import send_websocket_stream + + +async def listener_lifespan(socket: WebSocket) -> None: + async def handle_stream() -> AsyncGenerator[dict[str, float], None]: + while True: + yield {"time": time.time()} + await asyncio.sleep(0.5) + + task = asyncio.create_task(send_websocket_stream(socket=socket, stream=handle_stream())) + yield + task.cancel() + await task + + +@websocket_listener("/", connection_lifespan=listener_lifespan) +def handler(socket: WebSocket, data: Any) -> None: + print(f"{socket.client}: {data}") + + +app = Litestar([handler]) diff --git a/docs/examples/websockets/stream_and_receive_raw.py b/docs/examples/websockets/stream_and_receive_raw.py new file mode 100644 index 0000000000..6e6966b0e7 --- /dev/null +++ b/docs/examples/websockets/stream_and_receive_raw.py @@ -0,0 +1,27 @@ +import asyncio +import time +from typing import AsyncGenerator + +from litestar import Litestar, WebSocket, websocket +from litestar.handlers import send_websocket_stream + + +@websocket("/") +async def handler(socket: WebSocket) -> None: + await socket.accept() + + async def handle_stream() -> AsyncGenerator[dict[str, float], None]: + while True: + yield {"time": time.time()} + await asyncio.sleep(0.5) + + async def handle_receive() -> None: + async for event in socket.iter_json(): + print(f"{socket.client}: {event}") + + async with asyncio.TaskGroup() as tg: + tg.create_task(send_websocket_stream(socket=socket, stream=handle_stream())) + tg.create_task(handle_receive()) + + +app = Litestar([handler]) diff --git a/docs/examples/websockets/stream_basic.py b/docs/examples/websockets/stream_basic.py new file mode 100644 index 0000000000..7ed6517331 --- /dev/null +++ b/docs/examples/websockets/stream_basic.py @@ -0,0 +1,15 @@ +import asyncio +import time +from typing import AsyncGenerator + +from litestar import Litestar, websocket_stream + + +@websocket_stream("/") +async def ping() -> AsyncGenerator[float, None]: + while True: + yield time.time() + await asyncio.sleep(0.5) + + +app = Litestar([ping]) diff --git a/docs/examples/websockets/stream_di_hog.py b/docs/examples/websockets/stream_di_hog.py new file mode 100644 index 0000000000..dca552bbc7 --- /dev/null +++ b/docs/examples/websockets/stream_di_hog.py @@ -0,0 +1,23 @@ +import asyncio +from typing import AsyncGenerator + +from app.lib import ping_external_resource +from litestar import Litestar, websocket_stream + +RESOURCE_LOCK = asyncio.Lock() + + +async def acquire_lock() -> AsyncGenerator[None, None]: + async with RESOURCE_LOCK: + yield + + +@websocket_stream("/") +async def ping(lock: asyncio.Lock) -> AsyncGenerator[float, None]: + while True: + alive = await ping_external_resource() + yield alive + await asyncio.sleep(1) + + +app = Litestar([ping], dependencies={"lock": acquire_lock}) diff --git a/docs/examples/websockets/stream_di_hog_fix.py b/docs/examples/websockets/stream_di_hog_fix.py new file mode 100644 index 0000000000..38716f3ef9 --- /dev/null +++ b/docs/examples/websockets/stream_di_hog_fix.py @@ -0,0 +1,19 @@ +import asyncio +from typing import AsyncGenerator + +from app.lib import ping_external_resource +from litestar import Litestar, websocket_stream + +RESOURCE_LOCK = asyncio.Lock() + + +@websocket_stream("/") +async def ping() -> AsyncGenerator[float, None]: + while True: + async with RESOURCE_LOCK: + alive = await ping_external_resource() + yield alive + await asyncio.sleep(1) + + +app = Litestar([ping]) diff --git a/docs/examples/websockets/stream_socket_access.py b/docs/examples/websockets/stream_socket_access.py new file mode 100644 index 0000000000..63d2c99d06 --- /dev/null +++ b/docs/examples/websockets/stream_socket_access.py @@ -0,0 +1,15 @@ +import asyncio +import time +from typing import Any, AsyncGenerator + +from litestar import Litestar, WebSocket, websocket_stream + + +@websocket_stream("/") +async def ping(socket: WebSocket) -> AsyncGenerator[dict[str, Any], None]: + while True: + yield {"time": time.time(), "client": socket.client} + await asyncio.sleep(0.5) + + +app = Litestar([ping]) diff --git a/docs/usage/websockets.rst b/docs/usage/websockets.rst index d1bb34aaa6..0450ace8ee 100644 --- a/docs/usage/websockets.rst +++ b/docs/usage/websockets.rst @@ -1,20 +1,31 @@ WebSockets ========== +There are three ways to handle WebSockets in Litestar: -Handling WebSockets in an application often involves dealing with low level constructs -such as the socket itself, setting up a loop and listening for incoming data, handling -exceptions, and parsing incoming and serializing outgoing data. In addition to the -low-level :class:`WebSocket route handler <.handlers.websocket>`, Litestar offers two -high level interfaces: +1. The low-level :class:`~litestar.handlers.websocket` route handler, providing basic + abstractions over the ASGI WebSocket interface +2. :class:`~litestar.handlers.websocket_listener` and :class:`~litestar.handlers.WebsocketListener`\ : + Reactive, event-driven WebSockets with full serialization and DTO support and support + for a synchronous interface +3. :class:`~litestar.handlers.websocket_stream` and :func:`~litestar.handlers.send_websocket_stream`\ : + Proactive, stream oriented WebSockets with full serialization and DTO support -- :class:`websocket_listener <.handlers.websocket_listener>` -- :class:`WebSocketListener <.handlers.WebsocketListener>` +The main difference between the low and high level interfaces is that, dealing with low +level interface requires, setting up a loop and listening for incoming data, handling +exceptions, client disconnects, and parsing incoming and serializing outgoing data. -These treat a WebSocket handler like any other route handler: as a callable that takes -in incoming data in an already pre-processed form and returns data to be serialized and -sent over the connection. The low level details will be handled behind the curtains. + + +WebSocket Listeners +-------------------- + +WebSocket Listeners can be used to interact with a WebSocket in an event-driven manner, +using a callback style interface. They treat a WebSocket handler like any other route +handler: A callable that takes in incoming data in an already pre-processed form and +returns data to be serialized and sent over the connection. The low level details will +be handled behind the curtains. .. code-block:: python @@ -44,7 +55,7 @@ type of data which should be received, and it will be converted accordingly. Receiving data --------------- +++++++++++++++ Data can be received in the listener via the ``data`` parameter. The data passed to this will be converted / parsed according to the given type annotation and supports @@ -78,7 +89,7 @@ form of JSON. Sending data ------------- ++++++++++++++ Sending data is done by simply returning the value to be sent from the handler function. Similar to receiving data, type annotations configure how the data is being handled. @@ -86,7 +97,8 @@ Values that are not :class:`str` or :class:`bytes` are assumed to be JSON encoda will be serialized accordingly before being sent. This serialization is available for all data types currently supported by Litestar ( :doc:`dataclasses `\ , :class:`TypedDict `, -:class:`NamedTuple `, :class:`msgspec.Struct`, etc.). +:class:`NamedTuple `, :class:`msgspec.Struct`, etc.), including +DTOs. .. tab-set:: @@ -113,25 +125,12 @@ all data types currently supported by Litestar ( :language: python -Transport modes ---------------- - -WebSockets have two transport modes: Text and binary. These can be specified -individually for receiving and sending data. - -.. note:: - It may seem intuitive that ``text`` and ``binary`` should map to :class:`str` and - :class:`bytes` respectively, but this is not the case. Listeners can receive and - send data in any format, independently of the mode. The mode only affects how - data is encoded during transport (i.e. on the protocol level). In most cases the - default mode - ``text`` - is all that's needed. Binary transport is usually employed - when sending binary blobs that don't have a meaningful string representation, such - as images. - +Setting transport modes ++++++++++++++++++++++++ -Setting the receive mode -++++++++++++++++++++++++ +Receive mode +~~~~~~~~~~~~ .. tab-set:: @@ -156,8 +155,8 @@ Setting the receive mode it will not respond to WebSocket events sending data in the text channel. -Setting the send mode -++++++++++++++++++++++ +Send mode +~~~~~~~~~ .. tab-set:: @@ -179,10 +178,10 @@ Setting the send mode Dependency injection --------------------- +++++++++++++++++++++ -:doc:`dependency-injection` is available as well and generally works the same as with -regular route handlers: +:doc:`dependency-injection` is available and generally works the same as in regular +route handlers: .. literalinclude:: /examples/websockets/dependency_injection_simple.py :language: python @@ -203,7 +202,7 @@ the ``yield`` will only be executed after the connection has been closed. Interacting with the WebSocket directly ---------------------------------------- ++++++++++++++++++++++++++++++++++++++++ Sometimes access to the socket instance is needed, in which case the :class:`WebSocket <.connection.WebSocket>` instance can be injected into the handler @@ -220,7 +219,7 @@ function via the ``socket`` argument: Customising connection acceptance ---------------------------------- ++++++++++++++++++++++++++++++++++ By default, Litestar will accept all incoming connections by awaiting ``WebSocket.accept()`` without arguments. This behavior can be customized by passing a custom ``connection_accept_handler`` function. Litestar will await this @@ -231,7 +230,7 @@ function to accept the connection. Class based WebSocket handling ------------------------------- +++++++++++++++++++++++++++++++ In addition to using a simple function as in the examples above, a class based approach is made possible by extending the @@ -254,7 +253,7 @@ encapsulate more complex logic. Custom WebSocket ----------------- +++++++++++++++++ .. versionadded:: 2.7.0 @@ -273,3 +272,118 @@ The example below illustrates how to implement a custom WebSocket class for the class on multiple layers, the layer closest to the route handler will take precedence. You can read more about this in the :ref:`usage/applications:layered architecture` section + + +WebSocket Streams +----------------- + +WebSocket streams can be used to proactively push data to a client, using an +asynchronous generator function. Data will be sent via the socket every time the +generator ``yield``\ s, until it is either exhausted or the client disconnects. + +.. literalinclude:: /examples/websockets/stream_basic.py + :language: python + :caption: Streaming the current time in 0.5 second intervals + + +Serialization ++++++++++++++ + +Just like with route handlers, type annotations configure how the data is being handled. +:class:`str` or :class:`bytes` will be sent as-is, while everything else will be encoded +as JSON before being sent. This serialization is available for all data types currently +supported by Litestar (:doc:`dataclasses `, +:class:`TypedDict `, :class:`NamedTuple `, +:class:`msgspec.Struct`, etc.), including DTOs. + + +Dependency Injection +++++++++++++++++++++ + +Dependency injection is available and works analogous to regular route handlers. + +.. important:: + One thing to keep in mind, especially for long-lived streams, is that dependencies + are scoped to the lifetime of the handler. This means that if for example a + database connection is acquired in a dependency, it will be held until the generator + stops. This may not be desirable in all cases, and acquiring resources ad-hoc inside + the generator itself preferable + + .. literalinclude:: /examples/websockets/stream_di_hog.py + :language: python + :caption: Bad: The lock will be held until the client disconnects + + + .. literalinclude:: /examples/websockets/stream_di_hog_fix.py + :language: python + :caption: Good: The lock will only be acquired when it's needed + + +Interacting with the WebSocket directly ++++++++++++++++++++++++++++++++++++++++ + +To interact with the :class:`WebSocket <.connection.WebSocket>` directly, it can be +injected into the generator function via the ``socket`` argument: + +.. literalinclude:: /examples/websockets/stream_socket_access.py + :language: python + + +Receiving data while streaming +++++++++++++++++++++++++++++++ + +By default, a stream will listen for a client disconnect in the background, and stop +the generator once received. Since this requires receiving data from the socket, it can +lead to data loss if the application is attempting to read from the same socket +simultaneously. + +.. tip:: + To prevent data loss, by default, ``websocket_stream`` will raise an + exception if it receives any data while listening for client disconnects. If + incoming data should be ignored, ``allow_data_discard`` should be set to ``True`` + +If receiving data while streaming is desired, +:func:`~litestar.handlers.send_websocket_stream` can be configured to not listen for +disconnects by setting ``listen_for_disconnect=False``. + +.. important:: + When using ``listen_for_disconnect=False``, the application needs to ensure the + disconnect event is received elsewhere, otherwise the stream will only terminate + when the generator is exhausted + + +Combining streaming and receiving data +--------------------------------------- + +To stream and receive data concurrently, the stream can be set up manually using +:func:`~litestar.handlers.send_websocket_stream` in combination with either a regular +:class:`~litestar.handlers.websocket` handler or a WebSocket listener. + +.. tab-set:: + + .. tab-item:: websocket_listener + + .. literalinclude:: /examples/websockets/stream_and_receive_listener.py + :language: python + + .. tab-item:: websocket handler + + .. literalinclude:: /examples/websockets/stream_and_receive_raw.py + :language: python + + +Transport modes +--------------- + +WebSockets have two transport modes: ``text`` and ``binary``. They dictate how bytes are +transferred over the wire and can be set independently from another, i.e. a socket can +send ``binary`` and receive ``text`` + + +It may seem intuitive that ``text`` and ``binary`` should map to :class:`str` and +:class:`bytes` respectively, but this is not the case. WebSockets can receive and +send data in any format, independently of the mode. The mode only affects how the +bytes are handled during transport (i.e. on the protocol level). In most cases the +default mode - ``text`` - is all that's needed. Binary transport is usually employed +when sending binary blobs that don't have a meaningful string representation, such +as images. \ No newline at end of file diff --git a/litestar/__init__.py b/litestar/__init__.py index 3235113fef..1303d932aa 100644 --- a/litestar/__init__.py +++ b/litestar/__init__.py @@ -2,7 +2,19 @@ from litestar.connection import Request, WebSocket from litestar.controller import Controller from litestar.enums import HttpMethod, MediaType -from litestar.handlers import asgi, delete, get, head, patch, post, put, route, websocket, websocket_listener +from litestar.handlers import ( + asgi, + delete, + get, + head, + patch, + post, + put, + route, + websocket, + websocket_listener, + websocket_stream, +) from litestar.response import Response from litestar.router import Router from litestar.utils.version import get_version @@ -30,4 +42,5 @@ "route", "websocket", "websocket_listener", + "websocket_stream", ) diff --git a/litestar/handlers/__init__.py b/litestar/handlers/__init__.py index 0abf8ee105..8cb1731f57 100644 --- a/litestar/handlers/__init__.py +++ b/litestar/handlers/__init__.py @@ -5,8 +5,10 @@ WebsocketListener, WebsocketListenerRouteHandler, WebsocketRouteHandler, + send_websocket_stream, websocket, websocket_listener, + websocket_stream, ) __all__ = ( @@ -24,6 +26,8 @@ "post", "put", "route", + "send_websocket_stream", "websocket", "websocket_listener", + "websocket_stream", ) diff --git a/litestar/handlers/websocket_handlers/stream.py b/litestar/handlers/websocket_handlers/stream.py index fe3fb82997..c4e689b4e7 100644 --- a/litestar/handlers/websocket_handlers/stream.py +++ b/litestar/handlers/websocket_handlers/stream.py @@ -25,7 +25,8 @@ async def send_websocket_stream( socket: WebSocket, - stream: AsyncGenerator[str | bytes, Any], + stream: AsyncGenerator[Any, Any], + *, close: bool = True, mode: WebSocketMode = "text", send_handler: Callable[[WebSocket, Any], Awaitable[Any]] | None = None, @@ -72,8 +73,12 @@ async def time_handler(socket: WebSocket) -> None: send_handler = functools.partial(type(socket).send_data, mode=mode) async def send_stream() -> None: - async for event in stream: - await send_handler(socket, event) + try: + # client might have disconnected elsewhere, so we stop sending + while socket.connection_state != "disconnect": + await send_handler(socket, await stream.__anext__()) + except StopAsyncIteration: + pass if listen_for_disconnect: # wrap 'send_stream' and disconnect listener, so they'll cancel the other once @@ -92,16 +97,18 @@ async def disconnect_listener() -> None: await socket.receive_data("text") if warn_on_data_discard: warnings.warn( - "received data from websocket while listening for client" - "disconnect in a websocket_stream. listen_for_disconnect is" - "not safe to use when attempting to receive data from the " - "same socket concurrently with a websocket_stream. set " + "received data from websocket while listening for client " + "disconnect in a websocket_stream. listen_for_disconnect " + "is not safe to use when attempting to receive data from " + "the same socket concurrently with a websocket_stream. set " "listen_for_disconnect=False if you're attempting to " "receive data from this socket or set " "warn_on_data_discard=False to disable this warning", stacklevel=2, category=LitestarWarning, ) + await socket.close(4500) + except WebSocketDisconnect: # client disconnected, we can stop streaming tg.cancel_scope.cancel() @@ -142,6 +149,7 @@ def websocket_stream( Sending the current time to the connected client every 0.5 seconds: .. code-block:: python + @websocket_stream("/time") async def send_time() -> AsyncGenerator[str, None]: while True: @@ -165,8 +173,7 @@ async def send_time() -> AsyncGenerator[str, None]: mode: WebSocket mode used for sending return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing outbound response data. type_encoders: A mapping of types to callables that transform them into types supported for serialization. - listen_for_disconnect: If ``True``, listen for client - disconnects in the background. If a client disconnects, + listen_for_disconnect: If ``True``, listen for client disconnects in the background. If a client disconnects, stop the generator and cancel sending data. Should always be ``True`` unless disconnects are handled elsewhere, for example by reading data from the socket concurrently. Should never be set to ``True`` when reading data from socket concurrently, as it can lead to data loss @@ -254,7 +261,9 @@ def on_registration(self, app: Litestar) -> None: return_dto = self.resolve_return_dto() # make sure the closure doesn't capture self._ws_stream / self - send_mode = self._ws_stream_options.send_mode + send_mode = cast( # pyright doesn't track the 'Literal' here for some reason + "WebSocketMode", self._ws_stream_options.send_mode + ) listen_for_disconnect = self._ws_stream_options.listen_for_disconnect warn_on_data_discard = self._ws_stream_options.warn_on_data_discard diff --git a/tests/unit/test_handlers/test_websocket_handlers/test_stream.py b/tests/unit/test_handlers/test_websocket_handlers/test_stream.py index 060964e9a5..d731f37d20 100644 --- a/tests/unit/test_handlers/test_websocket_handlers/test_stream.py +++ b/tests/unit/test_handlers/test_websocket_handlers/test_stream.py @@ -2,14 +2,14 @@ import asyncio import dataclasses -from typing import AsyncGenerator, Generator +from typing import AsyncGenerator, Dict, Generator from unittest.mock import MagicMock import pytest from litestar import Controller, Litestar, WebSocket from litestar.dto import DataclassDTO, dto_field -from litestar.exceptions import ImproperlyConfiguredException +from litestar.exceptions import ImproperlyConfiguredException, LitestarWarning from litestar.handlers.websocket_handlers import websocket_stream from litestar.testing import create_test_client @@ -93,7 +93,7 @@ async def handler(socket: WebSocket, message: str) -> AsyncGenerator[str, None]: def test_websocket_stream_handle_disconnect() -> None: - @websocket_stream("/", warn_on_data_discard=False) + @websocket_stream("/") async def handler() -> AsyncGenerator[str, None]: while True: yield "foo" @@ -111,7 +111,7 @@ async def handler() -> AsyncGenerator[str, None]: def test_websocket_stream_send_json() -> None: @websocket_stream("/") - async def handler() -> AsyncGenerator[dict[str, str], None]: + async def handler() -> AsyncGenerator[Dict[str, str], None]: # noqa: UP006 yield {"hello": "there"} yield {"and": "goodbye"} @@ -150,3 +150,16 @@ def foo() -> bytes: return b"" Litestar([foo]) + + +def test_websocket_stream_raise_if_data_receive_on_listen_for_disconnect() -> None: + @websocket_stream("/") + async def handler() -> AsyncGenerator[str, None]: + while True: + yield "foo" + await asyncio.sleep(0.1) + + with pytest.warns(LitestarWarning, match="received data from websocket"): + with create_test_client([handler], raise_server_exceptions=True) as client, client.websocket_connect("/") as ws: + ws.send_text("foo") + assert ws.receive_text(timeout=0.1) == "foo"