Skip to content

Commit

Permalink
async iteration support
Browse files Browse the repository at this point in the history
  • Loading branch information
pelme committed Aug 10, 2024
1 parent 1e8ca53 commit 93f0dca
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 3 deletions.
58 changes: 55 additions & 3 deletions htpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

__version__ = "24.8.0"
__all__: list[str] = []

import functools
from collections.abc import Callable, Iterable, Iterator
from collections.abc import (
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Iterable,
Iterator,
)
from typing import Any, Protocol, TypeAlias, TypeVar, overload

from markupsafe import Markup as _Markup
Expand Down Expand Up @@ -123,10 +129,42 @@ def iter_node(x: Node) -> Iterator[str]:
elif isinstance(x, Iterable):
for child in x:
yield from iter_node(child)
elif isinstance(x, Awaitable | AsyncIterable):
raise ValueError(
f"{x!r} is not a valid child element. "
"Use async iteration to retrieve element content: https://htpy.dev/async/"
)
else:
raise ValueError(f"{x!r} is not a valid child element")


async def aiter_node(x: Node) -> AsyncIterator[str]:
while isinstance(x, Awaitable) or (not isinstance(x, BaseElement) and callable(x)):
if isinstance(x, Awaitable):
x = await x
else:
x = x()

if x is None:
return

if isinstance(x, BaseElement):
async for child in x:
yield child
elif isinstance(x, str) or hasattr(x, "__html__"):
yield str(_escape(x))
elif isinstance(x, AsyncIterable):
async for child in x: # type: ignore
async for chunk in aiter_node(child): # pyright: ignore[reportUnknownArgumentType]
yield chunk
elif isinstance(x, Iterable):
for child in x: # type: ignore
async for chunk in aiter_node(child):
yield chunk
else:
raise ValueError(f"{x!r} is not a valid async child element")


@functools.lru_cache(maxsize=300)
def _get_element(name: str) -> Element:
if not name.islower():
Expand Down Expand Up @@ -192,6 +230,12 @@ def __call__(self: BaseElementSelf, *args: Any, **kwargs: Any) -> BaseElementSel
self._children,
)

async def __aiter__(self) -> AsyncIterator[str]:
yield f"<{self._name}{_attrs_string(self._attrs)}>"
async for x in aiter_node(self._children):
yield x
yield f"</{self._name}>"

def __iter__(self) -> Iterator[str]:
yield f"<{self._name}{_attrs_string(self._attrs)}>"
yield from iter_node(self._children)
Expand Down Expand Up @@ -240,7 +284,15 @@ def __html__(self) -> str: ...
_ClassNamesDict: TypeAlias = dict[str, bool]
_ClassNames: TypeAlias = Iterable[str | None | bool | _ClassNamesDict] | _ClassNamesDict
Node: TypeAlias = (
None | bool | str | BaseElement | _HasHtml | Iterable["Node"] | Callable[[], "Node"]
None
| bool
| str
| BaseElement
| _HasHtml
| Iterable["Node"]
| Callable[[], "Node"]
| AsyncIterable["Node"]
| Awaitable["Node"]
)

Attribute: TypeAlias = None | bool | str | _HasHtml | _ClassNames
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ optional-dependencies.dev = [
"mypy",
"pyright",
"pytest",
"pytest-asyncio",
"black",
"ruff",
"django",
Expand Down
43 changes: 43 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from collections.abc import AsyncIterator

import pytest

from htpy import Element, li, ul


async def async_lis() -> AsyncIterator[Element]:
yield li["a"]
yield li["b"]


async def hi() -> Element:
return li["hi"]


@pytest.mark.asyncio
async def test_async_iterator() -> None:
result = [chunk async for chunk in ul[async_lis()]]
assert result == ["<ul>", "<li>", "a", "</li>", "<li>", "b", "</li>", "</ul>"]


@pytest.mark.asyncio
async def test_cororoutinefunction_children() -> None:
result = [chunk async for chunk in ul[hi]]
assert result == ["<ul>", "<li>", "hi", "</li>", "</ul>"]


@pytest.mark.asyncio
async def test_cororoutine_children() -> None:
result = [chunk async for chunk in ul[hi()]]
assert result == ["<ul>", "<li>", "hi", "</li>", "</ul>"]


def test_sync_iteration_with_async_children() -> None:
with pytest.raises(
ValueError,
match=(
r"<async_generator object async_lis at .+> is not a valid child element\. "
r"Use async iteration to retrieve element content: https://htpy.dev/async/"
),
):
str(ul[async_lis()])

0 comments on commit 93f0dca

Please sign in to comment.