From b391be88627e5164a74b213f81b6d9ee4e98db3f Mon Sep 17 00:00:00 2001 From: Andreas Pelme Date: Sun, 15 Dec 2024 10:18:07 +0100 Subject: [PATCH] Make possible to use one-shot iterators as children nodes. Previously there were checks for generators specifically. We also support one-off iterables that are not based on generators such as itertools.chain() or any non-generator based object that implements __next__(). This regression was introduced in https://github.com/pelme/htpy/pull/56. Based on https://github.com/pelme/htpy/pull/71. --- htpy/__init__.py | 23 +++++++++++-------- tests/test_children.py | 52 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 61 insertions(+), 14 deletions(-) diff --git a/htpy/__init__.py b/htpy/__init__.py index 979bedf..b07d4dc 100644 --- a/htpy/__init__.py +++ b/htpy/__init__.py @@ -4,7 +4,7 @@ import functools import keyword import typing as t -from collections.abc import Callable, Generator, Iterable, Iterator +from collections.abc import Callable, Iterable, Iterator from markupsafe import Markup as _Markup from markupsafe import escape as _escape @@ -297,12 +297,18 @@ def encode(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: def _validate_children(children: t.Any) -> None: - if isinstance(children, _KnownValidChildren): + # Validate non-lazy iterables. + # list and tuple are iterables and part of _KnownValidChildren. Since we + # know they can be consumed multiple times, we can validate them now rather + # at render time to provider better error messages. + if isinstance(children, list | tuple): + for child in children: # pyright: ignore[reportUnknownVariableType] + _validate_children(child) return - if isinstance(children, Iterable) and not isinstance(children, _KnownInvalidChildren): - for child in children: # pyright: ignore [reportUnknownVariableType] - _validate_children(child) + if not isinstance(children, _KnownInvalidChildren) and isinstance( + children, _KnownValidChildren + ): return raise TypeError(f"{children!r} is not a valid child element") @@ -487,15 +493,14 @@ def __html__(self) -> str: ... _KnownInvalidChildren: UnionType = bytes | bytearray | memoryview - -_KnownValidChildren: UnionType = ( # pyright: ignore [reportUnknownVariableType] +_KnownValidChildren: UnionType = ( None | BaseElement | ContextProvider # pyright: ignore [reportMissingTypeArgument] | ContextConsumer # pyright: ignore [reportMissingTypeArgument] - | Callable # pyright: ignore [reportMissingTypeArgument] | str | int - | Generator # pyright: ignore [reportMissingTypeArgument] | _HasHtml + | Callable + | Iterable ) diff --git a/tests/test_children.py b/tests/test_children.py index fcfb818..ba34ea0 100644 --- a/tests/test_children.py +++ b/tests/test_children.py @@ -6,6 +6,7 @@ import pathlib import re import typing as t +from collections.abc import Iterable import pytest from markupsafe import Markup @@ -23,6 +24,28 @@ from .conftest import RenderFixture, TraceFixture +T = t.TypeVar("T") + + +class SingleShotIterator(t.Generic[T]): + def __init__(self, value: T) -> None: + self.value = value + self.yielded = False + + def __iter__(self) -> SingleShotIterator[T]: + return self + + def __next__(self) -> T: + if not self.yielded: + self.yielded = True + return self.value + else: + raise StopIteration + + +assert isinstance(SingleShotIterator("foo"), Iterable) + + def test_void_element(render: RenderFixture) -> None: result = input(name="foo") assert_type(result, VoidElement) @@ -109,6 +132,16 @@ def test_generator_children(render: RenderFixture) -> None: assert render(result) == [""] +def test_non_generator_iterable(render: RenderFixture) -> None: + result = ul[SingleShotIterator("hello")] + + assert render(result) == [ + "", + ] + + def test_html_tag_with_doctype(render: RenderFixture) -> None: result = html(foo="bar")["hello"] assert render(result) == ["", '', "hello", ""] @@ -259,6 +292,18 @@ def test_invalid_child_direct(not_a_child: t.Any) -> None: div[not_a_child] +@pytest.mark.parametrize("not_a_child", _invalid_children) +def test_invalid_child_wrapped_in_list(not_a_child: t.Any) -> None: + with pytest.raises(TypeError, match="is not a valid child element"): + div[[not_a_child]] + + +@pytest.mark.parametrize("not_a_child", _invalid_children) +def test_invalid_child_wrapped_in_tuple(not_a_child: t.Any) -> None: + with pytest.raises(TypeError, match="is not a valid child element"): + div[(not_a_child,)] + + @pytest.mark.parametrize("not_a_child", _invalid_children) def test_invalid_child_nested_iterable(not_a_child: t.Any) -> None: with pytest.raises(TypeError, match="is not a valid child element"): @@ -276,14 +321,11 @@ def test_invalid_child_lazy_callable(not_a_child: t.Any, render: RenderFixture) @pytest.mark.parametrize("not_a_child", _invalid_children) -def test_invalid_child_lazy_generator(not_a_child: t.Any, render: RenderFixture) -> None: +def test_invalid_child_lazy_iterator(not_a_child: t.Any, render: RenderFixture) -> None: """ Ensure proper exception is raised for lazily evaluated invalid children. """ - def gen() -> t.Any: - yield not_a_child - - element = div[gen()] + element = div[SingleShotIterator(not_a_child)] with pytest.raises(TypeError, match="is not a valid child element"): render(element)