Skip to content

Commit

Permalink
Make possible to use one-shot iterators as children nodes.
Browse files Browse the repository at this point in the history
Previously there were checks for generators specifically. We also support
one-off iterables that are not based on generators such as itertools.chain()
or any non-generator based object that implements __next__().

This regression was introduced in #56.

Based on #71.
  • Loading branch information
pelme committed Dec 15, 2024
1 parent ac4b0bd commit b391be8
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 14 deletions.
23 changes: 14 additions & 9 deletions htpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import keyword
import typing as t
from collections.abc import Callable, Generator, Iterable, Iterator
from collections.abc import Callable, Iterable, Iterator

from markupsafe import Markup as _Markup
from markupsafe import escape as _escape
Expand Down Expand Up @@ -297,12 +297,18 @@ def encode(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:


def _validate_children(children: t.Any) -> None:
if isinstance(children, _KnownValidChildren):
# Validate non-lazy iterables.
# list and tuple are iterables and part of _KnownValidChildren. Since we
# know they can be consumed multiple times, we can validate them now rather
# at render time to provider better error messages.
if isinstance(children, list | tuple):
for child in children: # pyright: ignore[reportUnknownVariableType]
_validate_children(child)
return

if isinstance(children, Iterable) and not isinstance(children, _KnownInvalidChildren):
for child in children: # pyright: ignore [reportUnknownVariableType]
_validate_children(child)
if not isinstance(children, _KnownInvalidChildren) and isinstance(
children, _KnownValidChildren
):
return

raise TypeError(f"{children!r} is not a valid child element")
Expand Down Expand Up @@ -487,15 +493,14 @@ def __html__(self) -> str: ...


_KnownInvalidChildren: UnionType = bytes | bytearray | memoryview

_KnownValidChildren: UnionType = ( # pyright: ignore [reportUnknownVariableType]
_KnownValidChildren: UnionType = (
None
| BaseElement
| ContextProvider # pyright: ignore [reportMissingTypeArgument]
| ContextConsumer # pyright: ignore [reportMissingTypeArgument]
| Callable # pyright: ignore [reportMissingTypeArgument]
| str
| int
| Generator # pyright: ignore [reportMissingTypeArgument]
| _HasHtml
| Callable
| Iterable
)
52 changes: 47 additions & 5 deletions tests/test_children.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pathlib
import re
import typing as t
from collections.abc import Iterable

import pytest
from markupsafe import Markup
Expand All @@ -23,6 +24,28 @@
from .conftest import RenderFixture, TraceFixture


T = t.TypeVar("T")


class SingleShotIterator(t.Generic[T]):
def __init__(self, value: T) -> None:
self.value = value
self.yielded = False

def __iter__(self) -> SingleShotIterator[T]:
return self

def __next__(self) -> T:
if not self.yielded:
self.yielded = True
return self.value
else:
raise StopIteration


assert isinstance(SingleShotIterator("foo"), Iterable)


def test_void_element(render: RenderFixture) -> None:
result = input(name="foo")
assert_type(result, VoidElement)
Expand Down Expand Up @@ -109,6 +132,16 @@ def test_generator_children(render: RenderFixture) -> None:
assert render(result) == ["<ul>", "<li>", "a", "</li>", "<li>", "b", "</li>", "</ul>"]


def test_non_generator_iterable(render: RenderFixture) -> None:
result = ul[SingleShotIterator("hello")]

assert render(result) == [
"<ul>",
"hello",
"</ul>",
]


def test_html_tag_with_doctype(render: RenderFixture) -> None:
result = html(foo="bar")["hello"]
assert render(result) == ["<!doctype html>", '<html foo="bar">', "hello", "</html>"]
Expand Down Expand Up @@ -259,6 +292,18 @@ def test_invalid_child_direct(not_a_child: t.Any) -> None:
div[not_a_child]


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_wrapped_in_list(not_a_child: t.Any) -> None:
with pytest.raises(TypeError, match="is not a valid child element"):
div[[not_a_child]]


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_wrapped_in_tuple(not_a_child: t.Any) -> None:
with pytest.raises(TypeError, match="is not a valid child element"):
div[(not_a_child,)]


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_nested_iterable(not_a_child: t.Any) -> None:
with pytest.raises(TypeError, match="is not a valid child element"):
Expand All @@ -276,14 +321,11 @@ def test_invalid_child_lazy_callable(not_a_child: t.Any, render: RenderFixture)


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_lazy_generator(not_a_child: t.Any, render: RenderFixture) -> None:
def test_invalid_child_lazy_iterator(not_a_child: t.Any, render: RenderFixture) -> None:
"""
Ensure proper exception is raised for lazily evaluated invalid children.
"""

def gen() -> t.Any:
yield not_a_child

element = div[gen()]
element = div[SingleShotIterator(not_a_child)]
with pytest.raises(TypeError, match="is not a valid child element"):
render(element)

0 comments on commit b391be8

Please sign in to comment.