Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support async iteration #38

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/assets/starlette.webm
Binary file not shown.
58 changes: 57 additions & 1 deletion docs/streaming.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Streaming of Contents

Internally, htpy is built with generators. Most of the time, you would render
the full page with `str()`, but htpy can also incrementally generate pages which
the full page with `str()`, but htpy can also incrementally generate pages synchronously or asynchronous which
can then be streamed to the browser. If your page uses a database or other
services to retrieve data, you can sending the first part of the page to the
client while the page is being generated.
@@ -111,3 +111,59 @@ print(
# output: <div><h1>Fibonacci!</h1>fib(12)=6765</div>

```


## Asynchronous streaming

It is also possible to use htpy to stream fully asynchronous. This intended to be used
with ASGI/async web frameworks/servers such as Starlette and Django. You can
build htpy components using Python's `asyncio` module and the `async`/`await`
syntax.

### Starlette, ASGI and uvicorn example

```python
title="starlette_demo.py"
import asyncio
from collections.abc import AsyncIterator

from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import StreamingResponse

from htpy import Element, div, h1, li, p, ul

app = Starlette(debug=True)


@app.route("/")
async def index(request: Request) -> StreamingResponse:
return StreamingResponse(await index_page(), media_type="text/html")


async def index_page() -> Element:
return div[
h1["Starlette Async example"],
p["This page is generated asynchronously using Starlette and ASGI."],
ul[(li[str(num)] async for num in slow_numbers(1, 10))],
]


async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]:
for number in range(minimum, maximum + 1):
yield number
await asyncio.sleep(0.5)

