From 11bf4ff7f818f73c98208465c37876fb4157b740 Mon Sep 17 00:00:00 2001 From: Andreas Pelme Date: Sun, 15 Dec 2024 10:18:07 +0100 Subject: [PATCH] Make possible to use one-shot iterators as children nodes. Previously there were checks for generators specifically and not iterators in general. We also support one-off iterators that are not based on generators such as itertools.chain(). This regression was introduced in https://github.com/pelme/htpy/pull/56. Based on https://github.com/pelme/htpy/pull/71. --- docs/changelog.md | 5 +++ htpy/__init__.py | 29 ++++++++++++------ tests/test_children.py | 69 ++++++++++++++++++++++++++++++++++-------- 3 files changed, 82 insertions(+), 21 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 1d9ce34..a4f30d7 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,10 @@ # Changelog +## NEXT +- Fixed handling of non-generator iterators such as `itertools.chain()` as +children. Thanks to Aleksei Pirogov ([@astynax](https://github.com/astynax)). +[PR #72](https://github.com/pelme/htpy/pull/72). + ## 24.10.1 - 2024-10-24 - Fix handling of Python keywords such as `` in html2htpy. [PR #61](https://github.com/pelme/htpy/pull/61). diff --git a/htpy/__init__.py b/htpy/__init__.py index 979bedf..655ec43 100644 --- a/htpy/__init__.py +++ b/htpy/__init__.py @@ -4,7 +4,7 @@ import functools import keyword import typing as t -from collections.abc import Callable, Generator, Iterable, Iterator +from collections.abc import Callable, Iterable, Iterator from markupsafe import Markup as _Markup from markupsafe import escape as _escape @@ -297,14 +297,26 @@ def encode(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: def _validate_children(children: t.Any) -> None: - if isinstance(children, _KnownValidChildren): + # Non-lazy iterables: + # list and tuple are iterables and part of _KnownValidChildren. Since we + # know they can be consumed multiple times, we validate them recursively now + # rather than at render time to provide better error messages. + if isinstance(children, list | tuple): + for child in children: # pyright: ignore[reportUnknownVariableType] + _validate_children(child) return - if isinstance(children, Iterable) and not isinstance(children, _KnownInvalidChildren): - for child in children: # pyright: ignore [reportUnknownVariableType] - _validate_children(child) + # bytes, bytearray etc: + # These are Iterable (part of _KnownValidChildren) but still not + # useful as a child node. + if isinstance(children, _KnownInvalidChildren): + raise TypeError(f"{children!r} is not a valid child element") + + # Element, str, int and all other regular/valid types. + if isinstance(children, _KnownValidChildren): return + # Arbitrary objects that are not valid children. raise TypeError(f"{children!r} is not a valid child element") @@ -487,15 +499,14 @@ def __html__(self) -> str: ... _KnownInvalidChildren: UnionType = bytes | bytearray | memoryview - -_KnownValidChildren: UnionType = ( # pyright: ignore [reportUnknownVariableType] +_KnownValidChildren: UnionType = ( None | BaseElement | ContextProvider # pyright: ignore [reportMissingTypeArgument] | ContextConsumer # pyright: ignore [reportMissingTypeArgument] - | Callable # pyright: ignore [reportMissingTypeArgument] | str | int - | Generator # pyright: ignore [reportMissingTypeArgument] | _HasHtml + | Callable + | Iterable ) diff --git a/tests/test_children.py b/tests/test_children.py index fcfb818..f9fd6bd 100644 --- a/tests/test_children.py +++ b/tests/test_children.py @@ -6,6 +6,7 @@ import pathlib import re import typing as t +from collections.abc import Iterator import pytest from markupsafe import Markup @@ -16,13 +17,36 @@ from .conftest import Trace if t.TYPE_CHECKING: - from collections.abc import Callable, Generator + from collections.abc import Callable from htpy import Node from .conftest import RenderFixture, TraceFixture +T = t.TypeVar("T") + + +class SingleShotIterator(Iterator[T]): + def __init__(self, value: T, trace: TraceFixture = lambda x: None) -> None: + self.value = value + self.trace = trace + self.consumed = False + + def __iter__(self) -> SingleShotIterator[T]: + return self + + def __next__(self) -> T: + if self.consumed: + self.trace("SingleShotIterator: StopIteration") + raise StopIteration + + self.consumed = True + self.trace("SingleShotIterator: returning value") + + return self.value + + def test_void_element(render: RenderFixture) -> None: result = input(name="foo") assert_type(result, VoidElement) @@ -88,12 +112,12 @@ def test_flatten_very_nested_children(render: RenderFixture) -> None: def test_flatten_nested_generators(render: RenderFixture) -> None: - def cols() -> Generator[str, None, None]: + def cols() -> Iterator[str]: yield "a" yield "b" yield "c" - def rows() -> Generator[Generator[str, None, None], None, None]: + def rows() -> Iterator[Iterator[str]]: yield cols() yield cols() yield cols() @@ -104,11 +128,23 @@ def rows() -> Generator[Generator[str, None, None], None, None]: def test_generator_children(render: RenderFixture) -> None: - gen: Generator[Element, None, None] = (li[x] for x in ["a", "b"]) + gen: Iterator[Element] = (li[x] for x in ["a", "b"]) result = ul[gen] assert render(result) == [""] +def test_non_generator_iterator(render: RenderFixture, trace: TraceFixture) -> None: + result = div[SingleShotIterator("hello", trace=trace)] + + assert render(result) == [ + "
", + Trace("SingleShotIterator: returning value"), + "hello", + Trace("SingleShotIterator: StopIteration"), + "
", + ] + + def test_html_tag_with_doctype(render: RenderFixture) -> None: result = html(foo="bar")["hello"] assert render(result) == ["", '', "hello", ""] @@ -137,7 +173,7 @@ def test_ignored(render: RenderFixture, ignored_value: t.Any) -> None: def test_lazy_iter(render: RenderFixture, trace: TraceFixture) -> None: - def generate_list() -> Generator[Element, None, None]: + def generate_list() -> Iterator[Element]: trace("before yield") yield li("#a") trace("after yield") @@ -202,7 +238,7 @@ def test_safe_children(render: RenderFixture) -> None: def test_nested_callable_generator(render: RenderFixture) -> None: - def func() -> Generator[str, None, None]: + def func() -> Iterator[str]: return (x for x in "abc") assert render(div[func]) == ["
", "a", "b", "c", "
"] @@ -260,7 +296,19 @@ def test_invalid_child_direct(not_a_child: t.Any) -> None: @pytest.mark.parametrize("not_a_child", _invalid_children) -def test_invalid_child_nested_iterable(not_a_child: t.Any) -> None: +def test_invalid_child_wrapped_in_list(not_a_child: t.Any) -> None: + with pytest.raises(TypeError, match="is not a valid child element"): + div[[not_a_child]] + + +@pytest.mark.parametrize("not_a_child", _invalid_children) +def test_invalid_child_wrapped_in_tuple(not_a_child: t.Any) -> None: + with pytest.raises(TypeError, match="is not a valid child element"): + div[(not_a_child,)] + + +@pytest.mark.parametrize("not_a_child", _invalid_children) +def test_invalid_child_nested_iterator(not_a_child: t.Any) -> None: with pytest.raises(TypeError, match="is not a valid child element"): div[[not_a_child]] @@ -276,14 +324,11 @@ def test_invalid_child_lazy_callable(not_a_child: t.Any, render: RenderFixture) @pytest.mark.parametrize("not_a_child", _invalid_children) -def test_invalid_child_lazy_generator(not_a_child: t.Any, render: RenderFixture) -> None: +def test_invalid_child_lazy_iterator(not_a_child: t.Any, render: RenderFixture) -> None: """ Ensure proper exception is raised for lazily evaluated invalid children. """ - def gen() -> t.Any: - yield not_a_child - - element = div[gen()] + element = div[SingleShotIterator(not_a_child)] with pytest.raises(TypeError, match="is not a valid child element"): render(element)