diff --git a/docs/assets/starlette.webm b/docs/assets/starlette.webm new file mode 100644 index 0000000..0587c96 Binary files /dev/null and b/docs/assets/starlette.webm differ diff --git a/docs/streaming.md b/docs/streaming.md index 403df6b..bcc3389 100644 --- a/docs/streaming.md +++ b/docs/streaming.md @@ -1,7 +1,7 @@ # Streaming of Contents Internally, htpy is built with generators. Most of the time, you would render -the full page with `str()`, but htpy can also incrementally generate pages which +the full page with `str()`, but htpy can also incrementally generate pages synchronously or asynchronous which can then be streamed to the browser. If your page uses a database or other services to retrieve data, you can sending the first part of the page to the client while the page is being generated. @@ -111,3 +111,59 @@ print( # output:

Fibonacci!

fib(12)=6765
``` + + +## Asynchronous streaming + +It is also possible to use htpy to stream fully asynchronous. This intended to be used +with ASGI/async web frameworks/servers such as Starlette and Django. You can +build htpy components using Python's `asyncio` module and the `async`/`await` +syntax. + +### Starlette, ASGI and uvicorn example + +```python +title="starlette_demo.py" +import asyncio +from collections.abc import AsyncIterator + +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import StreamingResponse + +from htpy import Element, div, h1, li, p, ul + +app = Starlette(debug=True) + + +@app.route("/") +async def index(request: Request) -> StreamingResponse: + return StreamingResponse(await index_page(), media_type="text/html") + + +async def index_page() -> Element: + return div[ + h1["Starlette Async example"], + p["This page is generated asynchronously using Starlette and ASGI."], + ul[(li[str(num)] async for num in slow_numbers(1, 10))], + ] + + +async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]: + for number in range(minimum, maximum + 1): + yield number + await asyncio.sleep(0.5) + +``` + +Run with [uvicorn](https://www.uvicorn.org/): + + +``` +$ uvicorn starlette_demo:app +``` + +In the browser, it looks like this: + diff --git a/examples/async_example.py b/examples/async_example.py new file mode 100644 index 0000000..cd0d9a2 --- /dev/null +++ b/examples/async_example.py @@ -0,0 +1,24 @@ +import asyncio +import random + +from htpy import Element, b, div, h1 + + +async def magic_number() -> Element: + await asyncio.sleep(2) + return b[f"The Magic Number is: {random.randint(1, 100)}"] + + +async def my_component() -> Element: + return div[ + h1["The Magic Number"], + magic_number(), + ] + + +async def main() -> None: + async for chunk in await my_component(): + print(chunk) + + +asyncio.run(main()) diff --git a/examples/starlette_app.py b/examples/starlette_app.py new file mode 100644 index 0000000..f06b57c --- /dev/null +++ b/examples/starlette_app.py @@ -0,0 +1,35 @@ +import asyncio +from collections.abc import AsyncIterator + +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import StreamingResponse +from starlette.routing import Route + +from htpy import Element, div, h1, li, p, ul + + +async def index(request: Request) -> StreamingResponse: + return StreamingResponse(await index_page(), media_type="text/html") + + +async def index_page() -> Element: + return div[ + h1["Starlette Async example"], + p["This page is generated asynchronously using Starlette and ASGI."], + ul[(li[str(num)] async for num in slow_numbers(1, 10))], + ] + + +async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]: + for number in range(minimum, maximum + 1): + yield number + await asyncio.sleep(0.5) + + +app = Starlette( + debug=True, + routes=[ + Route("/", index), + ], +) diff --git a/htpy/__init__.py b/htpy/__init__.py index 1a343e1..02c1308 100644 --- a/htpy/__init__.py +++ b/htpy/__init__.py @@ -6,7 +6,15 @@ import dataclasses import functools import typing as t -from collections.abc import Callable, Generator, Iterable, Iterator +from collections.abc import ( + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + Generator, + Iterable, + Iterator, +) from markupsafe import Markup as _Markup from markupsafe import escape as _escape @@ -126,6 +134,9 @@ class ContextProvider(t.Generic[T]): def __iter__(self) -> Iterator[str]: return iter_node(self) + def __aiter__(self) -> AsyncIterator[str]: + return aiter_node(self) + def __str__(self) -> str: return render_node(self) @@ -137,6 +148,10 @@ class ContextConsumer(t.Generic[T]): func: Callable[[T], Node] +def _is_noop_node(x: Node) -> bool: + return x is None or x is True or x is False + + class _NO_DEFAULT: pass @@ -168,15 +183,8 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It while not isinstance(x, BaseElement) and callable(x): x = x() - if x is None: - return - - if x is True: + if _is_noop_node(x): return - - if x is False: - return - if isinstance(x, BaseElement): yield from x._iter_context(context_dict) # pyright: ignore [reportPrivateUsage] elif isinstance(x, ContextProvider): @@ -196,6 +204,68 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It elif isinstance(x, Iterable) and not isinstance(x, _KnownInvalidChildren): # pyright: ignore [reportUnnecessaryIsInstance] for child in x: yield from _iter_node_context(child, context_dict) + elif isinstance(x, Awaitable | AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance] + raise ValueError( + f"{x!r} is not a valid child element. " + "Use async iteration to retrieve element content: https://htpy.dev/streaming/" + ) + else: + raise TypeError(f"{x!r} is not a valid child element") + + +def aiter_node(x: Node) -> AsyncIterator[str]: + return _aiter_node_context(x, {}) + + +async def _aiter_node_context( + x: Node, context_dict: dict[Context[t.Any], t.Any] +) -> AsyncIterator[str]: + while True: + if isinstance(x, Awaitable): + x = await x + continue + + if not isinstance(x, BaseElement) and callable(x): + x = x() + continue + + break + + if _is_noop_node(x): + return + + if isinstance(x, BaseElement): + async for child in x._aiter_context(context_dict): # pyright: ignore [reportPrivateUsage] + yield child + elif isinstance(x, ContextProvider): + async for chunk in _aiter_node_context( + x.func(), + {**context_dict, x.context: x.value}, # pyright: ignore [reportUnknownMemberType] + ): + yield chunk + + elif isinstance(x, ContextConsumer): + context_value = context_dict.get(x.context, x.context.default) + if context_value is _NO_DEFAULT: + raise LookupError( + f'Context value for "{x.context.name}" does not exist, ' + f"requested by {x.debug_name}()." + ) + async for chunk in _aiter_node_context(x.func(context_value), context_dict): + yield chunk + + elif isinstance(x, str | _HasHtml): + yield str(_escape(x)) + elif isinstance(x, int): + yield str(x) + elif isinstance(x, Iterable) and not isinstance(x, _KnownInvalidChildren): # pyright: ignore [reportUnnecessaryIsInstance] + for child in x: # type: ignore[assignment] + async for chunk in _aiter_node_context(child, context_dict): + yield chunk + elif isinstance(x, AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance] + async for child in x: # type: ignore[assignment] + async for chunk in _aiter_node_context(child, context_dict): # pyright: ignore[reportUnknownArgumentType] + yield chunk else: raise TypeError(f"{x!r} is not a valid child element") @@ -267,6 +337,15 @@ def __call__(self: BaseElementSelf, *args: t.Any, **kwargs: t.Any) -> BaseElemen self._children, ) + async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]: + yield f"<{self._name}{self._attrs}>" + async for x in _aiter_node_context(self._children, context): + yield x + yield f"" + + def __aiter__(self) -> AsyncIterator[str]: + return self._aiter_context({}) + def __iter__(self) -> Iterator[str]: return self._iter_context({}) @@ -314,8 +393,16 @@ def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]: yield "" yield from super()._iter_context(ctx) + async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]: + yield "" + async for x in super()._aiter_context(context): + yield x + class VoidElement(BaseElement): + async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]: + yield f"<{self._name}{self._attrs}>" + def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]: yield f"<{self._name}{self._attrs}>" @@ -350,6 +437,8 @@ def __html__(self) -> str: ... | Callable[[], "Node"] | ContextProvider[t.Any] | ContextConsumer[t.Any] + | AsyncIterable["Node"] + | Awaitable["Node"] ) Attribute: t.TypeAlias = None | bool | str | int | _HasHtml | _ClassNames @@ -483,6 +572,8 @@ def __html__(self) -> str: ... _KnownValidChildren: UnionType = ( # pyright: ignore [reportUnknownVariableType] None | BaseElement + | AsyncIterable # pyright: ignore [reportMissingTypeArgument] + | Awaitable # pyright: ignore [reportMissingTypeArgument] | ContextProvider # pyright: ignore [reportMissingTypeArgument] | ContextConsumer # pyright: ignore [reportMissingTypeArgument] | Callable # pyright: ignore [reportMissingTypeArgument] diff --git a/tests/conftest.py b/tests/conftest.py index 48a7eb7..6cbc26b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,12 @@ from __future__ import annotations +import asyncio import dataclasses import typing as t import pytest -from htpy import Node, iter_node +from htpy import Node, aiter_node, iter_node if t.TYPE_CHECKING: from collections.abc import Callable, Generator @@ -49,9 +50,31 @@ def func(description: str) -> None: @pytest.fixture -def render(render_result: RenderResult) -> Generator[RenderFixture, None, None]: +def render_async(render_result: RenderResult) -> RenderFixture: + def func(node: Node) -> RenderResult: + async def run() -> RenderResult: + async for chunk in aiter_node(node): + render_result.append(chunk) + return render_result + + return asyncio.run(run(), debug=True) + + return func + + +@pytest.fixture(params=["sync", "async"]) +def render( + request: pytest.FixtureRequest, + render_async: RenderFixture, + render_result: RenderResult, +) -> Generator[RenderFixture, None, None]: called = False + def render_sync(node: Node) -> RenderResult: + for chunk in iter_node(node): + render_result.append(chunk) + return render_result + def func(node: Node) -> RenderResult: nonlocal called @@ -59,10 +82,11 @@ def func(node: Node) -> RenderResult: raise AssertionError("render() must only be called once per test") called = True - for chunk in iter_node(node): - render_result.append(chunk) - return render_result + if request.param == "sync": + return render_sync(node) + else: + return render_async(node) yield func diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 0000000..a31908c --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import typing as t + +import pytest + +from htpy import Element, li, ul + +from .conftest import Trace + +if t.TYPE_CHECKING: + from collections.abc import AsyncIterator + + from .conftest import RenderFixture, TraceFixture + + +def test_async_iterator(render_async: RenderFixture, trace: TraceFixture) -> None: + async def lis() -> AsyncIterator[Element]: + trace("pre a") + yield li["a"] + trace("pre b") + yield li["b"] + trace("post b") + + result = render_async(ul[lis()]) + assert result == [ + "", + ] + + +def test_awaitable(render_async: RenderFixture, trace: TraceFixture) -> None: + async def hi() -> Element: + trace("in hi()") + return li["hi"] + + result = render_async(ul[hi()]) + assert result == [ + "", + ] + + +@pytest.mark.filterwarnings(r"ignore:coroutine '.*\.coroutine' was never awaited") +def test_sync_iteration_coroutine() -> None: + async def coroutine() -> None: + return None + + with pytest.raises( + ValueError, + match=( + r" is not a valid child element\. " + r"Use async iteration to retrieve element content: https://htpy.dev/streaming/" + ), + ): + list(ul[coroutine()]) + + +def test_sync_iteration_async_generator() -> None: + async def generator() -> AsyncIterator[None]: + return + yield + + with pytest.raises( + ValueError, + match=( + r" is not a valid child element\. " + r"Use async iteration to retrieve element content: https://htpy.dev/streaming/" + ), + ): + list(ul[generator()]) diff --git a/tests/test_context.py b/tests/test_context.py index 56c210b..d76d40c 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import typing as t import markupsafe @@ -35,12 +36,21 @@ def test_context_provider(render: RenderFixture) -> None: class Test_provider_outer_api: - """Ensure provider implements __iter__/__str__""" + """Ensure provider implements __iter__/__aiter__/__str__""" def test_iter(self) -> None: result = letter_ctx.provider("c", lambda: div[display_letter("Hello")]) assert list(result) == ["
", "Hello: c!", "
"] + def test_aiter(self) -> None: + provider = letter_ctx.provider("c", lambda: div[display_letter("Hello")]) + + async def run() -> list[str]: + return [chunk async for chunk in provider] + + result = asyncio.run(run(), debug=True) + assert result == ["
", "Hello: c!", "
"] + def test_str(self) -> None: result = str(letter_ctx.provider("c", lambda: div[display_letter("Hello")])) assert result == "
Hello: c!
"