Skip to content

Commit

Permalink
Async iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
pelme committed Sep 15, 2024
1 parent f4b80a2 commit 8970b1a
Show file tree
Hide file tree
Showing 8 changed files with 340 additions and 16 deletions.
Binary file added docs/assets/starlette.webm
Binary file not shown.
58 changes: 57 additions & 1 deletion docs/streaming.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Streaming of Contents

Internally, htpy is built with generators. Most of the time, you would render
the full page with `str()`, but htpy can also incrementally generate pages which
the full page with `str()`, but htpy can also incrementally generate pages synchronously or asynchronous which
can then be streamed to the browser. If your page uses a database or other
services to retrieve data, you can sending the first part of the page to the
client while the page is being generated.
Expand Down Expand Up @@ -111,3 +111,59 @@ print(
# output: <div><h1>Fibonacci!</h1>fib(12)=6765</div>

```


## Asynchronous streaming

It is also possible to use htpy to stream fully asynchronous. This intended to be used
with ASGI/async web frameworks/servers such as Starlette and Django. You can
build htpy components using Python's `asyncio` module and the `async`/`await`
syntax.

### Starlette, ASGI and uvicorn example

```python
title="starlette_demo.py"
import asyncio
from collections.abc import AsyncIterator

from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import StreamingResponse

from htpy import Element, div, h1, li, p, ul

app = Starlette(debug=True)


@app.route("/")
async def index(request: Request) -> StreamingResponse:
return StreamingResponse(await index_page(), media_type="text/html")


async def index_page() -> Element:
return div[
h1["Starlette Async example"],
p["This page is generated asynchronously using Starlette and ASGI."],
ul[(li[str(num)] async for num in slow_numbers(1, 10))],
]


async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]:
for number in range(minimum, maximum + 1):
yield number
await asyncio.sleep(0.5)

