diff --git a/conftest.py b/conftest.py index 02aa6eb..25f3ed5 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,6 @@ +import asyncio +import typing as t + import pytest @@ -13,3 +16,28 @@ def django_env() -> None: ] ) django.setup() + + +@pytest.fixture(params=["sync", "async"]) +def to_list(request: pytest.FixtureRequest) -> t.Any: + from htpy import Node, aiter_node, iter_node + + def func(node: Node) -> t.Any: + if request.param == "sync": + return list(iter_node(node)) + else: + + async def get_list() -> t.Any: + result = [] + async for chunk in aiter_node(node): + result.append(chunk) + return result + + return asyncio.run(get_list(), debug=True) + + return func + + +@pytest.fixture +def to_str(to_list): + return lambda node: "".join(to_list(node)) diff --git a/docs/assets/starlette.webm b/docs/assets/starlette.webm new file mode 100644 index 0000000..0587c96 Binary files /dev/null and b/docs/assets/starlette.webm differ diff --git a/docs/streaming.md b/docs/streaming.md index 403df6b..768c4e3 100644 --- a/docs/streaming.md +++ b/docs/streaming.md @@ -1,7 +1,7 @@ # Streaming of Contents Internally, htpy is built with generators. Most of the time, you would render -the full page with `str()`, but htpy can also incrementally generate pages which +the full page with `str()`, but htpy can also incrementally generate pages synchronously or asynchronous which can then be streamed to the browser. If your page uses a database or other services to retrieve data, you can sending the first part of the page to the client while the page is being generated. @@ -16,7 +16,6 @@ client while the page is being generated. This video shows what it looks like in the browser to generate a HTML table with [Django StreamingHttpResponse](https://docs.djangoproject.com/en/5.0/ref/request-response/#django.http.StreamingHttpResponse) ([source code](https://github.com/pelme/htpy/blob/main/examples/djangoproject/stream/views.py)): @@ -111,3 +110,59 @@ print( # output:

Fibonacci!

