Skip to content

Commit

Permalink
Make possible to use one-shot iterators as children nodes.
Browse files Browse the repository at this point in the history
Previously there were checks for generators specifically and not iterators in
general. We also support one-off iterators that are not based on generators such
as itertools.chain().

This regression was introduced in #56.

Based on #71.
  • Loading branch information
pelme committed Dec 15, 2024
1 parent e774cd7 commit 0e7372a
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 21 deletions.
5 changes: 5 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## NEXT
- Fixed handling of non-generator iterators such as `itertools.chain()` as
children. Thanks to Aleksei Pirogov ([@astynax](https://github.com/astynax)).
[PR #72](https://github.com/pelme/htpy/pull/72).

## 24.10.1 - 2024-10-24
- Fix handling of Python keywords such as `<del>` in html2htpy. [PR #61](https://github.com/pelme/htpy/pull/61).

Expand Down
29 changes: 20 additions & 9 deletions htpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import keyword
import typing as t
from collections.abc import Callable, Generator, Iterable, Iterator
from collections.abc import Callable, Iterable, Iterator

from markupsafe import Markup as _Markup
from markupsafe import escape as _escape
Expand Down Expand Up @@ -297,14 +297,26 @@ def encode(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:


def _validate_children(children: t.Any) -> None:
if isinstance(children, _KnownValidChildren):
# Non-lazy iterables:
# list and tuple are iterables and part of _KnownValidChildren. Since we
# know they can be consumed multiple times, we validate them recursively now
# rather than at render time to provide better error messages.
if isinstance(children, list | tuple):
for child in children: # pyright: ignore[reportUnknownVariableType]
_validate_children(child)
return

if isinstance(children, Iterable) and not isinstance(children, _KnownInvalidChildren):
for child in children: # pyright: ignore [reportUnknownVariableType]
_validate_children(child)
# bytes, bytearray etc:
# These are Iterable (part of _KnownValidChildren) but still not
# useful as a child node.
if isinstance(children, _KnownInvalidChildren):
raise TypeError(f"{children!r} is not a valid child element")

# Element, str, int and all other regular/valid types.
if isinstance(children, _KnownValidChildren):
return

# Arbitrary objects that are not valid children.
raise TypeError(f"{children!r} is not a valid child element")


Expand Down Expand Up @@ -487,15 +499,14 @@ def __html__(self) -> str: ...


_KnownInvalidChildren: UnionType = bytes | bytearray | memoryview

_KnownValidChildren: UnionType = ( # pyright: ignore [reportUnknownVariableType]
_KnownValidChildren: UnionType = (
None
| BaseElement
| ContextProvider # pyright: ignore [reportMissingTypeArgument]
| ContextConsumer # pyright: ignore [reportMissingTypeArgument]
| Callable # pyright: ignore [reportMissingTypeArgument]
| str
| int
| Generator # pyright: ignore [reportMissingTypeArgument]
| _HasHtml
| Callable
| Iterable
)
69 changes: 57 additions & 12 deletions tests/test_children.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pathlib
import re
import typing as t
from collections.abc import Iterator

import pytest
from markupsafe import Markup
Expand All @@ -16,13 +17,36 @@
from .conftest import Trace

if t.TYPE_CHECKING:
from collections.abc import Callable, Generator
from collections.abc import Callable

from htpy import Node

from .conftest import RenderFixture, TraceFixture


T = t.TypeVar("T")


class SingleShotIterator(Iterator[T]):
def __init__(self, value: T, trace: TraceFixture = lambda x: None) -> None:
self.value = value
self.trace = trace
self.consumed = False

def __iter__(self) -> SingleShotIterator[T]:
return self

def __next__(self) -> T:
if self.consumed:
self.trace("SingleShotIterator: StopIteration")
raise StopIteration

self.consumed = True
self.trace("SingleShotIterator: returning value")

return self.value


def test_void_element(render: RenderFixture) -> None:
result = input(name="foo")
assert_type(result, VoidElement)
Expand Down Expand Up @@ -88,12 +112,12 @@ def test_flatten_very_nested_children(render: RenderFixture) -> None:


def test_flatten_nested_generators(render: RenderFixture) -> None:
def cols() -> Generator[str, None, None]:
def cols() -> Iterator[str]:
yield "a"
yield "b"
yield "c"

def rows() -> Generator[Generator[str, None, None], None, None]:
def rows() -> Iterator[Iterator[str]]:
yield cols()
yield cols()
yield cols()
Expand All @@ -104,11 +128,23 @@ def rows() -> Generator[Generator[str, None, None], None, None]:


def test_generator_children(render: RenderFixture) -> None:
gen: Generator[Element, None, None] = (li[x] for x in ["a", "b"])
gen: Iterator[Element] = (li[x] for x in ["a", "b"])
result = ul[gen]
assert render(result) == ["<ul>", "<li>", "a", "</li>", "<li>", "b", "</li>", "</ul>"]


def test_non_generator_iterator(render: RenderFixture, trace: TraceFixture) -> None:
result = ul[SingleShotIterator("hello", trace=trace)]

assert render(result) == [
"<ul>",
Trace("SingleShotIterator: returning value"),
"hello",
Trace("SingleShotIterator: StopIteration"),
"</ul>",
]


def test_html_tag_with_doctype(render: RenderFixture) -> None:
result = html(foo="bar")["hello"]
assert render(result) == ["<!doctype html>", '<html foo="bar">', "hello", "</html>"]
Expand Down Expand Up @@ -137,7 +173,7 @@ def test_ignored(render: RenderFixture, ignored_value: t.Any) -> None:


def test_lazy_iter(render: RenderFixture, trace: TraceFixture) -> None:
def generate_list() -> Generator[Element, None, None]:
def generate_list() -> Iterator[Element]:
trace("before yield")
yield li("#a")
trace("after yield")
Expand Down Expand Up @@ -202,7 +238,7 @@ def test_safe_children(render: RenderFixture) -> None:


def test_nested_callable_generator(render: RenderFixture) -> None:
def func() -> Generator[str, None, None]:
def func() -> Iterator[str]:
return (x for x in "abc")

assert render(div[func]) == ["<div>", "a", "b", "c", "</div>"]
Expand Down Expand Up @@ -260,7 +296,19 @@ def test_invalid_child_direct(not_a_child: t.Any) -> None:


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_nested_iterable(not_a_child: t.Any) -> None:
def test_invalid_child_wrapped_in_list(not_a_child: t.Any) -> None:
with pytest.raises(TypeError, match="is not a valid child element"):
div[[not_a_child]]


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_wrapped_in_tuple(not_a_child: t.Any) -> None:
with pytest.raises(TypeError, match="is not a valid child element"):
div[(not_a_child,)]


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_nested_iterator(not_a_child: t.Any) -> None:
with pytest.raises(TypeError, match="is not a valid child element"):
div[[not_a_child]]

Expand All @@ -276,14 +324,11 @@ def test_invalid_child_lazy_callable(not_a_child: t.Any, render: RenderFixture)


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_lazy_generator(not_a_child: t.Any, render: RenderFixture) -> None:
def test_invalid_child_lazy_iterator(not_a_child: t.Any, render: RenderFixture) -> None:
"""
Ensure proper exception is raised for lazily evaluated invalid children.
"""

def gen() -> t.Any:
yield not_a_child

element = div[gen()]
element = div[SingleShotIterator(not_a_child)]
with pytest.raises(TypeError, match="is not a valid child element"):
render(element)

0 comments on commit 0e7372a

Please sign in to comment.