diff --git a/.github/workflows/test_jax.yml b/.github/workflows/test_jax.yml index 644a14b..2b04bd8 100644 --- a/.github/workflows/test_jax.yml +++ b/.github/workflows/test_jax.yml @@ -29,6 +29,7 @@ jobs: run: | export SEPES_TEST_ARRAYLIB=jax export SEPES_BACKEND=jax + export XLA_FLAGS=--xla_force_host_platform_device_count=8 python -m pip install . coverage run -m pytest tests diff --git a/CHANGELOG.md b/CHANGELOG.md index a6fd830..82fa907 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,51 @@ # Changelog +## V0.12 + +### Deprecations + +- Reduce the core API size by removing: + 1) `tree_graph` (for graphviz) + 2) `tree_mermaid` (mermaidjs) + 3) `Partial/partial` -> Use `jax.tree_util.Partial` instead. + 4) `is_tree_equal` -> Use `bcmap(numpy.testing.*)(pytree1, pytree2)` instead. + 5) `freeze` -> Use `ft.partial(tree_mask, lambda _: True)` instead. + 6) `unfreeze` -> Use `tree_unmask` instead. + 7) `is_nondiff` + 8) `BaseKey` + + +### Changes + +- `tree_{mask,unmask}` now accepts only callable `cond` argument. + + For masking using pytree boolean mask use the following pattern: + + ```python + import jax + import sepes as sp + import functools as ft + tree = [[1, 2], 3] # the nested tree + where = [[True, False], True] # mask tree[0][1] and tree[1] + mask = ft.partial(sp.tree_mask, cond=lambda _: True) + sp.at(tree)[where].apply(mask) # apply using `at` + # [[#1, 2], #3] + # or simply apply to the node directly + tree = [[mask(1), 2], mask(3)] + # [[#1, 2], #3] + ``` + +- Rename `is_frozen` to `is_masked` + - frozen could mean non-trainable array, however the masking is not only for arrays but also for other types that will be hidden across jax transformations. + +- Rename `AtIndexer` to `at` for shorter syntax. + +### Additions + +- Add `fill_value` in `at[...].get(fill_value=...)` to add default value for non + selected leaves. Useful for arrays under `jax.jit` to avoid variable size related errors. +- Add `jax.tree_util.{SequenceKey,GetAttrKey,DictKey}` as valid path keys in `at[...]`. + ## V0.11.3 - Raise error if `autoinit` is used with `__init__` method defined. @@ -7,20 +53,43 @@ - Add `at` as an alias for `AtIndexer` for shorter syntax. - Deprecate `AtIndexer.__call__` in favor of `value_and_tree` to apply function in a functional manner by copying the input argument. -```python -import sepes as sp -class Counter(sp.TreeClass): - def __init__(self, count: int): - self.count = count - def increment(self, value): - self.count += value - return self.count -counter = Counter(0) -# the function follow jax.value_and_grad semantics where the tree is the -# copied mutated input argument, if the function mutates the input arguments -sp.value_and_tree(lambda C: C.increment(1))(counter) -# (1, Counter(count=1)) -``` + ```python + import sepes as sp + class Counter(sp.TreeClass): + def __init__(self, count: int): + self.count = count + def increment(self, value): + self.count += value + return self.count + counter = Counter(0) + # the function follow jax.value_and_grad semantics where the tree is the + # copied mutated input argument, if the function mutates the input arguments + sp.value_and_tree(lambda C: C.increment(1))(counter) + # (1, Counter(count=1)) + ``` + +- Add sharding info in `tree_summary`, `G` for global, `S` for sharded shape. + + ```python + import jax + import sepes as sp + from jax.sharding import Mesh, NamedSharding as N, PartitionSpec as P + import numpy as np + import os + os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + x = jax.numpy.ones([4 * 4, 2 * 2]) + mesh = Mesh(devices=np.array(jax.devices()).reshape(4, 2), axis_names=["i", "j"]) + sharding = N(mesh=mesh, spec=P("i", "j")) + x = jax.device_put(x, device=sharding) + + print(sp.tree_summary(x)) + ┌────┬───────────┬─────┬───────┐ + │Name│Type │Count│Size │ + ├────┼───────────┼─────┼───────┤ + │Σ │G:f32[16,4]│64 │256.00B│ + │ │S:f32[4,2] │ │ │ + └────┴───────────┴─────┴───────┘ + ``` - Updated docstrings. e.g. How to construct flops counter in `tree_summary` using `jax.jit` diff --git a/docs/API/constructor.rst b/docs/API/constructor.rst new file mode 100644 index 0000000..3afe172 --- /dev/null +++ b/docs/API/constructor.rst @@ -0,0 +1,10 @@ +🏗️ Constructor utils API +============================= + + +.. currentmodule:: sepes + +.. autofunction:: field +.. autofunction:: fields +.. autofunction:: autoinit +.. autofunction:: leafwise \ No newline at end of file diff --git a/docs/API/core.rst b/docs/API/core.rst deleted file mode 100644 index e23f8c7..0000000 --- a/docs/API/core.rst +++ /dev/null @@ -1,31 +0,0 @@ -🎯 Core API -============================= - - -.. currentmodule:: sepes - -.. autoclass:: TreeClass - :members: - at -.. autoclass:: Partial -.. autoclass:: partial -.. autoclass:: AtIndexer - :members: - get, - set, - apply, - scan, - reduce, - pluck, - -.. autoclass:: at -.. autoclass:: BaseKey - :members: - __eq__ -.. autofunction:: autoinit -.. autofunction:: leafwise -.. autofunction:: field -.. autofunction:: fields -.. autofunction:: bcmap -.. autofunction:: is_tree_equal -.. autofunction:: value_and_tree \ No newline at end of file diff --git a/docs/API/masking.rst b/docs/API/masking.rst index c1be176..3414824 100644 --- a/docs/API/masking.rst +++ b/docs/API/masking.rst @@ -3,9 +3,6 @@ .. currentmodule:: sepes -.. autofunction:: is_nondiff -.. autofunction:: freeze -.. autofunction:: unfreeze -.. autofunction:: is_frozen +.. autofunction:: is_masked .. autofunction:: tree_mask .. autofunction:: tree_unmask diff --git a/docs/API/module.rst b/docs/API/module.rst new file mode 100644 index 0000000..4b56db8 --- /dev/null +++ b/docs/API/module.rst @@ -0,0 +1,10 @@ +📍 Module API +============================= + + +.. currentmodule:: sepes + +.. autoclass:: TreeClass + :members: + at + diff --git a/docs/API/pretty_print.rst b/docs/API/pretty_print.rst index d863ea9..627fdf4 100644 --- a/docs/API/pretty_print.rst +++ b/docs/API/pretty_print.rst @@ -4,8 +4,6 @@ .. currentmodule:: sepes .. autofunction:: tree_diagram -.. autofunction:: tree_graph -.. autofunction:: tree_mermaid .. autofunction:: tree_repr .. autofunction:: tree_str .. autofunction:: tree_summary \ No newline at end of file diff --git a/docs/API/sepes.rst b/docs/API/sepes.rst index a857ed8..3e396bd 100644 --- a/docs/API/sepes.rst +++ b/docs/API/sepes.rst @@ -5,7 +5,9 @@ :maxdepth: 2 :caption: API Documentation - core + module masking + tree + constructor pretty_print backend diff --git a/docs/API/tree.rst b/docs/API/tree.rst new file mode 100644 index 0000000..0cb4f9a --- /dev/null +++ b/docs/API/tree.rst @@ -0,0 +1,17 @@ +🌲 Tree utils API +============================= + + +.. currentmodule:: sepes + +.. autoclass:: at + :members: + get, + set, + apply, + scan, + reduce, + pluck, + +.. autofunction:: value_and_tree +.. autofunction:: bcmap \ No newline at end of file diff --git a/docs/_static/tree_graph.svg b/docs/_static/tree_graph.svg deleted file mode 100644 index 380a167..0000000 --- a/docs/_static/tree_graph.svg +++ /dev/null @@ -1,67 +0,0 @@ - - - - - - -G - - - -5353602176 - -list - - - -5353602432 - -[0]=1 - - - -5353602176->5353602432 - - - - - -5353602496 - -[1]=2 - - - -5353602176->5353602496 - - - - - -5353602816 - -[2]:dict - - - -5353602176->5353602816 - - - - - -5353602560 - -['a']=3 - - - -5353602816->5353602560 - - - - - diff --git a/docs/_static/tree_graph_stylized.svg b/docs/_static/tree_graph_stylized.svg deleted file mode 100644 index f6a8d7b..0000000 --- a/docs/_static/tree_graph_stylized.svg +++ /dev/null @@ -1,67 +0,0 @@ - - - - - - -G - - - -5345369024 - -list - - - -5353442880 - -[0]=1 - - - -5345369024->5353442880 - - - - - -5353442496 - -[1]=2 - - - -5345369024->5353442496 - - - - - -5353171392 - -[2]:dict - - - -5345369024->5353171392 - - - - - -5353173184 - -['a']=3 - - - -5353171392->5353173184 - - - - - diff --git a/docs/_static/tree_mermaid.jpg b/docs/_static/tree_mermaid.jpg deleted file mode 100644 index 07f1d82..0000000 Binary files a/docs/_static/tree_mermaid.jpg and /dev/null differ diff --git a/sepes/__init__.py b/sepes/__init__.py index 75f69f1..d4ff6a3 100644 --- a/sepes/__init__.py +++ b/sepes/__init__.py @@ -15,69 +15,37 @@ from sepes._src.backend import backend_context from sepes._src.code_build import autoinit, field, fields from sepes._src.tree_base import TreeClass -from sepes._src.tree_index import AtIndexer, BaseKey, at -from sepes._src.tree_mask import ( - freeze, - is_frozen, - is_nondiff, - tree_mask, - tree_unmask, - unfreeze, -) -from sepes._src.tree_pprint import ( - tree_diagram, - tree_graph, - tree_mermaid, - tree_repr, - tree_str, - tree_summary, -) -from sepes._src.tree_util import ( - Partial, - bcmap, - is_tree_equal, - leafwise, - partial, - value_and_tree, -) +from sepes._src.tree_index import at +from sepes._src.tree_mask import is_masked, tree_mask, tree_unmask +from sepes._src.tree_pprint import tree_diagram, tree_repr, tree_str, tree_summary +from sepes._src.tree_util import bcmap, leafwise, value_and_tree -__all__ = ( - # general utils +__all__ = [ + # module utils "TreeClass", - "is_tree_equal", - "field", - "fields", - "autoinit", # pprint utils "tree_diagram", - "tree_graph", - "tree_mermaid", "tree_repr", "tree_str", "tree_summary", # masking utils - "is_nondiff", - "is_frozen", - "freeze", - "unfreeze", + "is_masked", "tree_unmask", "tree_mask", - # indexing utils - "AtIndexer", - "at", - "BaseKey", # tree utils + "at", "bcmap", - "Partial", - "partial", - "leafwise", "value_and_tree", + # construction utils + "field", + "fields", + "autoinit", + "leafwise", # backend utils "backend_context", -) +] -__version__ = "0.11.3" +__version__ = "0.12.0" -AtIndexer.__module__ = "sepes" +at.__module__ = "sepes" TreeClass.__module__ = "sepes" -Partial.__module__ = "sepes" diff --git a/sepes/_src/backend/arraylib/__init__.py b/sepes/_src/backend/arraylib/__init__.py index bbdd925..2bfeeaa 100644 --- a/sepes/_src/backend/arraylib/__init__.py +++ b/sepes/_src/backend/arraylib/__init__.py @@ -14,20 +14,32 @@ """Backend tools for sepes.""" from __future__ import annotations + import functools as ft +from typing import Callable, NamedTuple + + +class NoImplError(NamedTuple): + op: Callable + + def __call__(self, *args, **kwargs): + raise NotImplementedError(f"No implementation for {self.op}" + f" with {args=} {kwargs=}") + -tobytes = ft.singledispatch(lambda array: ...) -where = ft.singledispatch(lambda condition, x, y: ...) -nbytes = ft.singledispatch(lambda array: ...) -shape = ft.singledispatch(lambda array: ...) -dtype = ft.singledispatch(lambda array: ...) -min = ft.singledispatch(lambda array: ...) -max = ft.singledispatch(lambda array: ...) -mean = ft.singledispatch(lambda array: ...) -std = ft.singledispatch(lambda array: ...) -all = ft.singledispatch(lambda array: ...) -is_floating = ft.singledispatch(lambda array: ...) -is_integer = ft.singledispatch(lambda array: ...) -is_inexact = ft.singledispatch(lambda array: ...) -is_bool = ft.singledispatch(lambda array: ...) -ndarrays: tuple[type, ...] = () +tobytes = ft.singledispatch(NoImplError("tobytes")) +where = ft.singledispatch(NoImplError("where")) +nbytes = ft.singledispatch(NoImplError("nbytes")) +shape = ft.singledispatch(NoImplError("shape")) +dtype = ft.singledispatch(NoImplError("dtype")) +min = ft.singledispatch(NoImplError("min")) +max = ft.singledispatch(NoImplError("max")) +mean = ft.singledispatch(NoImplError("mean")) +std = ft.singledispatch(NoImplError("std")) +all = ft.singledispatch(NoImplError("all")) +array_equal = ft.singledispatch(NoImplError("array_equal")) +is_floating = ft.singledispatch(NoImplError("is_floating")) +is_integer = ft.singledispatch(NoImplError("is_integer")) +is_inexact = ft.singledispatch(NoImplError("is_inexact")) +is_bool = ft.singledispatch(NoImplError("is_bool")) +ndarrays: list[type] = [] diff --git a/sepes/_src/backend/arraylib/jax.py b/sepes/_src/backend/arraylib/jax.py index 1022f8e..c494b78 100644 --- a/sepes/_src/backend/arraylib/jax.py +++ b/sepes/_src/backend/arraylib/jax.py @@ -14,12 +14,13 @@ from __future__ import annotations - -from jax import Array import jax.numpy as jnp +import numpy as np +from jax import Array + import sepes._src.backend.arraylib as arraylib -arraylib.tobytes.register(Array, lambda x: jnp.array(x).tobytes()) +arraylib.tobytes.register(Array, lambda x: np.array(x).tobytes()) arraylib.where.register(Array, jnp.where) arraylib.nbytes.register(Array, lambda x: x.nbytes) arraylib.shape.register(Array, jnp.shape) @@ -29,8 +30,9 @@ arraylib.mean.register(Array, jnp.mean) arraylib.std.register(Array, jnp.std) arraylib.all.register(Array, jnp.all) +arraylib.array_equal.register(Array, np.array_equal) # NOTE: not traceable arraylib.is_floating.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.floating)) arraylib.is_integer.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.integer)) arraylib.is_inexact.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.inexact)) arraylib.is_bool.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.bool_)) -arraylib.ndarrays += (Array,) +arraylib.ndarrays.append(Array) diff --git a/sepes/_src/backend/arraylib/numpy.py b/sepes/_src/backend/arraylib/numpy.py index 285c916..1bf8d55 100644 --- a/sepes/_src/backend/arraylib/numpy.py +++ b/sepes/_src/backend/arraylib/numpy.py @@ -16,6 +16,7 @@ import numpy as np from numpy import ndarray + import sepes._src.backend.arraylib as arraylib arraylib.tobytes.register(ndarray, lambda x: np.array(x).tobytes()) @@ -28,8 +29,9 @@ arraylib.mean.register(ndarray, np.mean) arraylib.std.register(ndarray, np.std) arraylib.all.register(ndarray, np.all) +arraylib.array_equal.register(ndarray, np.array_equal) arraylib.is_floating.register(ndarray, lambda x: np.issubdtype(x.dtype, np.floating)) arraylib.is_integer.register(ndarray, lambda x: np.issubdtype(x.dtype, np.integer)) arraylib.is_inexact.register(ndarray, lambda x: np.issubdtype(x.dtype, np.inexact)) arraylib.is_bool.register(ndarray, lambda x: np.issubdtype(x.dtype, np.bool_)) -arraylib.ndarrays += (ndarray,) +arraylib.ndarrays.append(ndarray) diff --git a/sepes/_src/backend/arraylib/torch.py b/sepes/_src/backend/arraylib/torch.py index 696a309..6ddac57 100644 --- a/sepes/_src/backend/arraylib/torch.py +++ b/sepes/_src/backend/arraylib/torch.py @@ -17,6 +17,7 @@ import numpy as np import torch from torch import Tensor + import sepes._src.backend.arraylib as arraylib floatings = [torch.float16, torch.float32, torch.float64] @@ -33,8 +34,9 @@ arraylib.mean.register(Tensor, torch.mean) arraylib.std.register(Tensor, torch.std) arraylib.all.register(Tensor, torch.all) +arraylib.array_equal.register(Tensor, torch.equal) arraylib.is_floating.register(Tensor, lambda x: x.dtype in floatings) arraylib.is_integer.register(Tensor, lambda x: x.dtype in integers) arraylib.is_inexact.register(Tensor, lambda x: x.dtype in floatings + complexes) arraylib.is_bool.register(Tensor, lambda x: x.dtype == torch.bool) -arraylib.ndarrays += (Tensor,) +arraylib.ndarrays.append(Tensor) diff --git a/sepes/_src/backend/treelib/__init__.py b/sepes/_src/backend/treelib/__init__.py index 90655b7..a6cb835 100644 --- a/sepes/_src/backend/treelib/__init__.py +++ b/sepes/_src/backend/treelib/__init__.py @@ -61,7 +61,7 @@ class AbstractTreeLib(abc.ABC): @staticmethod @abc.abstractmethod - def tree_map( + def map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -72,7 +72,7 @@ def tree_map( @staticmethod @abc.abstractmethod - def tree_path_map( + def path_map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -83,7 +83,7 @@ def tree_path_map( @staticmethod @abc.abstractmethod - def tree_flatten( + def flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -92,7 +92,7 @@ def tree_flatten( @staticmethod @abc.abstractmethod - def tree_path_flatten( + def path_flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -101,7 +101,7 @@ def tree_path_flatten( @staticmethod @abc.abstractmethod - def tree_unflatten(treedef: Any, leaves: Iterable[Any]) -> Any: + def unflatten(treedef: Any, leaves: Iterable[Any]) -> Any: ... @staticmethod diff --git a/sepes/_src/backend/treelib/jax.py b/sepes/_src/backend/treelib/jax.py index cf49f8a..43bc9d7 100644 --- a/sepes/_src/backend/treelib/jax.py +++ b/sepes/_src/backend/treelib/jax.py @@ -36,7 +36,7 @@ def __str__(self): class JaxTreeLib(AbstractTreeLib): @staticmethod - def tree_map( + def map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -51,7 +51,7 @@ def tree_map( return jtu.tree_unflatten(treedef, concurrent_map(func, flat, **config)) @staticmethod - def tree_path_map( + def path_map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -66,7 +66,7 @@ def tree_path_map( return jtu.tree_unflatten(treedef, concurrent_map(func, flat, **config)) @staticmethod - def tree_flatten( + def flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -74,7 +74,7 @@ def tree_flatten( return jtu.tree_flatten(tree, is_leaf=is_leaf) @staticmethod - def tree_path_flatten( + def path_flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -82,7 +82,7 @@ def tree_path_flatten( return jtu.tree_flatten_with_path(tree, is_leaf=is_leaf) @staticmethod - def tree_unflatten(treedef: jtu.PyTreeDef, leaves: Iterable[Any]) -> Any: + def unflatten(treedef: jtu.PyTreeDef, leaves: Iterable[Any]) -> Any: return jtu.tree_unflatten(treedef, leaves) @staticmethod diff --git a/sepes/_src/backend/treelib/optree.py b/sepes/_src/backend/treelib/optree.py index 78015ad..4747494 100644 --- a/sepes/_src/backend/treelib/optree.py +++ b/sepes/_src/backend/treelib/optree.py @@ -61,7 +61,7 @@ def __str__(self) -> str: class OpTreeTreeLib(AbstractTreeLib): @staticmethod - def tree_map( + def map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -76,7 +76,7 @@ def tree_map( return ot.tree_unflatten(treedef, concurrent_map(func, flat, **config)) @staticmethod - def tree_path_map( + def path_map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -92,7 +92,7 @@ def tree_path_map( return ot.tree_unflatten(treedef, concurrent_map(func, flat, **config)) @staticmethod - def tree_flatten( + def flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -101,7 +101,7 @@ def tree_flatten( return (leaves, treedef) @staticmethod - def tree_path_flatten( + def path_flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -110,7 +110,7 @@ def tree_path_flatten( return (list(zip(ot.treespec_paths(treedef), leaves)), treedef) @staticmethod - def tree_unflatten(treedef: ot.PyTreeDef, leaves: Iterable[Any]) -> Any: + def unflatten(treedef: ot.PyTreeDef, leaves: Iterable[Any]) -> Any: return ot.tree_unflatten(treedef, leaves) @staticmethod diff --git a/sepes/_src/code_build.py b/sepes/_src/code_build.py index 4889a8c..e2e44bd 100644 --- a/sepes/_src/code_build.py +++ b/sepes/_src/code_build.py @@ -294,6 +294,7 @@ def field( Buffer creation using :attr:`on_getattr`: >>> import sepes as sp + >>> import jax >>> import jax.numpy as jnp >>> @sp.autoinit ... class Tree(sp.TreeClass): @@ -308,6 +309,7 @@ def field( Parameterization using :attr:`on_getattr`: >>> import sepes as sp + >>> import jax >>> import jax.numpy as jnp >>> def symmetric(array: jax.Array) -> jax.Array: ... triangle = jnp.triu(array) # upper triangle diff --git a/sepes/_src/tree_base.py b/sepes/_src/tree_base.py index 978ada2..1f5a0b3 100644 --- a/sepes/_src/tree_base.py +++ b/sepes/_src/tree_base.py @@ -19,14 +19,13 @@ import abc from typing import Any, Hashable, TypeVar -from typing_extensions import Unpack +from typing_extensions import Self, Unpack import sepes from sepes._src.code_build import fields +from sepes._src.tree_index import at from sepes._src.tree_pprint import PPSpec, tree_repr, tree_str from sepes._src.tree_util import is_tree_equal, tree_copy, tree_hash, value_and_tree -from typing_extensions import Self -from sepes._src.tree_index import AtIndexer T = TypeVar("T", bound=Hashable) S = TypeVar("S") @@ -148,11 +147,11 @@ class TreeClass(metaclass=TreeClassMeta): the tree. for example: >>> @sp.leafwise - ... @sp.autoinit ... class Tree(sp.TreeClass): - ... a:int = 1 - ... b:float = 2.0 - >>> tree = Tree() + ... def __init__(self, a:int, b:float): + ... self.a = a + ... self.b = b + >>> tree = Tree(a=1, b=2.0) >>> tree + 1 # will add 1 to each leaf Tree(a=2, b=3.0) @@ -161,45 +160,16 @@ class TreeClass(metaclass=TreeClassMeta): used to ``get``, ``set``, or ``apply`` a function to a leaf or a group of leaves using ``leaf`` name, index or by a boolean mask. - >>> @sp.autoinit - ... class Tree(sp.TreeClass): - ... a:int = 1 - ... b:float = 2.0 - >>> tree = Tree() + >>> class Tree(sp.TreeClass): + ... def __init__(self, a:int, b:float): + ... self.a = a + ... self.b = b + >>> tree = Tree(a=1, b=2.0) >>> tree.at["a"].get() Tree(a=1, b=None) >>> tree.at[0].get() Tree(a=1, b=None) - Note: - - Under ``jax.tree_util.***`` or ``optree`` all :class:`.TreeClass` - attributes are treated as leaves. - - To hide/ignore a specific attribute from the tree leaves, during - ``jax.tree_util.***`` operations, freeze the leaf using :func:`.freeze` - or :func:`.tree_mask`. - - >>> # freeze(exclude) a leaf from the tree leaves: - >>> import jax - >>> import sepes as sp - >>> @sp.autoinit - ... class Tree(sp.TreeClass): - ... a:int = 1 - ... b:float = 2.0 - >>> tree = Tree() - >>> tree = tree.at["a"].apply(sp.freeze) - >>> jax.tree_util.tree_leaves(tree) - [2.0] - - >>> # undo the freeze - >>> tree = tree.at["a"].apply(sp.unfreeze, is_leaf=sp.is_frozen) - >>> jax.tree_util.tree_leaves(tree) - [1, 2.0] - - >>> # using `tree_mask` to exclude a leaf from the tree leaves - >>> freeze_mask = Tree(a=True, b=False) - >>> jax.tree_util.tree_leaves(sp.tree_mask(tree, freeze_mask)) - [2.0] - Note: ``AttributeError`` is raised, If a method that mutates the instance is called directly. Instead use :func:`.value_and_tree` to call @@ -236,23 +206,23 @@ def __init_subclass__(klass: type[T], **k): if "__delattr__" in vars(klass): raise TypeError(f"Reserved method `__delattr__` defined in `{klass}`.") super().__init_subclass__(**k) - # register the class with the proper tree backend. - # the registration envolves defining two rules: how to flatten the nested - # structure of the class and how to unflatten the flattened structure. - # The flatten rule for `TreeClass` is equivalent to vars(self). and the - # unflatten rule is equivalent to `klass(**flat_tree)`. The flatten/unflatten - # rule is exactly same as the flatten rule for normal dictionaries. + # - register the class with the proper tree backend. + # - the registration envolves defining two rules: how to flatten the nested + # structure of the class and how to unflatten the flattened structure. + # The flatten rule for `TreeClass` is equivalent to vars(self). and the + # unflatten rule is equivalent to `klass(**flat_tree)`. The flatten/unflatten + # rule is exactly same as the flatten rule for normal dictionaries. treelib = sepes._src.backend.treelib treelib.register_treeclass(klass) def __setattr__(self, key: str, value: Any) -> None: - # implements the controlled mutability behavior. - # In essence, setattr is allowed to set attributes during initialization - # and during functional call using .at["method"](*, **) by marking the - # instnace as mutable. Otherwise, setattr is disallowed. - # recall that during the functional call using .at["method"](*, **) - # the tree is always copied and the copy is marked as mutable, thus - # setattr is allowed to set attributes on the copy not the original. + # - implements the controlled mutability behavior. + # - In essence, setattr is allowed to set attributes during initialization + # and during functional call using `value_and_tree(method)(*, **)` by marking the + # instnace as mutable. Otherwise, setattr is disallowed. + # - recall that during the functional call using `value_and_tree(method)(*, **)` + # the tree is always copied and the copy is marked as mutable, thus + # setattr is allowed to set attributes on the copy not the original. if id(self) not in _mutable_instance_registry: raise AttributeError( f"Cannot set attribute {value=} to `{key=}` " @@ -262,13 +232,13 @@ def __setattr__(self, key: str, value: Any) -> None: getattr(object, "__setattr__")(self, key, value) def __delattr__(self, key: str) -> None: - # same as __setattr__ but for delattr. - # both __setattr__ and __delattr__ are used to implement the - # controlled mutability behavior during initialization and - # during functional call using .at["method"](*, **). - # recall that during the functional call using .at["method"](*, **) - # the tree is always copied and the copy is marked as mutable, thus - # setattr is allowed to set attributes on the copy not the original. + # - same as __setattr__ but for delattr. + # - both __setattr__ and __delattr__ are used to implement the + # - controlled mutability behavior during initialization and + # during functional call using `value_and_tree(method)(*, **)`. + # - recall that during the functional call using `value_and_tree(method)(*, **)` + # the tree is always copied and the copy is marked as mutable, thus + # setattr is allowed to set attributes on the copy not the original. if id(self) not in _mutable_instance_registry: raise AttributeError( f"Cannot delete attribute `{key}` " @@ -277,7 +247,7 @@ def __delattr__(self, key: str) -> None: getattr(object, "__delattr__")(self, key) @property - def at(self) -> AtIndexer[Self]: + def at(self) -> at[Self]: """Immutable out-of-place indexing. - ``.at[***].get()``: @@ -292,20 +262,18 @@ def at(self) -> AtIndexer[Self]: - ``int`` for positional indexing for sequences. - ``...`` to select all leaves. - a boolean mask of the same structure as the tree - - ``re.Pattern`` to index all keys matching a regex pattern. - - an instance of ``BaseKey`` with custom logic to index a pytree. - a tuple of the above types to index multiple keys at same level. Example: >>> import sepes as sp - >>> @sp.autoinit - ... class Tree(sp.TreeClass): - ... a: int = 1 - ... b: float = 2.0 + >>> class Tree(sp.TreeClass): + ... def __init__(self, a:int, b:float): + ... self.a = a + ... self.b = b ... def add(self, x: int) -> int: ... self.a += x ... return self.a - >>> tree = Tree() + >>> tree = Tree(a=1, b=2.0) >>> tree.at["a"].get() Tree(a=1, b=None) >>> tree.at["a"].set(100) @@ -317,7 +285,10 @@ def at(self) -> AtIndexer[Self]: - ``pytree.at[*][**]`` is equivalent to selecting pytree.*.** . - ``pytree.at[*, **]`` is equivalent selecting pytree.* and pytree.** """ - return AtIndexer(self) + # NOTE: use `at` as a property to enable chaining syntax. + # instead of at(at(tree)[...].apply(...))[...].set(...) + # chaining syntax is tree.at[...].apply(...).at[...].set(...) + return at(self) def __repr__(self) -> str: return tree_repr(self) diff --git a/sepes/_src/tree_index.py b/sepes/_src/tree_index.py index 721c388..4bb98f4 100644 --- a/sepes/_src/tree_index.py +++ b/sepes/_src/tree_index.py @@ -12,31 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Define lens-like indexing/masking for pytrees.""" - -# enable get/set/apply/scan/reduce operations on selected parts of a nested -# structure -pytree- in out-of-place manner. this process invovles defining two -# parts: 1) *where* to select the parts of the pytree and 2) *what* to do with -# the selected parts. the *where* part is defined either by a path or a boolean -# mask. the *what* part is defined by a set value, or a function to apply to -# the selected parts. once we have a *final* boolean mask that encompasses all -# path and the boolean mask, we can use `tree_map` to apply the *what* part to -# the *where* part. for example, for a tree = [[1, 2], 3, 4] and boolean mask -# [[True, False], False, True] and path mask [0][1], then we select only leaf -# 1 that is at the intersection of the boolean mask and the path mask. then we -# apply the *what* part to the *where* part. +"""Define lens-like indexing for pytrees + +This module provides a way to index and mask pytrees (e.g. TreeClass) in an +out-of-place manner.Out-of-place means that the original pytree is not modified, +instead a new pytree with the selected leaves are modified. + +The indexing is done through two concepts: + +1) Selection (Where): Determines parts of the pytree for manipulation via a path or a boolean mask. +2) Operation (What): Defines actions on selected parts, such as setting values or applying functions. + +For example, the following code defines a dict pytree with where of same structure +as the tree. The where (Selection) defines which parts of the tree to select and +the set (Operation) operation sets the selected parts to 100. + +>>> import sepes as sp +>>> tree = {"a": 1, "b": [1, 2, 3]} +>>> where = {"a": True, "b": [False, True, False]} +>>> sp.at(tree)[where].set(100) +{'a': 100, 'b': [1, 100, 3]} +""" from __future__ import annotations import abc import functools as ft import re -from typing import Any, Callable, Hashable, Tuple, TypeVar, Generic +from typing import Any, Callable, Generic, Hashable, TypeVar, Sequence from typing_extensions import Self import sepes import sepes._src.backend.arraylib as arraylib +from sepes._src.backend import is_package_avaiable from sepes._src.backend.treelib import ParallelConfig from sepes._src.tree_pprint import tree_repr @@ -44,241 +53,21 @@ S = TypeVar("S") PyTree = Any EllipsisType = TypeVar("EllipsisType") -KeyEntry = TypeVar("KeyEntry", bound=Hashable) -KeyPath = Tuple[KeyEntry, ...] +PathKeyEntry = TypeVar("PathKeyEntry", bound=Hashable) _no_initializer = object() +_no_fill_value = object() class BaseKey(abc.ABC): - """Parent class for all match classes. - - - Subclass this class to create custom match keys by implementing - the `__eq__` method. The ``__eq__`` method should return True if the - key matches the given path entry and False otherwise. The path entry - refers to the entry defined in the ``tree_flatten_with_keys`` method of - the pytree class. - - - Typical path entries in ``jax`` are: - - - ``jax.tree_util.GetAttrKey`` for attributes - - ``jax.tree_util.DictKey`` for mapping keys - - ``jax.tree_util.SequenceKey`` for sequence indices - - - When implementing the ``__eq__`` method you can use the ``singledispatchmethod`` - to unpack the path entry for example: - - - ``jax.tree_util.GetAttrKey`` -> `key.name` - - ``jax.tree_util.DictKey`` -> `key.key` - - ``jax.tree_util.SequenceKey`` -> `key.index` - - - See Examples for more details. - - Example: - >>> # define an match strategy to match a leaf with a given name and type - >>> import sepes as sp - >>> from typing import NamedTuple - >>> import jax - >>> class NameTypeContainer(NamedTuple): - ... name: str - ... type: type - >>> @jax.tree_util.register_pytree_with_keys_class - ... class Tree: - ... def __init__(self, a, b) -> None: - ... self.a = a - ... self.b = b - ... def tree_flatten_with_keys(self): - ... ak = (NameTypeContainer("a", type(self.a)), self.a) - ... bk = (NameTypeContainer("b", type(self.b)), self.b) - ... return (ak, bk), None - ... @classmethod - ... def tree_unflatten(cls, aux_data, children): - ... return cls(*children) - ... @property - ... def at(self): - ... return sp.at(self) - >>> tree = Tree(1, 2) - >>> class MatchNameType(sp.BaseKey): - ... def __init__(self, name, type): - ... self.name = name - ... self.type = type - ... def __eq__(self, other): - ... if isinstance(other, NameTypeContainer): - ... return other == (self.name, self.type) - ... return False - >>> tree = tree.at[MatchNameType("a", int)].get() - >>> assert jax.tree_util.tree_leaves(tree) == [1] - - Note: - - use ``BaseKey.def_alias(type, func)`` to define an index type alias - for `BaseKey` subclasses. This is useful for convience when - creating new match strategies. - - >>> import sepes as sp - >>> import functools as ft - >>> from types import FunctionType - >>> import jax.tree_util as jtu - >>> # lets define a new match strategy called `FuncKey` that applies - >>> # a function to the path entry and returns True if the function - >>> # returns True and False otherwise. - >>> # for example `FuncKey(lambda x: x.startswith("a"))` will match - >>> # all leaves that start with "a". - >>> class FuncKey(sp.BaseKey): - ... def __init__(self, func): - ... self.func = func - ... @ft.singledispatchmethod - ... def __eq__(self, key): - ... return self.func(key) - ... @__eq__.register(jtu.GetAttrKey) - ... def _(self, key: jtu.GetAttrKey): - ... # unpack the GetAttrKey - ... return self.func(key.name) - ... @__eq__.register(jtu.DictKey) - ... def _(self, key: jtu.DictKey): - ... # unpack the DictKey - ... return self.func(key.key) - ... @__eq__.register(jtu.SequenceKey) - ... def _(self, key: jtu.SequenceKey): - ... return self.func(key.index) - >>> # instead of using ``FuncKey(function)`` we can define an alias - >>> # for `FuncKey`, for this example we will define any FunctionType - >>> # as a `FuncKey` by default. - >>> @sp.BaseKey.def_alias(FunctionType) - ... def _(func): - ... return FuncKey(func) - >>> # create a simple pytree - >>> @sp.autoinit - ... class Tree(sp.TreeClass): - ... a: int - ... b: str - >>> tree = Tree(1, "string") - >>> # now we can use the `FuncKey` alias to match all leaves that - >>> # are strings and start with "a" - >>> tree.at[lambda x: isinstance(x, str) and x.startswith("a")].get() - Tree(a=1, b=None) - """ + """Parent class for all match classes.""" @abc.abstractmethod - def __eq__(self, entry: KeyEntry) -> bool: + def __eq__(self, entry: PathKeyEntry) -> bool: pass - broadcastable: bool = False - - -class IndexKey(BaseKey): - """Match a leaf with a given index.""" - - def __init__(self, idx: int) -> None: - self.idx = idx - - def __eq__(self, key: KeyEntry) -> bool: - if isinstance(key, int): - return self.idx == key - treelib = sepes._src.backend.treelib - if isinstance(key, type(treelib.sequence_key(0))): - return self.idx == key.idx - return False - - def __repr__(self) -> str: - return f"{self.idx}" - - -class NameKey(BaseKey): - """Match a leaf with a given key.""" - - def __init__(self, name: str) -> None: - self.name = name - - def __eq__(self, key: KeyEntry) -> bool: - if isinstance(key, str): - return self.name == key - treelib = sepes._src.backend.treelib - if isinstance(key, type(treelib.attribute_key(""))): - return self.name == key.name - if isinstance(key, type(treelib.dict_key(""))): - return self.name == key.key - return False - - def __repr__(self) -> str: - return f"{self.name}" - - -class EllipsisKey(BaseKey): - """Match all leaves.""" - - broadcastable = True - - def __init__(self, _): - del _ - - def __eq__(self, _: KeyEntry) -> bool: - return True - - def __repr__(self) -> str: - return "..." - - -class MultiKey(BaseKey): - """Match a leaf with multiple keys at the same level.""" - - def __init__(self, *keys: tuple[BaseKey, ...]): - self.keys = tuple(keys) - - def __eq__(self, entry) -> bool: - return any(entry == key for key in self.keys) - - def __repr__(self) -> str: - return f"({', '.join(map(repr, self.keys))})" - - -class RegexKey(BaseKey): - """Match a leaf with a regex pattern inside 'at' property. - - Args: - pattern: regex pattern to match. - - Example: - >>> import sepes as sp - >>> import re - >>> @sp.autoinit - ... class Tree(sp.TreeClass): - ... weight_1: float = 1.0 - ... weight_2: float = 2.0 - ... weight_3: float = 3.0 - ... bias: float = 0.0 - >>> tree = Tree() - >>> tree.at[re.compile(r"weight_.*")].set(100.0) # set all weights to 100.0 - Tree(weight_1=100.0, weight_2=100.0, weight_3=100.0, bias=0.0) - """ - - def __init__(self, pattern: str) -> None: - self.pattern = pattern - - def __eq__(self, key: KeyEntry) -> bool: - if isinstance(key, str): - return re.fullmatch(self.pattern, key) is not None - treelib = sepes._src.backend.treelib - if isinstance(key, type(treelib.attribute_key(""))): - return re.fullmatch(self.pattern, key.name) is not None - if isinstance(key, type(treelib.dict_key(""))): - return re.fullmatch(self.pattern, key.key) is not None - return False - - def __repr__(self) -> str: - return f"{self.pattern}" - - -# dispatch on type of indexer to convert input item to at indexer -# `__getitem__` to the appropriate key -# avoid using container pytree types to avoid conflict between -# matching as a mask or as an instance of `BaseKey` -indexer_dispatcher = ft.singledispatch(lambda x: x) -indexer_dispatcher.register(type(...), EllipsisKey) -indexer_dispatcher.register(int, IndexKey) -indexer_dispatcher.register(str, NameKey) -indexer_dispatcher.register(re.Pattern, RegexKey) - -BaseKey.def_alias = indexer_dispatcher.register + @property + @abc.abstractmethod + def broadcast(self): ... _INVALID_INDEXER = """\ @@ -286,14 +75,13 @@ def __repr__(self) -> str: - `str` for mapping keys or class attributes. - `int` for positional indexing for sequences. - `...` to select all leaves. + - ``re.Pattern`` to match a leaf level path with a regex pattern. - Boolean mask of a compatible structure as the pytree. - - `re.Pattern` to index all keys matching a regex pattern. - - Instance of `BaseKey` with custom logic to index a pytree. - `tuple` of the above types to match multiple leaves at the same level. """ _NO_LEAF_MATCH = """\ -No leaf match is found for where={where}. Available keys are {names}. +No leaf match is found for where={where}, Available keys are {names} Check the following: - If where is `str` then check if the key exists as a key or attribute. - If where is `int` then check if the index is in range. @@ -321,9 +109,9 @@ def is_leaf_func(node) -> bool: return False return True - return treelib.tree_path_map(func, tree, is_leaf=is_leaf_func) + return treelib.path_map(func, tree, is_leaf=is_leaf_func) - if any(mask.broadcastable for mask in where): + if any(where_i.broadcast for where_i in where): # should the selected subtree be broadcasted to the full tree # e.g. tree = [[1, 2], 3, 4] and where = [0], then # broadcast with True will be [[True, True], False, False] @@ -334,8 +122,8 @@ def is_leaf_func(node) -> bool: # and without broadcast the result will be [100, 3, 4] def bool_tree(value: bool, tree: Any): - leaves, treedef = treelib.tree_flatten(tree, is_leaf=is_leaf) - return treelib.tree_unflatten(treedef, [value] * len(leaves)) + leaves, treedef = treelib.flatten(tree, is_leaf=is_leaf) + return treelib.unflatten(treedef, [value] * len(leaves)) true_tree = ft.partial(bool_tree, True) false_tree = ft.partial(bool_tree, False) @@ -380,9 +168,10 @@ def path_map_func(path, leaf): mask = one_level_tree_path_map(path_map_func, tree) if not match: - path_leaf, _ = treelib.tree_path_flatten(tree, is_leaf=is_leaf) + path_leaf, _ = treelib.path_flatten(tree, is_leaf=is_leaf) + path = "/".join(str(where_i.input) for where_i in where) names = "".join("\n - " + treelib.keystr(path) for path, _ in path_leaf) - raise LookupError(_NO_LEAF_MATCH.format(where=where, names=names)) + raise LookupError(_NO_LEAF_MATCH.format(where=path, names=names)) return mask @@ -390,9 +179,10 @@ def path_map_func(path, leaf): def resolve_where( where: list[Any], tree: T, - is_leaf: Callable[[Any], None] | None = None, + is_leaf: Callable[[Any], bool] | None = None, ): treelib = sepes._src.backend.treelib + ndarrays = tuple(arraylib.ndarrays) def combine_bool_leaves(*leaves): # given a list of boolean leaves, combine them using `and` @@ -404,7 +194,7 @@ def combine_bool_leaves(*leaves): return verdict def is_bool_leaf(leaf: Any) -> bool: - if isinstance(leaf, arraylib.ndarrays): + if isinstance(leaf, ndarrays): return arraylib.is_bool(leaf) return isinstance(leaf, bool) @@ -423,7 +213,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool: nonlocal seen_tuple, level_paths, bool_masks # used to check if a pytree is a valid indexing pytree # used with `is_leaf` argument of any `tree_*` function - leaves, _ = treelib.tree_flatten(node) + leaves, _ = treelib.flatten(node) if all(map(is_bool_leaf, leaves)): # if all leaves are boolean then this is maybe a boolean mask. @@ -442,7 +232,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool: bool_masks += [node] return True - if isinstance(resolved_key := indexer_dispatcher(node), BaseKey): + if isinstance(resolved_key := at.dispatcher(node), BaseKey): # valid resolution of `BaseKey` is a valid indexing leaf # makes it possible to dispatch on multi-leaf pytree level_paths += [resolved_key] @@ -463,7 +253,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool: # each for loop iteration is a level in the where path # this means that if where = ("a", "b", "c") then this means # we are travering the tree at level "a" then level "b" then level "c" - treelib.tree_flatten(level_keys, is_leaf=verify_and_aggregate_is_leaf) + treelib.flatten(level_keys, is_leaf=verify_and_aggregate_is_leaf) # if len(level_paths) > 1 then this means that we have multiple keys # at the same level, for example where = ("a", ("b", "c")) then this # means that for a parent "a", select "b" and "c". @@ -476,17 +266,14 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool: if bool_masks: all_masks = [mask, *bool_masks] if mask else bool_masks - mask = treelib.tree_map(combine_bool_leaves, *all_masks) + mask = treelib.map(combine_bool_leaves, *all_masks) return mask -class AtIndexer(Generic[T]): +class at(Generic[T]): """Operate on a pytree at a given path using a path or mask in out-of-place manner. - Note: - Use :class:`.at` as a shorter alias for this class. - Args: tree: pytree to operate on. where: one of the following: @@ -495,13 +282,9 @@ class AtIndexer(Generic[T]): - ``int`` for positional indexing for sequences. - ``...`` to select all leaves. - a boolean mask of the same structure as the tree - - ``re.Pattern`` to index all keys matching a regex pattern. - - an instance of ``BaseKey`` with custom logic to index a pytree. + - ``re.Pattern`` to match a leaf level path with a regex pattern. - a tuple of the above to match multiple keys at the same level. - Note: - Alternatively, use ``at(tree)[where]`` to index a pytree. - Example: >>> import jax >>> import sepes as sp @@ -514,26 +297,23 @@ class AtIndexer(Generic[T]): >>> sp.at(tree)[mask].set(100) {'a': 1, 'b': [1, 100, 100]} """ - def __init__(self, tree: T, where: list[Any] | None = None) -> None: - vars(self)["tree"] = tree - vars(self)["where"] = [] if where is None else where - - def __setattr__(self, key: str, _: Any) -> None: - raise AttributeError(f"Cannot set {key=} on {type(self).__name__} instance") + self.tree = tree + self.where = [] if where is None else where def __getitem__(self, where: Any) -> Self: """Index a pytree at a given path using a path or mask.""" return type(self)(self.tree, [*self.where, where]) def __repr__(self) -> str: - return f"{type(self).__name__}(tree={tree_repr(self.tree)}, where={self.where})" + return f"{type(self).__name__}({tree_repr(self.tree)}, where={self.where})" def get( self, *, - is_leaf: Callable[[Any], None] | None = None, + is_leaf: Callable[[Any], bool] | None = None, is_parallel: bool | ParallelConfig = False, + fill_value: Any = _no_fill_value, ): """Get the leaf values at the specified location. @@ -547,6 +327,10 @@ def get( - ``max_workers``: maximum number of workers to use. - ``kind``: kind of pool to use, either ``thread`` or ``process``. + fill_value: the value to fill the non-selected leaves with. + Useful to use with ``jax.jit`` to avoid variable size arrays + leaves related errors. + Returns: A _new_ pytree of leaf values at the specified location, with the non-selected leaf values set to None if the leaf is not an array. @@ -558,19 +342,25 @@ def get( {'a': None, 'b': [1, None, None]} """ treelib = sepes._src.backend.treelib + ndarrays = tuple(arraylib.ndarrays) def leaf_get(where: Any, leaf: Any): # support both array and non-array leaves # for array boolean mask we select **parts** of the array that # matches the mask, for example if the mask is Array([True, False, False]) # and the leaf is Array([1, 2, 3]) then the result is Array([1]) - if isinstance(where, arraylib.ndarrays) and len(arraylib.shape(where)): + # because of the variable resultant size of the output + if isinstance(where, ndarrays) and len(arraylib.shape(where)): + if fill_value is not _no_fill_value: + return arraylib.where(where, leaf, fill_value) return leaf[where] # non-array boolean mask we select the leaf if the mask is True # and `None` otherwise + if fill_value is not _no_fill_value: + return leaf if where else fill_value return leaf if where else None - return treelib.tree_map( + return treelib.map( leaf_get, resolve_where(self.where, self.tree, is_leaf), self.tree, @@ -582,7 +372,7 @@ def set( self, set_value: Any, *, - is_leaf: Callable[[Any], None] | None = None, + is_leaf: Callable[[Any], bool] | None = None, is_parallel: bool | ParallelConfig = False, ): """Set the leaf values at the specified location. @@ -609,6 +399,7 @@ def set( {'a': 1, 'b': [100, 2, 3]} """ treelib = sepes._src.backend.treelib + ndarrays = tuple(arraylib.ndarrays) def leaf_set(where: Any, leaf: Any, set_value: Any): # support both array and non-array leaves @@ -616,12 +407,12 @@ def leaf_set(where: Any, leaf: Any, set_value: Any): # matches the mask, for example if the mask is Array([True, False, False]) # and the leaf is Array([1, 2, 3]) then the result is Array([1, 100, 100]) # with set_value = 100 - if isinstance(where, arraylib.ndarrays): + if isinstance(where, ndarrays): return arraylib.where(where, set_value, leaf) return set_value if where else leaf - _, lhsdef = treelib.tree_flatten(self.tree, is_leaf=is_leaf) - _, rhsdef = treelib.tree_flatten(set_value, is_leaf=is_leaf) + _, lhsdef = treelib.flatten(self.tree, is_leaf=is_leaf) + _, rhsdef = treelib.flatten(set_value, is_leaf=is_leaf) if lhsdef == rhsdef: # do not broadcast set_value if it is a pytree of same structure @@ -629,7 +420,7 @@ def leaf_set(where: Any, leaf: Any, set_value: Any): # to tree2 leaves if tree2 is a pytree of same structure as tree # instead of making each leaf of tree a copy of tree2 # is design is similar to ``numpy`` design `np.at[...].set(Array)` - return treelib.tree_map( + return treelib.map( leaf_set, resolve_where(self.where, self.tree, is_leaf), self.tree, @@ -638,7 +429,7 @@ def leaf_set(where: Any, leaf: Any, set_value: Any): is_parallel=is_parallel, ) - return treelib.tree_map( + return treelib.map( ft.partial(leaf_set, set_value=set_value), resolve_where(self.where, self.tree, is_leaf), self.tree, @@ -650,7 +441,7 @@ def apply( self, func: Callable[[Any], Any], *, - is_leaf: Callable[[Any], None] | None = None, + is_leaf: Callable[[Any], bool] | None = None, is_parallel: bool | ParallelConfig = False, ): """Apply a function to the leaf values at the specified location. @@ -685,19 +476,19 @@ def apply( >>> is_parallel = dict(max_workers=2) >>> images = sp.at(path)[...].apply(imread, is_parallel=is_parallel) # doctest: +SKIP """ - treelib = sepes._src.backend.treelib + ndarrays = tuple(arraylib.ndarrays) def leaf_apply(where: Any, leaf: Any): # same as `leaf_set` but with `func` applied to the leaf # one thing to note is that, the where mask select an array # then the function needs work properly when applied to the selected # array elements - if isinstance(where, arraylib.ndarrays): + if isinstance(where, ndarrays): return arraylib.where(where, func(leaf), leaf) return func(leaf) if where else leaf - return treelib.tree_map( + return treelib.map( leaf_apply, resolve_where(self.where, self.tree, is_leaf), self.tree, @@ -710,7 +501,7 @@ def scan( func: Callable[[Any, S], tuple[Any, S]], state: S, *, - is_leaf: Callable[[Any], None] | None = None, + is_leaf: Callable[[Any], bool] | None = None, ) -> tuple[Any, S]: """Apply a function while carrying a state. @@ -746,6 +537,7 @@ def scan( leaf values while carrying a state and returning a single value. """ treelib = sepes._src.backend.treelib + ndarrays = tuple(arraylib.ndarrays) running_state = state def stateless_func(leaf): @@ -754,11 +546,11 @@ def stateless_func(leaf): return leaf def leaf_apply(where: Any, leaf: Any): - if isinstance(where, arraylib.ndarrays): + if isinstance(where, ndarrays): return arraylib.where(where, stateless_func(leaf), leaf) return stateless_func(leaf) if where else leaf - out_tree = treelib.tree_map( + out_tree = treelib.map( leaf_apply, resolve_where(self.where, self.tree, is_leaf), self.tree, @@ -771,7 +563,7 @@ def reduce( func: Callable[[Any, Any], Any], *, initializer: Any = _no_initializer, - is_leaf: Callable[[Any], None] | None = None, + is_leaf: Callable[[Any], bool] | None = None, ) -> Any: """Reduce the leaf values at the specified location. @@ -799,7 +591,7 @@ def reduce( """ treelib = sepes._src.backend.treelib tree = self.get(is_leaf=is_leaf) # type: ignore - leaves, _ = treelib.tree_flatten(tree, is_leaf=is_leaf) + leaves, _ = treelib.flatten(tree, is_leaf=is_leaf) if initializer is _no_initializer: return ft.reduce(func, leaves) return ft.reduce(func, leaves, initializer) @@ -808,7 +600,7 @@ def pluck( self, count: int | None = None, *, - is_leaf: Callable[[Any], None] | None = None, + is_leaf: Callable[[Any], bool] | None = None, is_parallel: bool | ParallelConfig = False, ) -> list[Any]: """Extract subtrees at the specified location. @@ -880,7 +672,7 @@ def aggregate_subtrees(node: Any) -> bool: # for example if tree = dict(a=1) and mask is dict(a=True) # then returns [1] and not [dict(a=1)] return False - leaves, _ = treelib.tree_flatten(node, is_leaf=lambda x: x is None) + leaves, _ = treelib.flatten(node, is_leaf=lambda x: x is None) # in essence if the subtree does not contain any None leaves # then it is a valid subtree to be plucked # this because `get` sets the non-selected leaves to None @@ -890,9 +682,102 @@ def aggregate_subtrees(node: Any) -> bool: count -= 1 return True - treelib.tree_flatten(tree, is_leaf=aggregate_subtrees) + treelib.flatten(tree, is_leaf=aggregate_subtrees) return subtrees -# shorter alias -at = AtIndexer +# pass through for boolean pytrees masks and tuple of keys +at.dispatcher = ft.singledispatch(lambda x: x) + + +def def_rule( + user_type: type[T], + path_compare_func: Callable[[T, PathKeyEntry], bool], + *, + broadcastable: bool = False, +) -> None: + # remove the BaseKey abstraction from the user-facing function + class UserKey(BaseKey): + broadcast: bool = broadcastable + + def __init__(self, input: T): + self.input = input + + def __eq__(self, key: PathKeyEntry) -> bool: + return path_compare_func(self.input, key) + + at.dispatcher.register(user_type, UserKey) + + +at.def_rule = def_rule + + +# key rules to match user input to with the path entry + + +def str_compare(name: str, key: PathKeyEntry): + """Match a leaf with a given name.""" + if isinstance(key, str): + return name == key + treelib = sepes._src.backend.treelib + if isinstance(key, type(treelib.attribute_key(""))): + return name == key.name + if isinstance(key, type(treelib.dict_key(""))): + return name == key.key + return False + + +def int_compare(idx: int, key: PathKeyEntry) -> bool: + """Match a leaf with a given index.""" + if isinstance(key, int): + return idx == key + treelib = sepes._src.backend.treelib + if isinstance(key, type(treelib.sequence_key(0))): + return idx == key.idx + return False + + +def regex_compare(pattern: re.Pattern, key: PathKeyEntry) -> bool: + """Match a path with a regex pattern inside 'at' property.""" + if isinstance(key, str): + return re.fullmatch(pattern, key) is not None + treelib = sepes._src.backend.treelib + if isinstance(key, type(treelib.attribute_key(""))): + return re.fullmatch(pattern, key.name) is not None + if isinstance(key, type(treelib.dict_key(""))): + return re.fullmatch(pattern, key.key) is not None + return False + + +def ellipsis_compare(_, __): + return True + + +at.def_rule(str, str_compare, broadcastable=False) +at.def_rule(int, int_compare, broadcastable=False) +at.def_rule(re.Pattern, regex_compare, broadcastable=False) +at.def_rule(type(...), ellipsis_compare, broadcastable=True) + + +class MultiKey(BaseKey): + """Match a leaf with multiple keys at the same level.""" + + def __init__(self, *keys): + self.keys = tuple(keys) + + def __eq__(self, entry: PathKeyEntry) -> bool: + return any(entry == key for key in self.keys) + + broadcast: bool = False + + +if is_package_avaiable("jax"): + import jax.tree_util as jtu + + def jax_key_compare(input, key: PathKeyEntry) -> bool: + """Enable indexing with jax keys directly in `at`.""" + return input == key + + at.def_rule(jtu.SequenceKey, jax_key_compare, broadcastable=False) + at.def_rule(jtu.GetAttrKey, jax_key_compare, broadcastable=False) + at.def_rule(jtu.DictKey, jax_key_compare, broadcastable=False) diff --git a/sepes/_src/tree_mask.py b/sepes/_src/tree_mask.py index e113e98..5f14370 100644 --- a/sepes/_src/tree_mask.py +++ b/sepes/_src/tree_mask.py @@ -30,23 +30,19 @@ MaskType = Union[T, Callable[[Any], bool]] -class _FrozenError(NamedTuple): +class _MaskedError(NamedTuple): opname: str def __call__(self, *a, **k): raise NotImplementedError( - f"Cannot apply `{self.opname}` operation to a frozen object " + f"Cannot apply `{self.opname}` operation on a masked object " f"{', '.join(map(str, a))} " f"{', '.join(k + '=' + str(v) for k, v in k.items())}.\n" - "Unfreeze the object first by unmasking the frozen mask:\n" - "Example:\n" - ">>> import jax\n" - ">>> import sepes as sp\n" - ">>> tree = sp.tree_unmask(tree)" + "Unmask the object first using `tree_unmask`" ) -class _FrozenBase(Static): +class _MaskBase(Static[T]): # the objective of this class is to wrap a pytree node with a custom wrapper # that yields no leaves when flattened. This is useful to avoid updating # the node by effectivly *hiding it* from function transformations that operates @@ -69,43 +65,44 @@ def __repr__(self) -> str: def __str__(self) -> str: return "#" + tree_str(self.__wrapped__) - def __copy__(self) -> _FrozenBase[T]: + def __copy__(self) -> _MaskBase[T]: return type(self)(tree_copy(self.__wrapped__)) # raise helpful error message when trying to interact with frozen object - __add__ = __radd__ = __iadd__ = _FrozenError("+") - __sub__ = __rsub__ = __isub__ = _FrozenError("-") - __mul__ = __rmul__ = __imul__ = _FrozenError("*") - __matmul__ = __rmatmul__ = __imatmul__ = _FrozenError("@") - __truediv__ = __rtruediv__ = __itruediv__ = _FrozenError("/") - __floordiv__ = __rfloordiv__ = __ifloordiv__ = _FrozenError("//") - __mod__ = __rmod__ = __imod__ = _FrozenError("%") - __pow__ = __rpow__ = __ipow__ = _FrozenError("**") - __lshift__ = __rlshift__ = __ilshift__ = _FrozenError("<<") - __rshift__ = __rrshift__ = __irshift__ = _FrozenError(">>") - __and__ = __rand__ = __iand__ = _FrozenError("and") - __xor__ = __rxor__ = __ixor__ = _FrozenError("") - __or__ = __ror__ = __ior__ = _FrozenError("or") - __neg__ = __pos__ = __abs__ = __invert__ = _FrozenError("unary operation") - __call__ = _FrozenError("__call__") - - -@tree_summary.def_type(_FrozenBase) + __add__ = __radd__ = __iadd__ = _MaskedError("+") + __sub__ = __rsub__ = __isub__ = _MaskedError("-") + __mul__ = __rmul__ = __imul__ = _MaskedError("*") + __matmul__ = __rmatmul__ = __imatmul__ = _MaskedError("@") + __truediv__ = __rtruediv__ = __itruediv__ = _MaskedError("/") + __floordiv__ = __rfloordiv__ = __ifloordiv__ = _MaskedError("//") + __mod__ = __rmod__ = __imod__ = _MaskedError("%") + __pow__ = __rpow__ = __ipow__ = _MaskedError("**") + __lshift__ = __rlshift__ = __ilshift__ = _MaskedError("<<") + __rshift__ = __rrshift__ = __irshift__ = _MaskedError(">>") + __and__ = __rand__ = __iand__ = _MaskedError("and") + __xor__ = __rxor__ = __ixor__ = _MaskedError("") + __or__ = __ror__ = __ior__ = _MaskedError("or") + __neg__ = __pos__ = __abs__ = __invert__ = _MaskedError("unary") + __lt__ = __le__ = __gt__ = __ge__ = _MaskedError("comparison") + __call__ = _MaskedError("__call__") + + +@tree_summary.def_type(_MaskBase) def _(node) -> str: return f"#{tree_summary.type_dispatcher(node.__wrapped__)}" -class _FrozenHashable(_FrozenBase): +class _MaskedHashable(_MaskBase): def __hash__(self) -> int: return tree_hash(self.__wrapped__) def __eq__(self, rhs: Any) -> bool: - if not isinstance(rhs, _FrozenHashable): + if not isinstance(rhs, _MaskedHashable): return False return is_tree_equal(self.__wrapped__, rhs.__wrapped__) -class _FrozenArray(_FrozenBase): +class _MaskedArray(_MaskBase): # wrap arrays with a custom wrapper that implements hash and equality # using the wrapped array's bytes representation and sha256 hash function # this is useful to select some array to hold without updating in the process @@ -115,7 +112,7 @@ def __hash__(self) -> int: return int(hashlib.sha256(bytes).hexdigest(), 16) def __eq__(self, other) -> bool: - if not isinstance(other, _FrozenArray): + if not isinstance(other, _MaskedArray): return False lhs, rhs = self.__wrapped__, other.__wrapped__ # fast path to avoid calling `all` on large arrays @@ -123,136 +120,62 @@ def __eq__(self, other) -> bool: return False if arraylib.dtype(lhs) != arraylib.dtype(rhs): return False - return arraylib.all(lhs == rhs) + return arraylib.array_equal(lhs, rhs) -def freeze(value: T) -> _FrozenBase[T]: - """Freeze a value to avoid updating it by through function transformations. - - Args: - value: A value to freeze. - - Note: - - :func:`.freeze` is idempotent, i.e. ``freeze(freeze(x)) == freeze(x)``. - - Example: - >>> import jax - >>> import sepes as sp - >>> import jax.tree_util as jtu - >>> # Usage with `jax.tree_util.tree_leaves` - >>> # no leaves for a wrapped value - >>> jtu.tree_leaves(sp.freeze(2.)) - [] - - >>> # retrieve the frozen wrapper value using `is_leaf=sp.is_frozen` - >>> jtu.tree_leaves(sp.freeze(2.), is_leaf=sp.is_frozen) - [#2.0] - - >>> # Usage with `jax.tree_util.tree_map` - >>> a= [1,2,3] - >>> a[1] = sp.freeze(a[1]) - >>> jtu.tree_map(lambda x:x+100, a) - [101, #2, 103] - """ +def mask(value: T) -> _MaskBase[T]: # dispatching is used to customize the type of the wrapper based on the type # of the value. For instance, hashable values dont need custom hash and # equality implementations, so they are wrapped with a simpler wrapper. # this approach avoids type logic in the wrapper equality and hash methods, # thus effectively improving performance of the wrapper. - return freeze.type_dispatcher(value) + return mask.type_dispatcher(value) -freeze.type_dispatcher = ft.singledispatch(_FrozenHashable) -freeze.def_type = freeze.type_dispatcher.register +mask.type_dispatcher = ft.singledispatch(_MaskedHashable) +mask.def_type = mask.type_dispatcher.register for ndarray in arraylib.ndarrays: - @freeze.def_type(ndarray) - def freeze_array(value: T) -> _FrozenArray[T]: + @mask.def_type(ndarray) + def mask_array(value: T) -> _MaskedArray[T]: # wrap arrays with a custom wrapper that implements hash and equality # arrays can be hashed by converting them to bytes and hashing the bytes - return _FrozenArray(value) + return _MaskedArray(value) -@freeze.def_type(_FrozenBase) -def _(value: _FrozenBase[T]) -> _FrozenBase[T]: - # idempotent freeze operation, meaning that freeze(freeze(x)) == freeze(x) +@mask.def_type(_MaskBase) +def _(value: _MaskBase[T]) -> _MaskBase[T]: + # idempotent mask operation, meaning that mask(mask(x)) == mask(x) # this is useful to avoid recursive unwrapping of frozen values, plus its - # meaningless to freeze a frozen value. + # meaningless to mask a frozen value. return value -def is_frozen(value: Any) -> bool: +def is_masked(value: Any) -> bool: """Returns True if the value is a frozen wrapper.""" - return isinstance(value, _FrozenBase) - - -def unfreeze(value: T) -> T: - """Unfreeze :func:`.freeze` value, otherwise return the value itself. + return isinstance(value, _MaskBase) - Args: - value: A value to unfreeze. - - Note: - - use ``is_leaf=sp.is_frozen`` with ``tree_map`` to unfreeze a tree.** - Example: - >>> import sepes as sp - >>> import jax - >>> frozen_value = sp.freeze(1) - >>> sp.unfreeze(frozen_value) - 1 - >>> # usage with `jax.tree_map` - >>> frozen_tree = jax.tree_map(sp.freeze, {"a": 1, "b": 2}) - >>> unfrozen_tree = jax.tree_map(sp.unfreeze, frozen_tree, is_leaf=sp.is_frozen) - >>> unfrozen_tree - {'a': 1, 'b': 2} - """ - return unfreeze.type_dispatcher(value) +def unmask(value: T) -> T: + return unmask.type_dispatcher(value) -unfreeze.type_dispatcher = ft.singledispatch(lambda x: x) -unfreeze.def_type = unfreeze.type_dispatcher.register +unmask.type_dispatcher = ft.singledispatch(lambda x: x) +unmask.def_type = unmask.type_dispatcher.register -@unfreeze.def_type(_FrozenBase) -def _(value: _FrozenBase[T]) -> T: +@unmask.def_type(_MaskBase) +def _(value: _MaskBase[T]) -> T: return getattr(value, "__wrapped__") def is_nondiff(value: Any) -> bool: - """Returns True for non-inexact types, False otherwise. - - Args: - value: A value to check. - - Note: - - :func:`.is_nondiff` uses single dispatch to support custom types. To define - a custom behavior for a certain type, use ``is_nondiff.def_type(type, func)``. - - Example: - >>> import sepes as sp - >>> import jax.numpy as jnp - >>> sp.is_nondiff(jnp.array(1)) # int array is non-diff type - True - >>> sp.is_nondiff(jnp.array(1.)) # float array is diff type - False - >>> sp.is_nondiff(1) # int is non-diff type - True - >>> sp.is_nondiff(1.) # float is diff type - False - - Note: - This function is meant to be used with ``jax.tree_map`` to - create a mask for non-differentiable nodes in a tree, that can be used - to freeze the non-differentiable nodes before passing the tree to a - ``jax`` transformation. - """ return is_nondiff.type_dispatcher(value) -is_nondiff.type_dispatcher = ft.singledispatch(lambda x: True) +is_nondiff.type_dispatcher = ft.singledispatch(lambda _: True) is_nondiff.def_type = is_nondiff.type_dispatcher.register @@ -274,78 +197,53 @@ def _(_: float | complex) -> bool: def _tree_mask_map( tree: T, - mask: MaskType, + cond: Callable[[Any], bool], func: type | Callable[[Any], Any], *, is_leaf: Callable[[Any], None] | None = None, ): - treelib = sepes._src.backend.treelib - # apply func to leaves satisfying mask pytree/condtion - _, lhsdef = treelib.tree_flatten(tree, is_leaf=is_leaf) - _, rhsdef = treelib.tree_flatten(mask, is_leaf=is_leaf) - - if (lhsdef == rhsdef) and (type(mask) is type(tree)): - # a tree with the same structure as tree with boolean values - # and also a callable. - def map_func(x, y): - return func(x) if y else x - return treelib.tree_map(map_func, tree, mask, is_leaf=is_leaf) - - if isinstance(mask, Callable): + if not isinstance(cond, Callable): # a callable that accepts a leaf and returns a boolean # but *not* a tree with the same structure as tree with boolean values. - def map_func(x): - return func(x) if mask(x) else x + raise TypeError( + f"`cond` must be a callable that accepts a leaf and returns a boolean " + f" Got {cond=} and {tree=}." + ) - return treelib.tree_map(map_func, tree, is_leaf=is_leaf) + treelib = sepes._src.backend.treelib - raise ValueError( - f"`mask` must be a callable that accepts a leaf and returns a boolean " - f"or a tree with the same structure as tree with boolean values." - f" Got {mask=} and {tree=}." - ) + def map_func(x): + return func(x) if cond(x) else x + + return treelib.map(map_func, tree, is_leaf=is_leaf) def tree_mask( tree: T, - mask: MaskType = is_nondiff, + cond: Callable[[Any], bool] = is_nondiff, *, is_leaf: Callable[[Any], None] | None = None, ): """Mask leaves of a pytree based on ``mask`` boolean pytree or callable. + Masked leaves are wrapped with a wrapper that yields no leaves when + ``tree_flatten`` is called on it. + Args: tree: A pytree of values. - mask: A pytree of boolean values or a callable that accepts a leaf and - returns a boolean. If a leaf is ``True`` either in the mask or the - callable, the leaf is wrapped by with a wrapper that yields no - leaves when ``tree_flatten`` is called on it, otherwise - it is unchanged. defaults to :func:`.is_nondiff` which returns true for - non-differentiable nodes. + cond: A callable that accepts a leaf and returns a boolean to mark the leaf + for masking. Defaults to masking non-differentiable leaf nodes that + are not instances of of python float, python complex, or inexact + array types. is_leaf: A callable that accepts a leaf and returns a boolean. If provided, it is used to determine if a value is a leaf. for example, ``is_leaf=lambda x: isinstance(x, list)`` will treat lists as leaves and will not recurse into them. - Note: - - Masked leaves are wrapped with a wrapper that yields no leaves when - ``tree_flatten`` is called on it. - - Masking is equivalent to applying :func:`.freeze` to the masked leaves. - - >>> import sepes as sp - >>> import jax - >>> tree = [1, 2, {"a": 3, "b": 4.}] - >>> # mask all non-differentiable nodes by default - >>> def mask_if_nondiff(x): - ... return sp.freeze(x) if sp.is_nondiff(x) else x - >>> masked_tree = jax.tree_map(mask_if_nondiff, tree) - - - Use masking on tree containing non-differentiable nodes before passing - the tree to a ``jax`` transformation. - Example: >>> import sepes as sp + >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = sp.tree_mask(tree) @@ -357,32 +255,32 @@ def tree_mask( [1, 2, {'a': 3, 'b': 4.0}] Example: - >>> # pass non-differentiable values to `jax.grad` + Pass non-differentiable values to ``jax.grad`` + >>> import sepes as sp >>> import jax >>> @jax.grad ... def square(tree): ... tree = sp.tree_unmask(tree) - ... return tree[0]**2 + ... return tree[0] ** 2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(sp.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2) """ - return _tree_mask_map(tree, mask=mask, func=freeze, is_leaf=is_leaf) + return _tree_mask_map(tree, cond=cond, func=mask, is_leaf=is_leaf) -def tree_unmask(tree: T, mask: MaskType = lambda _: True): - """Undo the masking of tree leaves according to ``mask``. defaults to unmasking all leaves. +def tree_unmask(tree: T, cond: Callable[[Any], bool] = lambda _: True): + """Undo the masking of tree leaves according to ``cond``. defaults to unmasking all leaves. Args: tree: A pytree of values. - mask: A pytree of boolean values or a callable that accepts a leaf and - returns a boolean. If a leaf is True either in the mask or the - callable, the leaf is unfrozen, otherwise it is unchanged. defaults - unmasking all nodes. + cond: A callable that accepts a leaf and returns a boolean to mark the + leaf to be unmasked. Defaults to always unmask. Example: >>> import sepes as sp + >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = sp.tree_mask(tree) @@ -394,27 +292,19 @@ def tree_unmask(tree: T, mask: MaskType = lambda _: True): [1, 2, {'a': 3, 'b': 4.0}] Example: - >>> # pass non-differentiable values to `jax.grad` + Pass non-differentiable values to ``jax.grad`` + >>> import sepes as sp >>> import jax >>> @jax.grad ... def square(tree): ... tree = sp.tree_unmask(tree) - ... return tree[0]**2 + ... return tree[0] ** 2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(sp.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2) - - Note: - - Unmasking is equivalent to applying :func:`.unfreeze` on the masked leaves. - - >>> import sepes as sp - >>> import jax - >>> tree = [1, 2, {"a": 3, "b": 4.}] - >>> # unmask all nodes - >>> tree = jax.tree_map(sp.unfreeze, tree, is_leaf=sp.is_frozen) """ - return _tree_mask_map(tree, mask=mask, func=unfreeze, is_leaf=is_frozen) + return _tree_mask_map(tree, cond=cond, func=unmask, is_leaf=is_masked) if is_package_avaiable("jax"): @@ -424,6 +314,6 @@ def tree_unmask(tree: T, mask: MaskType = lambda _: True): # otherwise calling `freeze` inside a jax transformation on # a tracer will hide the tracer from jax and will cause leaked tracer # error. - @freeze.def_type(jax.core.Tracer) + @mask.def_type(jax.core.Tracer) def _(value: jax.core.Tracer) -> jax.core.Tracer: return value diff --git a/sepes/_src/tree_pprint.py b/sepes/_src/tree_pprint.py index 27ea5f2..fb24e65 100644 --- a/sepes/_src/tree_pprint.py +++ b/sepes/_src/tree_pprint.py @@ -31,7 +31,6 @@ from sepes._src.backend import is_package_avaiable from sepes._src.tree_util import ( Node, - Partial, construct_tree, is_path_leaf_depth_factory, tree_type_path_leaves, @@ -178,13 +177,12 @@ def _(func: Callable, **spec: Unpack[PPSpec]) -> str: return f"{name}({', '.join(header)})" -@tree_str.def_type(Partial) @tree_str.def_type(ft.partial) def _(node: ft.partial, **spec: Unpack[PPSpec]) -> str: func = tree_str.pp(node.func, **spec) args = tree_str.pps(tree_str.pp, node.args, **spec) keywords = tree_str.pps(tree_str.kv_pp, node.keywords, **spec) - return f"Partial(" + ",".join([func, args, keywords]) + ")" + return "partial(" + ",".join([func, args, keywords]) + ")" @tree_str.def_type(list) @@ -242,7 +240,6 @@ def array_pp(node, **spec: Unpack[PPSpec]) -> str: return f"{base}(μ={mean}, σ={std}, ∈{interval})" -@tree_repr.def_type(Partial) @tree_repr.def_type(ft.partial) def _(node: ft.partial, **spec: Unpack[PPSpec]) -> str: func = tree_repr.pp(node.func, **spec) @@ -363,139 +360,6 @@ def step( return (text if tabwidth is None else text.expandtabs(tabwidth)).rstrip() -def tree_mermaid( - tree: PyTree, - depth: int | float = float("inf"), - is_leaf: Callable[[Any], None] | None = None, - tabwidth: int | None = 4, -) -> str: - """Generate a mermaid diagram syntax for arbitrary pytrees. - - Args: - tree: PyTree - depth: depth of the tree to print. default is max depth - is_leaf: function to determine if a node is a leaf. default is None - tabwidth: tab width of the repr string. default is 4. - - Example: - >>> import sepes as sp - >>> tree = [1, 2, dict(a=3)] - >>> # as rendered by mermaid - >>> print(sp.tree_mermaid(tree)) # doctest: +SKIP - - .. image:: ../_static/tree_mermaid.jpg - :width: 300px - :align: center - - Note: - - Copy the output and paste it in the mermaid live editor to interact with - the diagram. https://mermaid.live - """ - - def step(node: Node, depth: int = 0) -> str: - if len(node.children) == 0: - (key, _), value = node.data - ppstr = f"{key}=" if key is not None else "" - ppstr += tree_repr(value, depth=0) - ppstr = "" + ppstr + "" - return f'\tid{id(node.parent)} --- id{id(node)}("{ppstr}")\n' - - (key, type), _ = node.data - ppstr = f"{key}:" if key is not None else "" - ppstr += f"{type.__name__}" - ppstr = "" + ppstr + "" - - if node.parent is None: - text = f'\tid{id(node)}("{ppstr}")\n' - else: - text = f'\tid{id(node.parent)} --- id{id(node)}("{ppstr}")\n' - - for child in node.children.values(): - text += step(child, depth=depth + 1) - return text - - is_path_leaf = is_path_leaf_depth_factory(depth) - root = construct_tree(tree, is_leaf=is_leaf, is_path_leaf=is_path_leaf) - text = "flowchart LR\n" + step(root) - return (text.expandtabs(tabwidth) if tabwidth is not None else text).rstrip() - - -# dispatcher for dot nodestyles -dot_dispatcher = ft.singledispatch(lambda _: dict(shape="box")) - - -def tree_graph( - tree: PyTree, - depth: int | float = float("inf"), - is_leaf: Callable[[Any], None] | None = None, - tabwidth: int | None = 4, -) -> str: - """Generate a dot diagram syntax for arbitrary pytrees. - - Args: - tree: pytree - depth: depth of the tree to print. default is max depth - is_leaf: function to determine if a node is a leaf. default is None - tabwidth: tab width of the repr string. default is 4. - - Returns: - str: dot diagram syntax - - Example: - >>> import sepes as sp - >>> tree = [1, 2, dict(a=3)] - >>> # as rendered by graphviz - - .. image:: ../_static/tree_graph.svg - - Example: - >>> # define custom style for a node by dispatching on the value - >>> # the defined function should return a dict of attributes - >>> # that will be passed to graphviz. - >>> import sepes as sp - >>> tree = [1, 2, dict(a=3)] - >>> @sp.tree_graph.def_nodestyle(list) - ... def _(_) -> dict[str, str]: - ... return dict(shape="circle", style="filled", fillcolor="lightblue") - - .. image:: ../_static/tree_graph_stylized.svg - """ - - def step(node: Node, depth: int = 0) -> str: - (key, type), value = node.data - - # dispatch node style - style = ", ".join(f"{k}={v}" for k, v in dot_dispatcher(value).items()) - - if len(node.children) == 0: - ppstr = f"{key}=" if key is not None else "" - ppstr += tree_repr(value, depth=0) - text = f'\t{id(node)} [label="{ppstr}", {style}];\n' - text += f"\t{id(node.parent)} -> {id(node)};\n" - return text - - ppstr = f"{key}:" if key is not None else "" - ppstr += f"{type.__name__}" - - if node.parent is None: - text = f'\t{id(node)} [label="{ppstr}", {style}];\n' - else: - text = f'\t{id(node)} [label="{ppstr}", {style}];\n' - text += f"\t{id(node.parent)} -> {id(node)};\n" - - for child in node.children.values(): - text += step(child, depth=depth + 1) - return text - - is_path_leaf = is_path_leaf_depth_factory(depth) - root = construct_tree(tree, is_leaf=is_leaf, is_path_leaf=is_path_leaf) - text = "digraph G {\n" + step(root) + "}" - return (text.expandtabs(tabwidth) if tabwidth is not None else text).rstrip() - - -tree_graph.def_nodestyle = dot_dispatcher.register - - def format_width(string, width=60): """Strip newline/tab characters if less than max width.""" children_length = len(string) - string.count("\n") - string.count("\t") @@ -570,39 +434,19 @@ def tree_summary( >>> import sepes as sp >>> import jax.numpy as jnp >>> print(sp.tree_summary([1, [2, [3]], jnp.array([1, 2, 3])])) - ┌─────────┬──────┬─────┬──────┐ - │Name │Type │Count│Size │ - ├─────────┼──────┼─────┼──────┤ - │[0] │int │1 │ │ - ├─────────┼──────┼─────┼──────┤ - │[1][0] │int │1 │ │ - ├─────────┼──────┼─────┼──────┤ - │[1][1][0]│int │1 │ │ - ├─────────┼──────┼─────┼──────┤ - │[2] │i32[3]│3 │12.00B│ - ├─────────┼──────┼─────┼──────┤ - │Σ │list │6 │12.00B│ - └─────────┴──────┴─────┴──────┘ - - Example: - Set custom type display for ``jax`` jaxprs - - >>> import jax - >>> import sepes as sp - >>> ClosedJaxprType = type(jax.make_jaxpr(lambda x: x)(1)) - >>> @sp.tree_summary.def_type(ClosedJaxprType) - ... def _(expr: ClosedJaxprType) -> str: - ... jaxpr = expr.jaxpr - ... return f"Jaxpr({jaxpr.invars}, {jaxpr.outvars})" - >>> def func(x, y): - ... return x - >>> jaxpr = jax.make_jaxpr(func)(1, 2) - >>> print(sp.tree_summary(jaxpr)) - ┌────┬──────────────────┬─────┬────┐ - │Name│Type │Count│Size│ - ├────┼──────────────────┼─────┼────┤ - │Σ │Jaxpr([a, b], [a])│1 │ │ - └────┴──────────────────┴─────┴────┘ + ┌─────────┬────────────────────────────────────┬─────┬──────┐ + │Name │Type │Count│Size │ + ├─────────┼────────────────────────────────────┼─────┼──────┤ + │[0] │int │1 │ │ + ├─────────┼────────────────────────────────────┼─────┼──────┤ + │[1][0] │int │1 │ │ + ├─────────┼────────────────────────────────────┼─────┼──────┤ + │[1][1][0]│int │1 │ │ + ├─────────┼────────────────────────────────────┼─────┼──────┤ + │[2] │i32[3] │3 │12.00B│ + ├─────────┼────────────────────────────────────┼─────┼──────┤ + │Σ │list[int,list[int,list[int]],i32[3]]│6 │12.00B│ + └─────────┴────────────────────────────────────┴─────┴──────┘ Example: Display flops of a function in tree summary @@ -662,14 +506,14 @@ def tree_size(tree: PyTree) -> int: def reduce_func(acc, node): return acc + tree_summary.size_dispatcher(node) - leaves, _ = treelib.tree_flatten(tree) + leaves, _ = treelib.flatten(tree) return ft.reduce(reduce_func, leaves, 0) def tree_count(tree: PyTree) -> int: def reduce_func(acc, node): return acc + tree_summary.count_dispatcher(node) - leaves, _ = treelib.tree_flatten(tree) + leaves, _ = treelib.flatten(tree) return ft.reduce(reduce_func, leaves, 0) traces_leaves = tree_type_path_leaves( @@ -725,6 +569,21 @@ def _(node: Any) -> str: dtype = arraylib.dtype(node) return tree_repr(ShapeDTypePP(shape, dtype)) +@tree_summary.def_type(list) +@tree_summary.def_type(tuple) +def _(node: tuple) -> str: + # - output Container[types,...] instead of just container type in the type col. + # - usually this encounterd if the tree_summary depth is not inf + # so the tree leaves could contain non-atomic types. + treelib = sepes._src.backend.treelib + + one_level_types = treelib.map( + tree_summary.type_dispatcher, + node, + is_leaf=lambda x: False if id(x) == id(node) else True, + ) + return f"{type(node).__name__}[{','.join(one_level_types)}]" + if is_package_avaiable("jax"): # jax pretty printing extra handlers @@ -764,4 +623,19 @@ def _(node, **spec: Unpack[PPSpec]) -> str: shape = node.aval.shape dtype = node.aval.dtype string = tree_repr.dispatch(ShapeDTypePP(shape, dtype), **spec) - return f"Tracer({string})" + return f"{type(node).__name__}({string})" + + # handle the sharding info if it is sharded + @tree_summary.def_type(jax.Array) + def _(node: Any) -> str: + """Return the type repr of the node.""" + # global shape + global_shape = arraylib.shape(node) + shard_shape = node.sharding.shard_shape(global_shape) + dtype = arraylib.dtype(node) + global_info = tree_repr(ShapeDTypePP(global_shape, dtype)) + + if global_shape == shard_shape: + return global_info + shard_info = tree_repr(ShapeDTypePP(shard_shape, dtype)) + return f"G:{global_info}\nS:{shard_info}" diff --git a/sepes/_src/tree_util.py b/sepes/_src/tree_util.py index 88fb24d..429b1b6 100644 --- a/sepes/_src/tree_util.py +++ b/sepes/_src/tree_util.py @@ -16,15 +16,17 @@ from __future__ import annotations +import copy import functools as ft import operator as op -import copy from math import ceil, floor, trunc -from typing import Any, Callable, Hashable, Iterator, Sequence, Tuple, TypeVar, Generic -import sepes._src.backend.arraylib as arraylib -from sepes._src.backend import is_package_avaiable +from typing import Any, Callable, Generic, Hashable, Iterator, Sequence, Tuple, TypeVar + from typing_extensions import ParamSpec + import sepes +import sepes._src.backend.arraylib as arraylib +from sepes._src.backend import is_package_avaiable T = TypeVar("T") T1 = TypeVar("T1") @@ -42,7 +44,7 @@ def tree_hash(*trees: PyTree) -> int: treelib = sepes._src.backend.treelib - leaves, treedef = treelib.tree_flatten(trees) + leaves, treedef = treelib.flatten(trees) return hash((*leaves, treedef)) @@ -57,7 +59,7 @@ def tree_copy(tree: T) -> T: def is_leaf(node) -> bool: return isinstance(node, types) - return treelib.tree_map(tree_copy.copy_dispatcher, tree, is_leaf=is_leaf) + return treelib.map(tree_copy.copy_dispatcher, tree, is_leaf=is_leaf) # default behavior is to copy the tree elements except for registered types @@ -66,18 +68,31 @@ def is_leaf(node) -> bool: tree_copy.def_type = tree_copy.copy_dispatcher.register +@tree_copy.def_type(int) +@tree_copy.def_type(float) +@tree_copy.def_type(complex) +@tree_copy.def_type(str) +@tree_copy.def_type(bytes) +def _(x: T) -> T: + # skip applying `copy.copy` on immutable atom types + return x + + def is_array_like(node) -> bool: return hasattr(node, "shape") and hasattr(node, "dtype") -def _is_leaf_rhs_equal(leaf, rhs) -> bool: +def _is_leaf_rhs_equal(leaf, rhs): if is_array_like(leaf): if is_array_like(rhs): if leaf.shape != rhs.shape: return False if leaf.dtype != rhs.dtype: return False - verdict = arraylib.all(leaf == rhs) + try: + verdict = arraylib.all(leaf == rhs) + except NotImplementedError: + verdict = leaf == rhs try: return bool(verdict) except Exception: @@ -94,11 +109,11 @@ def is_tree_equal(*trees: Any) -> bool: """ treelib = sepes._src.backend.treelib tree0, *rest = trees - leaves0, treedef0 = treelib.tree_flatten(tree0) + leaves0, treedef0 = treelib.flatten(tree0) verdict = True for tree in rest: - leaves, treedef = treelib.tree_flatten(tree) + leaves, treedef = treelib.flatten(tree) if (treedef != treedef0) or verdict is False: return False verdict = ft.reduce(op.and_, map(_is_leaf_rhs_equal, leaves0, leaves), verdict) @@ -116,73 +131,28 @@ def __init_subclass__(klass, **k) -> None: treelib.register_static(klass) -class Partial(Static): - """``Partial`` function with support for positional partial application. - - Args: - func: The function to be partially applied. - args: Positional arguments to be partially applied. use ``...`` as a - placeholder for positional arguments. - kwargs: Keyword arguments to be partially applied. - - Example: - >>> import sepes as sp - >>> def f(a, b, c): - ... print(f"a: {a}, b: {b}, c: {c}") - ... return a + b + c - - >>> # positional arguments using `...` placeholder - >>> f_a = sp.Partial(f, ..., 2, 3) - >>> f_a(1) - a: 1, b: 2, c: 3 - 6 - - >>> # keyword arguments - >>> f_b = sp.Partial(f, b=2, c=3) - >>> f_a(1) - a: 1, b: 2, c: 3 - 6 - - Note: - - The ``...`` is used to indicate a placeholder for positional arguments. - - https://stackoverflow.com/a/7811270 - """ - - __slots__ = ["func", "args", "keywords"] # type: ignore - - def __init__(self, func: Callable[..., Any], *args: Any, **kwargs: Any): - self.func = func - self.args = args - self.keywords = kwargs - +class partial(ft.partial): def __call__(self, *args: Any, **kwargs: Any) -> Any: iargs = iter(args) args = (next(iargs) if arg is ... else arg for arg in self.args) # type: ignore return self.func(*args, *iargs, **{**self.keywords, **kwargs}) - def __repr__(self) -> str: - return f"Partial({self.func}, {self.args}, {self.keywords})" - - def __hash__(self) -> int: - return tree_hash(self) - - def __eq__(self, other: Any) -> bool: - return is_tree_equal(self, other) - - -# to match python -partial = Partial - def bcmap( func: Callable[P, T], + broadcast_to: int | str | None = None, *, is_leaf: Callable[[Any], bool] | None = None, ) -> Callable[P, T]: """Map a function over pytree leaves with automatic broadcasting for scalar arguments. Args: - func: the function to be mapped over the pytree + func: the function to be mapped over the pytree. + broadcast_to: Accepts integer for broadcasting to a specific argument + or string for broadcasting to a specific keyword argument. + If ``None``, then the function is broadcasted to the first argument + or the first keyword argument if no positional arguments are provided. + Defaults to ``None``. is_leaf: a predicate function that returns True if the node is a leaf. Example: @@ -199,7 +169,6 @@ def bcmap( >>> print(sp.tree_str(tree_add(tree_of_arrays, 1))) dict(a=[2 3 4], b=[5 6 7]) """ - # add broadcasting argnum/argname to the function later treelib = sepes._src.backend.treelib @ft.wraps(func) @@ -209,23 +178,29 @@ def wrapper(*args, **kwargs): leaves = [] kwargs_keys: list[str] = [] + bdcst_to = ( + (0 if len(args) else next(iter(kwargs))) + if broadcast_to is None + else broadcast_to + ) + treedef0 = ( # reference treedef is the first positional argument - treelib.tree_flatten(args[0], is_leaf=is_leaf)[1] + treelib.flatten(args[bdcst_to], is_leaf=is_leaf)[1] if len(args) # reference treedef is the first keyword argument - else treelib.tree_flatten(kwargs[next(iter(kwargs))], is_leaf=is_leaf)[1] + else treelib.flatten(kwargs[bdcst_to], is_leaf=is_leaf)[1] ) for arg in args: - if treedef0 == treelib.tree_flatten(arg, is_leaf=is_leaf)[1]: + if treedef0 == treelib.flatten(arg, is_leaf=is_leaf)[1]: cargs += [...] leaves += [treedef0.flatten_up_to(arg)] else: cargs += [arg] for key in kwargs: - if treedef0 == treelib.tree_flatten(kwargs[key], is_leaf=is_leaf)[1]: + if treedef0 == treelib.flatten(kwargs[key], is_leaf=is_leaf)[1]: ckwargs[key] = ... leaves += [treedef0.flatten_up_to(kwargs[key])] kwargs_keys += [key] @@ -239,7 +214,7 @@ def wrapper(*args, **kwargs): args = args_kwargs_values[:split_index] kwargs = dict(zip(kwargs_keys, args_kwargs_values[split_index:])) all_leaves += [bfunc(*args, **kwargs)] - return treelib.tree_unflatten(treedef0, all_leaves) + return treelib.unflatten(treedef0, all_leaves) return wrapper @@ -266,7 +241,8 @@ def leafwise(klass: type[T]) -> type[T]: The decorated class. Example: - >>> # use ``numpy`` functions on :class:`TreeClass`` classes decorated with ``leafwise`` + Use ``numpy`` functions on :class:`TreeClass`` classes decorated with :func:`leafwise` + >>> import sepes as sp >>> import jax.numpy as jnp >>> @sp.leafwise @@ -321,15 +297,15 @@ def leafwise(klass: type[T]) -> type[T]: def uop(func): def wrapper(self): - return treelib.tree_map(func, self) + return treelib.map(func, self) return ft.wraps(func)(wrapper) def bop(func): def wrapper(leaf, rhs=None): if isinstance(rhs, type(leaf)): - return treelib.tree_map(func, leaf, rhs) - return treelib.tree_map(lambda x: func(x, rhs), leaf) + return treelib.map(func, leaf, rhs) + return treelib.map(lambda x: func(x, rhs), leaf) return ft.wraps(func)(wrapper) @@ -391,7 +367,7 @@ def tree_type_path_leaves( is_path_leaf: Callable[[KeyTypePath], bool] | None = None, ) -> Sequence[tuple[KeyTypePath, Any]]: treelib = sepes._src.backend.treelib - _, atomicdef = treelib.tree_flatten(1) + _, atomicdef = treelib.flatten(1) # mainly used for visualization def flatten_one_level(type_path: KeyTypePath, tree: PyTree): @@ -407,7 +383,7 @@ def one_level_is_leaf(node) -> bool: return False return True - path_leaf, treedef = treelib.tree_path_flatten(tree, is_leaf=one_level_is_leaf) + path_leaf, treedef = treelib.path_flatten(tree, is_leaf=one_level_is_leaf) if treedef == atomicdef: yield type_path, tree @@ -501,7 +477,7 @@ def construct_tree( return root -def value_and_tree(func, argnums: int | Sequence[int] = 0): +def value_and_tree(func: Callable[..., T], argnums: int | Sequence[int] = 0): """Call a function on copied input argument and return the value and the tree. Input arguments are copied before calling the function, and the argument @@ -614,15 +590,15 @@ def immutate_is_leaf(node): return False @ft.wraps(func) - def stateless_func(*args, **kwargs) -> tuple[Any, PyTree | tuple[PyTree, ...]]: + def stateless_func(*args, **kwargs) -> tuple[T, PyTree | tuple[PyTree, ...]]: # copy the incoming inputs (args, kwargs) = tree_copy((args, kwargs)) # and edit the node/record to make it mutable (if there is a rule for it) - treelib.tree_map(lambda _: _, (args, kwargs), is_leaf=mutate_is_leaf) + treelib.map(lambda _: _, (args, kwargs), is_leaf=mutate_is_leaf) output = func(*args, **kwargs) # traverse each node in the tree depth-first manner # to undo the mutation (if there is a rule for it) - treelib.tree_map(lambda _: _, (args, kwargs), is_leaf=immutate_is_leaf) + treelib.map(lambda _: _, (args, kwargs), is_leaf=immutate_is_leaf) out_args = tuple(a for i, a in enumerate(args) if i in argnums) out_args = out_args[0] if is_int_argnum else out_args return output, out_args diff --git a/tests/test_index.py b/tests/test_index.py index 097543e..0728380 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -23,7 +23,7 @@ from sepes._src.backend import arraylib, backend, treelib from sepes._src.code_build import autoinit from sepes._src.tree_base import TreeClass, _mutable_instance_registry -from sepes._src.tree_index import AtIndexer, BaseKey +from sepes._src.tree_index import at, BaseKey from sepes._src.tree_util import is_tree_equal, leafwise, value_and_tree test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy") @@ -117,7 +117,7 @@ def __init__(self, c: int, d: int): ], ) def test_indexer_get(tree, expected, where): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal(indexer.get(), expected) assert is_tree_equal(indexer.get(is_parallel=True), expected) @@ -150,11 +150,33 @@ def test_indexer_get(tree, expected, where): ], ) def test_array_indexer_get(tree, expected, where): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal(indexer.get(), expected) assert is_tree_equal(indexer.get(is_parallel=True), expected) +@pytest.mark.skipif(backend != "jax", reason="test jax jit with get") +def test_get_fill_value(): + import jax + import jax.numpy as jnp + + tree = dict(a=jnp.array([1, 2, 3]), b=jnp.array([4, 5, 6])) + mask = dict( + a=jnp.array([False, True, False]), + b=jnp.array([False, True, False]), + ) + + @jax.jit + def jit_func(tree): + return at(tree)[mask].get(fill_value=0) + + out = jit_func(tree) + a = out["a"] + b = out["b"] + assert jnp.all(a == jnp.array([0, 2, 0])) + assert jnp.all(b == jnp.array([0, 5, 0])) + + @pytest.mark.parametrize( ["tree", "expected", "where", "set_value"], [ @@ -191,7 +213,7 @@ def test_array_indexer_get(tree, expected, where): ], ) def test_indexer_set(tree, expected, where, set_value): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal(indexer.set(set_value), expected) assert is_tree_equal(indexer.set(set_value, is_parallel=True), expected) @@ -233,7 +255,7 @@ def test_indexer_set(tree, expected, where, set_value): ], ) def test_array_indexer_set(tree, expected, where, set_value): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal(indexer.set(set_value), expected) assert is_tree_equal(indexer.set(set_value, is_parallel=True), expected) @@ -268,7 +290,7 @@ def test_array_indexer_set(tree, expected, where, set_value): ], ) def test_indexer_apply(tree, expected, where): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal(indexer.apply(lambda _: _X), expected) assert is_tree_equal( indexer.apply(lambda _: _X, is_parallel=True), @@ -307,7 +329,7 @@ def test_indexer_apply(tree, expected, where): ], ) def test_array_indexer_apply(tree, expected, where): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal(indexer.apply(lambda _: _X), expected) assert is_tree_equal( indexer.apply(lambda _: _X, is_parallel=True), @@ -343,7 +365,7 @@ def test_array_indexer_apply(tree, expected, where): ], ) def test_indexer_reduce(tree, expected, where): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal( indexer.reduce(lambda x, y: x + y, initializer=0), expected, @@ -378,7 +400,7 @@ def test_indexer_reduce(tree, expected, where): ], ) def test_array_indexer_reduce(tree, expected, where): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal( indexer.reduce(lambda x, y: x + y, initializer=0), expected, @@ -405,7 +427,7 @@ def test_array_indexer_reduce(tree, expected, where): ], ) def test_indexer_scan(tree, expected, where): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal( indexer.scan(lambda x, s: (x + s, x), state=0), expected, @@ -451,8 +473,8 @@ def __call__(self, x): a = A(1) _, b = value_and_tree(lambda A: A(2))(a) - assert treelib.tree_flatten(a)[0] == [1] - assert treelib.tree_flatten(b)[0] == [3] + assert treelib.flatten(a)[0] == [1] + assert treelib.flatten(b)[0] == [3] with pytest.raises(TypeError): a.at[0](1) @@ -480,7 +502,7 @@ def delete(self, name): def test_unsupported_where(where): t = namedtuple("a", ["x", "y"])(1, 2) with pytest.raises(NotImplementedError): - AtIndexer(t, where=where).get() + at(t, where=where).get() @pytest.mark.skipif(backend != "jax", reason="jax backend needed") @@ -496,7 +518,7 @@ def __init__(self, a, b) -> None: @property def at(self): - return AtIndexer(self) + return at(self) if backend == "jax": import jax.tree_util as jtu @@ -533,7 +555,7 @@ def __init__(self, a, b) -> None: @property def at(self): - return AtIndexer(self) + return at(self) import optree as ot @@ -575,26 +597,26 @@ class Tree(TreeClass): t = Tree() - assert repr(t.at["a"]) == "AtIndexer(tree=Tree(a=1, b=2), where=['a'])" - assert str(t.at["a"]) == "AtIndexer(tree=Tree(a=1, b=2), where=['a'])" - assert repr(t.at[...]) == "AtIndexer(tree=Tree(a=1, b=2), where=[Ellipsis])" + assert repr(t.at["a"]) == "at(Tree(a=1, b=2), where=['a'])" + assert str(t.at["a"]) == "at(Tree(a=1, b=2), where=['a'])" + assert repr(t.at[...]) == "at(Tree(a=1, b=2), where=[Ellipsis])" def test_compat_mask(): tree = [1, 2, [3, 4]] - tree_ = AtIndexer(tree)[[False, False, True]].set(10) + tree_ = at(tree)[[False, False, True]].set(10) assert tree_ == [1, 2, 10] def test_pluck(): tree = [1, 2, [3, 4]] - subtrees = AtIndexer(tree)[2].pluck() + subtrees = at(tree)[2].pluck() assert subtrees[0] == [3, 4] - assert AtIndexer(tree)[0, 1].pluck(1) == [1] - assert AtIndexer(tree)[0, 1].pluck(2) == [1, 2] + assert at(tree)[0, 1].pluck(1) == [1] + assert at(tree)[0, 1].pluck(2) == [1, 2] tree = dict(a=1, b=2) - assert AtIndexer(tree)[...].pluck() == [1, 2] + assert at(tree)[...].pluck() == [1, 2] @pytest.mark.skipif(backend != "jax", reason="jax backend needed") diff --git a/tests/test_mask.py b/tests/test_mask.py index d03cf8b..9b8de81 100644 --- a/tests/test_mask.py +++ b/tests/test_mask.py @@ -13,6 +13,8 @@ # limitations under the License. import copy +import functools as ft +import os from typing import Any import pytest @@ -20,17 +22,12 @@ from sepes._src.backend import backend, treelib from sepes._src.code_build import autoinit from sepes._src.tree_base import TreeClass -from sepes._src.tree_mask import ( - freeze, - is_frozen, - tree_mask, - tree_unmask, - unfreeze, -) -import os +from sepes._src.tree_mask import is_masked, tree_mask, tree_unmask from sepes._src.tree_util import is_tree_equal, leafwise, tree_hash test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy") +freeze = ft.partial(tree_mask, cond=lambda _: True) +unfreeze = ft.partial(tree_unmask, cond=lambda _: True) if test_arraylib == "jax": import jax.numpy as arraylib @@ -54,14 +51,14 @@ class A(TreeClass): b = a.at[...].apply(freeze) c = ( a.at["a"] - .apply(unfreeze, is_leaf=is_frozen) + .apply(unfreeze, is_leaf=is_masked) .at["b"] - .apply(unfreeze, is_leaf=is_frozen) + .apply(unfreeze, is_leaf=is_masked) ) - assert treelib.tree_flatten(a)[0] == [1, 2] - assert treelib.tree_flatten(b)[0] == [] - assert treelib.tree_flatten(c)[0] == [1, 2] + assert treelib.flatten(a)[0] == [1, 2] + assert treelib.flatten(b)[0] == [] + assert treelib.flatten(c)[0] == [1, 2] assert unfreeze(freeze(1.0)) == 1.0 @autoinit @@ -80,17 +77,17 @@ class A(TreeClass): b: int a = A(1, 2) - b = treelib.tree_map(freeze, a) + b = treelib.map(freeze, a) c = ( a.at["a"] - .apply(unfreeze, is_leaf=is_frozen) + .apply(unfreeze, is_leaf=is_masked) .at["b"] - .apply(unfreeze, is_leaf=is_frozen) + .apply(unfreeze, is_leaf=is_masked) ) - assert treelib.tree_flatten(a)[0] == [1, 2] - assert treelib.tree_flatten(b)[0] == [] - assert treelib.tree_flatten(c)[0] == [1, 2] + assert treelib.flatten(a)[0] == [1, 2] + assert treelib.flatten(b)[0] == [] + assert treelib.flatten(c)[0] == [1, 2] @autoinit class L0(TreeClass): @@ -104,11 +101,11 @@ class L1(TreeClass): class L2(TreeClass): c: L1 = L1() - t = treelib.tree_map(freeze, L2()) + t = treelib.map(freeze, L2()) - assert treelib.tree_flatten(t)[0] == [] - assert treelib.tree_flatten(t.c)[0] == [] - assert treelib.tree_flatten(t.c.b)[0] == [] + assert treelib.flatten(t)[0] == [] + assert treelib.flatten(t.c)[0] == [] + assert treelib.flatten(t.c.b)[0] == [] class L1(TreeClass): def __init__(self): @@ -118,9 +115,9 @@ class L2(TreeClass): def __init__(self): self.c = L1() - t = treelib.tree_map(freeze, L2()) - assert treelib.tree_flatten(t.c)[0] == [] - assert treelib.tree_flatten(t.c.b)[0] == [] + t = treelib.map(freeze, L2()) + assert treelib.flatten(t.c)[0] == [] + assert treelib.flatten(t.c.b)[0] == [] def test_freeze_errors(): @@ -160,25 +157,25 @@ class Test(TreeClass): c: str = freeze("test") t = Test() - assert treelib.tree_flatten(t)[0] == [1] + assert treelib.flatten(t)[0] == [1] with pytest.raises(AttributeError): - treelib.tree_map(freeze, t).a = 1 + treelib.map(freeze, t).a = 1 with pytest.raises(AttributeError): - treelib.tree_map(unfreeze, t).a = 1 + treelib.map(unfreeze, t).a = 1 hash(t) t = Test() - treelib.tree_map(unfreeze, t, is_leaf=is_frozen) - treelib.tree_map(freeze, t) + treelib.map(unfreeze, t, is_leaf=is_masked) + treelib.map(freeze, t) @autoinit class Test(TreeClass): a: int - t = treelib.tree_map(freeze, (Test(100))) + t = treelib.map(freeze, (Test(100))) class Test(TreeClass): def __init__(self, x): @@ -223,7 +220,7 @@ class Test(TreeClass): t = Test() - assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == [] + assert treelib.flatten(treelib.map(freeze, t))[0] == [] def test_freeze_nondiff(): @@ -234,10 +231,10 @@ class Test(TreeClass): t = Test() - assert treelib.tree_flatten(t)[0] == ["a"] - assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == [] - assert treelib.tree_flatten( - (treelib.tree_map(freeze, t)).at["b"].apply(unfreeze, is_leaf=is_frozen) + assert treelib.flatten(t)[0] == ["a"] + assert treelib.flatten(treelib.map(freeze, t))[0] == [] + assert treelib.flatten( + (treelib.map(freeze, t)).at["b"].apply(unfreeze, is_leaf=is_masked) )[0] == ["a"] @autoinit @@ -246,11 +243,11 @@ class T0(TreeClass): t = T0() - assert treelib.tree_flatten(t)[0] == ["a"] - assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == [] + assert treelib.flatten(t)[0] == ["a"] + assert treelib.flatten(treelib.map(freeze, t))[0] == [] - assert treelib.tree_flatten(t)[0] == ["a"] - assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == [] + assert treelib.flatten(t)[0] == ["a"] + assert treelib.flatten(treelib.map(freeze, t))[0] == [] def test_freeze_nondiff_with_mask(): @@ -278,11 +275,11 @@ class L2(TreeClass): t = t.at["d"]["d"]["a"].apply(freeze) t = t.at["d"]["d"]["b"].apply(freeze) - assert treelib.tree_flatten(t)[0] == [10, 20, 30, 1, 2, 3, 3] + assert treelib.flatten(t)[0] == [10, 20, 30, 1, 2, 3, 3] def test_non_dataclass_input_to_freeze(): - assert treelib.tree_flatten(freeze(1))[0] == [] + assert treelib.flatten(freeze(1))[0] == [] def test_tree_mask(): @@ -299,18 +296,18 @@ class L1(TreeClass): tree = L1() - assert treelib.tree_flatten(tree)[0] == [1, 2, 3] - assert treelib.tree_flatten(treelib.tree_map(freeze, tree))[0] == [] - assert treelib.tree_flatten(treelib.tree_map(freeze, tree))[0] == [] - assert treelib.tree_flatten(tree.at[...].apply(freeze))[0] == [] - assert treelib.tree_flatten(tree.at[tree > 1].apply(freeze))[0] == [1] - assert treelib.tree_flatten(tree.at[tree == 1].apply(freeze))[0] == [2, 3] - assert treelib.tree_flatten(tree.at[tree < 1].apply(freeze))[0] == [1, 2, 3] + assert treelib.flatten(tree)[0] == [1, 2, 3] + assert treelib.flatten(treelib.map(freeze, tree))[0] == [] + assert treelib.flatten(treelib.map(freeze, tree))[0] == [] + assert treelib.flatten(tree.at[...].apply(freeze))[0] == [] + assert treelib.flatten(tree.at[tree > 1].apply(freeze))[0] == [1] + assert treelib.flatten(tree.at[tree == 1].apply(freeze))[0] == [2, 3] + assert treelib.flatten(tree.at[tree < 1].apply(freeze))[0] == [1, 2, 3] - assert treelib.tree_flatten(tree.at["a"].apply(freeze))[0] == [2, 3] - assert treelib.tree_flatten(tree.at["b"].apply(freeze))[0] == [1] - assert treelib.tree_flatten(tree.at["b"]["x"].apply(freeze))[0] == [1, 3] - assert treelib.tree_flatten(tree.at["b"]["y"].apply(freeze))[0] == [1, 2] + assert treelib.flatten(tree.at["a"].apply(freeze))[0] == [2, 3] + assert treelib.flatten(tree.at["b"].apply(freeze))[0] == [1] + assert treelib.flatten(tree.at["b"]["x"].apply(freeze))[0] == [1, 3] + assert treelib.flatten(tree.at["b"]["y"].apply(freeze))[0] == [1, 2] def test_tree_unmask(): @@ -328,21 +325,21 @@ class L1(TreeClass): tree = L1() frozen_tree = tree.at[...].apply(freeze) - assert treelib.tree_flatten(frozen_tree)[0] == [] + assert treelib.flatten(frozen_tree)[0] == [] mask = tree == tree - unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_frozen) - assert treelib.tree_flatten(unfrozen_tree)[0] == [1, 2, 3] + unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_masked) + assert treelib.flatten(unfrozen_tree)[0] == [1, 2, 3] mask = tree > 1 - unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_frozen) - assert treelib.tree_flatten(unfrozen_tree)[0] == [2, 3] + unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_masked) + assert treelib.flatten(unfrozen_tree)[0] == [2, 3] - unfrozen_tree = frozen_tree.at["a"].apply(unfreeze, is_leaf=is_frozen) - # assert treelib.tree_flatten(unfrozen_tree)[0] == [1] + unfrozen_tree = frozen_tree.at["a"].apply(unfreeze, is_leaf=is_masked) + # assert treelib.flatten(unfrozen_tree)[0] == [1] - # unfrozen_tree = frozen_tree.at["b"].apply(unfreeze, is_leaf=is_frozen) - # assert treelib.tree_flatten(unfrozen_tree)[0] == [2, 3] + # unfrozen_tree = frozen_tree.at["b"].apply(unfreeze, is_leaf=is_masked) + # assert treelib.flatten(unfrozen_tree)[0] == [2, 3] def test_tree_mask_unfreeze(): @@ -361,12 +358,12 @@ class L1(TreeClass): mask = tree == tree frozen_tree = tree.at[...].apply(freeze) - unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_frozen) - assert treelib.tree_flatten(unfrozen_tree)[0] == [1, 2, 3] + unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_masked) + assert treelib.flatten(unfrozen_tree)[0] == [1, 2, 3] # frozen_tree = tree.at["a"].apply(freeze) - # unfrozen_tree = frozen_tree.at["a"].apply(unfreeze, is_leaf=is_frozen) - # assert treelib.tree_flatten(unfrozen_tree)[0] == [1, 2, 3] + # unfrozen_tree = frozen_tree.at["a"].apply(unfreeze, is_leaf=is_masked) + # assert treelib.flatten(unfrozen_tree)[0] == [1, 2, 3] def test_wrapper(): @@ -403,18 +400,16 @@ def test_wrapper(): @pytest.mark.skipif(backend == "default", reason="no array backend installed") def test_tree_mask_tree_unmask(): tree = [1, 2, 3.0] - assert treelib.tree_flatten(tree_mask(tree))[0] == [3.0] - assert treelib.tree_flatten(tree_unmask(tree_mask(tree)))[0] == [1, 2, 3.0] + assert treelib.flatten(tree_mask(tree))[0] == [3.0] + assert treelib.flatten(tree_unmask(tree_mask(tree)))[0] == [1, 2, 3.0] mask_func = lambda x: x < 2 - assert treelib.tree_flatten(tree_mask(tree, mask_func))[0] == [2, 3.0] + assert treelib.flatten(tree_mask(tree, mask_func))[0] == [2, 3.0] assert freeze(freeze(1)) == freeze(1) - assert tree_mask({"a": 1}, mask={"a": True}) == {"a": freeze(1)} - - with pytest.raises(ValueError): - tree_mask({"a": 1}, mask=1.0) + with pytest.raises(TypeError): + tree_mask({"a": 1}, cond=1.0) assert copy.copy(freeze(1)) == freeze(1) @@ -424,7 +419,7 @@ def test_tree_mask_tree_unmask(): @pytest.mark.skipif(backend == "default", reason="no array backend installed") def test_array_tree_mask_tree_unmask(): - frozen_array = tree_mask(arraylib.ones((5, 5)), mask=lambda _: True) + frozen_array = tree_mask(arraylib.ones((5, 5)), cond=lambda _: True) assert frozen_array == frozen_array assert not (frozen_array == freeze(arraylib.ones((5, 6)))) diff --git a/tests/test_operator.py b/tests/test_operator.py index 8a89205..b8908c1 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -15,6 +15,7 @@ from __future__ import annotations import math +import os from typing import Any import pytest @@ -22,9 +23,10 @@ from sepes._src.backend import backend from sepes._src.code_build import autoinit, field from sepes._src.tree_base import TreeClass -from sepes._src.tree_mask import freeze +from sepes._src.tree_mask import tree_mask from sepes._src.tree_util import bcmap, is_tree_equal, leafwise -import os + +freeze = lambda x: tree_mask(x, cond=lambda _: True) test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy") if test_arraylib == "jax": @@ -171,3 +173,17 @@ def test_bcmap(tree, expected): def test_math_operations_errors(): with pytest.raises(TypeError): tree1 + "s" + + +def test_bcmap_int_argnum_broadcast_to(): + def func(x, y): + return x + y + + assert bcmap(func, broadcast_to=1)(1, [2, 3, 4]) == [3, 4, 5] + + +def test_bcmap_key_argnum_broadcast_to(): + def func(x, y): + return x + y + + assert bcmap(func, broadcast_to="y")(x=1, y=[2, 3, 4]) == [3, 4, 5] diff --git a/tests/test_pprint.py b/tests/test_pprint.py index 622f654..8eb8003 100644 --- a/tests/test_pprint.py +++ b/tests/test_pprint.py @@ -15,12 +15,11 @@ from __future__ import annotations import dataclasses as dc -import re +import os from collections import namedtuple from typing import Any import pytest -import os test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy") backend = os.environ.get("SEPES_BACKEND", "jax") @@ -29,8 +28,6 @@ from sepes._src.tree_pprint import ( _table, tree_diagram, - tree_graph, - tree_mermaid, tree_repr, tree_str, tree_summary, @@ -144,7 +141,7 @@ def test_tree_summary(): assert ( tree_summary(r1, depth=1) # trunk-ignore(flake8/E501) - == "┌────┬────────────┬─────┬───────┐\n│Name│Type │Count│Size │\n├────┼────────────┼─────┼───────┤\n│.a │int │1 │ │\n├────┼────────────┼─────┼───────┤\n│.b │str │1 │ │\n├────┼────────────┼─────┼───────┤\n│.c │float │1 │ │\n├────┼────────────┼─────┼───────┤\n│.d │str │1 │ │\n├────┼────────────┼─────┼───────┤\n│.e │list │5 │ │\n├────┼────────────┼─────┼───────┤\n│.f │set │1 │ │\n├────┼────────────┼─────┼───────┤\n│.g │dict │27 │100.00B│\n├────┼────────────┼─────┼───────┤\n│.h │f32[5,1] │5 │20.00B │\n├────┼────────────┼─────┼───────┤\n│.i │f32[1,6] │6 │24.00B │\n├────┼────────────┼─────┼───────┤\n│.j │f32[1,1,4,5]│20 │80.00B │\n├────┼────────────┼─────┼───────┤\n│.k │tuple │3 │ │\n├────┼────────────┼─────┼───────┤\n│.l │a │2 │ │\n├────┼────────────┼─────┼───────┤\n│.m │f32[5,5] │25 │100.00B│\n├────┼────────────┼─────┼───────┤\n│.n │bool[] │1 │1.00B │\n├────┼────────────┼─────┼───────┤\n│.o │c64[2] │2 │16.00B │\n├────┼────────────┼─────┼───────┤\n│Σ │Repr1 │101 │341.00B│\n└────┴────────────┴─────┴───────┘" + == "┌────┬─────────────────────────┬─────┬───────┐\n│Name│Type │Count│Size │\n├────┼─────────────────────────┼─────┼───────┤\n│.a │int │1 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.b │str │1 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.c │float │1 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.d │str │1 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.e │list[int,int,int,int,int]│5 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.f │set │1 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.g │dict │27 │100.00B│\n├────┼─────────────────────────┼─────┼───────┤\n│.h │f32[5,1] │5 │20.00B │\n├────┼─────────────────────────┼─────┼───────┤\n│.i │f32[1,6] │6 │24.00B │\n├────┼─────────────────────────┼─────┼───────┤\n│.j │f32[1,1,4,5] │20 │80.00B │\n├────┼─────────────────────────┼─────┼───────┤\n│.k │tuple[int,int,int] │3 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.l │a[int,int] │2 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.m │f32[5,5] │25 │100.00B│\n├────┼─────────────────────────┼─────┼───────┤\n│.n │bool[] │1 │1.00B │\n├────┼─────────────────────────┼─────┼───────┤\n│.o │c64[2] │2 │16.00B │\n├────┼─────────────────────────┼─────┼───────┤\n│Σ │Repr1 │101 │341.00B│\n└────┴─────────────────────────┴─────┴───────┘" ) assert ( @@ -165,20 +162,6 @@ def test_tree_diagram(): assert tree_diagram(r1, depth=1) == out -@pytest.mark.skipif(backend != "jax", reason="jax is not installed") -def test_tree_mermaid(): - assert ( - re.sub(r"id\d*", "***", tree_mermaid(r1, depth=1)) - # trunk-ignore(flake8/E501) - == 'flowchart LR\n ***("Repr1")\n *** --- ***(".a=1")\n *** --- ***(".b=string")\n *** --- ***(".c=1.0")\n *** --- ***(".d=aaaaa")\n *** --- ***(".e=[...]")\n *** --- ***(".f={...}")\n *** --- ***(".g=dict(...)")\n *** --- ***(".h=f32[5,1](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".i=f32[1,6](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".j=f32[1,1,4,5](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".k=(...)")\n *** --- ***(".l=a(...)")\n *** --- ***(".m=f32[5,5](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".n=bool[]")\n *** --- ***(".o=c64[2]")' - ) - assert ( - re.sub(r"id\d*", "***", tree_mermaid(r1, depth=2)) - # trunk-ignore(flake8/E501) - == 'flowchart LR\n ***("Repr1")\n *** --- ***(".a=1")\n *** --- ***(".b=string")\n *** --- ***(".c=1.0")\n *** --- ***(".d=aaaaa")\n *** --- ***(".e:list")\n *** --- ***("[0]=10")\n *** --- ***("[1]=10")\n *** --- ***("[2]=10")\n *** --- ***("[3]=10")\n *** --- ***("[4]=10")\n *** --- ***(".f={...}")\n *** --- ***(".g:dict")\n *** --- ***("[\'a\']=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")\n *** --- ***("[\'b\']=bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")\n *** --- ***("[\'c\']=f32[5,5](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".h=f32[5,1](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".i=f32[1,6](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".j=f32[1,1,4,5](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".k:tuple")\n *** --- ***("[0]=1")\n *** --- ***("[1]=2")\n *** --- ***("[2]=3")\n *** --- ***(".l:a")\n *** --- ***(".b=1")\n *** --- ***(".c=2")\n *** --- ***(".m=f32[5,5](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".n=bool[]")\n *** --- ***(".o=c64[2]")' - ) - - @pytest.mark.skipif(backend != "jax", reason="jax is not installed") def test_misc(): x = (1, 2, 3) @@ -251,16 +234,6 @@ def test_invalid_depth(): tree_diagram(1, depth="a") with pytest.raises(TypeError): tree_summary(1, depth="a") - with pytest.raises(TypeError): - tree_mermaid(1, depth="a") - - -@pytest.mark.skipif(backend != "jax", reason="jax is not installed") -def test_tree_graph(): - assert ( - re.sub(r"\b\d{10,}", "***", tree_graph(r1)) - == 'digraph G {\n *** [label="Repr1", shape=box];\n *** [label=".a=1", shape=box];\n *** -> ***;\n *** [label=".b=string", shape=box];\n *** -> ***;\n *** [label=".c=1.0", shape=box];\n *** -> ***;\n *** [label=".d=aaaaa", shape=box];\n *** -> ***;\n *** [label=".e:list", shape=box];\n *** -> ***;\n *** [label="[0]=10", shape=box];\n *** -> ***;\n *** [label="[1]=10", shape=box];\n *** -> ***;\n *** [label="[2]=10", shape=box];\n *** -> ***;\n *** [label="[3]=10", shape=box];\n *** -> ***;\n *** [label="[4]=10", shape=box];\n *** -> ***;\n *** [label=".f={...}", shape=box];\n *** -> ***;\n *** [label=".g:dict", shape=box];\n *** -> ***;\n *** [label="[\'a\']=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", shape=box];\n *** -> ***;\n *** [label="[\'b\']=bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", shape=box];\n *** -> ***;\n *** [label="[\'c\']=f32[5,5](μ=1.00, σ=0.00, ∈[1.00,1.00])", shape=box];\n *** -> ***;\n *** [label=".h=f32[5,1](μ=1.00, σ=0.00, ∈[1.00,1.00])", shape=box];\n *** -> ***;\n *** [label=".i=f32[1,6](μ=1.00, σ=0.00, ∈[1.00,1.00])", shape=box];\n *** -> ***;\n *** [label=".j=f32[1,1,4,5](μ=1.00, σ=0.00, ∈[1.00,1.00])", shape=box];\n *** -> ***;\n *** [label=".k:tuple", shape=box];\n *** -> ***;\n *** [label="[0]=1", shape=box];\n *** -> ***;\n *** [label="[1]=2", shape=box];\n *** -> ***;\n *** [label="[2]=3", shape=box];\n *** -> ***;\n *** [label=".l:a", shape=box];\n *** -> ***;\n *** [label=".b=1", shape=box];\n *** -> ***;\n *** [label=".c=2", shape=box];\n *** -> ***;\n *** [label=".m=f32[5,5](μ=1.00, σ=0.00, ∈[1.00,1.00])", shape=box];\n *** -> ***;\n *** [label=".n=bool[]", shape=box];\n *** -> ***;\n *** [label=".o=c64[2]", shape=box];\n *** -> ***;\n}' - ) @pytest.mark.skipif(backend != "jax", reason="jax is not installed") @@ -270,9 +243,25 @@ def test_tracer_repr(): @jax.jit def f(x): out = tree_repr(x) - assert out == "Tracer(f32[10,10])" + assert out == "DynamicJaxprTracer(f32[10,10])" out = tree_str(x) - assert out == "Tracer(f32[10,10])" + assert out == "DynamicJaxprTracer(f32[10,10])" return x f(jax.numpy.ones((10, 10))) + + +@pytest.mark.skipif(backend != "jax", reason="testing jax specific sharding info") +def test_jax_sharding_tree_summary(): + import jax + import numpy as np + from jax.sharding import Mesh, NamedSharding, PartitionSpec + + x = jax.numpy.ones([4 * 4, 2 * 2]) + mesh = Mesh(devices=np.array(jax.devices()).reshape(4, 2), axis_names=["i", "j"]) + sharding = NamedSharding(mesh=mesh, spec=PartitionSpec("i", "j")) + x = jax.device_put(x, device=sharding) + assert ( + tree_summary(x) + == "┌────┬───────────┬─────┬───────┐\n│Name│Type │Count│Size │\n├────┼───────────┼─────┼───────┤\n│Σ │G:f32[16,4]│64 │256.00B│\n│ │S:f32[4,2] │ │ │\n└────┴───────────┴─────┴───────┘" + ) diff --git a/tests/test_treeclass.py b/tests/test_treeclass.py index 67da1f3..61887e3 100644 --- a/tests/test_treeclass.py +++ b/tests/test_treeclass.py @@ -15,11 +15,12 @@ import copy import dataclasses as dc import inspect +import os from typing import Any import numpy.testing as npt import pytest -import os + from sepes._src.backend import backend, treelib from sepes._src.code_build import ( autoinit, @@ -29,8 +30,10 @@ fields, ) from sepes._src.tree_base import TreeClass -from sepes._src.tree_mask import freeze -from sepes._src.tree_util import Partial, is_tree_equal, value_and_tree +from sepes._src.tree_mask import tree_mask +from sepes._src.tree_util import is_tree_equal, partial, value_and_tree + +freeze = lambda x: tree_mask(x, cond=lambda _: True) test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy") if test_arraylib == "jax": @@ -147,7 +150,7 @@ def __init__( test = Test() - assert treelib.tree_flatten(test)[0] == [] + assert treelib.flatten(test)[0] == [] class Test(TreeClass): def __init__(self, a=arraylib.array([1, 2, 3]), b=arraylib.array([4, 5, 6])): @@ -155,7 +158,7 @@ def __init__(self, a=arraylib.array([1, 2, 3]), b=arraylib.array([4, 5, 6])): self.b = b test = Test() - npt.assert_allclose(treelib.tree_flatten(test)[0][0], arraylib.array([4, 5, 6])) + npt.assert_allclose(treelib.flatten(test)[0][0], arraylib.array([4, 5, 6])) def test_post_init(): @@ -200,7 +203,7 @@ def inc(self, x): l1 = L1() - assert treelib.tree_flatten(l1)[0] == [2, 4, 5, 5] + assert treelib.flatten(l1)[0] == [2, 4, 5, 5] assert l1.inc(10) == 20 assert l1.sub(10) == 0 assert l1.d == 5 @@ -212,7 +215,7 @@ class L1(L0): l1 = L1() - assert treelib.tree_flatten(l1)[0] == [2, 4, 5] + assert treelib.flatten(l1)[0] == [2, 4, 5] def test_registering_state(): @@ -414,7 +417,7 @@ class Test(TreeClass): t = Test(1) assert t.a == freeze(1) - assert treelib.tree_flatten(t)[0] == [] + assert treelib.flatten(t)[0] == [] def test_super(): @@ -522,10 +525,10 @@ def test_partial(): def f(a, b, c): return a + b + c - f_a = Partial(f, ..., 2, 3) + f_a = partial(f, ..., 2, 3) assert f_a(1) == 6 - f_b = Partial(f, 1, ..., 3) + f_b = partial(f, 1, ..., 3) assert f_b(2) == 6 assert f_b == f_b