fib(12)=6765
``` + + +## Asynchronous streaming + +It is also possible to use htpy to stream fully asynchronous. This intended to be used +with ASGI/async web frameworks/servers such as Starlette and Django. You can +build htpy components using Python's `asyncio` module and the `async`/`await` +syntax. + +### Starlette, ASGI and uvicorn example + +```python +title="starlette_demo.py" +import asyncio +from collections.abc import AsyncIterator + +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import StreamingResponse + +from htpy import Element, div, h1, li, p, ul + +app = Starlette(debug=True) + + +@app.route("/") +async def index(request: Request) -> StreamingResponse: + return StreamingResponse(await index_page(), media_type="text/html") + + +async def index_page() -> Element: + return div[ + h1["Starlette Async example"], + p["This page is generated asynchronously using Starlette and ASGI."], + ul[(li[str(num)] async for num in slow_numbers(1, 10))], + ] + + +async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]: + for number in range(minimum, maximum + 1): + yield number + await asyncio.sleep(0.5) + +``` + +Run with [uvicorn](https://www.uvicorn.org/): + + +``` +$ uvicorn starlette_demo:app +``` + +In the browser, it looks like this: + diff --git a/examples/async_coroutine.py b/examples/async_coroutine.py new file mode 100644 index 0000000..26e737b --- /dev/null +++ b/examples/async_coroutine.py @@ -0,0 +1,26 @@ +import asyncio +import random + +from htpy import Element, b, div, h1 + + +async def magic_number() -> Element: + await asyncio.sleep(1) + return b[f"The Magic Number is: {random.randint(1, 100)}"] + + +async def my_component() -> Element: + return div[ + h1["The Magic Number"], + await magic_number(), + ] + + +async def main() -> None: + import time + + async for chunk in my_component(): + print(f"got: chunk") + + +asyncio.run(main()) diff --git a/examples/starlette_app.py b/examples/starlette_app.py new file mode 100644 index 0000000..eab1d15 --- /dev/null +++ b/examples/starlette_app.py @@ -0,0 +1,29 @@ +import asyncio +from collections.abc import AsyncIterator + +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import StreamingResponse + +from htpy import Element, div, h1, li, p, ul + +app = Starlette(debug=True) + + +@app.route("/") +async def index(request: Request) -> StreamingResponse: + return StreamingResponse(await index_page(), media_type="text/html") + + +async def index_page() -> Element: + return div[ + h1["Starlette Async example"], + p["This page is generated asynchronously using Starlette and ASGI."], + ul[(li[str(num)] async for num in slow_numbers(1, 10))], + ] + + +async def slow_numbers(minimum: int, maximum: int) -> AsyncIterator[int]: + for number in range(minimum, maximum + 1): + yield number + await asyncio.sleep(0.5) diff --git a/htpy/__init__.py b/htpy/__init__.py index ceb8af2..fe7e33d 100644 --- a/htpy/__init__.py +++ b/htpy/__init__.py @@ -6,7 +6,14 @@ import dataclasses import functools import typing as t -from collections.abc import Callable, Iterable, Iterator +from collections.abc import ( + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + Iterable, + Iterator, +) from markupsafe import Markup as _Markup from markupsafe import escape as _escape @@ -191,10 +198,73 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It elif isinstance(x, Iterable): # pyright: ignore [reportUnnecessaryIsInstance] for child in x: yield from _iter_node_context(child, context_dict) + elif isinstance(x, Awaitable | AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance] + raise ValueError( + f"{x!r} is not a valid child element. " + "Use async iteration to retrieve element content: https://htpy.dev/async/" + ) else: raise ValueError(f"{x!r} is not a valid child element") +def aiter_node(x: Node) -> AsyncIterator[str]: + return _aiter_node_context(x, {}) + + +async def _aiter_node_context( + x: Node, context_dict: dict[Context[t.Any], t.Any] +) -> AsyncIterator[str]: + while True: + if isinstance(x, Awaitable): + x = await x + continue + + if not isinstance(x, BaseElement) and callable(x): + x = x() + continue + + break + + if x is None: + return + + if x is True: + return + + if x is False: + return + + if isinstance(x, BaseElement): + async for child in x._aiter_context(context_dict): # pyright: ignore [reportPrivateUsage] + yield child + elif isinstance(x, ContextProvider): + async for chunk in _aiter_node_context(x.func(), {**context_dict, x.context: x.value}): # pyright: ignore [reportUnknownMemberType] + yield chunk + + elif isinstance(x, ContextConsumer): + context_value = context_dict.get(x.context, x.context.default) + if context_value is _NO_DEFAULT: + raise LookupError( + f'Context value for "{x.context.name}" does not exist, ' + f"requested by {x.debug_name}()." + ) + async for chunk in _aiter_node_context(x.func(context_value), context_dict): + yield chunk + + elif isinstance(x, str | _HasHtml): + yield str(_escape(x)) + elif isinstance(x, Iterable): + for child in x: # type: ignore[assignment] + async for chunk in _aiter_node_context(child, context_dict): + yield chunk + elif isinstance(x, AsyncIterable): # pyright: ignore[reportUnnecessaryIsInstance] + async for child in x: # type: ignore[assignment] + async for chunk in _aiter_node_context(child, context_dict): # pyright: ignore[reportUnknownArgumentType] + yield chunk + else: + raise ValueError(f"{x!r} is not a valid async child element") + + @functools.lru_cache(maxsize=300) def _get_element(name: str) -> Element: if not name.islower(): @@ -223,7 +293,10 @@ def __str__(self) -> _Markup: @t.overload def __call__( - self: BaseElementSelf, id_class: str, attrs: dict[str, Attribute], **kwargs: Attribute + self: BaseElementSelf, + id_class: str, + attrs: dict[str, Attribute], + **kwargs: Attribute, ) -> BaseElementSelf: ... @t.overload def __call__( @@ -262,6 +335,15 @@ def __call__(self: BaseElementSelf, *args: t.Any, **kwargs: t.Any) -> BaseElemen self._children, ) + async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]: + yield f"<{self._name}{_attrs_string(self._attrs)}>" + async for x in _aiter_node_context(self._children, context): + yield x + yield f"" + + def __aiter__(self) -> AsyncIterator[str]: + return self._aiter_context({}) + def __iter__(self) -> Iterator[str]: return self._iter_context({}) @@ -296,8 +378,16 @@ def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]: yield "" yield from super()._iter_context(ctx) + async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]: + yield "" + async for x in super()._aiter_context(context): + yield x + class VoidElement(BaseElement): + async def _aiter_context(self, context: dict[Context[t.Any], t.Any]) -> AsyncIterator[str]: + yield f"<{self._name}{_attrs_string(self._attrs)}>" + def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]: yield f"<{self._name}{_attrs_string(self._attrs)}>" @@ -328,6 +418,8 @@ def __html__(self) -> str: ... | Callable[[], "Node"] | ContextProvider[t.Any] | ContextConsumer[t.Any] + | AsyncIterable["Node"] + | Awaitable["Node"] ) Attribute: t.TypeAlias = None | bool | str | _HasHtml | _ClassNames diff --git a/pyproject.toml b/pyproject.toml index f2fc1cd..5506f73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ optional-dependencies.dev = [ "mypy", "pyright", "pytest", + "pytest-asyncio", "black", "ruff", "django", diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 0000000..cb6a4c7 --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,48 @@ +from collections.abc import AsyncIterator + +import pytest + +from htpy import Element, li, ul + + +async def async_lis() -> AsyncIterator[Element]: + yield li["a"] + yield li["b"] + + +async def hi() -> Element: + return li["hi"] + + +@pytest.mark.asyncio +async def test_async_iterator() -> None: + result = [chunk async for chunk in ul[async_lis()]] + assert result == [""] + + +@pytest.mark.asyncio +async def test_cororoutinefunction_children() -> None: + result = [chunk async for chunk in ul[hi]] + assert result == [""] + + +@pytest.mark.asyncio +async def test_cororoutine_children() -> None: + result = [chunk async for chunk in ul[hi()]] + assert result == [""] + + +def test_sync_iteration_with_async_children() -> None: + with pytest.raises( + ValueError, + match=( + r" is not a valid child element\. " + r"Use async iteration to retrieve element content: https://htpy.dev/async/" + ), + ): + str(ul[async_lis()]) + + +@pytest.mark.xfail +def test_repr_with_async_children() -> None: + assert repr(ul[async_lis()]) == "'>" diff --git a/tests/test_attributes.py b/tests/test_attributes.py index 13c09eb..6b62330 100644 --- a/tests/test_attributes.py +++ b/tests/test_attributes.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import typing as t import pytest @@ -5,33 +7,36 @@ from htpy import button, div +if t.TYPE_CHECKING: + from .types import ToStr + def test_attribute() -> None: assert str(div(id="hello")["hi"]) == '
hi
' class Test_class_names: - def test_str(self) -> None: + def test_str(self, to_str: ToStr) -> None: result = div(class_='">foo bar') - assert str(result) == '
' + assert to_str(result) == '
' - def test_safestring(self) -> None: + def test_safestring(self, to_str: ToStr) -> None: result = div(class_=Markup('">foo bar')) - assert str(result) == '
' + assert to_str(result) == '
' - def test_list(self) -> None: + def test_list(self, to_str: ToStr) -> None: result = div(class_=['">foo', Markup('">bar'), False, None, "", "baz"]) - assert str(result) == '
' + assert to_str(result) == '
' - def test_tuple(self) -> None: + def test_tuple(self, to_str: ToStr) -> None: result = div(class_=('">foo', Markup('">bar'), False, None, "", "baz")) - assert str(result) == '
' + assert to_str(result) == '
' - def test_dict(self) -> None: + def test_dict(self, to_str: ToStr) -> None: result = div(class_={'">foo': True, Markup('">bar'): True, "x": False, "baz": True}) - assert str(result) == '
' + assert to_str(result) == '
' - def test_nested_dict(self) -> None: + def test_nested_dict(self, to_str: ToStr) -> None: result = div( class_=[ '">list-foo', @@ -39,54 +44,54 @@ def test_nested_dict(self) -> None: {'">dict-foo': True, Markup('">list-bar'): True, "x": False}, ] ) - assert str(result) == ( + assert to_str(result) == ( '
' ) - def test_false(self) -> None: - result = str(div(class_=False)) + def test_false(self, to_str: ToStr) -> None: + result = to_str(div(class_=False)) assert result == "
" - def test_none(self) -> None: - result = str(div(class_=None)) + def test_none(self, to_str: ToStr) -> None: + result = to_str(div(class_=None)) assert result == "
" - def test_no_classes(self) -> None: - result = str(div(class_={"foo": False})) + def test_no_classes(self, to_str: ToStr) -> None: + result = to_str(div(class_={"foo": False})) assert result == "
" -def test_dict_attributes() -> None: +def test_dict_attributes(to_str: ToStr) -> None: result = div({"@click": 'hi = "hello"'}) - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_underscore() -> None: +def test_underscore(to_str: ToStr) -> None: # Hyperscript (https://hyperscript.org/) uses _, make sure it works good. result = div(_="foo") - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_dict_attributes_avoid_replace() -> None: +def test_dict_attributes_avoid_replace(to_str: ToStr) -> None: result = div({"class_": "foo", "hello_hi": "abc"}) - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_dict_attribute_false() -> None: +def test_dict_attribute_false(to_str: ToStr) -> None: result = div({"bool-false": False}) - assert str(result) == "
" + assert to_str(result) == "
" -def test_dict_attribute_true() -> None: +def test_dict_attribute_true(to_str: ToStr) -> None: result = div({"bool-true": True}) - assert str(result) == "
" + assert to_str(result) == "
" -def test_underscore_replacement() -> None: +def test_underscore_replacement(to_str: ToStr) -> None: result = button(hx_post="/foo")["click me!"] - assert str(result) == """""" + assert to_str(result) == """""" class Test_attribute_escape: @@ -98,49 +103,48 @@ class Test_attribute_escape: ], ) - def test_dict(self, x: str) -> None: + def test_dict(self, x: str, to_str: ToStr) -> None: result = div({x: x}) - assert str(result) == """
""" + assert to_str(result) == """
""" - def test_kwarg(self, x: str) -> None: + def test_kwarg(self, x: str, to_str: ToStr) -> None: result = div(**{x: x}) - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_boolean_attribute_true() -> None: +def test_boolean_attribute_true(to_str: ToStr) -> None: result = button(disabled=True) - assert str(result) == "" + assert to_str(result) == "" -def test_kwarg_attribute_none() -> None: +def test_kwarg_attribute_none(to_str: ToStr) -> None: result = div(foo=None) - assert str(result) == "
" + assert to_str(result) == "
" -def test_dict_attribute_none() -> None: +def test_dict_attribute_none(to_str: ToStr) -> None: result = div({"foo": None}) - assert str(result) == "
" + assert to_str(result) == "
" -def test_boolean_attribute_false() -> None: +def test_boolean_attribute_false(to_str: ToStr) -> None: result = button(disabled=False) - assert str(result) == "" + assert to_str(result) == "" -def test_id_class() -> None: +def test_id_class(to_str: ToStr) -> None: result = div("#myid.cls1.cls2") - - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_id_class_only_id() -> None: +def test_id_class_only_id(to_str: ToStr) -> None: result = div("#myid") - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_id_class_only_classes() -> None: +def test_id_class_only_classes(to_str: ToStr) -> None: result = div(".foo.bar") - assert str(result) == """
""" + assert to_str(result) == """
""" def test_id_class_wrong_order() -> None: @@ -158,39 +162,39 @@ def test_id_class_bad_type() -> None: div({"oops": "yes"}, {}) # type: ignore -def test_id_class_and_kwargs() -> None: +def test_id_class_and_kwargs(to_str: ToStr) -> None: result = div("#theid", for_="hello", data_foo="""" + assert to_str(result) == """
""" -def test_attrs_and_kwargs() -> None: +def test_attrs_and_kwargs(to_str: ToStr) -> None: result = div({"a": "1", "for": "a"}, for_="b", b="2") - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_class_priority() -> None: +def test_class_priority(to_str: ToStr) -> None: result = div(".a", {"class": "b"}, class_="c") - assert str(result) == """
""" + assert to_str(result) == """
""" result = div(".a", {"class": "b"}) - assert str(result) == """
""" + assert to_str(result) == """
""" -def test_attribute_priority() -> None: +def test_attribute_priority(to_str: ToStr) -> None: result = div({"foo": "a"}, foo="b") - assert str(result) == """
""" + assert to_str(result) == """
""" @pytest.mark.parametrize("not_an_attr", [1234, b"foo", object(), object, 1, 0, None]) -def test_invalid_attribute_key(not_an_attr: t.Any) -> None: +def test_invalid_attribute_key(not_an_attr: t.Any, to_str: ToStr) -> None: with pytest.raises(ValueError, match="Attribute key must be a string"): - str(div({not_an_attr: "foo"})) + to_str(div({not_an_attr: "foo"})) @pytest.mark.parametrize( "not_an_attr", [1234, b"foo", object(), object, 1, 0], ) -def test_invalid_attribute_value(not_an_attr: t.Any) -> None: +def test_invalid_attribute_value(not_an_attr: t.Any, to_str: ToStr) -> None: with pytest.raises(ValueError, match="Attribute value must be a string"): str(div(foo=not_an_attr)) diff --git a/tests/test_children.py b/tests/test_children.py index 6cffcfb..441fcae 100644 --- a/tests/test_children.py +++ b/tests/test_children.py @@ -13,54 +13,55 @@ from htpy import Node + from .types import ToList, ToStr -def test_void_element() -> None: - element = input(name="foo") - assert_type(element, VoidElement) - assert isinstance(element, VoidElement) - result = str(element) - assert str(result) == '' +def test_void_element(to_str: ToStr) -> None: + result = input(name="foo") + assert_type(result, VoidElement) + assert isinstance(result, VoidElement) + assert to_str(result) == '' -def test_children() -> None: - assert str(div[img]) == "
" +def test_children(to_str: ToStr) -> None: + assert to_str(div[img]) == "
" -def test_multiple_children() -> None: + +def test_multiple_children(to_str: ToStr) -> None: result = ul[li, li] - assert str(result) == "
" + assert to_str(result) == "
" -def test_list_children() -> None: +def test_list_children(to_str: ToStr) -> None: children: list[Element] = [li["a"], li["b"]] result = ul[children] - assert str(result) == "
  • a
  • b
" + assert to_str(result) == "
  • a
  • b
" -def test_tuple_children() -> None: +def test_tuple_children(to_str: ToStr) -> None: result = ul[(li["a"], li["b"])] - assert str(result) == "
  • a
  • b
" + assert to_str(result) == "
  • a
  • b
" -def test_flatten_nested_children() -> None: +def test_flatten_nested_children(to_str: ToStr) -> None: result = dl[ [ (dt["a"], dd["b"]), (dt["c"], dd["d"]), ] ] - assert str(result) == """
a
b
c
d
""" + assert to_str(result) == """
a
b
c
d
""" -def test_flatten_very_nested_children() -> None: +def test_flatten_very_nested_children(to_str: ToStr) -> None: # maybe not super useful but the nesting may be arbitrarily deep result = div[[([["a"]],)], [([["b"]],)]] - assert str(result) == """
ab
""" + assert to_str(result) == """
ab
""" -def test_flatten_nested_generators() -> None: +def test_flatten_nested_generators(to_str: ToStr) -> None: def cols() -> Generator[str, None, None]: yield "a" yield "b" @@ -73,43 +74,43 @@ def rows() -> Generator[Generator[str, None, None], None, None]: result = div[rows()] - assert str(result) == """
abcabcabc
""" + assert to_str(result) == """
abcabcabc
""" -def test_generator_children() -> None: +def test_generator_children(to_str: ToStr) -> None: gen: Generator[Element, None, None] = (li[x] for x in ["a", "b"]) result = ul[gen] - assert str(result) == "
  • a
  • b
" + assert to_str(result) == "
  • a
  • b
" -def test_html_tag_with_doctype() -> None: +def test_html_tag_with_doctype(to_str: ToStr) -> None: result = html(foo="bar")["hello"] - assert str(result) == 'hello' + assert to_str(result) == 'hello' -def test_void_element_children() -> None: +def test_void_element_children(to_str: ToStr) -> None: with pytest.raises(TypeError): img["hey"] # type: ignore[index] -def test_call_without_args() -> None: +def test_call_without_args(to_str: ToStr) -> None: result = img() - assert str(result) == "" + assert to_str(result) == "" -def test_custom_element() -> None: - el = my_custom_element() - assert_type(el, Element) - assert isinstance(el, Element) - assert str(el) == "" +def test_custom_element(to_str: ToStr) -> None: + result = my_custom_element() + assert_type(result, Element) + assert isinstance(result, Element) + assert to_str(result) == "" @pytest.mark.parametrize("ignored_value", [None, True, False]) -def test_ignored(ignored_value: t.Any) -> None: - assert str(div[ignored_value]) == "
" +def test_ignored(to_str: ToStr, ignored_value: t.Any) -> None: + assert to_str(div[ignored_value]) == "
" -def test_iter() -> None: +def test_sync_iter() -> None: trace = "not started" def generate_list() -> Generator[Element, None, None]: @@ -134,8 +135,8 @@ def generate_list() -> Generator[Element, None, None]: assert trace == "done" -def test_iter_str() -> None: - _, child, _ = div["a"] +def test_iter_str(to_list: ToList) -> None: + _, child, _ = to_list(div["a"]) assert child == "a" # Make sure we dont get Markup (subclass of str) diff --git a/tests/test_comment.py b/tests/test_comment.py index 4844edd..9049c4f 100644 --- a/tests/test_comment.py +++ b/tests/test_comment.py @@ -1,17 +1,24 @@ +from __future__ import annotations + +import typing as t + from htpy import comment, div +if t.TYPE_CHECKING: + from .types import ToStr + -def test_simple() -> None: - assert str(div[comment("hi")]) == "
" +def test_simple(to_str: ToStr) -> None: + assert to_str(div[comment("hi")]) == "
" -def test_escape_two_dashes() -> None: - assert str(div[comment("foo--bar")]) == "
" +def test_escape_two_dashes(to_str: ToStr) -> None: + assert to_str(div[comment("foo--bar")]) == "
" -def test_escape_three_dashes() -> None: - assert str(div[comment("foo---bar")]) == "
" +def test_escape_three_dashes(to_str: ToStr) -> None: + assert to_str(div[comment("foo---bar")]) == "
" -def test_escape_four_dashes() -> None: - assert str(div[comment("foo----bar")]) == "
" +def test_escape_four_dashes(to_str: ToStr) -> None: + assert to_str(div[comment("foo----bar")]) == "
" diff --git a/tests/test_context.py b/tests/test_context.py index 77a88a2..c538222 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,9 +1,14 @@ +from __future__ import annotations + import typing as t import pytest from htpy import Context, Node, div +if t.TYPE_CHECKING: + from .types import ToStr + letter_ctx: Context[t.Literal["a", "b", "c"]] = Context("letter", default="a") no_default_ctx = Context[str]("no_default") @@ -18,25 +23,25 @@ def display_no_default(value: str) -> str: return f"{value=}" -def test_context_default() -> None: +def test_context_default(to_str: ToStr) -> None: result = div[display_letter("Yo")] - assert str(result) == "
Yo: a!
" + assert to_str(result) == "
Yo: a!
" -def test_context_provider() -> None: +def test_context_provider(to_str: ToStr) -> None: result = letter_ctx.provider("c", lambda: div[display_letter("Hello")]) - assert str(result) == "
Hello: c!
" + assert to_str(result) == "
Hello: c!
" -def test_no_default() -> None: +def test_no_default(to_str: ToStr) -> None: with pytest.raises( LookupError, match='Context value for "no_default" does not exist, requested by display_no_default()', ): - str(div[display_no_default()]) + to_str(div[display_no_default()]) -def test_nested_override() -> None: +def test_nested_override(to_str: ToStr) -> None: result = div[ letter_ctx.provider( "b", @@ -46,10 +51,10 @@ def test_nested_override() -> None: ), ) ] - assert str(result) == "
Nested: c!
" + assert to_str(result) == "
Nested: c!
" -def test_multiple_consumers() -> None: +def test_multiple_consumers(to_str: ToStr) -> None: a_ctx: Context[t.Literal["a"]] = Context("a_ctx", default="a") b_ctx: Context[t.Literal["b"]] = Context("b_ctx", default="b") @@ -59,10 +64,10 @@ def ab_display(a: t.Literal["a"], b: t.Literal["b"], greeting: str) -> str: return f"{greeting} a={a}, b={b}" result = div[ab_display("Hello")] - assert str(result) == "
Hello a=a, b=b
" + assert to_str(result) == "
Hello a=a, b=b
" -def test_nested_consumer() -> None: +def test_nested_consumer(to_str: ToStr) -> None: ctx: Context[str] = Context("ctx") @ctx.consumer @@ -75,10 +80,10 @@ def inner(value: str, from_outer: str) -> Node: result = div[ctx.provider("foo", outer)] - assert str(result) == "
outer: foo, inner: foo
" + assert to_str(result) == "
outer: foo, inner: foo
" -def test_context_passed_via_iterable() -> None: +def test_context_passed_via_iterable(to_str: ToStr) -> None: ctx: Context[str] = Context("ctx") @ctx.consumer @@ -87,4 +92,4 @@ def echo(value: str) -> str: result = div[ctx.provider("foo", lambda: [echo()])] - assert str(result) == "
foo
" + assert to_str(result) == "
foo
" diff --git a/tests/test_django.py b/tests/test_django.py index b3ca37a..4ca3fba 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -1,9 +1,10 @@ -from typing import Any +from __future__ import annotations + +import typing as t import pytest from django.core import management from django.forms.utils import ErrorList -from django.http import HttpRequest from django.template import Context, Template, TemplateDoesNotExist from django.template.loader import render_to_string from django.utils.html import escape @@ -11,6 +12,12 @@ from htpy import Element, Node, div, li, ul +if t.TYPE_CHECKING: + from django.http import HttpRequest + + from .types import ToStr + + pytestmark = pytest.mark.usefixtures("django_env") @@ -21,26 +28,26 @@ def test_template_injection() -> None: assert result == '
  • I am safe!
' -def test_SafeString() -> None: +def test_SafeString(to_str: ToStr) -> None: result = ul[SafeString("
  • hello
  • ")] - assert str(result) == "
    • hello
    " + assert to_str(result) == "
    • hello
    " -def test_explicit_escape() -> None: +def test_explicit_escape(to_str: ToStr) -> None: result = ul[escape("")] - assert str(result) == "
      <hello>
    " + assert to_str(result) == "
      <hello>
    " -def test_errorlist() -> None: +def test_errorlist(to_str: ToStr) -> None: result = div[ErrorList(["my error"])] - assert str(result) == """
    • my error
    """ + assert to_str(result) == """
    • my error
    """ -def my_template(context: dict[str, Any], request: HttpRequest | None) -> Element: +def my_template(context: dict[str, t.Any], request: HttpRequest | None) -> Element: return div[f"hey {context['name']}"] -def my_template_fragment(context: dict[str, Any], request: HttpRequest | None) -> Node: +def my_template_fragment(context: dict[str, t.Any], request: HttpRequest | None) -> Node: return [div[f"hey {context['name']}"]] diff --git a/tests/types.py b/tests/types.py new file mode 100644 index 0000000..8691d5a --- /dev/null +++ b/tests/types.py @@ -0,0 +1,7 @@ +import typing as t +from collections.abc import Callable + +from htpy import Node + +ToStr: t.TypeAlias = Callable[[Node], str] +ToList: t.TypeAlias = Callable[[Node], list[str]]