```

Run with [uvicorn](https://www.uvicorn.org/):


```
$ uvicorn starlette_demo:app
```

In the browser, it looks like this:
<video width="500" controls loop >
<source src="/assets/starlette.webm" type="video/webm">
</video>
24 changes: 24 additions & 0 deletions examples/async_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import asyncio
import random

from htpy import Element, b, div, h1


async def magic_number() -> Element:
await asyncio.sleep(2)
return b[f"The Magic Number is: {random.randint(1, 100)}"]


async def my_component() -> Element:
return div[
h1["The Magic Number"],
magic_number(),
]


async def main() -> None:
async for chunk in await my_component():
print(chunk)


asyncio.run(main())
35 changes: 35 additions & 0 deletions examples/starlette_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import asyncio
from collections.abc import AsyncIterator

from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import StreamingResponse
from starlette.routing import Route

from htpy import Element, div, h1, li, p, ul


async def index(request: Request) -> StreamingResponse:
return StreamingResponse(await index_page(), media_type="text/html")


async def index_page() -> Element:
return div[
h1["Starlette Async example"],
p["This page is generated asynchronously using Starlette and ASGI."],
ul[(li[str(num)] async for num in slow_numbers(1, 10))],
]


async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]:
for number in range(minimum, maximum + 1):
yield number
await asyncio.sleep(0.5)


app = Starlette(
debug=True,
routes=[
Route("/", index),
],
)
109 changes: 100 additions & 9 deletions htpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
import dataclasses
import functools
import typing as t
from collections.abc import Callable, Generator, Iterable, Iterator
from collections.abc import (
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Generator,
Iterable,
Iterator,
)

from markupsafe import Markup as _Markup
from markupsafe import escape as _escape
Expand Down Expand Up @@ -126,6 +134,9 @@ class ContextProvider(t.Generic[T]):
def __iter__(self) -> Iterator[str]:
return iter_node(self)

def __aiter__(self) -> AsyncIterator[str]:
return aiter_node(self)

def __str__(self) -> str:
return render_node(self)

Expand All @@ -137,6 +148,10 @@ class ContextConsumer(t.Generic[T]):
func: Callable[[T], Node]


def _is_noop_node(x: Node) -> bool:
return x is None or x is True or x is False


class _NO_DEFAULT:
pass

Expand Down Expand Up @@ -168,15 +183,8 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It
while not isinstance(x, BaseElement) and callable(x):
x = x()

if x is None:
return

if x is True:
if _is_noop_node(x):
return

if x is False:
return

if isinstance(x, BaseElement):
yield from x._iter_context(context_dict) # pyright: ignore [reportPrivateUsage]
elif isinstance(x, ContextProvider):
Expand All @@ -196,6 +204,68 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It
elif isinstance(x, Iterable) and not isinstance(x, _KnownInvalidChildren): # pyright: ignore [reportUnnecessaryIsInstance]
for child in x:
yield from _iter_node_context(child, context_dict)
elif isinstance(x, Awaitable | AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance]
raise ValueError(
f"{x!r} is not a valid child element. "
"Use async iteration to retrieve element content: https://htpy.dev/streaming/"
)
else:
raise TypeError(f"{x!r} is not a valid child element")


def aiter_node(x: Node) -> AsyncIterator[str]:
return _aiter_node_context(x, {})


async def _aiter_node_context(
x: Node, context_dict: dict[Context[t.Any], t.Any]
) -> AsyncIterator[str]:
while True:
if isinstance(x, Awaitable):
x = await x
continue

if not isinstance(x, BaseElement) and callable(x):
x = x()
continue

break

if _is_noop_node(x):
return

if isinstance(x, BaseElement):
async for child in x._aiter_context(context_dict): # pyright: ignore [reportPrivateUsage]
yield child
elif isinstance(x, ContextProvider):
async for chunk in _aiter_node_context(
x.func(),
{**context_dict, x.context: x.value}, # pyright: ignore [reportUnknownMemberType]
):
yield chunk

elif isinstance(x, ContextConsumer):
context_value = context_dict.get(x.context, x.context.default)
if context_value is _NO_DEFAULT:
raise LookupError(
f'Context value for "{x.context.name}" does not exist, '
f"requested by {x.debug_name}()."
)
async for chunk in _aiter_node_context(x.func(context_value), context_dict):
yield chunk

elif isinstance(x, str | _HasHtml):
yield str(_escape(x))
elif isinstance(x, int):
yield str(x)
elif isinstance(x, Iterable) and not isinstance(x, _KnownInvalidChildren): # pyright: ignore [reportUnnecessaryIsInstance]
for child in x: # type: ignore[assignment]
async for chunk in _aiter_node_context(child, context_dict):
yield chunk
elif isinstance(x, AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance]
async for child in x: # type: ignore[assignment]
async for chunk in _aiter_node_context(child, context_dict): # pyright: ignore[reportUnknownArgumentType]
yield chunk
else:
raise TypeError(f"{x!r} is not a valid child element")

Expand Down Expand Up @@ -267,6 +337,15 @@ def __call__(self: BaseElementSelf, *args: t.Any, **kwargs: t.Any) -> BaseElemen
self._children,
)

async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]:
yield f"<{self._name}{self._attrs}>"
async for x in _aiter_node_context(self._children, context):
yield x
yield f"</{self._name}>"

def __aiter__(self) -> AsyncIterator[str]:
return self._aiter_context({})

def __iter__(self) -> Iterator[str]:
return self._iter_context({})

Expand Down Expand Up @@ -314,8 +393,16 @@ def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
yield "<!doctype html>"
yield from super()._iter_context(ctx)

async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]:
yield "<!doctype html>"
async for x in super()._aiter_context(context):
yield x


class VoidElement(BaseElement):
async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]:
yield f"<{self._name}{self._attrs}>"

def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
yield f"<{self._name}{self._attrs}>"

Expand Down Expand Up @@ -350,6 +437,8 @@ def __html__(self) -> str: ...
| Callable[[], "Node"]
| ContextProvider[t.Any]
| ContextConsumer[t.Any]
| AsyncIterable["Node"]
| Awaitable["Node"]
)

Attribute: t.TypeAlias = None | bool | str | int | _HasHtml | _ClassNames
Expand Down Expand Up @@ -483,6 +572,8 @@ def __html__(self) -> str: ...
_KnownValidChildren: UnionType = ( # pyright: ignore [reportUnknownVariableType]
None
| BaseElement
| AsyncIterable # pyright: ignore [reportMissingTypeArgument]
| Awaitable # pyright: ignore [reportMissingTypeArgument]
| ContextProvider # pyright: ignore [reportMissingTypeArgument]
| ContextConsumer # pyright: ignore [reportMissingTypeArgument]
| Callable # pyright: ignore [reportMissingTypeArgument]
Expand Down
34 changes: 29 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import asyncio
import dataclasses
import typing as t

import pytest

from htpy import Node, iter_node
from htpy import Node, aiter_node, iter_node

if t.TYPE_CHECKING:
from collections.abc import Callable, Generator
Expand Down Expand Up @@ -49,20 +50,43 @@ def func(description: str) -> None:


@pytest.fixture
def render(render_result: RenderResult) -> Generator[RenderFixture, None, None]:
def render_async(render_result: RenderResult) -> RenderFixture:
def func(node: Node) -> RenderResult:
async def run() -> RenderResult:
async for chunk in aiter_node(node):
render_result.append(chunk)
return render_result

return asyncio.run(run(), debug=True)

return func


@pytest.fixture(params=["sync", "async"])
def render(
request: pytest.FixtureRequest,
render_async: RenderFixture,
render_result: RenderResult,
) -> Generator[RenderFixture, None, None]:
called = False

def render_sync(node: Node) -> RenderResult:
for chunk in iter_node(node):
render_result.append(chunk)
return render_result

def func(node: Node) -> RenderResult:
nonlocal called

if called:
raise AssertionError("render() must only be called once per test")

called = True
for chunk in iter_node(node):
render_result.append(chunk)

return render_result
if request.param == "sync":
return render_sync(node)
else:
return render_async(node)

yield func

Expand Down
Loading

0 comments on commit 8970b1a

Please sign in to comment.