Skip to content

Commit

Permalink
fix: ServerSentEvent typing error (#3048)
Browse files Browse the repository at this point in the history
  • Loading branch information
euri10 authored Jan 30, 2024
1 parent 747fb90 commit f90a12d
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
10 changes: 5 additions & 5 deletions litestar/response/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

if TYPE_CHECKING:
from litestar.background_tasks import BackgroundTask, BackgroundTasks
from litestar.types import ResponseCookies, ResponseHeaders, StreamType
from litestar.types import ResponseCookies, ResponseHeaders, SSEData, StreamType

_LINE_BREAK_RE = re.compile(r"\r\n|\r|\n")
DEFAULT_SEPARATOR = "\r\n"
Expand All @@ -21,11 +21,11 @@
class _ServerSentEventIterator(AsyncIteratorWrapper[bytes]):
__slots__ = ("content_async_iterator", "event_id", "event_type", "retry_duration", "comment_message")

content_async_iterator: AsyncIteratorWrapper[bytes | str] | AsyncIterable[str | bytes] | AsyncIterator[str | bytes]
content_async_iterator: AsyncIterable[SSEData]

def __init__(
self,
content: str | bytes | StreamType[str | bytes],
content: str | bytes | StreamType[SSEData],
event_type: str | None = None,
event_id: int | str | None = None,
retry_duration: int | None = None,
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
if isinstance(content, (str, bytes)):
self.content_async_iterator = AsyncIteratorWrapper([content])
elif isinstance(content, (Iterable, Iterator)):
self.content_async_iterator = AsyncIteratorWrapper(content) # type: ignore[arg-type]
self.content_async_iterator = AsyncIteratorWrapper(content)
elif isinstance(content, (AsyncIterable, AsyncIterator, AsyncIteratorWrapper)):
self.content_async_iterator = content
else:
Expand Down Expand Up @@ -131,7 +131,7 @@ def encode(self) -> bytes:
class ServerSentEvent(Stream):
def __init__(
self,
content: str | bytes | StreamType[str | bytes],
content: str | bytes | StreamType[SSEData],
*,
background: BackgroundTask | BackgroundTasks | None = None,
cookies: ResponseCookies | None = None,
Expand Down
3 changes: 2 additions & 1 deletion litestar/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
)
from .empty import Empty, EmptyType
from .file_types import FileInfo, FileSystemProtocol
from .helper_types import AnyIOBackend, MaybePartial, OptionalSequence, StreamType, SyncOrAsyncUnion
from .helper_types import AnyIOBackend, MaybePartial, OptionalSequence, SSEData, StreamType, SyncOrAsyncUnion
from .internal_types import (
ControllerRouterHandler,
ReservedKwargs,
Expand Down Expand Up @@ -155,6 +155,7 @@
"Send",
"Serializer",
"StreamType",
"SSEData",
"SyncOrAsyncUnion",
"TypeDecodersSequence",
"TypeEncodersMap",
Expand Down
10 changes: 9 additions & 1 deletion litestar/types/helper_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Dict,
Iterable,
Iterator,
Literal,
Expand All @@ -18,9 +20,12 @@
if TYPE_CHECKING:
from typing_extensions import TypeAlias

from litestar.response.sse import ServerSentEventMessage


T = TypeVar("T")

__all__ = ("OptionalSequence", "SyncOrAsyncUnion", "AnyIOBackend", "StreamType", "MaybePartial")
__all__ = ("OptionalSequence", "SyncOrAsyncUnion", "AnyIOBackend", "StreamType", "MaybePartial", "SSEData")

OptionalSequence: TypeAlias = Optional[Sequence[T]]
"""Types 'T' as union of Sequence[T] and None."""
Expand All @@ -37,3 +42,6 @@

MaybePartial: TypeAlias = Union[T, partial]
"""A potentially partial callable."""

SSEData: TypeAlias = Union[int, str, bytes, Dict[str, Any], "ServerSentEventMessage"]
"""A type alias for SSE data."""
5 changes: 3 additions & 2 deletions tests/unit/test_response/test_sse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, AsyncIterator, Iterator, List
from typing import AsyncIterator, Iterator, List

import anyio
import pytest
Expand All @@ -10,6 +10,7 @@
from litestar.response import ServerSentEvent
from litestar.response.sse import ServerSentEventMessage
from litestar.testing import create_async_test_client
from litestar.types import SSEData


async def test_sse_steaming_response() -> None:
Expand Down Expand Up @@ -61,7 +62,7 @@ def numbers(minimum: int, maximum: int) -> Iterator[str]:
async def test_various_sse_inputs(input: str, expected_events: List[HTTPXServerSentEvent]) -> None:
@get("/testme")
async def handler() -> ServerSentEvent:
async def numbers() -> AsyncIterator[Any]:
async def numbers() -> AsyncIterator[SSEData]:
for i in range(1, 6):
await anyio.sleep(0.001)
if input == "integer":
Expand Down

0 comments on commit f90a12d

Please sign in to comment.