diff --git a/docs/assets/starlette.webm b/docs/assets/starlette.webm new file mode 100644 index 0000000..0587c96 Binary files /dev/null and b/docs/assets/starlette.webm differ diff --git a/docs/streaming.md b/docs/streaming.md index 403df6b..768c4e3 100644 --- a/docs/streaming.md +++ b/docs/streaming.md @@ -1,7 +1,7 @@ # Streaming of Contents Internally, htpy is built with generators. Most of the time, you would render -the full page with `str()`, but htpy can also incrementally generate pages which +the full page with `str()`, but htpy can also incrementally generate pages synchronously or asynchronous which can then be streamed to the browser. If your page uses a database or other services to retrieve data, you can sending the first part of the page to the client while the page is being generated. @@ -16,7 +16,6 @@ client while the page is being generated. This video shows what it looks like in the browser to generate a HTML table with [Django StreamingHttpResponse](https://docs.djangoproject.com/en/5.0/ref/request-response/#django.http.StreamingHttpResponse) ([source code](https://github.com/pelme/htpy/blob/main/examples/djangoproject/stream/views.py)): <video width="500" controls loop > - <source src="/assets/stream.webm" type="video/webm"> </video> @@ -111,3 +110,59 @@ print( # output: <div><h1>Fibonacci!</h1>fib(12)=6765</div> ``` + + +## Asynchronous streaming + +It is also possible to use htpy to stream fully asynchronous. This intended to be used +with ASGI/async web frameworks/servers such as Starlette and Django. You can +build htpy components using Python's `asyncio` module and the `async`/`await` +syntax. + +### Starlette, ASGI and uvicorn example + +```python +title="starlette_demo.py" +import asyncio +from collections.abc import AsyncIterator + +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import StreamingResponse + +from htpy import Element, div, h1, li, p, ul + +app = Starlette(debug=True) + + +@app.route("/") +async def index(request: Request) -> StreamingResponse: + return StreamingResponse(await index_page(), media_type="text/html") + + +async def index_page() -> Element: + return div[ + h1["Starlette Async example"], + p["This page is generated asynchronously using Starlette and ASGI."], + ul[(li[str(num)] async for num in slow_numbers(1, 10))], + ] + + +async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]: + for number in range(minimum, maximum + 1): + yield number + await asyncio.sleep(0.5) + +``` + +Run with [uvicorn](https://www.uvicorn.org/): + + +``` +$ uvicorn starlette_demo:app +``` + +In the browser, it looks like this: +<video width="500" controls loop > + <source src="/assets/starlette.webm" type="video/webm"> +</video> diff --git a/examples/async_coroutine.py b/examples/async_coroutine.py new file mode 100644 index 0000000..cd0d9a2 --- /dev/null +++ b/examples/async_coroutine.py @@ -0,0 +1,24 @@ +import asyncio +import random + +from htpy import Element, b, div, h1 + + +async def magic_number() -> Element: + await asyncio.sleep(2) + return b[f"The Magic Number is: {random.randint(1, 100)}"] + + +async def my_component() -> Element: + return div[ + h1["The Magic Number"], + magic_number(), + ] + + +async def main() -> None: + async for chunk in await my_component(): + print(chunk) + + +asyncio.run(main()) diff --git a/examples/starlette_app.py b/examples/starlette_app.py new file mode 100644 index 0000000..f06b57c --- /dev/null +++ b/examples/starlette_app.py @@ -0,0 +1,35 @@ +import asyncio +from collections.abc import AsyncIterator + +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import StreamingResponse +from starlette.routing import Route + +from htpy import Element, div, h1, li, p, ul + + +async def index(request: Request) -> StreamingResponse: + return StreamingResponse(await index_page(), media_type="text/html") + + +async def index_page() -> Element: + return div[ + h1["Starlette Async example"], + p["This page is generated asynchronously using Starlette and ASGI."], + ul[(li[str(num)] async for num in slow_numbers(1, 10))], + ] + + +async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]: + for number in range(minimum, maximum + 1): + yield number + await asyncio.sleep(0.5) + + +app = Starlette( + debug=True, + routes=[ + Route("/", index), + ], +) diff --git a/htpy/__init__.py b/htpy/__init__.py index 997a846..02c1308 100644 --- a/htpy/__init__.py +++ b/htpy/__init__.py @@ -6,7 +6,15 @@ import dataclasses import functools import typing as t -from collections.abc import Callable, Generator, Iterable, Iterator +from collections.abc import ( + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + Generator, + Iterable, + Iterator, +) from markupsafe import Markup as _Markup from markupsafe import escape as _escape @@ -126,6 +134,9 @@ class ContextProvider(t.Generic[T]): def __iter__(self) -> Iterator[str]: return iter_node(self) + def __aiter__(self) -> AsyncIterator[str]: + return aiter_node(self) + def __str__(self) -> str: return render_node(self) @@ -137,6 +148,10 @@ class ContextConsumer(t.Generic[T]): func: Callable[[T], Node] +def _is_noop_node(x: Node) -> bool: + return x is None or x is True or x is False + + class _NO_DEFAULT: pass @@ -168,15 +183,8 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It while not isinstance(x, BaseElement) and callable(x): x = x() - if x is None: - return - - if x is True: - return - - if x is False: + if _is_noop_node(x): return - if isinstance(x, BaseElement): yield from x._iter_context(context_dict) # pyright: ignore [reportPrivateUsage] elif isinstance(x, ContextProvider): @@ -196,6 +204,68 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It elif isinstance(x, Iterable) and not isinstance(x, _KnownInvalidChildren): # pyright: ignore [reportUnnecessaryIsInstance] for child in x: yield from _iter_node_context(child, context_dict) + elif isinstance(x, Awaitable | AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance] + raise ValueError( + f"{x!r} is not a valid child element. " + "Use async iteration to retrieve element content: https://htpy.dev/streaming/" + ) + else: + raise TypeError(f"{x!r} is not a valid child element") + + +def aiter_node(x: Node) -> AsyncIterator[str]: + return _aiter_node_context(x, {}) + + +async def _aiter_node_context( + x: Node, context_dict: dict[Context[t.Any], t.Any] +) -> AsyncIterator[str]: + while True: + if isinstance(x, Awaitable): + x = await x + continue + + if not isinstance(x, BaseElement) and callable(x): + x = x() + continue + + break + + if _is_noop_node(x): + return + + if isinstance(x, BaseElement): + async for child in x._aiter_context(context_dict): # pyright: ignore [reportPrivateUsage] + yield child + elif isinstance(x, ContextProvider): + async for chunk in _aiter_node_context( + x.func(), + {**context_dict, x.context: x.value}, # pyright: ignore [reportUnknownMemberType] + ): + yield chunk + + elif isinstance(x, ContextConsumer): + context_value = context_dict.get(x.context, x.context.default) + if context_value is _NO_DEFAULT: + raise LookupError( + f'Context value for "{x.context.name}" does not exist, ' + f"requested by {x.debug_name}()." + ) + async for chunk in _aiter_node_context(x.func(context_value), context_dict): + yield chunk + + elif isinstance(x, str | _HasHtml): + yield str(_escape(x)) + elif isinstance(x, int): + yield str(x) + elif isinstance(x, Iterable) and not isinstance(x, _KnownInvalidChildren): # pyright: ignore [reportUnnecessaryIsInstance] + for child in x: # type: ignore[assignment] + async for chunk in _aiter_node_context(child, context_dict): + yield chunk + elif isinstance(x, AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance] + async for child in x: # type: ignore[assignment] + async for chunk in _aiter_node_context(child, context_dict): # pyright: ignore[reportUnknownArgumentType] + yield chunk else: raise TypeError(f"{x!r} is not a valid child element") @@ -267,6 +337,15 @@ def __call__(self: BaseElementSelf, *args: t.Any, **kwargs: t.Any) -> BaseElemen self._children, ) + async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]: + yield f"<{self._name}{self._attrs}>" + async for x in _aiter_node_context(self._children, context): + yield x + yield f"</{self._name}>" + + def __aiter__(self) -> AsyncIterator[str]: + return self._aiter_context({}) + def __iter__(self) -> Iterator[str]: return self._iter_context({}) @@ -275,9 +354,6 @@ def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]: yield from _iter_node_context(self._children, ctx) yield f"</{self._name}>" - def __repr__(self) -> str: - return f"<{self.__class__.__name__} '{self}'>" - # Allow starlette Response.render to directly render this element without # explicitly casting to str: # https://github.com/encode/starlette/blob/5ed55c441126687106109a3f5e051176f88cd3e6/starlette/responses.py#L44-L49 @@ -308,17 +384,31 @@ def __getitem__(self: ElementSelf, children: Node) -> ElementSelf: _validate_children(children) return self.__class__(self._name, self._attrs, children) # pyright: ignore [reportUnknownArgumentType] + def __repr__(self) -> str: + return f"<{self.__class__.__name__} '<{self._name}{self._attrs}>...</{self._name}>'>" + class HTMLElement(Element): def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]: yield "<!doctype html>" yield from super()._iter_context(ctx) + async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]: + yield "<!doctype html>" + async for x in super()._aiter_context(context): + yield x + class VoidElement(BaseElement): + async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]: + yield f"<{self._name}{self._attrs}>" + def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]: yield f"<{self._name}{self._attrs}>" + def __repr__(self) -> str: + return f"<{self.__class__.__name__} '<{self._name}{self._attrs}>'>" + def render_node(node: Node) -> _Markup: return _Markup("".join(iter_node(node))) @@ -347,6 +437,8 @@ def __html__(self) -> str: ... | Callable[[], "Node"] | ContextProvider[t.Any] | ContextConsumer[t.Any] + | AsyncIterable["Node"] + | Awaitable["Node"] ) Attribute: t.TypeAlias = None | bool | str | int | _HasHtml | _ClassNames @@ -480,6 +572,8 @@ def __html__(self) -> str: ... _KnownValidChildren: UnionType = ( # pyright: ignore [reportUnknownVariableType] None | BaseElement + | AsyncIterable # pyright: ignore [reportMissingTypeArgument] + | Awaitable # pyright: ignore [reportMissingTypeArgument] | ContextProvider # pyright: ignore [reportMissingTypeArgument] | ContextConsumer # pyright: ignore [reportMissingTypeArgument] | Callable # pyright: ignore [reportMissingTypeArgument] diff --git a/htpy/starlette.py b/htpy/starlette.py new file mode 100644 index 0000000..e448bd9 --- /dev/null +++ b/htpy/starlette.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import typing as t + +from starlette.responses import StreamingResponse + +from . import aiter_node + +if t.TYPE_CHECKING: + from starlette.background import BackgroundTask + + from . import Node + + +class HtpyResponse(StreamingResponse): + def __init__( + self, + content: Node, + status_code: int = 200, + headers: t.Mapping[str, str] | None = None, + media_type: str | None = "text/html", + background: BackgroundTask | None = None, + ): + super().__init__( + aiter_node(content), + status_code=status_code, + headers=headers, + media_type=media_type, + background=background, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 0451ea7..94d5274 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,11 @@ from __future__ import annotations +import asyncio import typing as t import pytest -from htpy import Node, iter_node +from htpy import Node, aiter_node, iter_node if t.TYPE_CHECKING: from collections.abc import Generator @@ -28,12 +29,31 @@ def django_env() -> None: @pytest.fixture -def render() -> Generator[RenderFixture, None, None]: +def render_async() -> RenderFixture: + def func(node: Node) -> list[str]: + async def get_list() -> list[str]: + return [chunk async for chunk in aiter_node(node)] + + return asyncio.run(get_list(), debug=True) + + return func + + +@pytest.fixture(params=["sync", "async"]) +def render( + request: pytest.FixtureRequest, render_async: RenderFixture +) -> Generator[RenderFixture, None, None]: called = False def func(node: Node) -> list[str]: nonlocal called called = True + + if request.param == "sync": + return list(iter_node(node)) + else: + return render_async(node) + return list(iter_node(node)) yield func diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 0000000..0c8951b --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import typing as t + +import pytest + +from htpy import Element, li, ul + +if t.TYPE_CHECKING: + from collections.abc import AsyncIterator + + from .conftest import RenderFixture + + +async def async_lis() -> AsyncIterator[Element]: + yield li["a"] + yield li["b"] + + +async def hi() -> Element: + return li["hi"] + + +def test_async_iterator(render_async: RenderFixture) -> None: + result = render_async(ul[async_lis()]) + assert result == ["<ul>", "<li>", "a", "</li>", "<li>", "b", "</li>", "</ul>"] + + +def test_async_function_children(render_async: RenderFixture) -> None: + result = render_async(ul[hi]) + assert result == ["<ul>", "<li>", "hi", "</li>", "</ul>"] + + +def test_awaitable_children(render_async: RenderFixture) -> None: + result = render_async(ul[hi()]) + assert result == ["<ul>", "<li>", "hi", "</li>", "</ul>"] + + +def test_sync_iteration_with_async_children() -> None: + with pytest.raises( + ValueError, + match=( + r"<async_generator object async_lis at .+> is not a valid child element\. " + r"Use async iteration to retrieve element content: https://htpy.dev/streaming/" + ), + ): + list(ul[async_lis()]) diff --git a/tests/test_context.py b/tests/test_context.py index 2144fc4..d76d40c 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,7 +1,9 @@ from __future__ import annotations +import asyncio import typing as t +import markupsafe import pytest from htpy import Context, Node, div @@ -33,6 +35,28 @@ def test_context_provider(render: RenderFixture) -> None: assert render(result) == ["<div>", "Hello: c!", "</div>"] +class Test_provider_outer_api: + """Ensure provider implements __iter__/__aiter__/__str__""" + + def test_iter(self) -> None: + result = letter_ctx.provider("c", lambda: div[display_letter("Hello")]) + assert list(result) == ["<div>", "Hello: c!", "</div>"] + + def test_aiter(self) -> None: + provider = letter_ctx.provider("c", lambda: div[display_letter("Hello")]) + + async def run() -> list[str]: + return [chunk async for chunk in provider] + + result = asyncio.run(run(), debug=True) + assert result == ["<div>", "Hello: c!", "</div>"] + + def test_str(self) -> None: + result = str(letter_ctx.provider("c", lambda: div[display_letter("Hello")])) + assert result == "<div>Hello: c!</div>" + assert isinstance(result, markupsafe.Markup) + + def test_no_default(render: RenderFixture) -> None: with pytest.raises( LookupError, diff --git a/tests/test_element.py b/tests/test_element.py index eac1c85..2579316 100644 --- a/tests/test_element.py +++ b/tests/test_element.py @@ -19,7 +19,7 @@ def test_invalid_element_name() -> None: def test_element_repr() -> None: - assert repr(htpy.div("#a")) == """<Element '<div id="a"></div>'>""" + assert repr(htpy.div("#a")) == """<Element '<div id="a">...</div>'>""" def test_void_element_repr() -> None: diff --git a/tests/test_starlette.py b/tests/test_starlette.py index 6ca9a3b..ec39229 100644 --- a/tests/test_starlette.py +++ b/tests/test_starlette.py @@ -7,7 +7,8 @@ from starlette.routing import Route from starlette.testclient import TestClient -from htpy import h1 +from htpy import Element, h1, p +from htpy.starlette import HtpyResponse if t.TYPE_CHECKING: from starlette.requests import Request @@ -17,11 +18,25 @@ async def html_response(request: Request) -> HTMLResponse: return HTMLResponse(h1["Hello, HTMLResponse!"]) +async def stuff() -> Element: + return p["stuff"] + + +async def htpy_response(request: Request) -> HtpyResponse: + return HtpyResponse( + ( + h1["Hello, HtpyResponse!"], + stuff(), + ) + ) + + client = TestClient( Starlette( debug=True, routes=[ Route("/html-response", html_response), + Route("/htpy-response", htpy_response), ], ) ) @@ -29,4 +44,10 @@ async def html_response(request: Request) -> HTMLResponse: def test_html_response() -> None: response = client.get("/html-response") - assert response.content == b"<h1>Hello, HTMLResponse!</h1>" + assert response.text == "<h1>Hello, HTMLResponse!</h1>" + + +def test_htpy_response() -> None: + response = client.get("/htpy-response") + assert response.headers["content-type"] == "text/html; charset=utf-8" + assert response.text == "<h1>Hello, HtpyResponse!</h1><p>stuff</p>"