From 876d6f8bc815fbf4ff257bc560cf7f3b6266dbec Mon Sep 17 00:00:00 2001 From: Andreas Pelme Date: Sun, 15 Dec 2024 10:18:07 +0100 Subject: [PATCH] Make possible to use one-shot iterators as children nodes. Previously there were checks for generators specifically. We also support one-off iterables that are not based on generators such as itertools.chain() or any non-generator based object that implements __next__(). This regression was introduced in https://github.com/pelme/htpy/pull/56. Based on https://github.com/pelme/htpy/pull/71. --- docs/changelog.md | 5 ++++ htpy/__init__.py | 29 +++++++++++++++-------- tests/test_children.py | 52 ++++++++++++++++++++++++++++++++++++++---- 3 files changed, 72 insertions(+), 14 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 1d9ce34..f75d5c4 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,10 @@ # Changelog +## NEXT +- Fixed handling of non-generator iterables such as `itertools.chain()` as +children. Thanks to Aleksei Pirogov ([@astynax](https://github.com/astynax)). +[PR #72](https://github.com/pelme/htpy/pull/72). + ## 24.10.1 - 2024-10-24 - Fix handling of Python keywords such as `` in html2htpy. [PR #61](https://github.com/pelme/htpy/pull/61). diff --git a/htpy/__init__.py b/htpy/__init__.py index 979bedf..655ec43 100644 --- a/htpy/__init__.py +++ b/htpy/__init__.py @@ -4,7 +4,7 @@ import functools import keyword import typing as t -from collections.abc import Callable, Generator, Iterable, Iterator +from collections.abc import Callable, Iterable, Iterator from markupsafe import Markup as _Markup from markupsafe import escape as _escape @@ -297,14 +297,26 @@ def encode(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: def _validate_children(children: t.Any) -> None: - if isinstance(children, _KnownValidChildren): + # Non-lazy iterables: + # list and tuple are iterables and part of _KnownValidChildren. Since we + # know they can be consumed multiple times, we validate them recursively now + # rather than at render time to provide better error messages. + if isinstance(children, list | tuple): + for child in children: # pyright: ignore[reportUnknownVariableType] + _validate_children(child) return - if isinstance(children, Iterable) and not isinstance(children, _KnownInvalidChildren): - for child in children: # pyright: ignore [reportUnknownVariableType] - _validate_children(child) + # bytes, bytearray etc: + # These are Iterable (part of _KnownValidChildren) but still not + # useful as a child node. + if isinstance(children, _KnownInvalidChildren): + raise TypeError(f"{children!r} is not a valid child element") + + # Element, str, int and all other regular/valid types. + if isinstance(children, _KnownValidChildren): return + # Arbitrary objects that are not valid children. raise TypeError(f"{children!r} is not a valid child element") @@ -487,15 +499,14 @@ def __html__(self) -> str: ... _KnownInvalidChildren: UnionType = bytes | bytearray | memoryview - -_KnownValidChildren: UnionType = ( # pyright: ignore [reportUnknownVariableType] +_KnownValidChildren: UnionType = ( None | BaseElement | ContextProvider # pyright: ignore [reportMissingTypeArgument] | ContextConsumer # pyright: ignore [reportMissingTypeArgument] - | Callable # pyright: ignore [reportMissingTypeArgument] | str | int - | Generator # pyright: ignore [reportMissingTypeArgument] | _HasHtml + | Callable + | Iterable ) diff --git a/tests/test_children.py b/tests/test_children.py index fcfb818..f96e041 100644 --- a/tests/test_children.py +++ b/tests/test_children.py @@ -6,6 +6,7 @@ import pathlib import re import typing as t +from collections.abc import Iterable import pytest from markupsafe import Markup @@ -23,6 +24,28 @@ from .conftest import RenderFixture, TraceFixture +T = t.TypeVar("T") + + +class SingleShotIterable(t.Generic[T]): + def __init__(self, value: T) -> None: + self.value = value + self.consumed = False + + def __iter__(self) -> SingleShotIterable[T]: + return self + + def __next__(self) -> T: + if self.consumed: + raise StopIteration + + self.consumed = True + return self.value + + +assert isinstance(SingleShotIterable("foo"), Iterable) + + def test_void_element(render: RenderFixture) -> None: result = input(name="foo") assert_type(result, VoidElement) @@ -109,6 +132,16 @@ def test_generator_children(render: RenderFixture) -> None: assert render(result) == [""] +def test_non_generator_iterable(render: RenderFixture) -> None: + result = ul[SingleShotIterable("hello")] + + assert render(result) == [ + "", + ] + + def test_html_tag_with_doctype(render: RenderFixture) -> None: result = html(foo="bar")["hello"] assert render(result) == ["", '', "hello", ""] @@ -259,6 +292,18 @@ def test_invalid_child_direct(not_a_child: t.Any) -> None: div[not_a_child] +@pytest.mark.parametrize("not_a_child", _invalid_children) +def test_invalid_child_wrapped_in_list(not_a_child: t.Any) -> None: + with pytest.raises(TypeError, match="is not a valid child element"): + div[[not_a_child]] + + +@pytest.mark.parametrize("not_a_child", _invalid_children) +def test_invalid_child_wrapped_in_tuple(not_a_child: t.Any) -> None: + with pytest.raises(TypeError, match="is not a valid child element"): + div[(not_a_child,)] + + @pytest.mark.parametrize("not_a_child", _invalid_children) def test_invalid_child_nested_iterable(not_a_child: t.Any) -> None: with pytest.raises(TypeError, match="is not a valid child element"): @@ -276,14 +321,11 @@ def test_invalid_child_lazy_callable(not_a_child: t.Any, render: RenderFixture) @pytest.mark.parametrize("not_a_child", _invalid_children) -def test_invalid_child_lazy_generator(not_a_child: t.Any, render: RenderFixture) -> None: +def test_invalid_child_lazy_iterable(not_a_child: t.Any, render: RenderFixture) -> None: """ Ensure proper exception is raised for lazily evaluated invalid children. """ - def gen() -> t.Any: - yield not_a_child - - element = div[gen()] + element = div[SingleShotIterable(not_a_child)] with pytest.raises(TypeError, match="is not a valid child element"): render(element)