diff --git a/docs/changelog.md b/docs/changelog.md index 68877f7..8540d50 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -6,6 +6,9 @@ #42](https://github.com/pelme/htpy/issues/42). - Run tests on Python 3.13 RC (no changes were required, earlier versions should work fine too). [PR #45](https://github.com/pelme/htpy/pull/45). + - Attributes that are not strings will now be rejected runtime. Attributes have + been typed as strings previously but this is now also enforced during runtime. + If you need to pass non-strings as attribute values, wrap them in str() calls. ## 24.8.0 - 2024-08-03 - Allow conditional rendering based on `bool`. [PR #40](https://github.com/pelme/htpy/pull/41). diff --git a/htpy/__init__.py b/htpy/__init__.py index 3c03930..4689f72 100644 --- a/htpy/__init__.py +++ b/htpy/__init__.py @@ -80,7 +80,10 @@ def _kwarg_attribute_name(name: str) -> str: def _generate_attrs(raw_attrs: dict[str, Attribute]) -> Iterable[tuple[str, Attribute]]: for key, value in raw_attrs.items(): - if value in (False, None): + if not isinstance(key, str): # pyright: ignore [reportUnnecessaryIsInstance] + raise ValueError("Attribute key must be a string") + + if value is False or value is None: continue if key == "class": @@ -91,6 +94,9 @@ def _generate_attrs(raw_attrs: dict[str, Attribute]) -> Iterable[tuple[str, Attr yield _force_escape(key), True else: + if not isinstance(value, str | _HasHtml): + raise ValueError(f"Attribute value must be a string , got {value!r}") + yield _force_escape(key), _force_escape(value) @@ -118,9 +124,9 @@ def iter_node(x: Node) -> Iterator[str]: if isinstance(x, BaseElement): yield from x - elif isinstance(x, str) or hasattr(x, "__html__"): + elif isinstance(x, str | _HasHtml): yield str(_escape(x)) - elif isinstance(x, Iterable): + elif isinstance(x, Iterable): # pyright: ignore [reportUnnecessaryIsInstance] for child in x: yield from iter_node(child) else: @@ -234,6 +240,7 @@ def comment(text: str) -> _Markup: return _Markup(f"") +@t.runtime_checkable class _HasHtml(t.Protocol): def __html__(self) -> str: ... diff --git a/tests/test_attributes.py b/tests/test_attributes.py index 6d3975c..13c09eb 100644 --- a/tests/test_attributes.py +++ b/tests/test_attributes.py @@ -1,3 +1,5 @@ +import typing as t + import pytest from markupsafe import Markup @@ -177,3 +179,18 @@ def test_class_priority() -> None: def test_attribute_priority() -> None: result = div({"foo": "a"}, foo="b") assert str(result) == """
""" + + +@pytest.mark.parametrize("not_an_attr", [1234, b"foo", object(), object, 1, 0, None]) +def test_invalid_attribute_key(not_an_attr: t.Any) -> None: + with pytest.raises(ValueError, match="Attribute key must be a string"): + str(div({not_an_attr: "foo"})) + + +@pytest.mark.parametrize( + "not_an_attr", + [1234, b"foo", object(), object, 1, 0], +) +def test_invalid_attribute_value(not_an_attr: t.Any) -> None: + with pytest.raises(ValueError, match="Attribute value must be a string"): + str(div(foo=not_an_attr))