diff --git a/docs/assets/starlette.webm b/docs/assets/starlette.webm
new file mode 100644
index 0000000..0587c96
Binary files /dev/null and b/docs/assets/starlette.webm differ
diff --git a/docs/streaming.md b/docs/streaming.md
index 403df6b..bcc3389 100644
--- a/docs/streaming.md
+++ b/docs/streaming.md
@@ -1,7 +1,7 @@
# Streaming of Contents
Internally, htpy is built with generators. Most of the time, you would render
-the full page with `str()`, but htpy can also incrementally generate pages which
+the full page with `str()`, but htpy can also incrementally generate pages synchronously or asynchronous which
can then be streamed to the browser. If your page uses a database or other
services to retrieve data, you can sending the first part of the page to the
client while the page is being generated.
@@ -111,3 +111,59 @@ print(
# output:
Fibonacci!
fib(12)=6765
```
+
+
+## Asynchronous streaming
+
+It is also possible to use htpy to stream fully asynchronous. This intended to be used
+with ASGI/async web frameworks/servers such as Starlette and Django. You can
+build htpy components using Python's `asyncio` module and the `async`/`await`
+syntax.
+
+### Starlette, ASGI and uvicorn example
+
+```python
+title="starlette_demo.py"
+import asyncio
+from collections.abc import AsyncIterator
+
+from starlette.applications import Starlette
+from starlette.requests import Request
+from starlette.responses import StreamingResponse
+
+from htpy import Element, div, h1, li, p, ul
+
+app = Starlette(debug=True)
+
+
+@app.route("/")
+async def index(request: Request) -> StreamingResponse:
+ return StreamingResponse(await index_page(), media_type="text/html")
+
+
+async def index_page() -> Element:
+ return div[
+ h1["Starlette Async example"],
+ p["This page is generated asynchronously using Starlette and ASGI."],
+ ul[(li[str(num)] async for num in slow_numbers(1, 10))],
+ ]
+
+
+async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]:
+ for number in range(minimum, maximum + 1):
+ yield number
+ await asyncio.sleep(0.5)
+
+```
+
+Run with [uvicorn](https://www.uvicorn.org/):
+
+
+```
+$ uvicorn starlette_demo:app
+```
+
+In the browser, it looks like this:
+
diff --git a/examples/async_example.py b/examples/async_example.py
new file mode 100644
index 0000000..cd0d9a2
--- /dev/null
+++ b/examples/async_example.py
@@ -0,0 +1,24 @@
+import asyncio
+import random
+
+from htpy import Element, b, div, h1
+
+
+async def magic_number() -> Element:
+ await asyncio.sleep(2)
+ return b[f"The Magic Number is: {random.randint(1, 100)}"]
+
+
+async def my_component() -> Element:
+ return div[
+ h1["The Magic Number"],
+ magic_number(),
+ ]
+
+
+async def main() -> None:
+ async for chunk in await my_component():
+ print(chunk)
+
+
+asyncio.run(main())
diff --git a/examples/starlette_app.py b/examples/starlette_app.py
new file mode 100644
index 0000000..f06b57c
--- /dev/null
+++ b/examples/starlette_app.py
@@ -0,0 +1,35 @@
+import asyncio
+from collections.abc import AsyncIterator
+
+from starlette.applications import Starlette
+from starlette.requests import Request
+from starlette.responses import StreamingResponse
+from starlette.routing import Route
+
+from htpy import Element, div, h1, li, p, ul
+
+
+async def index(request: Request) -> StreamingResponse:
+ return StreamingResponse(await index_page(), media_type="text/html")
+
+
+async def index_page() -> Element:
+ return div[
+ h1["Starlette Async example"],
+ p["This page is generated asynchronously using Starlette and ASGI."],
+ ul[(li[str(num)] async for num in slow_numbers(1, 10))],
+ ]
+
+
+async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]:
+ for number in range(minimum, maximum + 1):
+ yield number
+ await asyncio.sleep(0.5)
+
+
+app = Starlette(
+ debug=True,
+ routes=[
+ Route("/", index),
+ ],
+)
diff --git a/htpy/__init__.py b/htpy/__init__.py
index 1a343e1..02c1308 100644
--- a/htpy/__init__.py
+++ b/htpy/__init__.py
@@ -6,7 +6,15 @@
import dataclasses
import functools
import typing as t
-from collections.abc import Callable, Generator, Iterable, Iterator
+from collections.abc import (
+ AsyncIterable,
+ AsyncIterator,
+ Awaitable,
+ Callable,
+ Generator,
+ Iterable,
+ Iterator,
+)
from markupsafe import Markup as _Markup
from markupsafe import escape as _escape
@@ -126,6 +134,9 @@ class ContextProvider(t.Generic[T]):
def __iter__(self) -> Iterator[str]:
return iter_node(self)
+ def __aiter__(self) -> AsyncIterator[str]:
+ return aiter_node(self)
+
def __str__(self) -> str:
return render_node(self)
@@ -137,6 +148,10 @@ class ContextConsumer(t.Generic[T]):
func: Callable[[T], Node]
+def _is_noop_node(x: Node) -> bool:
+ return x is None or x is True or x is False
+
+
class _NO_DEFAULT:
pass
@@ -168,15 +183,8 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It
while not isinstance(x, BaseElement) and callable(x):
x = x()
- if x is None:
- return
-
- if x is True:
+ if _is_noop_node(x):
return
-
- if x is False:
- return
-
if isinstance(x, BaseElement):
yield from x._iter_context(context_dict) # pyright: ignore [reportPrivateUsage]
elif isinstance(x, ContextProvider):
@@ -196,6 +204,68 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It
elif isinstance(x, Iterable) and not isinstance(x, _KnownInvalidChildren): # pyright: ignore [reportUnnecessaryIsInstance]
for child in x:
yield from _iter_node_context(child, context_dict)
+ elif isinstance(x, Awaitable | AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance]
+ raise ValueError(
+ f"{x!r} is not a valid child element. "
+ "Use async iteration to retrieve element content: https://htpy.dev/streaming/"
+ )
+ else:
+ raise TypeError(f"{x!r} is not a valid child element")
+
+
+def aiter_node(x: Node) -> AsyncIterator[str]:
+ return _aiter_node_context(x, {})
+
+
+async def _aiter_node_context(
+ x: Node, context_dict: dict[Context[t.Any], t.Any]
+) -> AsyncIterator[str]:
+ while True:
+ if isinstance(x, Awaitable):
+ x = await x
+ continue
+
+ if not isinstance(x, BaseElement) and callable(x):
+ x = x()
+ continue
+
+ break
+
+ if _is_noop_node(x):
+ return
+
+ if isinstance(x, BaseElement):
+ async for child in x._aiter_context(context_dict): # pyright: ignore [reportPrivateUsage]
+ yield child
+ elif isinstance(x, ContextProvider):
+ async for chunk in _aiter_node_context(
+ x.func(),
+ {**context_dict, x.context: x.value}, # pyright: ignore [reportUnknownMemberType]
+ ):
+ yield chunk
+
+ elif isinstance(x, ContextConsumer):
+ context_value = context_dict.get(x.context, x.context.default)
+ if context_value is _NO_DEFAULT:
+ raise LookupError(
+ f'Context value for "{x.context.name}" does not exist, '
+ f"requested by {x.debug_name}()."
+ )
+ async for chunk in _aiter_node_context(x.func(context_value), context_dict):
+ yield chunk
+
+ elif isinstance(x, str | _HasHtml):
+ yield str(_escape(x))
+ elif isinstance(x, int):
+ yield str(x)
+ elif isinstance(x, Iterable) and not isinstance(x, _KnownInvalidChildren): # pyright: ignore [reportUnnecessaryIsInstance]
+ for child in x: # type: ignore[assignment]
+ async for chunk in _aiter_node_context(child, context_dict):
+ yield chunk
+ elif isinstance(x, AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance]
+ async for child in x: # type: ignore[assignment]
+ async for chunk in _aiter_node_context(child, context_dict): # pyright: ignore[reportUnknownArgumentType]
+ yield chunk
else:
raise TypeError(f"{x!r} is not a valid child element")
@@ -267,6 +337,15 @@ def __call__(self: BaseElementSelf, *args: t.Any, **kwargs: t.Any) -> BaseElemen
self._children,
)
+ async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]:
+ yield f"<{self._name}{self._attrs}>"
+ async for x in _aiter_node_context(self._children, context):
+ yield x
+ yield f"{self._name}>"
+
+ def __aiter__(self) -> AsyncIterator[str]:
+ return self._aiter_context({})
+
def __iter__(self) -> Iterator[str]:
return self._iter_context({})
@@ -314,8 +393,16 @@ def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
yield ""
yield from super()._iter_context(ctx)
+ async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]:
+ yield ""
+ async for x in super()._aiter_context(context):
+ yield x
+
class VoidElement(BaseElement):
+ async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]:
+ yield f"<{self._name}{self._attrs}>"
+
def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
yield f"<{self._name}{self._attrs}>"
@@ -350,6 +437,8 @@ def __html__(self) -> str: ...
| Callable[[], "Node"]
| ContextProvider[t.Any]
| ContextConsumer[t.Any]
+ | AsyncIterable["Node"]
+ | Awaitable["Node"]
)
Attribute: t.TypeAlias = None | bool | str | int | _HasHtml | _ClassNames
@@ -483,6 +572,8 @@ def __html__(self) -> str: ...
_KnownValidChildren: UnionType = ( # pyright: ignore [reportUnknownVariableType]
None
| BaseElement
+ | AsyncIterable # pyright: ignore [reportMissingTypeArgument]
+ | Awaitable # pyright: ignore [reportMissingTypeArgument]
| ContextProvider # pyright: ignore [reportMissingTypeArgument]
| ContextConsumer # pyright: ignore [reportMissingTypeArgument]
| Callable # pyright: ignore [reportMissingTypeArgument]
diff --git a/tests/conftest.py b/tests/conftest.py
index 48a7eb7..6cbc26b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,11 +1,12 @@
from __future__ import annotations
+import asyncio
import dataclasses
import typing as t
import pytest
-from htpy import Node, iter_node
+from htpy import Node, aiter_node, iter_node
if t.TYPE_CHECKING:
from collections.abc import Callable, Generator
@@ -49,9 +50,31 @@ def func(description: str) -> None:
@pytest.fixture
-def render(render_result: RenderResult) -> Generator[RenderFixture, None, None]:
+def render_async(render_result: RenderResult) -> RenderFixture:
+ def func(node: Node) -> RenderResult:
+ async def run() -> RenderResult:
+ async for chunk in aiter_node(node):
+ render_result.append(chunk)
+ return render_result
+
+ return asyncio.run(run(), debug=True)
+
+ return func
+
+
+@pytest.fixture(params=["sync", "async"])
+def render(
+ request: pytest.FixtureRequest,
+ render_async: RenderFixture,
+ render_result: RenderResult,
+) -> Generator[RenderFixture, None, None]:
called = False
+ def render_sync(node: Node) -> RenderResult:
+ for chunk in iter_node(node):
+ render_result.append(chunk)
+ return render_result
+
def func(node: Node) -> RenderResult:
nonlocal called
@@ -59,10 +82,11 @@ def func(node: Node) -> RenderResult:
raise AssertionError("render() must only be called once per test")
called = True
- for chunk in iter_node(node):
- render_result.append(chunk)
- return render_result
+ if request.param == "sync":
+ return render_sync(node)
+ else:
+ return render_async(node)
yield func
diff --git a/tests/test_async.py b/tests/test_async.py
new file mode 100644
index 0000000..a31908c
--- /dev/null
+++ b/tests/test_async.py
@@ -0,0 +1,84 @@
+from __future__ import annotations
+
+import typing as t
+
+import pytest
+
+from htpy import Element, li, ul
+
+from .conftest import Trace
+
+if t.TYPE_CHECKING:
+ from collections.abc import AsyncIterator
+
+ from .conftest import RenderFixture, TraceFixture
+
+
+def test_async_iterator(render_async: RenderFixture, trace: TraceFixture) -> None:
+ async def lis() -> AsyncIterator[Element]:
+ trace("pre a")
+ yield li["a"]
+ trace("pre b")
+ yield li["b"]
+ trace("post b")
+
+ result = render_async(ul[lis()])
+ assert result == [
+ "