diff --git a/litestar/response/sse.py b/litestar/response/sse.py index 253a2fde34..7770929f9b 100644 --- a/litestar/response/sse.py +++ b/litestar/response/sse.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from litestar.background_tasks import BackgroundTask, BackgroundTasks - from litestar.types import ResponseCookies, ResponseHeaders, StreamType + from litestar.types import ResponseCookies, ResponseHeaders, SSEData, StreamType _LINE_BREAK_RE = re.compile(r"\r\n|\r|\n") DEFAULT_SEPARATOR = "\r\n" @@ -21,11 +21,11 @@ class _ServerSentEventIterator(AsyncIteratorWrapper[bytes]): __slots__ = ("content_async_iterator", "event_id", "event_type", "retry_duration", "comment_message") - content_async_iterator: AsyncIteratorWrapper[bytes | str] | AsyncIterable[str | bytes] | AsyncIterator[str | bytes] + content_async_iterator: AsyncIterable[SSEData] def __init__( self, - content: str | bytes | StreamType[str | bytes], + content: str | bytes | StreamType[SSEData], event_type: str | None = None, event_id: int | str | None = None, retry_duration: int | None = None, @@ -56,7 +56,7 @@ def __init__( if isinstance(content, (str, bytes)): self.content_async_iterator = AsyncIteratorWrapper([content]) elif isinstance(content, (Iterable, Iterator)): - self.content_async_iterator = AsyncIteratorWrapper(content) # type: ignore[arg-type] + self.content_async_iterator = AsyncIteratorWrapper(content) elif isinstance(content, (AsyncIterable, AsyncIterator, AsyncIteratorWrapper)): self.content_async_iterator = content else: @@ -131,7 +131,7 @@ def encode(self) -> bytes: class ServerSentEvent(Stream): def __init__( self, - content: str | bytes | StreamType[str | bytes], + content: str | bytes | StreamType[SSEData], *, background: BackgroundTask | BackgroundTasks | None = None, cookies: ResponseCookies | None = None, diff --git a/litestar/types/__init__.py b/litestar/types/__init__.py index 35eaf014dc..6eea3f0ff1 100644 --- a/litestar/types/__init__.py +++ b/litestar/types/__init__.py @@ -73,7 +73,7 @@ ) from .empty import Empty, EmptyType from .file_types import FileInfo, FileSystemProtocol -from .helper_types import AnyIOBackend, MaybePartial, OptionalSequence, StreamType, SyncOrAsyncUnion +from .helper_types import AnyIOBackend, MaybePartial, OptionalSequence, SSEData, StreamType, SyncOrAsyncUnion from .internal_types import ( ControllerRouterHandler, ReservedKwargs, @@ -155,6 +155,7 @@ "Send", "Serializer", "StreamType", + "SSEData", "SyncOrAsyncUnion", "TypeDecodersSequence", "TypeEncodersMap", diff --git a/litestar/types/helper_types.py b/litestar/types/helper_types.py index c211c33896..588ae5409f 100644 --- a/litestar/types/helper_types.py +++ b/litestar/types/helper_types.py @@ -3,9 +3,11 @@ from functools import partial from typing import ( TYPE_CHECKING, + Any, AsyncIterable, AsyncIterator, Awaitable, + Dict, Iterable, Iterator, Literal, @@ -18,9 +20,12 @@ if TYPE_CHECKING: from typing_extensions import TypeAlias + from litestar.response.sse import ServerSentEventMessage + + T = TypeVar("T") -__all__ = ("OptionalSequence", "SyncOrAsyncUnion", "AnyIOBackend", "StreamType", "MaybePartial") +__all__ = ("OptionalSequence", "SyncOrAsyncUnion", "AnyIOBackend", "StreamType", "MaybePartial", "SSEData") OptionalSequence: TypeAlias = Optional[Sequence[T]] """Types 'T' as union of Sequence[T] and None.""" @@ -37,3 +42,6 @@ MaybePartial: TypeAlias = Union[T, partial] """A potentially partial callable.""" + +SSEData: TypeAlias = Union[int, str, bytes, Dict[str, Any], "ServerSentEventMessage"] +"""A type alias for SSE data.""" diff --git a/tests/unit/test_response/test_sse.py b/tests/unit/test_response/test_sse.py index eade2426d5..bff7af75c2 100644 --- a/tests/unit/test_response/test_sse.py +++ b/tests/unit/test_response/test_sse.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncIterator, Iterator, List +from typing import AsyncIterator, Iterator, List import anyio import pytest @@ -10,6 +10,7 @@ from litestar.response import ServerSentEvent from litestar.response.sse import ServerSentEventMessage from litestar.testing import create_async_test_client +from litestar.types import SSEData async def test_sse_steaming_response() -> None: @@ -61,7 +62,7 @@ def numbers(minimum: int, maximum: int) -> Iterator[str]: async def test_various_sse_inputs(input: str, expected_events: List[HTTPXServerSentEvent]) -> None: @get("/testme") async def handler() -> ServerSentEvent: - async def numbers() -> AsyncIterator[Any]: + async def numbers() -> AsyncIterator[SSEData]: for i in range(1, 6): await anyio.sleep(0.001) if input == "integer":