diff --git a/examples/context.py b/examples/context.py new file mode 100644 index 0000000..fbf14b8 --- /dev/null +++ b/examples/context.py @@ -0,0 +1,32 @@ +import htpy as h + +# Defaults +theme_ctx = h.Context("plain") +side_menu_open_ctx = h.Context(False) + + +# @app.get("/") +def get_root() -> h.Node: + current_theme = "blue" # get user theme from database/session etc + side_menu_open = True # get the side menu open state from database/session etc + + return theme_ctx.provider( + current_theme, + lambda: side_menu_open_ctx.provider(side_menu_open, lambda: index_page()), + ) + + +@theme_ctx.consumer +def index_page(theme: str) -> h.Node: + return h.body(class_=f"theme-{theme}")[ + side_menu(), + h.main(), + ] + + +@side_menu_open_ctx.consumer +def side_menu(is_open: bool) -> h.Node: + return h.aside(open=is_open)["The side menu"] + + +print(get_root()) diff --git a/htpy/__init__.py b/htpy/__init__.py index 4001d92..224b940 100644 --- a/htpy/__init__.py +++ b/htpy/__init__.py @@ -3,6 +3,7 @@ __version__ = "24.8.1" __all__: list[str] = [] +import dataclasses import functools import typing as t from collections.abc import Callable, Iterable, Iterator @@ -109,7 +110,59 @@ def _attrs_string(attrs: dict[str, Attribute]) -> str: return " " + result -def iter_node(x: Node) -> Iterator[str]: +T = t.TypeVar("T") +P = t.ParamSpec("P") + + +@dataclasses.dataclass(frozen=True) +class ContextProvider(t.Generic[T]): + context: Context[T] + value: T + func: Callable[[], Node] + + def __iter__(self) -> Iterator[str]: + return iter_node(self) + + def __str__(self) -> str: + return "".join(self) + + +@dataclasses.dataclass(frozen=True) +class ContextConsumer(t.Generic[T]): + context: Context[T] + func: Callable[[T], Node] + args: t.Any + kwargs: t.Any + + +class _NO_DEFAULT: + pass + + +class Context(t.Generic[T]): + def __init__(self, default: T | type[_NO_DEFAULT] = _NO_DEFAULT) -> None: + self.default = default + + def provider(self, value: T, children_func: Callable[[], Node]) -> ContextProvider[T]: + return ContextProvider(self, value, children_func) + + def consumer( + self, + func: Callable[t.Concatenate[T, P], Node], + ) -> Callable[P, ContextConsumer[T]]: + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextConsumer[T]: + return ContextConsumer(self, func, args, kwargs) + + return wrapper + + +ContextDict: t.TypeAlias = dict[t.Any, t.Any] + + +def iter_node(x: Node, context: ContextDict | None = None) -> Iterator[str]: + context_dict: ContextDict = context or {} + while not isinstance(x, BaseElement) and callable(x): x = x() @@ -123,12 +176,19 @@ def iter_node(x: Node) -> Iterator[str]: return if isinstance(x, BaseElement): - yield from x + yield from x.iter_context(context_dict) + elif isinstance(x, ContextProvider): + yield from iter_node(x.func(), {**context_dict, x.context: x.value}) # pyright: ignore [reportUnknownMemberType] + elif isinstance(x, ContextConsumer): + context_value = context_dict.get(x.context, x.context.default) + if context_value is _NO_DEFAULT: + raise ValueError("Context has no value") + yield from iter_node(x.func(context_value, *x.args, **x.kwargs)) elif isinstance(x, str | _HasHtml): yield str(_escape(x)) elif isinstance(x, Iterable): # pyright: ignore [reportUnnecessaryIsInstance] for child in x: - yield from iter_node(child) + yield from iter_node(child, context_dict) else: raise ValueError(f"{x!r} is not a valid child element") @@ -201,8 +261,11 @@ def __call__(self: BaseElementSelf, *args: t.Any, **kwargs: t.Any) -> BaseElemen ) def __iter__(self) -> Iterator[str]: + return self.iter_context({}) + + def iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]: yield f"<{self._name}{_attrs_string(self._attrs)}>" - yield from iter_node(self._children) + yield from iter_node(self._children, ctx) yield f"{self._name}>" def __repr__(self) -> str: @@ -221,13 +284,13 @@ def __getitem__(self: ElementSelf, children: Node) -> ElementSelf: class HTMLElement(Element): - def __iter__(self) -> Iterator[str]: + def iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]: yield "" - yield from super().__iter__() + yield from super().iter_context(ctx) class VoidElement(BaseElement): - def __iter__(self) -> Iterator[str]: + def iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]: yield f"<{self._name}{_attrs_string(self._attrs)}>" @@ -248,7 +311,15 @@ def __html__(self) -> str: ... _ClassNamesDict: t.TypeAlias = dict[str, bool] _ClassNames: t.TypeAlias = Iterable[str | None | bool | _ClassNamesDict] | _ClassNamesDict Node: t.TypeAlias = ( - None | bool | str | BaseElement | _HasHtml | Iterable["Node"] | Callable[[], "Node"] + None + | bool + | str + | BaseElement + | _HasHtml + | Iterable["Node"] + | Callable[[], "Node"] + | ContextProvider[t.Any] + | ContextConsumer[t.Any] ) Attribute: t.TypeAlias = None | bool | str | _HasHtml | _ClassNames diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 0000000..7fe691a --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,46 @@ +import typing as t + +import pytest + +from htpy import Context, div + +letter_ctx: Context[t.Literal["a", "b", "c"]] = Context("a") +no_default_ctx = Context[str]() + + +@letter_ctx.consumer +def display_letter(letter: t.Literal["a", "b", "c"], greeting: str) -> str: + return f"{greeting}: {letter}!" + + +@no_default_ctx.consumer +def display_no_default(value: str) -> str: + return f"{value=}" + + +def test_context_default() -> None: + result = div[display_letter("Yo")] + assert str(result) == "