```

Run with [uvicorn](https://www.uvicorn.org/):


```
$ uvicorn starlette_demo:app
```

In the browser, it looks like this:
<video width="500" controls loop >
<source src="/assets/starlette.webm" type="video/webm">
</video>
24 changes: 24 additions & 0 deletions examples/async_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import asyncio
import random

from htpy import Element, b, div, h1


async def magic_number() -> Element:
await asyncio.sleep(2)
return b[f"The Magic Number is: {random.randint(1, 100)}"]


async def my_component() -> Element:
return div[
h1["The Magic Number"],
magic_number(),
]


async def main() -> None:
async for chunk in await my_component():
print(chunk)


asyncio.run(main())
35 changes: 35 additions & 0 deletions examples/starlette_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import asyncio
from collections.abc import AsyncIterator

from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import StreamingResponse
from starlette.routing import Route

from htpy import Element, div, h1, li, p, ul


async def index(request: Request) -> StreamingResponse:
return StreamingResponse(await index_page(), media_type="text/html")


async def index_page() -> Element:
return div[
h1["Starlette Async example"],
p["This page is generated asynchronously using Starlette and ASGI."],
ul[(li[str(num)] async for num in slow_numbers(1, 10))],
]


async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]:
for number in range(minimum, maximum + 1):
yield number
await asyncio.sleep(0.5)


app = Starlette(
debug=True,
routes=[
Route("/", index),
],
)
109 changes: 100 additions & 9 deletions htpy/__init__.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,15 @@
import dataclasses
import functools
import typing as t
from collections.abc import Callable, Generator, Iterable, Iterator
from collections.abc import (
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Generator,
Iterable,
Iterator,
)

from markupsafe import Markup as _Markup
from markupsafe import escape as _escape
@@ -126,6 +134,9 @@ class ContextProvider(t.Generic[T]):
def __iter__(self) -> Iterator[str]:
return iter_node(self)

def __aiter__(self) -> AsyncIterator[str]:
return aiter_node(self)

def __str__(self) -> str:
return render_node(self)

@@ -137,6 +148,10 @@ class ContextConsumer(t.Generic[T]):
func: Callable[[T], Node]


def _is_noop_node(x: Node) -> bool:
return x is None or x is True or x is False


class _NO_DEFAULT:
pass

@@ -168,15 +183,8 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It
while not isinstance(x, BaseElement) and callable(x):
x = x()

if x is None:
return

if x is True:
if _is_noop_node(x):
return

if x is False:
return

if isinstance(x, BaseElement):
yield from x._iter_context(context_dict) # pyright: ignore [reportPrivateUsage]
elif isinstance(x, ContextProvider):
@@ -196,6 +204,68 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It
elif isinstance(x, Iterable) and not isinstance(x, _KnownInvalidChildren): # pyright: ignore [reportUnnecessaryIsInstance]
for child in x:
yield from _iter_node_context(child, context_dict)
elif isinstance(x, Awaitable | AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance]
raise ValueError(
f"{x!r} is not a valid child element. "
"Use async iteration to retrieve element content: https://htpy.dev/streaming/"
)
else:
raise TypeError(f"{x!r} is not a valid child element")


def aiter_node(x: Node) -> AsyncIterator[str]:
return _aiter_node_context(x, {})


async def _aiter_node_context(
x: Node, context_dict: dict[Context[t.Any], t.Any]
) -> AsyncIterator[str]:
while True:
if isinstance(x, Awaitable):
x = await x
continue

if not isinstance(x, BaseElement) and callable(x):
x = x()
continue

break

if _is_noop_node(x):
return

if isinstance(x, BaseElement):
async for child in x._aiter_context(context_dict): # pyright: ignore [reportPrivateUsage]
yield child
elif isinstance(x, ContextProvider):
async for chunk in _aiter_node_context(
x.func(),
{**context_dict, x.context: x.value}, # pyright: ignore [reportUnknownMemberType]
):
yield chunk

elif isinstance(x, ContextConsumer):
context_value = context_dict.get(x.context, x.context.default)
if context_value is _NO_DEFAULT:
raise LookupError(
f'Context value for "{x.context.name}" does not exist, '
f"requested by {x.debug_name}()."
)
async for chunk in _aiter_node_context(x.func(context_value), context_dict):
yield chunk

elif isinstance(x, str | _HasHtml):
yield str(_escape(x))
elif isinstance(x, int):
yield str(x)
elif isinstance(x, Iterable) and not isinstance(x, _KnownInvalidChildren): # pyright: ignore [reportUnnecessaryIsInstance]
for child in x: # type: ignore[assignment]
async for chunk in _aiter_node_context(child, context_dict):
yield chunk
elif isinstance(x, AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance]
async for child in x: # type: ignore[assignment]
async for chunk in _aiter_node_context(child, context_dict): # pyright: ignore[reportUnknownArgumentType]
yield chunk
else:
raise TypeError(f"{x!r} is not a valid child element")

@@ -267,6 +337,15 @@ def __call__(self: BaseElementSelf, *args: t.Any, **kwargs: t.Any) -> BaseElemen
self._children,
)

async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]:
yield f"<{self._name}{self._attrs}>"
async for x in _aiter_node_context(self._children, context):
yield x
yield f"</{self._name}>"

def __aiter__(self) -> AsyncIterator[str]:
return self._aiter_context({})

def __iter__(self) -> Iterator[str]:
return self._iter_context({})

@@ -314,8 +393,16 @@ def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
yield "<!doctype html>"
yield from super()._iter_context(ctx)

async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]:
yield "<!doctype html>"
async for x in super()._aiter_context(context):
yield x


class VoidElement(BaseElement):
async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]:
yield f"<{self._name}{self._attrs}>"

def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
yield f"<{self._name}{self._attrs}>"

@@ -350,6 +437,8 @@ def __html__(self) -> str: ...
| Callable[[], "Node"]
| ContextProvider[t.Any]
| ContextConsumer[t.Any]
| AsyncIterable["Node"]
| Awaitable["Node"]
)

Attribute: t.TypeAlias = None | bool | str | int | _HasHtml | _ClassNames
@@ -483,6 +572,8 @@ def __html__(self) -> str: ...
_KnownValidChildren: UnionType = ( # pyright: ignore [reportUnknownVariableType]
None
| BaseElement
| AsyncIterable # pyright: ignore [reportMissingTypeArgument]
| Awaitable # pyright: ignore [reportMissingTypeArgument]
| ContextProvider # pyright: ignore [reportMissingTypeArgument]
| ContextConsumer # pyright: ignore [reportMissingTypeArgument]
| Callable # pyright: ignore [reportMissingTypeArgument]
30 changes: 30 additions & 0 deletions htpy/starlette.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations

import typing as t

from starlette.responses import StreamingResponse

from . import aiter_node

if t.TYPE_CHECKING:
from starlette.background import BackgroundTask

from . import Node


class HtpyResponse(StreamingResponse):
def __init__(
self,
content: Node,
status_code: int = 200,
headers: t.Mapping[str, str] | None = None,
media_type: str | None = "text/html",
background: BackgroundTask | None = None,
):
super().__init__(
aiter_node(content),
status_code=status_code,
headers=headers,
media_type=media_type,
background=background,
)
34 changes: 29 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import asyncio
import dataclasses
import typing as t

import pytest

from htpy import Node, iter_node
from htpy import Node, aiter_node, iter_node

if t.TYPE_CHECKING:
from collections.abc import Callable, Generator
@@ -49,20 +50,43 @@ def func(description: str) -> None:


@pytest.fixture
def render(render_result: RenderResult) -> Generator[RenderFixture, None, None]:
def render_async(render_result: RenderResult) -> RenderFixture:
def func(node: Node) -> RenderResult:
async def run() -> RenderResult:
async for chunk in aiter_node(node):
render_result.append(chunk)
return render_result

return asyncio.run(run(), debug=True)

return func


@pytest.fixture(params=["sync", "async"])
def render(
request: pytest.FixtureRequest,
render_async: RenderFixture,
render_result: RenderResult,
) -> Generator[RenderFixture, None, None]:
called = False

def render_sync(node: Node) -> RenderResult:
for chunk in iter_node(node):
render_result.append(chunk)
return render_result

def func(node: Node) -> RenderResult:
nonlocal called

if called:
raise AssertionError("render() must only be called once per test")

called = True
for chunk in iter_node(node):
render_result.append(chunk)

return render_result
if request.param == "sync":
return render_sync(node)
else:
return render_async(node)

yield func

84 changes: 84 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

import typing as t

import pytest

from htpy import Element, li, ul

from .conftest import Trace

if t.TYPE_CHECKING:
from collections.abc import AsyncIterator

from .conftest import RenderFixture, TraceFixture


def test_async_iterator(render_async: RenderFixture, trace: TraceFixture) -> None:
async def lis() -> AsyncIterator[Element]:
trace("pre a")
yield li["a"]
trace("pre b")
yield li["b"]
trace("post b")

result = render_async(ul[lis()])
assert result == [
"<ul>",
Trace("pre a"),
"<li>",
"a",
"</li>",
Trace("pre b"),
"<li>",
"b",
"</li>",
Trace("post b"),
"</ul>",
]


def test_awaitable(render_async: RenderFixture, trace: TraceFixture) -> None:
async def hi() -> Element:
trace("in hi()")
return li["hi"]

result = render_async(ul[hi()])
assert result == [
"<ul>",
Trace("in hi()"),
"<li>",
"hi",
"</li>",
"</ul>",
]


@pytest.mark.filterwarnings(r"ignore:coroutine '.*\.coroutine' was never awaited")
def test_sync_iteration_coroutine() -> None:
async def coroutine() -> None:
return None

with pytest.raises(
ValueError,
match=(
r"<coroutine object .+ at .+> is not a valid child element\. "
r"Use async iteration to retrieve element content: https://htpy.dev/streaming/"
),
):
list(ul[coroutine()])


def test_sync_iteration_async_generator() -> None:
async def generator() -> AsyncIterator[None]:
return
yield

with pytest.raises(
ValueError,
match=(
r"<async_generator object .+ at .+> is not a valid child element\. "
r"Use async iteration to retrieve element content: https://htpy.dev/streaming/"
),
):
list(ul[generator()])
12 changes: 11 additions & 1 deletion tests/test_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import typing as t

import markupsafe
@@ -35,12 +36,21 @@ def test_context_provider(render: RenderFixture) -> None:


class Test_provider_outer_api:
"""Ensure provider implements __iter__/__str__"""
"""Ensure provider implements __iter__/__aiter__/__str__"""

def test_iter(self) -> None:
result = letter_ctx.provider("c", lambda: div[display_letter("Hello")])
assert list(result) == ["<div>", "Hello: c!", "</div>"]

def test_aiter(self) -> None:
provider = letter_ctx.provider("c", lambda: div[display_letter("Hello")])

async def run() -> list[str]:
return [chunk async for chunk in provider]

result = asyncio.run(run(), debug=True)
assert result == ["<div>", "Hello: c!", "</div>"]

def test_str(self) -> None:
result = str(letter_ctx.provider("c", lambda: div[display_letter("Hello")]))
assert result == "<div>Hello: c!</div>"
25 changes: 23 additions & 2 deletions tests/test_starlette.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,8 @@
from starlette.routing import Route
from starlette.testclient import TestClient

from htpy import h1
from htpy import Element, h1, p
from htpy.starlette import HtpyResponse

if t.TYPE_CHECKING:
from starlette.requests import Request
@@ -17,16 +18,36 @@ async def html_response(request: Request) -> HTMLResponse:
return HTMLResponse(h1["Hello, HTMLResponse!"])


async def stuff() -> Element:
return p["stuff"]


async def htpy_response(request: Request) -> HtpyResponse:
return HtpyResponse(
(
h1["Hello, HtpyResponse!"],
stuff(),
)
)


client = TestClient(
Starlette(
debug=True,
routes=[
Route("/html-response", html_response),
Route("/htpy-response", htpy_response),
],
)
)


def test_html_response() -> None:
response = client.get("/html-response")
assert response.content == b"<h1>Hello, HTMLResponse!</h1>"
assert response.text == "<h1>Hello, HTMLResponse!</h1>"


def test_htpy_response() -> None:
response = client.get("/htpy-response")
assert response.headers["content-type"] == "text/html; charset=utf-8"
assert response.text == "<h1>Hello, HtpyResponse!</h1><p>stuff</p>"