Skip to content

Commit

Permalink
Draft context implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
pelme committed Aug 20, 2024
1 parent 69f4b40 commit 611a4da
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 8 deletions.
32 changes: 32 additions & 0 deletions examples/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import htpy as h

# Defaults
theme_ctx = h.Context("plain")
side_menu_open_ctx = h.Context(False)


# @app.get("/")
def get_root() -> h.Node:
current_theme = "blue" # get user theme from database/session etc
side_menu_open = True # get the side menu open state from database/session etc

return theme_ctx.provider(
current_theme,
lambda: side_menu_open_ctx.provider(side_menu_open, lambda: index_page()),
)


@theme_ctx.consumer
def index_page(theme: str) -> h.Node:
return h.body(class_=f"theme-{theme}")[
side_menu(),
h.main(),
]


@side_menu_open_ctx.consumer
def side_menu(is_open: bool) -> h.Node:
return h.aside(open=is_open)["The side menu"]


print(get_root())
87 changes: 79 additions & 8 deletions htpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
__version__ = "24.8.1"
__all__: list[str] = []

import dataclasses
import functools
import typing as t
from collections.abc import Callable, Iterable, Iterator
Expand Down Expand Up @@ -109,7 +110,59 @@ def _attrs_string(attrs: dict[str, Attribute]) -> str:
return " " + result


def iter_node(x: Node) -> Iterator[str]:
T = t.TypeVar("T")
P = t.ParamSpec("P")


@dataclasses.dataclass(frozen=True)
class ContextProvider(t.Generic[T]):
context: Context[T]
value: T
func: Callable[[], Node]

def __iter__(self) -> Iterator[str]:
return iter_node(self)

def __str__(self) -> str:
return "".join(self)


@dataclasses.dataclass(frozen=True)
class ContextConsumer(t.Generic[T]):
context: Context[T]
func: Callable[[T], Node]
args: t.Any
kwargs: t.Any


class _NO_DEFAULT:
pass


class Context(t.Generic[T]):
def __init__(self, default: T | type[_NO_DEFAULT] = _NO_DEFAULT) -> None:
self.default = default

def provider(self, value: T, children_func: Callable[[], Node]) -> ContextProvider[T]:
return ContextProvider(self, value, children_func)

def consumer(
self,
func: Callable[t.Concatenate[T, P], Node],
) -> Callable[P, ContextConsumer[T]]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextConsumer[T]:
return ContextConsumer(self, func, args, kwargs)

return wrapper


ContextDict: t.TypeAlias = dict[t.Any, t.Any]


def iter_node(x: Node, context: ContextDict | None = None) -> Iterator[str]:
context_dict: ContextDict = context or {}

while not isinstance(x, BaseElement) and callable(x):
x = x()

Expand All @@ -123,12 +176,19 @@ def iter_node(x: Node) -> Iterator[str]:
return

if isinstance(x, BaseElement):
yield from x
yield from x.iter_context(context_dict)
elif isinstance(x, ContextProvider):
yield from iter_node(x.func(), {**context_dict, x.context: x.value}) # pyright: ignore [reportUnknownMemberType]
elif isinstance(x, ContextConsumer):
context_value = context_dict.get(x.context, x.context.default)
if context_value is _NO_DEFAULT:
raise ValueError("Context has no value")
yield from iter_node(x.func(context_value, *x.args, **x.kwargs))
elif isinstance(x, str | _HasHtml):
yield str(_escape(x))
elif isinstance(x, Iterable): # pyright: ignore [reportUnnecessaryIsInstance]
for child in x:
yield from iter_node(child)
yield from iter_node(child, context_dict)
else:
raise ValueError(f"{x!r} is not a valid child element")

Expand Down Expand Up @@ -201,8 +261,11 @@ def __call__(self: BaseElementSelf, *args: t.Any, **kwargs: t.Any) -> BaseElemen
)

def __iter__(self) -> Iterator[str]:
return self.iter_context({})

def iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
yield f"<{self._name}{_attrs_string(self._attrs)}>"
yield from iter_node(self._children)
yield from iter_node(self._children, ctx)
yield f"</{self._name}>"

def __repr__(self) -> str:
Expand All @@ -221,13 +284,13 @@ def __getitem__(self: ElementSelf, children: Node) -> ElementSelf:


class HTMLElement(Element):
def __iter__(self) -> Iterator[str]:
def iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
yield "<!doctype html>"
yield from super().__iter__()
yield from super().iter_context(ctx)


class VoidElement(BaseElement):
def __iter__(self) -> Iterator[str]:
def iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
yield f"<{self._name}{_attrs_string(self._attrs)}>"


Expand All @@ -248,7 +311,15 @@ def __html__(self) -> str: ...
_ClassNamesDict: t.TypeAlias = dict[str, bool]
_ClassNames: t.TypeAlias = Iterable[str | None | bool | _ClassNamesDict] | _ClassNamesDict
Node: t.TypeAlias = (
None | bool | str | BaseElement | _HasHtml | Iterable["Node"] | Callable[[], "Node"]
None
| bool
| str
| BaseElement
| _HasHtml
| Iterable["Node"]
| Callable[[], "Node"]
| ContextProvider[t.Any]
| ContextConsumer[t.Any]
)

Attribute: t.TypeAlias = None | bool | str | _HasHtml | _ClassNames
Expand Down
46 changes: 46 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import typing as t

import pytest

from htpy import Context, div

letter_ctx: Context[t.Literal["a", "b", "c"]] = Context("a")
no_default_ctx = Context[str]()


@letter_ctx.consumer
def display_letter(letter: t.Literal["a", "b", "c"], greeting: str) -> str:
return f"{greeting}: {letter}!"


@no_default_ctx.consumer
def display_no_default(value: str) -> str:
return f"{value=}"


def test_context_default() -> None:
result = div[display_letter("Yo")]
assert str(result) == "<div>Yo: a!</div>"


def test_context_provider() -> None:
result = letter_ctx.provider("c", lambda: div[display_letter("Hello")])
assert str(result) == "<div>Hello: c!</div>"


def test_no_default() -> None:
with pytest.raises(ValueError, match="Context has no value"):
str(div[display_no_default()])


def test_nested_override() -> None:
result = div[
letter_ctx.provider(
"b",
lambda: letter_ctx.provider(
"c",
lambda: display_letter("Nested"),
),
)
]
assert str(result) == "<div>Nested: c!</div>"

0 comments on commit 611a4da

Please sign in to comment.