Skip to content

Commit

Permalink
dto support, docstrings, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut committed Dec 11, 2024
1 parent 19c1966 commit 0b8d878
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 57 deletions.
4 changes: 0 additions & 4 deletions litestar/connection/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,3 @@ async def send_msgpack(
None
"""
await self.send_data(data=encode_msgpack(data, serializer), mode="binary", encoding=encoding)

async def send_stream(self, stream: AsyncGenerator[str | bytes, Any], mode: WebSocketMode) -> None:
async for event in stream:
await self.send_data(event, mode=mode)
200 changes: 149 additions & 51 deletions litestar/handlers/websocket_handlers/stream.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
from __future__ import annotations

import dataclasses
import functools
import inspect
import warnings
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Mapping
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Mapping, cast

import anyio
from msgspec.json import Encoder as JsonEncoder
from typing_extensions import Self

from litestar.exceptions import LitestarWarning, WebSocketDisconnect
from litestar.exceptions import ImproperlyConfiguredException, LitestarWarning, WebSocketDisconnect
from litestar.handlers.websocket_handlers.route_handler import WebsocketRouteHandler
from litestar.types import Empty
from litestar.types.builtin_types import NoneType
from litestar.typing import FieldDefinition
from litestar.utils.signature import ParsedSignature

if TYPE_CHECKING:
from litestar import WebSocket
from litestar.types import Dependencies, ExceptionHandler, Guard, Middleware
from litestar import Litestar, WebSocket
from litestar.dto import AbstractDTO
from litestar.types import Dependencies, EmptyType, ExceptionHandler, Guard, Middleware, TypeEncodersMap
from litestar.types.asgi_types import WebSocketMode


Expand All @@ -21,11 +28,11 @@ async def send_websocket_stream(
stream: AsyncGenerator[str | bytes, Any],
close: bool = True,
mode: WebSocketMode = "text",
send_handler: Callable[[WebSocket, Any], Awaitable[Any]] | None = None,
listen_for_disconnect: bool = False,
warn_on_data_discard: bool = True,
) -> None:
"""
Stream data to the ``socket`` from an asynchronous generator.
"""Stream data to the ``socket`` from an asynchronous generator.
Example:
Sending the current time to the connected client every 0.5 seconds:
Expand All @@ -52,19 +59,27 @@ async def time_handler(socket: WebSocket) -> None:
socket: The :class:`~litestar.connection.WebSocket` to send to
stream: An asynchronous generator yielding data to send
close: If ``True``, close the socket after the generator is exhausted
mode: WebSocket mode to use for sending
mode: WebSocket mode to use for sending when no ``send_handler`` is specified
send_handler: Callable to handle the send process. If ``None``, defaults to ``type(socket).send_data``
listen_for_disconnect: If ``True``, listen for client disconnects in the background. If a client disconnects,
stop the generator and cancel sending data. Should always be ``True`` unless disconnects are handled
elsewhere, for example by reading data from the socket concurrently. Should never be set to ``True`` when
reading data from socket concurrently, as it can lead to data loss
warn_on_data_discard: If ``True`` and ``listen_for_disconnect=True``, warn if during listening for client
disconnects, data is received from the socket
"""
if send_handler is None:
send_handler = functools.partial(type(socket).send_data, mode=mode)

async def send_stream() -> None:
async for event in stream:
await send_handler(socket, event)

if listen_for_disconnect:
# wrap 'send_stream' and disconnect listener, so they'll cancel the other once
# one of the finishes
async def wrapped_stream() -> None:
await socket.send_stream(stream, mode=mode)
await send_stream()
# stream exhausted, we can stop listening for a disconnect
tg.cancel_scope.cancel()

Expand Down Expand Up @@ -96,7 +111,7 @@ async def disconnect_listener() -> None:
tg.start_soon(disconnect_listener)

else:
await socket.send_stream(stream=stream, mode=mode)
await send_stream()

if close and socket.connection_state != "disconnect":
await socket.close()
Expand All @@ -114,12 +129,13 @@ def websocket_stream(
signature_namespace: Mapping[str, Any] | None = None,
websocket_class: type[WebSocket] | None = None,
mode: WebSocketMode = "text",
return_dto: type[AbstractDTO] | None | EmptyType = Empty,
type_encoders: TypeEncodersMap | None = None,
listen_for_disconnect: bool = True,
warn_on_data_discard: bool = True,
**kwargs: Any,
) -> Callable[[Callable[..., AsyncGenerator[str | bytes, Any]]], WebsocketRouteHandler]:
"""
Create a WebSocket handler that accepts a connection and sends data to it from an
) -> Callable[[Callable[..., AsyncGenerator[Any, Any]]], WebsocketRouteHandler]:
"""Create a WebSocket handler that accepts a connection and sends data to it from an
async generator.
Example:
Expand Down Expand Up @@ -147,7 +163,10 @@ async def send_time() -> AsyncGenerator[str, None]:
websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's
default websocket class.
mode: WebSocket mode used for sending
listen_for_disconnect: If ``True``, listen for client disconnects in the background. If a client disconnects,
return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing outbound response data.
type_encoders: A mapping of types to callables that transform them into types supported for serialization.
listen_for_disconnect: If ``True``, listen for client
disconnects in the background. If a client disconnects,
stop the generator and cancel sending data. Should always be ``True`` unless disconnects are handled
elsewhere, for example by reading data from the socket concurrently. Should never be set to ``True`` when
reading data from socket concurrently, as it can lead to data loss
Expand All @@ -156,53 +175,132 @@ async def send_time() -> AsyncGenerator[str, None]:
**kwargs: Any additional kwarg - will be set in the opt dictionary.
"""

def decorator(fn: Callable[..., AsyncGenerator[str | bytes, Any]]) -> WebsocketRouteHandler:
signature = inspect.signature(fn)
generator_receives_socket = "socket" in signature.parameters
def decorator(fn: Callable[..., AsyncGenerator[Any, Any]]) -> WebsocketRouteHandler:
return WebSocketStreamHandler(
path=path,
dependencies=dependencies,
exception_handlers=exception_handlers,
guard=guards,
middleware=middleware,
name=name,
opt=opt,
signature_namespace=signature_namespace,
websocket_class=websocket_class,
return_dto=return_dto,
type_encoders=type_encoders,
**kwargs,
)(
_WebSocketStreamOptions(
generator_fn=fn,
send_mode=mode,
listen_for_disconnect=listen_for_disconnect,
warn_on_data_discard=warn_on_data_discard,
)
)

return decorator


class WebSocketStreamHandler(WebsocketRouteHandler):
__slots__ = ("_ws_stream_options",)
_ws_stream_options: _WebSocketStreamOptions

def __call__(self, fn: _WebSocketStreamOptions) -> Self: # type: ignore[override]
self._ws_stream_options = fn
self._fn = self._ws_stream_options.generator_fn # type: ignore[assignment]
return self

def on_registration(self, app: Litestar) -> None:
parsed_handler_signature = parsed_stream_fn_signature = ParsedSignature.from_fn(
self.fn, self.resolve_signature_namespace()
)

if not parsed_stream_fn_signature.return_type.is_subclass_of(AsyncGenerator):
raise ImproperlyConfiguredException(
f"Route handler {self}: 'websocket_stream' handlers must return an "
f"'AsyncGenerator', not {type(parsed_stream_fn_signature.return_type.raw)!r}"
)

# important not to use 'self._ws_stream_options.generator_fn' here; This would
# break in cases the decorator has been used inside a controller, as it would
# be a reference to the unbound method. The bound method is patched in later
# after the controller has been initialized. This is a workaround that should
# go away with v3.0's static handlers
stream_fn = cast(Callable[..., AsyncGenerator[Any, Any]], self.fn)

# construct a fake signature for the kwargs modelling, using the generator
# function passed to the handler as a base, to include all the dependencies,
# params, injection kwargs, etc. + 'socket', so DI works properly, but the
# signature looks to kwargs/signature modelling like a plain '@websocket'
# handler that returns 'None'
parsed_handler_signature = dataclasses.replace(
parsed_handler_signature, return_type=FieldDefinition.from_annotation(NoneType)
)
receives_socket_parameter = "socket" in parsed_stream_fn_signature.parameters

if not receives_socket_parameter:
parsed_handler_signature = dataclasses.replace(
parsed_handler_signature,
parameters={
**parsed_handler_signature.parameters,
"socket": FieldDefinition.from_annotation("WebSocket", name="socket"),
},
)

self._parsed_fn_signature = parsed_handler_signature
self._parsed_return_field = parsed_stream_fn_signature.return_type.inner_types[0]

@functools.wraps(fn)
async def handler(*args: Any, socket: WebSocket, **kw: Any) -> None:
if generator_receives_socket:
json_encoder = JsonEncoder(enc_hook=self.default_serializer)
return_dto = self.resolve_return_dto()

# make sure the closure doesn't capture self._ws_stream / self
send_mode = self._ws_stream_options.send_mode
listen_for_disconnect = self._ws_stream_options.listen_for_disconnect
warn_on_data_discard = self._ws_stream_options.warn_on_data_discard

async def send_handler(socket: WebSocket, data: Any) -> None:
if isinstance(data, (str, bytes)):
await socket.send_data(data=data, mode=send_mode)
return

if return_dto:
encoded_data = return_dto(socket).data_to_encodable_type(data)
data = json_encoder.encode(encoded_data)
await socket.send_data(data=data, mode=send_mode)
return

data = json_encoder.encode(data)
await socket.send_data(data=data, mode=send_mode)

@functools.wraps(stream_fn)
async def handler_fn(*args: Any, socket: WebSocket, **kw: Any) -> None:
if receives_socket_parameter:
kw["socket"] = socket

await send_websocket_stream(
socket=socket,
stream=fn(*args, **kw),
mode=mode,
stream=stream_fn(*args, **kw),
mode=send_mode,
close=True,
listen_for_disconnect=listen_for_disconnect,
warn_on_data_discard=warn_on_data_discard,
send_handler=send_handler,
)

handler.__annotations__ = fn.__annotations__
handler.__annotations__["return"] = None

if not generator_receives_socket:
handler.__annotations__["socket"] = "WebSocket"
self._fn = handler_fn

new_signature = signature
if not generator_receives_socket:
new_signature = new_signature.replace(
parameters=[
*signature.parameters.values(),
inspect.Parameter(name="socket", annotation="WebSocket", kind=inspect.Parameter.KEYWORD_ONLY),
],
return_annotation=None,
)

handler.__signature__ = new_signature # type: ignore[attr-defined]
super().on_registration(app)

return WebsocketRouteHandler(
path=path,
dependencies=dependencies,
exception_handlers=exception_handlers,
guard=guards,
middleware=middleware,
name=name,
oprt=opt,
signature_namespace=signature_namespace,
websocket_class=websocket_class,
**kwargs,
)(handler)

return decorator
class _WebSocketStreamOptions:
def __init__(
self,
generator_fn: Callable[..., AsyncGenerator[Any, Any]],
listen_for_disconnect: bool,
warn_on_data_discard: bool,
send_mode: WebSocketMode,
) -> None:
self.generator_fn = generator_fn
self.listen_for_disconnect = listen_for_disconnect
self.warn_on_data_discard = warn_on_data_discard
self.send_mode = send_mode
54 changes: 52 additions & 2 deletions tests/unit/test_handlers/test_websocket_handlers/test_stream.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from __future__ import annotations

import asyncio
from typing import AsyncGenerator
import dataclasses
from typing import AsyncGenerator, Generator
from unittest.mock import MagicMock

from litestar import WebSocket, Controller
import pytest

from litestar import Controller, Litestar, WebSocket
from litestar.dto import DataclassDTO, dto_field
from litestar.exceptions import ImproperlyConfiguredException
from litestar.handlers.websocket_handlers import websocket_stream
from litestar.testing import create_test_client

Expand Down Expand Up @@ -100,3 +107,46 @@ async def handler() -> AsyncGenerator[str, None]:
# ensure we still disconnect even after receiving some data
ws.send_text("")
assert ws.receive_text(timeout=0.1) == "foo"


def test_websocket_stream_send_json() -> None:
@websocket_stream("/")
async def handler() -> AsyncGenerator[dict[str, str], None]:
yield {"hello": "there"}
yield {"and": "goodbye"}

with create_test_client([handler]) as client, client.websocket_connect("/") as ws:
assert ws.receive_json(timeout=0.1) == {"hello": "there"}
assert ws.receive_json(timeout=0.1) == {"and": "goodbye"}


def test_websocket_stream_send_json_with_dto() -> None:
@dataclasses.dataclass
class Event:
id: int = dataclasses.field(metadata=dto_field("private"))
content: str

@websocket_stream("/", return_dto=DataclassDTO[Event])
async def handler() -> AsyncGenerator[Event, None]:
yield Event(id=1, content="hello")

with create_test_client([handler], signature_types=[Event]) as client, client.websocket_connect("/") as ws:
assert ws.receive_json(timeout=0.1) == {"content": "hello"}


def test_raises_if_stream_fn_does_not_return_async_generator() -> None:
with pytest.raises(ImproperlyConfiguredException):

@websocket_stream("/") # type: ignore[arg-type]
def foo() -> Generator[bytes, None, None]:
yield b""

Litestar([foo])

with pytest.raises(ImproperlyConfiguredException):

@websocket_stream("/") # type: ignore[arg-type]
def foo() -> bytes:
return b""

Litestar([foo])

0 comments on commit 0b8d878

Please sign in to comment.