diff --git a/htpy/__init__.py b/htpy/__init__.py index 4001d92..143c424 100644 --- a/htpy/__init__.py +++ b/htpy/__init__.py @@ -109,7 +109,45 @@ def _attrs_string(attrs: dict[str, Attribute]) -> str: return " " + result -def iter_node(x: Node) -> Iterator[str]: +T = t.TypeVar("T") + + +class ContextProvider(t.Generic[T]): + def __init__(self, context: Context[T], value: T, children_func: Callable[[], Node]) -> None: + self.context = context + self.value = value + self.children_func = children_func + + def __iter__(self) -> Iterator[str]: + return iter_node(self) + + def __str__(self) -> str: + return "".join(self) + + +class ContextUse(t.Generic[T]): + def __init__(self, context: Context[T], children_func: Callable[[T], Node]): + self.context = context + self.children_func = children_func + + +class Context(t.Generic[T]): + def __init__(self, default: T) -> None: + self.default = default + + def provide(self, value: T, children_func: Callable[[], Node]) -> ContextProvider[T]: + return ContextProvider(self, value, children_func) + + def use(self, children_func: Callable[[T], Node]) -> ContextUse[T]: + return ContextUse(self, children_func) + + +ContextDict: t.TypeAlias = dict[t.Any, t.Any] + + +def iter_node(x: Node, ctx: ContextDict | None = None) -> Iterator[str]: + _ctx: ContextDict = ctx or {} + while not isinstance(x, BaseElement) and callable(x): x = x() @@ -123,12 +161,16 @@ def iter_node(x: Node) -> Iterator[str]: return if isinstance(x, BaseElement): - yield from x + yield from x.iter_context(_ctx) + elif isinstance(x, ContextProvider): + yield from iter_node(x.children_func(), {**_ctx, x.context: x.value}) + elif isinstance(x, ContextUse): + yield from iter_node(x.children_func(_ctx.get(x.context, x.context.default)), _ctx) elif isinstance(x, str | _HasHtml): yield str(_escape(x)) elif isinstance(x, Iterable): # pyright: ignore [reportUnnecessaryIsInstance] for child in x: - yield from iter_node(child) + yield from iter_node(child, _ctx) else: raise ValueError(f"{x!r} is not a valid child element") @@ -201,8 +243,11 @@ def __call__(self: BaseElementSelf, *args: t.Any, **kwargs: t.Any) -> BaseElemen ) def __iter__(self) -> Iterator[str]: + return self.iter_context({}) + + def iter_context(self, ctx: dict[Context, t.Any]) -> Iterator[str]: yield f"<{self._name}{_attrs_string(self._attrs)}>" - yield from iter_node(self._children) + yield from iter_node(self._children, ctx) yield f"{self._name}>" def __repr__(self) -> str: @@ -221,13 +266,13 @@ def __getitem__(self: ElementSelf, children: Node) -> ElementSelf: class HTMLElement(Element): - def __iter__(self) -> Iterator[str]: + def iter_context(self, ctx: dict[Context, t.Any]) -> Iterator[str]: yield "" - yield from super().__iter__() + yield from super().iter_context(ctx) class VoidElement(BaseElement): - def __iter__(self) -> Iterator[str]: + def iter_context(self, ctx: dict[Context, t.Any]) -> Iterator[str]: yield f"<{self._name}{_attrs_string(self._attrs)}>" @@ -248,7 +293,15 @@ def __html__(self) -> str: ... _ClassNamesDict: t.TypeAlias = dict[str, bool] _ClassNames: t.TypeAlias = Iterable[str | None | bool | _ClassNamesDict] | _ClassNamesDict Node: t.TypeAlias = ( - None | bool | str | BaseElement | _HasHtml | Iterable["Node"] | Callable[[], "Node"] + None + | bool + | str + | BaseElement + | _HasHtml + | Iterable["Node"] + | Callable[[], "Node"] + | ContextProvider[t.Any] + | ContextUse[t.Any] ) Attribute: t.TypeAlias = None | bool | str | _HasHtml | _ClassNames diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 0000000..1f1045b --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,15 @@ +from htpy import Context, div + + +def test_context_default() -> None: + theme: Context[str] = Context("light") + + result = div[theme.use(lambda value: f"My theme is {value}")] + assert str(result) == "