diff --git a/.github/workflows/test_jax.yml b/.github/workflows/test_jax.yml
index 644a14b..2b04bd8 100644
--- a/.github/workflows/test_jax.yml
+++ b/.github/workflows/test_jax.yml
@@ -29,6 +29,7 @@ jobs:
run: |
export SEPES_TEST_ARRAYLIB=jax
export SEPES_BACKEND=jax
+ export XLA_FLAGS=--xla_force_host_platform_device_count=8
python -m pip install .
coverage run -m pytest tests
diff --git a/CHANGELOG.md b/CHANGELOG.md
index a6fd830..82fa907 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,51 @@
# Changelog
+## V0.12
+
+### Deprecations
+
+- Reduce the core API size by removing:
+ 1) `tree_graph` (for graphviz)
+ 2) `tree_mermaid` (mermaidjs)
+ 3) `Partial/partial` -> Use `jax.tree_util.Partial` instead.
+ 4) `is_tree_equal` -> Use `bcmap(numpy.testing.*)(pytree1, pytree2)` instead.
+ 5) `freeze` -> Use `ft.partial(tree_mask, lambda _: True)` instead.
+ 6) `unfreeze` -> Use `tree_unmask` instead.
+ 7) `is_nondiff`
+ 8) `BaseKey`
+
+
+### Changes
+
+- `tree_{mask,unmask}` now accepts only callable `cond` argument.
+
+ For masking using pytree boolean mask use the following pattern:
+
+ ```python
+ import jax
+ import sepes as sp
+ import functools as ft
+ tree = [[1, 2], 3] # the nested tree
+ where = [[True, False], True] # mask tree[0][1] and tree[1]
+ mask = ft.partial(sp.tree_mask, cond=lambda _: True)
+ sp.at(tree)[where].apply(mask) # apply using `at`
+ # [[#1, 2], #3]
+ # or simply apply to the node directly
+ tree = [[mask(1), 2], mask(3)]
+ # [[#1, 2], #3]
+ ```
+
+- Rename `is_frozen` to `is_masked`
+ - frozen could mean non-trainable array, however the masking is not only for arrays but also for other types that will be hidden across jax transformations.
+
+- Rename `AtIndexer` to `at` for shorter syntax.
+
+### Additions
+
+- Add `fill_value` in `at[...].get(fill_value=...)` to add default value for non
+ selected leaves. Useful for arrays under `jax.jit` to avoid variable size related errors.
+- Add `jax.tree_util.{SequenceKey,GetAttrKey,DictKey}` as valid path keys in `at[...]`.
+
## V0.11.3
- Raise error if `autoinit` is used with `__init__` method defined.
@@ -7,20 +53,43 @@
- Add `at` as an alias for `AtIndexer` for shorter syntax.
- Deprecate `AtIndexer.__call__` in favor of `value_and_tree` to apply function in a functional manner by copying the input argument.
-```python
-import sepes as sp
-class Counter(sp.TreeClass):
- def __init__(self, count: int):
- self.count = count
- def increment(self, value):
- self.count += value
- return self.count
-counter = Counter(0)
-# the function follow jax.value_and_grad semantics where the tree is the
-# copied mutated input argument, if the function mutates the input arguments
-sp.value_and_tree(lambda C: C.increment(1))(counter)
-# (1, Counter(count=1))
-```
+ ```python
+ import sepes as sp
+ class Counter(sp.TreeClass):
+ def __init__(self, count: int):
+ self.count = count
+ def increment(self, value):
+ self.count += value
+ return self.count
+ counter = Counter(0)
+ # the function follow jax.value_and_grad semantics where the tree is the
+ # copied mutated input argument, if the function mutates the input arguments
+ sp.value_and_tree(lambda C: C.increment(1))(counter)
+ # (1, Counter(count=1))
+ ```
+
+- Add sharding info in `tree_summary`, `G` for global, `S` for sharded shape.
+
+ ```python
+ import jax
+ import sepes as sp
+ from jax.sharding import Mesh, NamedSharding as N, PartitionSpec as P
+ import numpy as np
+ import os
+ os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
+ x = jax.numpy.ones([4 * 4, 2 * 2])
+ mesh = Mesh(devices=np.array(jax.devices()).reshape(4, 2), axis_names=["i", "j"])
+ sharding = N(mesh=mesh, spec=P("i", "j"))
+ x = jax.device_put(x, device=sharding)
+
+ print(sp.tree_summary(x))
+ ┌────┬───────────┬─────┬───────┐
+ │Name│Type │Count│Size │
+ ├────┼───────────┼─────┼───────┤
+ │Σ │G:f32[16,4]│64 │256.00B│
+ │ │S:f32[4,2] │ │ │
+ └────┴───────────┴─────┴───────┘
+ ```
- Updated docstrings. e.g. How to construct flops counter in `tree_summary` using `jax.jit`
diff --git a/docs/API/constructor.rst b/docs/API/constructor.rst
new file mode 100644
index 0000000..3afe172
--- /dev/null
+++ b/docs/API/constructor.rst
@@ -0,0 +1,10 @@
+🏗️ Constructor utils API
+=============================
+
+
+.. currentmodule:: sepes
+
+.. autofunction:: field
+.. autofunction:: fields
+.. autofunction:: autoinit
+.. autofunction:: leafwise
\ No newline at end of file
diff --git a/docs/API/core.rst b/docs/API/core.rst
deleted file mode 100644
index e23f8c7..0000000
--- a/docs/API/core.rst
+++ /dev/null
@@ -1,31 +0,0 @@
-🎯 Core API
-=============================
-
-
-.. currentmodule:: sepes
-
-.. autoclass:: TreeClass
- :members:
- at
-.. autoclass:: Partial
-.. autoclass:: partial
-.. autoclass:: AtIndexer
- :members:
- get,
- set,
- apply,
- scan,
- reduce,
- pluck,
-
-.. autoclass:: at
-.. autoclass:: BaseKey
- :members:
- __eq__
-.. autofunction:: autoinit
-.. autofunction:: leafwise
-.. autofunction:: field
-.. autofunction:: fields
-.. autofunction:: bcmap
-.. autofunction:: is_tree_equal
-.. autofunction:: value_and_tree
\ No newline at end of file
diff --git a/docs/API/masking.rst b/docs/API/masking.rst
index c1be176..3414824 100644
--- a/docs/API/masking.rst
+++ b/docs/API/masking.rst
@@ -3,9 +3,6 @@
.. currentmodule:: sepes
-.. autofunction:: is_nondiff
-.. autofunction:: freeze
-.. autofunction:: unfreeze
-.. autofunction:: is_frozen
+.. autofunction:: is_masked
.. autofunction:: tree_mask
.. autofunction:: tree_unmask
diff --git a/docs/API/module.rst b/docs/API/module.rst
new file mode 100644
index 0000000..4b56db8
--- /dev/null
+++ b/docs/API/module.rst
@@ -0,0 +1,10 @@
+📍 Module API
+=============================
+
+
+.. currentmodule:: sepes
+
+.. autoclass:: TreeClass
+ :members:
+ at
+
diff --git a/docs/API/pretty_print.rst b/docs/API/pretty_print.rst
index d863ea9..627fdf4 100644
--- a/docs/API/pretty_print.rst
+++ b/docs/API/pretty_print.rst
@@ -4,8 +4,6 @@
.. currentmodule:: sepes
.. autofunction:: tree_diagram
-.. autofunction:: tree_graph
-.. autofunction:: tree_mermaid
.. autofunction:: tree_repr
.. autofunction:: tree_str
.. autofunction:: tree_summary
\ No newline at end of file
diff --git a/docs/API/sepes.rst b/docs/API/sepes.rst
index a857ed8..3e396bd 100644
--- a/docs/API/sepes.rst
+++ b/docs/API/sepes.rst
@@ -5,7 +5,9 @@
:maxdepth: 2
:caption: API Documentation
- core
+ module
masking
+ tree
+ constructor
pretty_print
backend
diff --git a/docs/API/tree.rst b/docs/API/tree.rst
new file mode 100644
index 0000000..0cb4f9a
--- /dev/null
+++ b/docs/API/tree.rst
@@ -0,0 +1,17 @@
+🌲 Tree utils API
+=============================
+
+
+.. currentmodule:: sepes
+
+.. autoclass:: at
+ :members:
+ get,
+ set,
+ apply,
+ scan,
+ reduce,
+ pluck,
+
+.. autofunction:: value_and_tree
+.. autofunction:: bcmap
\ No newline at end of file
diff --git a/docs/_static/tree_graph.svg b/docs/_static/tree_graph.svg
deleted file mode 100644
index 380a167..0000000
--- a/docs/_static/tree_graph.svg
+++ /dev/null
@@ -1,67 +0,0 @@
-
-
-
-
-
diff --git a/docs/_static/tree_graph_stylized.svg b/docs/_static/tree_graph_stylized.svg
deleted file mode 100644
index f6a8d7b..0000000
--- a/docs/_static/tree_graph_stylized.svg
+++ /dev/null
@@ -1,67 +0,0 @@
-
-
-
-
-
diff --git a/docs/_static/tree_mermaid.jpg b/docs/_static/tree_mermaid.jpg
deleted file mode 100644
index 07f1d82..0000000
Binary files a/docs/_static/tree_mermaid.jpg and /dev/null differ
diff --git a/sepes/__init__.py b/sepes/__init__.py
index 75f69f1..d4ff6a3 100644
--- a/sepes/__init__.py
+++ b/sepes/__init__.py
@@ -15,69 +15,37 @@
from sepes._src.backend import backend_context
from sepes._src.code_build import autoinit, field, fields
from sepes._src.tree_base import TreeClass
-from sepes._src.tree_index import AtIndexer, BaseKey, at
-from sepes._src.tree_mask import (
- freeze,
- is_frozen,
- is_nondiff,
- tree_mask,
- tree_unmask,
- unfreeze,
-)
-from sepes._src.tree_pprint import (
- tree_diagram,
- tree_graph,
- tree_mermaid,
- tree_repr,
- tree_str,
- tree_summary,
-)
-from sepes._src.tree_util import (
- Partial,
- bcmap,
- is_tree_equal,
- leafwise,
- partial,
- value_and_tree,
-)
+from sepes._src.tree_index import at
+from sepes._src.tree_mask import is_masked, tree_mask, tree_unmask
+from sepes._src.tree_pprint import tree_diagram, tree_repr, tree_str, tree_summary
+from sepes._src.tree_util import bcmap, leafwise, value_and_tree
-__all__ = (
- # general utils
+__all__ = [
+ # module utils
"TreeClass",
- "is_tree_equal",
- "field",
- "fields",
- "autoinit",
# pprint utils
"tree_diagram",
- "tree_graph",
- "tree_mermaid",
"tree_repr",
"tree_str",
"tree_summary",
# masking utils
- "is_nondiff",
- "is_frozen",
- "freeze",
- "unfreeze",
+ "is_masked",
"tree_unmask",
"tree_mask",
- # indexing utils
- "AtIndexer",
- "at",
- "BaseKey",
# tree utils
+ "at",
"bcmap",
- "Partial",
- "partial",
- "leafwise",
"value_and_tree",
+ # construction utils
+ "field",
+ "fields",
+ "autoinit",
+ "leafwise",
# backend utils
"backend_context",
-)
+]
-__version__ = "0.11.3"
+__version__ = "0.12.0"
-AtIndexer.__module__ = "sepes"
+at.__module__ = "sepes"
TreeClass.__module__ = "sepes"
-Partial.__module__ = "sepes"
diff --git a/sepes/_src/backend/arraylib/__init__.py b/sepes/_src/backend/arraylib/__init__.py
index bbdd925..2bfeeaa 100644
--- a/sepes/_src/backend/arraylib/__init__.py
+++ b/sepes/_src/backend/arraylib/__init__.py
@@ -14,20 +14,32 @@
"""Backend tools for sepes."""
from __future__ import annotations
+
import functools as ft
+from typing import Callable, NamedTuple
+
+
+class NoImplError(NamedTuple):
+ op: Callable
+
+ def __call__(self, *args, **kwargs):
+ raise NotImplementedError(f"No implementation for {self.op}"
+ f" with {args=} {kwargs=}")
+
-tobytes = ft.singledispatch(lambda array: ...)
-where = ft.singledispatch(lambda condition, x, y: ...)
-nbytes = ft.singledispatch(lambda array: ...)
-shape = ft.singledispatch(lambda array: ...)
-dtype = ft.singledispatch(lambda array: ...)
-min = ft.singledispatch(lambda array: ...)
-max = ft.singledispatch(lambda array: ...)
-mean = ft.singledispatch(lambda array: ...)
-std = ft.singledispatch(lambda array: ...)
-all = ft.singledispatch(lambda array: ...)
-is_floating = ft.singledispatch(lambda array: ...)
-is_integer = ft.singledispatch(lambda array: ...)
-is_inexact = ft.singledispatch(lambda array: ...)
-is_bool = ft.singledispatch(lambda array: ...)
-ndarrays: tuple[type, ...] = ()
+tobytes = ft.singledispatch(NoImplError("tobytes"))
+where = ft.singledispatch(NoImplError("where"))
+nbytes = ft.singledispatch(NoImplError("nbytes"))
+shape = ft.singledispatch(NoImplError("shape"))
+dtype = ft.singledispatch(NoImplError("dtype"))
+min = ft.singledispatch(NoImplError("min"))
+max = ft.singledispatch(NoImplError("max"))
+mean = ft.singledispatch(NoImplError("mean"))
+std = ft.singledispatch(NoImplError("std"))
+all = ft.singledispatch(NoImplError("all"))
+array_equal = ft.singledispatch(NoImplError("array_equal"))
+is_floating = ft.singledispatch(NoImplError("is_floating"))
+is_integer = ft.singledispatch(NoImplError("is_integer"))
+is_inexact = ft.singledispatch(NoImplError("is_inexact"))
+is_bool = ft.singledispatch(NoImplError("is_bool"))
+ndarrays: list[type] = []
diff --git a/sepes/_src/backend/arraylib/jax.py b/sepes/_src/backend/arraylib/jax.py
index 1022f8e..c494b78 100644
--- a/sepes/_src/backend/arraylib/jax.py
+++ b/sepes/_src/backend/arraylib/jax.py
@@ -14,12 +14,13 @@
from __future__ import annotations
-
-from jax import Array
import jax.numpy as jnp
+import numpy as np
+from jax import Array
+
import sepes._src.backend.arraylib as arraylib
-arraylib.tobytes.register(Array, lambda x: jnp.array(x).tobytes())
+arraylib.tobytes.register(Array, lambda x: np.array(x).tobytes())
arraylib.where.register(Array, jnp.where)
arraylib.nbytes.register(Array, lambda x: x.nbytes)
arraylib.shape.register(Array, jnp.shape)
@@ -29,8 +30,9 @@
arraylib.mean.register(Array, jnp.mean)
arraylib.std.register(Array, jnp.std)
arraylib.all.register(Array, jnp.all)
+arraylib.array_equal.register(Array, np.array_equal) # NOTE: not traceable
arraylib.is_floating.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.floating))
arraylib.is_integer.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.integer))
arraylib.is_inexact.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.inexact))
arraylib.is_bool.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.bool_))
-arraylib.ndarrays += (Array,)
+arraylib.ndarrays.append(Array)
diff --git a/sepes/_src/backend/arraylib/numpy.py b/sepes/_src/backend/arraylib/numpy.py
index 285c916..1bf8d55 100644
--- a/sepes/_src/backend/arraylib/numpy.py
+++ b/sepes/_src/backend/arraylib/numpy.py
@@ -16,6 +16,7 @@
import numpy as np
from numpy import ndarray
+
import sepes._src.backend.arraylib as arraylib
arraylib.tobytes.register(ndarray, lambda x: np.array(x).tobytes())
@@ -28,8 +29,9 @@
arraylib.mean.register(ndarray, np.mean)
arraylib.std.register(ndarray, np.std)
arraylib.all.register(ndarray, np.all)
+arraylib.array_equal.register(ndarray, np.array_equal)
arraylib.is_floating.register(ndarray, lambda x: np.issubdtype(x.dtype, np.floating))
arraylib.is_integer.register(ndarray, lambda x: np.issubdtype(x.dtype, np.integer))
arraylib.is_inexact.register(ndarray, lambda x: np.issubdtype(x.dtype, np.inexact))
arraylib.is_bool.register(ndarray, lambda x: np.issubdtype(x.dtype, np.bool_))
-arraylib.ndarrays += (ndarray,)
+arraylib.ndarrays.append(ndarray)
diff --git a/sepes/_src/backend/arraylib/torch.py b/sepes/_src/backend/arraylib/torch.py
index 696a309..6ddac57 100644
--- a/sepes/_src/backend/arraylib/torch.py
+++ b/sepes/_src/backend/arraylib/torch.py
@@ -17,6 +17,7 @@
import numpy as np
import torch
from torch import Tensor
+
import sepes._src.backend.arraylib as arraylib
floatings = [torch.float16, torch.float32, torch.float64]
@@ -33,8 +34,9 @@
arraylib.mean.register(Tensor, torch.mean)
arraylib.std.register(Tensor, torch.std)
arraylib.all.register(Tensor, torch.all)
+arraylib.array_equal.register(Tensor, torch.equal)
arraylib.is_floating.register(Tensor, lambda x: x.dtype in floatings)
arraylib.is_integer.register(Tensor, lambda x: x.dtype in integers)
arraylib.is_inexact.register(Tensor, lambda x: x.dtype in floatings + complexes)
arraylib.is_bool.register(Tensor, lambda x: x.dtype == torch.bool)
-arraylib.ndarrays += (Tensor,)
+arraylib.ndarrays.append(Tensor)
diff --git a/sepes/_src/backend/treelib/__init__.py b/sepes/_src/backend/treelib/__init__.py
index 90655b7..a6cb835 100644
--- a/sepes/_src/backend/treelib/__init__.py
+++ b/sepes/_src/backend/treelib/__init__.py
@@ -61,7 +61,7 @@ class AbstractTreeLib(abc.ABC):
@staticmethod
@abc.abstractmethod
- def tree_map(
+ def map(
func: Callable[..., Any],
tree: Any,
*rest: Any,
@@ -72,7 +72,7 @@ def tree_map(
@staticmethod
@abc.abstractmethod
- def tree_path_map(
+ def path_map(
func: Callable[..., Any],
tree: Any,
*rest: Any,
@@ -83,7 +83,7 @@ def tree_path_map(
@staticmethod
@abc.abstractmethod
- def tree_flatten(
+ def flatten(
tree: Any,
*,
is_leaf: Callable[[Any], bool] | None = None,
@@ -92,7 +92,7 @@ def tree_flatten(
@staticmethod
@abc.abstractmethod
- def tree_path_flatten(
+ def path_flatten(
tree: Any,
*,
is_leaf: Callable[[Any], bool] | None = None,
@@ -101,7 +101,7 @@ def tree_path_flatten(
@staticmethod
@abc.abstractmethod
- def tree_unflatten(treedef: Any, leaves: Iterable[Any]) -> Any:
+ def unflatten(treedef: Any, leaves: Iterable[Any]) -> Any:
...
@staticmethod
diff --git a/sepes/_src/backend/treelib/jax.py b/sepes/_src/backend/treelib/jax.py
index cf49f8a..43bc9d7 100644
--- a/sepes/_src/backend/treelib/jax.py
+++ b/sepes/_src/backend/treelib/jax.py
@@ -36,7 +36,7 @@ def __str__(self):
class JaxTreeLib(AbstractTreeLib):
@staticmethod
- def tree_map(
+ def map(
func: Callable[..., Any],
tree: Any,
*rest: Any,
@@ -51,7 +51,7 @@ def tree_map(
return jtu.tree_unflatten(treedef, concurrent_map(func, flat, **config))
@staticmethod
- def tree_path_map(
+ def path_map(
func: Callable[..., Any],
tree: Any,
*rest: Any,
@@ -66,7 +66,7 @@ def tree_path_map(
return jtu.tree_unflatten(treedef, concurrent_map(func, flat, **config))
@staticmethod
- def tree_flatten(
+ def flatten(
tree: Any,
*,
is_leaf: Callable[[Any], bool] | None = None,
@@ -74,7 +74,7 @@ def tree_flatten(
return jtu.tree_flatten(tree, is_leaf=is_leaf)
@staticmethod
- def tree_path_flatten(
+ def path_flatten(
tree: Any,
*,
is_leaf: Callable[[Any], bool] | None = None,
@@ -82,7 +82,7 @@ def tree_path_flatten(
return jtu.tree_flatten_with_path(tree, is_leaf=is_leaf)
@staticmethod
- def tree_unflatten(treedef: jtu.PyTreeDef, leaves: Iterable[Any]) -> Any:
+ def unflatten(treedef: jtu.PyTreeDef, leaves: Iterable[Any]) -> Any:
return jtu.tree_unflatten(treedef, leaves)
@staticmethod
diff --git a/sepes/_src/backend/treelib/optree.py b/sepes/_src/backend/treelib/optree.py
index 78015ad..4747494 100644
--- a/sepes/_src/backend/treelib/optree.py
+++ b/sepes/_src/backend/treelib/optree.py
@@ -61,7 +61,7 @@ def __str__(self) -> str:
class OpTreeTreeLib(AbstractTreeLib):
@staticmethod
- def tree_map(
+ def map(
func: Callable[..., Any],
tree: Any,
*rest: Any,
@@ -76,7 +76,7 @@ def tree_map(
return ot.tree_unflatten(treedef, concurrent_map(func, flat, **config))
@staticmethod
- def tree_path_map(
+ def path_map(
func: Callable[..., Any],
tree: Any,
*rest: Any,
@@ -92,7 +92,7 @@ def tree_path_map(
return ot.tree_unflatten(treedef, concurrent_map(func, flat, **config))
@staticmethod
- def tree_flatten(
+ def flatten(
tree: Any,
*,
is_leaf: Callable[[Any], bool] | None = None,
@@ -101,7 +101,7 @@ def tree_flatten(
return (leaves, treedef)
@staticmethod
- def tree_path_flatten(
+ def path_flatten(
tree: Any,
*,
is_leaf: Callable[[Any], bool] | None = None,
@@ -110,7 +110,7 @@ def tree_path_flatten(
return (list(zip(ot.treespec_paths(treedef), leaves)), treedef)
@staticmethod
- def tree_unflatten(treedef: ot.PyTreeDef, leaves: Iterable[Any]) -> Any:
+ def unflatten(treedef: ot.PyTreeDef, leaves: Iterable[Any]) -> Any:
return ot.tree_unflatten(treedef, leaves)
@staticmethod
diff --git a/sepes/_src/code_build.py b/sepes/_src/code_build.py
index 4889a8c..e2e44bd 100644
--- a/sepes/_src/code_build.py
+++ b/sepes/_src/code_build.py
@@ -294,6 +294,7 @@ def field(
Buffer creation using :attr:`on_getattr`:
>>> import sepes as sp
+ >>> import jax
>>> import jax.numpy as jnp
>>> @sp.autoinit
... class Tree(sp.TreeClass):
@@ -308,6 +309,7 @@ def field(
Parameterization using :attr:`on_getattr`:
>>> import sepes as sp
+ >>> import jax
>>> import jax.numpy as jnp
>>> def symmetric(array: jax.Array) -> jax.Array:
... triangle = jnp.triu(array) # upper triangle
diff --git a/sepes/_src/tree_base.py b/sepes/_src/tree_base.py
index 978ada2..1f5a0b3 100644
--- a/sepes/_src/tree_base.py
+++ b/sepes/_src/tree_base.py
@@ -19,14 +19,13 @@
import abc
from typing import Any, Hashable, TypeVar
-from typing_extensions import Unpack
+from typing_extensions import Self, Unpack
import sepes
from sepes._src.code_build import fields
+from sepes._src.tree_index import at
from sepes._src.tree_pprint import PPSpec, tree_repr, tree_str
from sepes._src.tree_util import is_tree_equal, tree_copy, tree_hash, value_and_tree
-from typing_extensions import Self
-from sepes._src.tree_index import AtIndexer
T = TypeVar("T", bound=Hashable)
S = TypeVar("S")
@@ -148,11 +147,11 @@ class TreeClass(metaclass=TreeClassMeta):
the tree. for example:
>>> @sp.leafwise
- ... @sp.autoinit
... class Tree(sp.TreeClass):
- ... a:int = 1
- ... b:float = 2.0
- >>> tree = Tree()
+ ... def __init__(self, a:int, b:float):
+ ... self.a = a
+ ... self.b = b
+ >>> tree = Tree(a=1, b=2.0)
>>> tree + 1 # will add 1 to each leaf
Tree(a=2, b=3.0)
@@ -161,45 +160,16 @@ class TreeClass(metaclass=TreeClassMeta):
used to ``get``, ``set``, or ``apply`` a function to a leaf or a group of
leaves using ``leaf`` name, index or by a boolean mask.
- >>> @sp.autoinit
- ... class Tree(sp.TreeClass):
- ... a:int = 1
- ... b:float = 2.0
- >>> tree = Tree()
+ >>> class Tree(sp.TreeClass):
+ ... def __init__(self, a:int, b:float):
+ ... self.a = a
+ ... self.b = b
+ >>> tree = Tree(a=1, b=2.0)
>>> tree.at["a"].get()
Tree(a=1, b=None)
>>> tree.at[0].get()
Tree(a=1, b=None)
- Note:
- - Under ``jax.tree_util.***`` or ``optree`` all :class:`.TreeClass`
- attributes are treated as leaves.
- - To hide/ignore a specific attribute from the tree leaves, during
- ``jax.tree_util.***`` operations, freeze the leaf using :func:`.freeze`
- or :func:`.tree_mask`.
-
- >>> # freeze(exclude) a leaf from the tree leaves:
- >>> import jax
- >>> import sepes as sp
- >>> @sp.autoinit
- ... class Tree(sp.TreeClass):
- ... a:int = 1
- ... b:float = 2.0
- >>> tree = Tree()
- >>> tree = tree.at["a"].apply(sp.freeze)
- >>> jax.tree_util.tree_leaves(tree)
- [2.0]
-
- >>> # undo the freeze
- >>> tree = tree.at["a"].apply(sp.unfreeze, is_leaf=sp.is_frozen)
- >>> jax.tree_util.tree_leaves(tree)
- [1, 2.0]
-
- >>> # using `tree_mask` to exclude a leaf from the tree leaves
- >>> freeze_mask = Tree(a=True, b=False)
- >>> jax.tree_util.tree_leaves(sp.tree_mask(tree, freeze_mask))
- [2.0]
-
Note:
``AttributeError`` is raised, If a method that mutates the instance
is called directly. Instead use :func:`.value_and_tree` to call
@@ -236,23 +206,23 @@ def __init_subclass__(klass: type[T], **k):
if "__delattr__" in vars(klass):
raise TypeError(f"Reserved method `__delattr__` defined in `{klass}`.")
super().__init_subclass__(**k)
- # register the class with the proper tree backend.
- # the registration envolves defining two rules: how to flatten the nested
- # structure of the class and how to unflatten the flattened structure.
- # The flatten rule for `TreeClass` is equivalent to vars(self). and the
- # unflatten rule is equivalent to `klass(**flat_tree)`. The flatten/unflatten
- # rule is exactly same as the flatten rule for normal dictionaries.
+ # - register the class with the proper tree backend.
+ # - the registration envolves defining two rules: how to flatten the nested
+ # structure of the class and how to unflatten the flattened structure.
+ # The flatten rule for `TreeClass` is equivalent to vars(self). and the
+ # unflatten rule is equivalent to `klass(**flat_tree)`. The flatten/unflatten
+ # rule is exactly same as the flatten rule for normal dictionaries.
treelib = sepes._src.backend.treelib
treelib.register_treeclass(klass)
def __setattr__(self, key: str, value: Any) -> None:
- # implements the controlled mutability behavior.
- # In essence, setattr is allowed to set attributes during initialization
- # and during functional call using .at["method"](*, **) by marking the
- # instnace as mutable. Otherwise, setattr is disallowed.
- # recall that during the functional call using .at["method"](*, **)
- # the tree is always copied and the copy is marked as mutable, thus
- # setattr is allowed to set attributes on the copy not the original.
+ # - implements the controlled mutability behavior.
+ # - In essence, setattr is allowed to set attributes during initialization
+ # and during functional call using `value_and_tree(method)(*, **)` by marking the
+ # instnace as mutable. Otherwise, setattr is disallowed.
+ # - recall that during the functional call using `value_and_tree(method)(*, **)`
+ # the tree is always copied and the copy is marked as mutable, thus
+ # setattr is allowed to set attributes on the copy not the original.
if id(self) not in _mutable_instance_registry:
raise AttributeError(
f"Cannot set attribute {value=} to `{key=}` "
@@ -262,13 +232,13 @@ def __setattr__(self, key: str, value: Any) -> None:
getattr(object, "__setattr__")(self, key, value)
def __delattr__(self, key: str) -> None:
- # same as __setattr__ but for delattr.
- # both __setattr__ and __delattr__ are used to implement the
- # controlled mutability behavior during initialization and
- # during functional call using .at["method"](*, **).
- # recall that during the functional call using .at["method"](*, **)
- # the tree is always copied and the copy is marked as mutable, thus
- # setattr is allowed to set attributes on the copy not the original.
+ # - same as __setattr__ but for delattr.
+ # - both __setattr__ and __delattr__ are used to implement the
+ # - controlled mutability behavior during initialization and
+ # during functional call using `value_and_tree(method)(*, **)`.
+ # - recall that during the functional call using `value_and_tree(method)(*, **)`
+ # the tree is always copied and the copy is marked as mutable, thus
+ # setattr is allowed to set attributes on the copy not the original.
if id(self) not in _mutable_instance_registry:
raise AttributeError(
f"Cannot delete attribute `{key}` "
@@ -277,7 +247,7 @@ def __delattr__(self, key: str) -> None:
getattr(object, "__delattr__")(self, key)
@property
- def at(self) -> AtIndexer[Self]:
+ def at(self) -> at[Self]:
"""Immutable out-of-place indexing.
- ``.at[***].get()``:
@@ -292,20 +262,18 @@ def at(self) -> AtIndexer[Self]:
- ``int`` for positional indexing for sequences.
- ``...`` to select all leaves.
- a boolean mask of the same structure as the tree
- - ``re.Pattern`` to index all keys matching a regex pattern.
- - an instance of ``BaseKey`` with custom logic to index a pytree.
- a tuple of the above types to index multiple keys at same level.
Example:
>>> import sepes as sp
- >>> @sp.autoinit
- ... class Tree(sp.TreeClass):
- ... a: int = 1
- ... b: float = 2.0
+ >>> class Tree(sp.TreeClass):
+ ... def __init__(self, a:int, b:float):
+ ... self.a = a
+ ... self.b = b
... def add(self, x: int) -> int:
... self.a += x
... return self.a
- >>> tree = Tree()
+ >>> tree = Tree(a=1, b=2.0)
>>> tree.at["a"].get()
Tree(a=1, b=None)
>>> tree.at["a"].set(100)
@@ -317,7 +285,10 @@ def at(self) -> AtIndexer[Self]:
- ``pytree.at[*][**]`` is equivalent to selecting pytree.*.** .
- ``pytree.at[*, **]`` is equivalent selecting pytree.* and pytree.**
"""
- return AtIndexer(self)
+ # NOTE: use `at` as a property to enable chaining syntax.
+ # instead of at(at(tree)[...].apply(...))[...].set(...)
+ # chaining syntax is tree.at[...].apply(...).at[...].set(...)
+ return at(self)
def __repr__(self) -> str:
return tree_repr(self)
diff --git a/sepes/_src/tree_index.py b/sepes/_src/tree_index.py
index 721c388..4bb98f4 100644
--- a/sepes/_src/tree_index.py
+++ b/sepes/_src/tree_index.py
@@ -12,31 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Define lens-like indexing/masking for pytrees."""
-
-# enable get/set/apply/scan/reduce operations on selected parts of a nested
-# structure -pytree- in out-of-place manner. this process invovles defining two
-# parts: 1) *where* to select the parts of the pytree and 2) *what* to do with
-# the selected parts. the *where* part is defined either by a path or a boolean
-# mask. the *what* part is defined by a set value, or a function to apply to
-# the selected parts. once we have a *final* boolean mask that encompasses all
-# path and the boolean mask, we can use `tree_map` to apply the *what* part to
-# the *where* part. for example, for a tree = [[1, 2], 3, 4] and boolean mask
-# [[True, False], False, True] and path mask [0][1], then we select only leaf
-# 1 that is at the intersection of the boolean mask and the path mask. then we
-# apply the *what* part to the *where* part.
+"""Define lens-like indexing for pytrees
+
+This module provides a way to index and mask pytrees (e.g. TreeClass) in an
+out-of-place manner.Out-of-place means that the original pytree is not modified,
+instead a new pytree with the selected leaves are modified.
+
+The indexing is done through two concepts:
+
+1) Selection (Where): Determines parts of the pytree for manipulation via a path or a boolean mask.
+2) Operation (What): Defines actions on selected parts, such as setting values or applying functions.
+
+For example, the following code defines a dict pytree with where of same structure
+as the tree. The where (Selection) defines which parts of the tree to select and
+the set (Operation) operation sets the selected parts to 100.
+
+>>> import sepes as sp
+>>> tree = {"a": 1, "b": [1, 2, 3]}
+>>> where = {"a": True, "b": [False, True, False]}
+>>> sp.at(tree)[where].set(100)
+{'a': 100, 'b': [1, 100, 3]}
+"""
from __future__ import annotations
import abc
import functools as ft
import re
-from typing import Any, Callable, Hashable, Tuple, TypeVar, Generic
+from typing import Any, Callable, Generic, Hashable, TypeVar, Sequence
from typing_extensions import Self
import sepes
import sepes._src.backend.arraylib as arraylib
+from sepes._src.backend import is_package_avaiable
from sepes._src.backend.treelib import ParallelConfig
from sepes._src.tree_pprint import tree_repr
@@ -44,241 +53,21 @@
S = TypeVar("S")
PyTree = Any
EllipsisType = TypeVar("EllipsisType")
-KeyEntry = TypeVar("KeyEntry", bound=Hashable)
-KeyPath = Tuple[KeyEntry, ...]
+PathKeyEntry = TypeVar("PathKeyEntry", bound=Hashable)
_no_initializer = object()
+_no_fill_value = object()
class BaseKey(abc.ABC):
- """Parent class for all match classes.
-
- - Subclass this class to create custom match keys by implementing
- the `__eq__` method. The ``__eq__`` method should return True if the
- key matches the given path entry and False otherwise. The path entry
- refers to the entry defined in the ``tree_flatten_with_keys`` method of
- the pytree class.
-
- - Typical path entries in ``jax`` are:
-
- - ``jax.tree_util.GetAttrKey`` for attributes
- - ``jax.tree_util.DictKey`` for mapping keys
- - ``jax.tree_util.SequenceKey`` for sequence indices
-
- - When implementing the ``__eq__`` method you can use the ``singledispatchmethod``
- to unpack the path entry for example:
-
- - ``jax.tree_util.GetAttrKey`` -> `key.name`
- - ``jax.tree_util.DictKey`` -> `key.key`
- - ``jax.tree_util.SequenceKey`` -> `key.index`
-
-
- See Examples for more details.
-
- Example:
- >>> # define an match strategy to match a leaf with a given name and type
- >>> import sepes as sp
- >>> from typing import NamedTuple
- >>> import jax
- >>> class NameTypeContainer(NamedTuple):
- ... name: str
- ... type: type
- >>> @jax.tree_util.register_pytree_with_keys_class
- ... class Tree:
- ... def __init__(self, a, b) -> None:
- ... self.a = a
- ... self.b = b
- ... def tree_flatten_with_keys(self):
- ... ak = (NameTypeContainer("a", type(self.a)), self.a)
- ... bk = (NameTypeContainer("b", type(self.b)), self.b)
- ... return (ak, bk), None
- ... @classmethod
- ... def tree_unflatten(cls, aux_data, children):
- ... return cls(*children)
- ... @property
- ... def at(self):
- ... return sp.at(self)
- >>> tree = Tree(1, 2)
- >>> class MatchNameType(sp.BaseKey):
- ... def __init__(self, name, type):
- ... self.name = name
- ... self.type = type
- ... def __eq__(self, other):
- ... if isinstance(other, NameTypeContainer):
- ... return other == (self.name, self.type)
- ... return False
- >>> tree = tree.at[MatchNameType("a", int)].get()
- >>> assert jax.tree_util.tree_leaves(tree) == [1]
-
- Note:
- - use ``BaseKey.def_alias(type, func)`` to define an index type alias
- for `BaseKey` subclasses. This is useful for convience when
- creating new match strategies.
-
- >>> import sepes as sp
- >>> import functools as ft
- >>> from types import FunctionType
- >>> import jax.tree_util as jtu
- >>> # lets define a new match strategy called `FuncKey` that applies
- >>> # a function to the path entry and returns True if the function
- >>> # returns True and False otherwise.
- >>> # for example `FuncKey(lambda x: x.startswith("a"))` will match
- >>> # all leaves that start with "a".
- >>> class FuncKey(sp.BaseKey):
- ... def __init__(self, func):
- ... self.func = func
- ... @ft.singledispatchmethod
- ... def __eq__(self, key):
- ... return self.func(key)
- ... @__eq__.register(jtu.GetAttrKey)
- ... def _(self, key: jtu.GetAttrKey):
- ... # unpack the GetAttrKey
- ... return self.func(key.name)
- ... @__eq__.register(jtu.DictKey)
- ... def _(self, key: jtu.DictKey):
- ... # unpack the DictKey
- ... return self.func(key.key)
- ... @__eq__.register(jtu.SequenceKey)
- ... def _(self, key: jtu.SequenceKey):
- ... return self.func(key.index)
- >>> # instead of using ``FuncKey(function)`` we can define an alias
- >>> # for `FuncKey`, for this example we will define any FunctionType
- >>> # as a `FuncKey` by default.
- >>> @sp.BaseKey.def_alias(FunctionType)
- ... def _(func):
- ... return FuncKey(func)
- >>> # create a simple pytree
- >>> @sp.autoinit
- ... class Tree(sp.TreeClass):
- ... a: int
- ... b: str
- >>> tree = Tree(1, "string")
- >>> # now we can use the `FuncKey` alias to match all leaves that
- >>> # are strings and start with "a"
- >>> tree.at[lambda x: isinstance(x, str) and x.startswith("a")].get()
- Tree(a=1, b=None)
- """
+ """Parent class for all match classes."""
@abc.abstractmethod
- def __eq__(self, entry: KeyEntry) -> bool:
+ def __eq__(self, entry: PathKeyEntry) -> bool:
pass
- broadcastable: bool = False
-
-
-class IndexKey(BaseKey):
- """Match a leaf with a given index."""
-
- def __init__(self, idx: int) -> None:
- self.idx = idx
-
- def __eq__(self, key: KeyEntry) -> bool:
- if isinstance(key, int):
- return self.idx == key
- treelib = sepes._src.backend.treelib
- if isinstance(key, type(treelib.sequence_key(0))):
- return self.idx == key.idx
- return False
-
- def __repr__(self) -> str:
- return f"{self.idx}"
-
-
-class NameKey(BaseKey):
- """Match a leaf with a given key."""
-
- def __init__(self, name: str) -> None:
- self.name = name
-
- def __eq__(self, key: KeyEntry) -> bool:
- if isinstance(key, str):
- return self.name == key
- treelib = sepes._src.backend.treelib
- if isinstance(key, type(treelib.attribute_key(""))):
- return self.name == key.name
- if isinstance(key, type(treelib.dict_key(""))):
- return self.name == key.key
- return False
-
- def __repr__(self) -> str:
- return f"{self.name}"
-
-
-class EllipsisKey(BaseKey):
- """Match all leaves."""
-
- broadcastable = True
-
- def __init__(self, _):
- del _
-
- def __eq__(self, _: KeyEntry) -> bool:
- return True
-
- def __repr__(self) -> str:
- return "..."
-
-
-class MultiKey(BaseKey):
- """Match a leaf with multiple keys at the same level."""
-
- def __init__(self, *keys: tuple[BaseKey, ...]):
- self.keys = tuple(keys)
-
- def __eq__(self, entry) -> bool:
- return any(entry == key for key in self.keys)
-
- def __repr__(self) -> str:
- return f"({', '.join(map(repr, self.keys))})"
-
-
-class RegexKey(BaseKey):
- """Match a leaf with a regex pattern inside 'at' property.
-
- Args:
- pattern: regex pattern to match.
-
- Example:
- >>> import sepes as sp
- >>> import re
- >>> @sp.autoinit
- ... class Tree(sp.TreeClass):
- ... weight_1: float = 1.0
- ... weight_2: float = 2.0
- ... weight_3: float = 3.0
- ... bias: float = 0.0
- >>> tree = Tree()
- >>> tree.at[re.compile(r"weight_.*")].set(100.0) # set all weights to 100.0
- Tree(weight_1=100.0, weight_2=100.0, weight_3=100.0, bias=0.0)
- """
-
- def __init__(self, pattern: str) -> None:
- self.pattern = pattern
-
- def __eq__(self, key: KeyEntry) -> bool:
- if isinstance(key, str):
- return re.fullmatch(self.pattern, key) is not None
- treelib = sepes._src.backend.treelib
- if isinstance(key, type(treelib.attribute_key(""))):
- return re.fullmatch(self.pattern, key.name) is not None
- if isinstance(key, type(treelib.dict_key(""))):
- return re.fullmatch(self.pattern, key.key) is not None
- return False
-
- def __repr__(self) -> str:
- return f"{self.pattern}"
-
-
-# dispatch on type of indexer to convert input item to at indexer
-# `__getitem__` to the appropriate key
-# avoid using container pytree types to avoid conflict between
-# matching as a mask or as an instance of `BaseKey`
-indexer_dispatcher = ft.singledispatch(lambda x: x)
-indexer_dispatcher.register(type(...), EllipsisKey)
-indexer_dispatcher.register(int, IndexKey)
-indexer_dispatcher.register(str, NameKey)
-indexer_dispatcher.register(re.Pattern, RegexKey)
-
-BaseKey.def_alias = indexer_dispatcher.register
+ @property
+ @abc.abstractmethod
+ def broadcast(self): ...
_INVALID_INDEXER = """\
@@ -286,14 +75,13 @@ def __repr__(self) -> str:
- `str` for mapping keys or class attributes.
- `int` for positional indexing for sequences.
- `...` to select all leaves.
+ - ``re.Pattern`` to match a leaf level path with a regex pattern.
- Boolean mask of a compatible structure as the pytree.
- - `re.Pattern` to index all keys matching a regex pattern.
- - Instance of `BaseKey` with custom logic to index a pytree.
- `tuple` of the above types to match multiple leaves at the same level.
"""
_NO_LEAF_MATCH = """\
-No leaf match is found for where={where}. Available keys are {names}.
+No leaf match is found for where={where}, Available keys are {names}
Check the following:
- If where is `str` then check if the key exists as a key or attribute.
- If where is `int` then check if the index is in range.
@@ -321,9 +109,9 @@ def is_leaf_func(node) -> bool:
return False
return True
- return treelib.tree_path_map(func, tree, is_leaf=is_leaf_func)
+ return treelib.path_map(func, tree, is_leaf=is_leaf_func)
- if any(mask.broadcastable for mask in where):
+ if any(where_i.broadcast for where_i in where):
# should the selected subtree be broadcasted to the full tree
# e.g. tree = [[1, 2], 3, 4] and where = [0], then
# broadcast with True will be [[True, True], False, False]
@@ -334,8 +122,8 @@ def is_leaf_func(node) -> bool:
# and without broadcast the result will be [100, 3, 4]
def bool_tree(value: bool, tree: Any):
- leaves, treedef = treelib.tree_flatten(tree, is_leaf=is_leaf)
- return treelib.tree_unflatten(treedef, [value] * len(leaves))
+ leaves, treedef = treelib.flatten(tree, is_leaf=is_leaf)
+ return treelib.unflatten(treedef, [value] * len(leaves))
true_tree = ft.partial(bool_tree, True)
false_tree = ft.partial(bool_tree, False)
@@ -380,9 +168,10 @@ def path_map_func(path, leaf):
mask = one_level_tree_path_map(path_map_func, tree)
if not match:
- path_leaf, _ = treelib.tree_path_flatten(tree, is_leaf=is_leaf)
+ path_leaf, _ = treelib.path_flatten(tree, is_leaf=is_leaf)
+ path = "/".join(str(where_i.input) for where_i in where)
names = "".join("\n - " + treelib.keystr(path) for path, _ in path_leaf)
- raise LookupError(_NO_LEAF_MATCH.format(where=where, names=names))
+ raise LookupError(_NO_LEAF_MATCH.format(where=path, names=names))
return mask
@@ -390,9 +179,10 @@ def path_map_func(path, leaf):
def resolve_where(
where: list[Any],
tree: T,
- is_leaf: Callable[[Any], None] | None = None,
+ is_leaf: Callable[[Any], bool] | None = None,
):
treelib = sepes._src.backend.treelib
+ ndarrays = tuple(arraylib.ndarrays)
def combine_bool_leaves(*leaves):
# given a list of boolean leaves, combine them using `and`
@@ -404,7 +194,7 @@ def combine_bool_leaves(*leaves):
return verdict
def is_bool_leaf(leaf: Any) -> bool:
- if isinstance(leaf, arraylib.ndarrays):
+ if isinstance(leaf, ndarrays):
return arraylib.is_bool(leaf)
return isinstance(leaf, bool)
@@ -423,7 +213,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool:
nonlocal seen_tuple, level_paths, bool_masks
# used to check if a pytree is a valid indexing pytree
# used with `is_leaf` argument of any `tree_*` function
- leaves, _ = treelib.tree_flatten(node)
+ leaves, _ = treelib.flatten(node)
if all(map(is_bool_leaf, leaves)):
# if all leaves are boolean then this is maybe a boolean mask.
@@ -442,7 +232,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool:
bool_masks += [node]
return True
- if isinstance(resolved_key := indexer_dispatcher(node), BaseKey):
+ if isinstance(resolved_key := at.dispatcher(node), BaseKey):
# valid resolution of `BaseKey` is a valid indexing leaf
# makes it possible to dispatch on multi-leaf pytree
level_paths += [resolved_key]
@@ -463,7 +253,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool:
# each for loop iteration is a level in the where path
# this means that if where = ("a", "b", "c") then this means
# we are travering the tree at level "a" then level "b" then level "c"
- treelib.tree_flatten(level_keys, is_leaf=verify_and_aggregate_is_leaf)
+ treelib.flatten(level_keys, is_leaf=verify_and_aggregate_is_leaf)
# if len(level_paths) > 1 then this means that we have multiple keys
# at the same level, for example where = ("a", ("b", "c")) then this
# means that for a parent "a", select "b" and "c".
@@ -476,17 +266,14 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool:
if bool_masks:
all_masks = [mask, *bool_masks] if mask else bool_masks
- mask = treelib.tree_map(combine_bool_leaves, *all_masks)
+ mask = treelib.map(combine_bool_leaves, *all_masks)
return mask
-class AtIndexer(Generic[T]):
+class at(Generic[T]):
"""Operate on a pytree at a given path using a path or mask in out-of-place manner.
- Note:
- Use :class:`.at` as a shorter alias for this class.
-
Args:
tree: pytree to operate on.
where: one of the following:
@@ -495,13 +282,9 @@ class AtIndexer(Generic[T]):
- ``int`` for positional indexing for sequences.
- ``...`` to select all leaves.
- a boolean mask of the same structure as the tree
- - ``re.Pattern`` to index all keys matching a regex pattern.
- - an instance of ``BaseKey`` with custom logic to index a pytree.
+ - ``re.Pattern`` to match a leaf level path with a regex pattern.
- a tuple of the above to match multiple keys at the same level.
- Note:
- Alternatively, use ``at(tree)[where]`` to index a pytree.
-
Example:
>>> import jax
>>> import sepes as sp
@@ -514,26 +297,23 @@ class AtIndexer(Generic[T]):
>>> sp.at(tree)[mask].set(100)
{'a': 1, 'b': [1, 100, 100]}
"""
-
def __init__(self, tree: T, where: list[Any] | None = None) -> None:
- vars(self)["tree"] = tree
- vars(self)["where"] = [] if where is None else where
-
- def __setattr__(self, key: str, _: Any) -> None:
- raise AttributeError(f"Cannot set {key=} on {type(self).__name__} instance")
+ self.tree = tree
+ self.where = [] if where is None else where
def __getitem__(self, where: Any) -> Self:
"""Index a pytree at a given path using a path or mask."""
return type(self)(self.tree, [*self.where, where])
def __repr__(self) -> str:
- return f"{type(self).__name__}(tree={tree_repr(self.tree)}, where={self.where})"
+ return f"{type(self).__name__}({tree_repr(self.tree)}, where={self.where})"
def get(
self,
*,
- is_leaf: Callable[[Any], None] | None = None,
+ is_leaf: Callable[[Any], bool] | None = None,
is_parallel: bool | ParallelConfig = False,
+ fill_value: Any = _no_fill_value,
):
"""Get the leaf values at the specified location.
@@ -547,6 +327,10 @@ def get(
- ``max_workers``: maximum number of workers to use.
- ``kind``: kind of pool to use, either ``thread`` or ``process``.
+ fill_value: the value to fill the non-selected leaves with.
+ Useful to use with ``jax.jit`` to avoid variable size arrays
+ leaves related errors.
+
Returns:
A _new_ pytree of leaf values at the specified location, with the
non-selected leaf values set to None if the leaf is not an array.
@@ -558,19 +342,25 @@ def get(
{'a': None, 'b': [1, None, None]}
"""
treelib = sepes._src.backend.treelib
+ ndarrays = tuple(arraylib.ndarrays)
def leaf_get(where: Any, leaf: Any):
# support both array and non-array leaves
# for array boolean mask we select **parts** of the array that
# matches the mask, for example if the mask is Array([True, False, False])
# and the leaf is Array([1, 2, 3]) then the result is Array([1])
- if isinstance(where, arraylib.ndarrays) and len(arraylib.shape(where)):
+ # because of the variable resultant size of the output
+ if isinstance(where, ndarrays) and len(arraylib.shape(where)):
+ if fill_value is not _no_fill_value:
+ return arraylib.where(where, leaf, fill_value)
return leaf[where]
# non-array boolean mask we select the leaf if the mask is True
# and `None` otherwise
+ if fill_value is not _no_fill_value:
+ return leaf if where else fill_value
return leaf if where else None
- return treelib.tree_map(
+ return treelib.map(
leaf_get,
resolve_where(self.where, self.tree, is_leaf),
self.tree,
@@ -582,7 +372,7 @@ def set(
self,
set_value: Any,
*,
- is_leaf: Callable[[Any], None] | None = None,
+ is_leaf: Callable[[Any], bool] | None = None,
is_parallel: bool | ParallelConfig = False,
):
"""Set the leaf values at the specified location.
@@ -609,6 +399,7 @@ def set(
{'a': 1, 'b': [100, 2, 3]}
"""
treelib = sepes._src.backend.treelib
+ ndarrays = tuple(arraylib.ndarrays)
def leaf_set(where: Any, leaf: Any, set_value: Any):
# support both array and non-array leaves
@@ -616,12 +407,12 @@ def leaf_set(where: Any, leaf: Any, set_value: Any):
# matches the mask, for example if the mask is Array([True, False, False])
# and the leaf is Array([1, 2, 3]) then the result is Array([1, 100, 100])
# with set_value = 100
- if isinstance(where, arraylib.ndarrays):
+ if isinstance(where, ndarrays):
return arraylib.where(where, set_value, leaf)
return set_value if where else leaf
- _, lhsdef = treelib.tree_flatten(self.tree, is_leaf=is_leaf)
- _, rhsdef = treelib.tree_flatten(set_value, is_leaf=is_leaf)
+ _, lhsdef = treelib.flatten(self.tree, is_leaf=is_leaf)
+ _, rhsdef = treelib.flatten(set_value, is_leaf=is_leaf)
if lhsdef == rhsdef:
# do not broadcast set_value if it is a pytree of same structure
@@ -629,7 +420,7 @@ def leaf_set(where: Any, leaf: Any, set_value: Any):
# to tree2 leaves if tree2 is a pytree of same structure as tree
# instead of making each leaf of tree a copy of tree2
# is design is similar to ``numpy`` design `np.at[...].set(Array)`
- return treelib.tree_map(
+ return treelib.map(
leaf_set,
resolve_where(self.where, self.tree, is_leaf),
self.tree,
@@ -638,7 +429,7 @@ def leaf_set(where: Any, leaf: Any, set_value: Any):
is_parallel=is_parallel,
)
- return treelib.tree_map(
+ return treelib.map(
ft.partial(leaf_set, set_value=set_value),
resolve_where(self.where, self.tree, is_leaf),
self.tree,
@@ -650,7 +441,7 @@ def apply(
self,
func: Callable[[Any], Any],
*,
- is_leaf: Callable[[Any], None] | None = None,
+ is_leaf: Callable[[Any], bool] | None = None,
is_parallel: bool | ParallelConfig = False,
):
"""Apply a function to the leaf values at the specified location.
@@ -685,19 +476,19 @@ def apply(
>>> is_parallel = dict(max_workers=2)
>>> images = sp.at(path)[...].apply(imread, is_parallel=is_parallel) # doctest: +SKIP
"""
-
treelib = sepes._src.backend.treelib
+ ndarrays = tuple(arraylib.ndarrays)
def leaf_apply(where: Any, leaf: Any):
# same as `leaf_set` but with `func` applied to the leaf
# one thing to note is that, the where mask select an array
# then the function needs work properly when applied to the selected
# array elements
- if isinstance(where, arraylib.ndarrays):
+ if isinstance(where, ndarrays):
return arraylib.where(where, func(leaf), leaf)
return func(leaf) if where else leaf
- return treelib.tree_map(
+ return treelib.map(
leaf_apply,
resolve_where(self.where, self.tree, is_leaf),
self.tree,
@@ -710,7 +501,7 @@ def scan(
func: Callable[[Any, S], tuple[Any, S]],
state: S,
*,
- is_leaf: Callable[[Any], None] | None = None,
+ is_leaf: Callable[[Any], bool] | None = None,
) -> tuple[Any, S]:
"""Apply a function while carrying a state.
@@ -746,6 +537,7 @@ def scan(
leaf values while carrying a state and returning a single value.
"""
treelib = sepes._src.backend.treelib
+ ndarrays = tuple(arraylib.ndarrays)
running_state = state
def stateless_func(leaf):
@@ -754,11 +546,11 @@ def stateless_func(leaf):
return leaf
def leaf_apply(where: Any, leaf: Any):
- if isinstance(where, arraylib.ndarrays):
+ if isinstance(where, ndarrays):
return arraylib.where(where, stateless_func(leaf), leaf)
return stateless_func(leaf) if where else leaf
- out_tree = treelib.tree_map(
+ out_tree = treelib.map(
leaf_apply,
resolve_where(self.where, self.tree, is_leaf),
self.tree,
@@ -771,7 +563,7 @@ def reduce(
func: Callable[[Any, Any], Any],
*,
initializer: Any = _no_initializer,
- is_leaf: Callable[[Any], None] | None = None,
+ is_leaf: Callable[[Any], bool] | None = None,
) -> Any:
"""Reduce the leaf values at the specified location.
@@ -799,7 +591,7 @@ def reduce(
"""
treelib = sepes._src.backend.treelib
tree = self.get(is_leaf=is_leaf) # type: ignore
- leaves, _ = treelib.tree_flatten(tree, is_leaf=is_leaf)
+ leaves, _ = treelib.flatten(tree, is_leaf=is_leaf)
if initializer is _no_initializer:
return ft.reduce(func, leaves)
return ft.reduce(func, leaves, initializer)
@@ -808,7 +600,7 @@ def pluck(
self,
count: int | None = None,
*,
- is_leaf: Callable[[Any], None] | None = None,
+ is_leaf: Callable[[Any], bool] | None = None,
is_parallel: bool | ParallelConfig = False,
) -> list[Any]:
"""Extract subtrees at the specified location.
@@ -880,7 +672,7 @@ def aggregate_subtrees(node: Any) -> bool:
# for example if tree = dict(a=1) and mask is dict(a=True)
# then returns [1] and not [dict(a=1)]
return False
- leaves, _ = treelib.tree_flatten(node, is_leaf=lambda x: x is None)
+ leaves, _ = treelib.flatten(node, is_leaf=lambda x: x is None)
# in essence if the subtree does not contain any None leaves
# then it is a valid subtree to be plucked
# this because `get` sets the non-selected leaves to None
@@ -890,9 +682,102 @@ def aggregate_subtrees(node: Any) -> bool:
count -= 1
return True
- treelib.tree_flatten(tree, is_leaf=aggregate_subtrees)
+ treelib.flatten(tree, is_leaf=aggregate_subtrees)
return subtrees
-# shorter alias
-at = AtIndexer
+# pass through for boolean pytrees masks and tuple of keys
+at.dispatcher = ft.singledispatch(lambda x: x)
+
+
+def def_rule(
+ user_type: type[T],
+ path_compare_func: Callable[[T, PathKeyEntry], bool],
+ *,
+ broadcastable: bool = False,
+) -> None:
+ # remove the BaseKey abstraction from the user-facing function
+ class UserKey(BaseKey):
+ broadcast: bool = broadcastable
+
+ def __init__(self, input: T):
+ self.input = input
+
+ def __eq__(self, key: PathKeyEntry) -> bool:
+ return path_compare_func(self.input, key)
+
+ at.dispatcher.register(user_type, UserKey)
+
+
+at.def_rule = def_rule
+
+
+# key rules to match user input to with the path entry
+
+
+def str_compare(name: str, key: PathKeyEntry):
+ """Match a leaf with a given name."""
+ if isinstance(key, str):
+ return name == key
+ treelib = sepes._src.backend.treelib
+ if isinstance(key, type(treelib.attribute_key(""))):
+ return name == key.name
+ if isinstance(key, type(treelib.dict_key(""))):
+ return name == key.key
+ return False
+
+
+def int_compare(idx: int, key: PathKeyEntry) -> bool:
+ """Match a leaf with a given index."""
+ if isinstance(key, int):
+ return idx == key
+ treelib = sepes._src.backend.treelib
+ if isinstance(key, type(treelib.sequence_key(0))):
+ return idx == key.idx
+ return False
+
+
+def regex_compare(pattern: re.Pattern, key: PathKeyEntry) -> bool:
+ """Match a path with a regex pattern inside 'at' property."""
+ if isinstance(key, str):
+ return re.fullmatch(pattern, key) is not None
+ treelib = sepes._src.backend.treelib
+ if isinstance(key, type(treelib.attribute_key(""))):
+ return re.fullmatch(pattern, key.name) is not None
+ if isinstance(key, type(treelib.dict_key(""))):
+ return re.fullmatch(pattern, key.key) is not None
+ return False
+
+
+def ellipsis_compare(_, __):
+ return True
+
+
+at.def_rule(str, str_compare, broadcastable=False)
+at.def_rule(int, int_compare, broadcastable=False)
+at.def_rule(re.Pattern, regex_compare, broadcastable=False)
+at.def_rule(type(...), ellipsis_compare, broadcastable=True)
+
+
+class MultiKey(BaseKey):
+ """Match a leaf with multiple keys at the same level."""
+
+ def __init__(self, *keys):
+ self.keys = tuple(keys)
+
+ def __eq__(self, entry: PathKeyEntry) -> bool:
+ return any(entry == key for key in self.keys)
+
+ broadcast: bool = False
+
+
+if is_package_avaiable("jax"):
+ import jax.tree_util as jtu
+
+ def jax_key_compare(input, key: PathKeyEntry) -> bool:
+ """Enable indexing with jax keys directly in `at`."""
+ return input == key
+
+ at.def_rule(jtu.SequenceKey, jax_key_compare, broadcastable=False)
+ at.def_rule(jtu.GetAttrKey, jax_key_compare, broadcastable=False)
+ at.def_rule(jtu.DictKey, jax_key_compare, broadcastable=False)
diff --git a/sepes/_src/tree_mask.py b/sepes/_src/tree_mask.py
index e113e98..5f14370 100644
--- a/sepes/_src/tree_mask.py
+++ b/sepes/_src/tree_mask.py
@@ -30,23 +30,19 @@
MaskType = Union[T, Callable[[Any], bool]]
-class _FrozenError(NamedTuple):
+class _MaskedError(NamedTuple):
opname: str
def __call__(self, *a, **k):
raise NotImplementedError(
- f"Cannot apply `{self.opname}` operation to a frozen object "
+ f"Cannot apply `{self.opname}` operation on a masked object "
f"{', '.join(map(str, a))} "
f"{', '.join(k + '=' + str(v) for k, v in k.items())}.\n"
- "Unfreeze the object first by unmasking the frozen mask:\n"
- "Example:\n"
- ">>> import jax\n"
- ">>> import sepes as sp\n"
- ">>> tree = sp.tree_unmask(tree)"
+ "Unmask the object first using `tree_unmask`"
)
-class _FrozenBase(Static):
+class _MaskBase(Static[T]):
# the objective of this class is to wrap a pytree node with a custom wrapper
# that yields no leaves when flattened. This is useful to avoid updating
# the node by effectivly *hiding it* from function transformations that operates
@@ -69,43 +65,44 @@ def __repr__(self) -> str:
def __str__(self) -> str:
return "#" + tree_str(self.__wrapped__)
- def __copy__(self) -> _FrozenBase[T]:
+ def __copy__(self) -> _MaskBase[T]:
return type(self)(tree_copy(self.__wrapped__))
# raise helpful error message when trying to interact with frozen object
- __add__ = __radd__ = __iadd__ = _FrozenError("+")
- __sub__ = __rsub__ = __isub__ = _FrozenError("-")
- __mul__ = __rmul__ = __imul__ = _FrozenError("*")
- __matmul__ = __rmatmul__ = __imatmul__ = _FrozenError("@")
- __truediv__ = __rtruediv__ = __itruediv__ = _FrozenError("/")
- __floordiv__ = __rfloordiv__ = __ifloordiv__ = _FrozenError("//")
- __mod__ = __rmod__ = __imod__ = _FrozenError("%")
- __pow__ = __rpow__ = __ipow__ = _FrozenError("**")
- __lshift__ = __rlshift__ = __ilshift__ = _FrozenError("<<")
- __rshift__ = __rrshift__ = __irshift__ = _FrozenError(">>")
- __and__ = __rand__ = __iand__ = _FrozenError("and")
- __xor__ = __rxor__ = __ixor__ = _FrozenError("")
- __or__ = __ror__ = __ior__ = _FrozenError("or")
- __neg__ = __pos__ = __abs__ = __invert__ = _FrozenError("unary operation")
- __call__ = _FrozenError("__call__")
-
-
-@tree_summary.def_type(_FrozenBase)
+ __add__ = __radd__ = __iadd__ = _MaskedError("+")
+ __sub__ = __rsub__ = __isub__ = _MaskedError("-")
+ __mul__ = __rmul__ = __imul__ = _MaskedError("*")
+ __matmul__ = __rmatmul__ = __imatmul__ = _MaskedError("@")
+ __truediv__ = __rtruediv__ = __itruediv__ = _MaskedError("/")
+ __floordiv__ = __rfloordiv__ = __ifloordiv__ = _MaskedError("//")
+ __mod__ = __rmod__ = __imod__ = _MaskedError("%")
+ __pow__ = __rpow__ = __ipow__ = _MaskedError("**")
+ __lshift__ = __rlshift__ = __ilshift__ = _MaskedError("<<")
+ __rshift__ = __rrshift__ = __irshift__ = _MaskedError(">>")
+ __and__ = __rand__ = __iand__ = _MaskedError("and")
+ __xor__ = __rxor__ = __ixor__ = _MaskedError("")
+ __or__ = __ror__ = __ior__ = _MaskedError("or")
+ __neg__ = __pos__ = __abs__ = __invert__ = _MaskedError("unary")
+ __lt__ = __le__ = __gt__ = __ge__ = _MaskedError("comparison")
+ __call__ = _MaskedError("__call__")
+
+
+@tree_summary.def_type(_MaskBase)
def _(node) -> str:
return f"#{tree_summary.type_dispatcher(node.__wrapped__)}"
-class _FrozenHashable(_FrozenBase):
+class _MaskedHashable(_MaskBase):
def __hash__(self) -> int:
return tree_hash(self.__wrapped__)
def __eq__(self, rhs: Any) -> bool:
- if not isinstance(rhs, _FrozenHashable):
+ if not isinstance(rhs, _MaskedHashable):
return False
return is_tree_equal(self.__wrapped__, rhs.__wrapped__)
-class _FrozenArray(_FrozenBase):
+class _MaskedArray(_MaskBase):
# wrap arrays with a custom wrapper that implements hash and equality
# using the wrapped array's bytes representation and sha256 hash function
# this is useful to select some array to hold without updating in the process
@@ -115,7 +112,7 @@ def __hash__(self) -> int:
return int(hashlib.sha256(bytes).hexdigest(), 16)
def __eq__(self, other) -> bool:
- if not isinstance(other, _FrozenArray):
+ if not isinstance(other, _MaskedArray):
return False
lhs, rhs = self.__wrapped__, other.__wrapped__
# fast path to avoid calling `all` on large arrays
@@ -123,136 +120,62 @@ def __eq__(self, other) -> bool:
return False
if arraylib.dtype(lhs) != arraylib.dtype(rhs):
return False
- return arraylib.all(lhs == rhs)
+ return arraylib.array_equal(lhs, rhs)
-def freeze(value: T) -> _FrozenBase[T]:
- """Freeze a value to avoid updating it by through function transformations.
-
- Args:
- value: A value to freeze.
-
- Note:
- - :func:`.freeze` is idempotent, i.e. ``freeze(freeze(x)) == freeze(x)``.
-
- Example:
- >>> import jax
- >>> import sepes as sp
- >>> import jax.tree_util as jtu
- >>> # Usage with `jax.tree_util.tree_leaves`
- >>> # no leaves for a wrapped value
- >>> jtu.tree_leaves(sp.freeze(2.))
- []
-
- >>> # retrieve the frozen wrapper value using `is_leaf=sp.is_frozen`
- >>> jtu.tree_leaves(sp.freeze(2.), is_leaf=sp.is_frozen)
- [#2.0]
-
- >>> # Usage with `jax.tree_util.tree_map`
- >>> a= [1,2,3]
- >>> a[1] = sp.freeze(a[1])
- >>> jtu.tree_map(lambda x:x+100, a)
- [101, #2, 103]
- """
+def mask(value: T) -> _MaskBase[T]:
# dispatching is used to customize the type of the wrapper based on the type
# of the value. For instance, hashable values dont need custom hash and
# equality implementations, so they are wrapped with a simpler wrapper.
# this approach avoids type logic in the wrapper equality and hash methods,
# thus effectively improving performance of the wrapper.
- return freeze.type_dispatcher(value)
+ return mask.type_dispatcher(value)
-freeze.type_dispatcher = ft.singledispatch(_FrozenHashable)
-freeze.def_type = freeze.type_dispatcher.register
+mask.type_dispatcher = ft.singledispatch(_MaskedHashable)
+mask.def_type = mask.type_dispatcher.register
for ndarray in arraylib.ndarrays:
- @freeze.def_type(ndarray)
- def freeze_array(value: T) -> _FrozenArray[T]:
+ @mask.def_type(ndarray)
+ def mask_array(value: T) -> _MaskedArray[T]:
# wrap arrays with a custom wrapper that implements hash and equality
# arrays can be hashed by converting them to bytes and hashing the bytes
- return _FrozenArray(value)
+ return _MaskedArray(value)
-@freeze.def_type(_FrozenBase)
-def _(value: _FrozenBase[T]) -> _FrozenBase[T]:
- # idempotent freeze operation, meaning that freeze(freeze(x)) == freeze(x)
+@mask.def_type(_MaskBase)
+def _(value: _MaskBase[T]) -> _MaskBase[T]:
+ # idempotent mask operation, meaning that mask(mask(x)) == mask(x)
# this is useful to avoid recursive unwrapping of frozen values, plus its
- # meaningless to freeze a frozen value.
+ # meaningless to mask a frozen value.
return value
-def is_frozen(value: Any) -> bool:
+def is_masked(value: Any) -> bool:
"""Returns True if the value is a frozen wrapper."""
- return isinstance(value, _FrozenBase)
-
-
-def unfreeze(value: T) -> T:
- """Unfreeze :func:`.freeze` value, otherwise return the value itself.
+ return isinstance(value, _MaskBase)
- Args:
- value: A value to unfreeze.
-
- Note:
- - use ``is_leaf=sp.is_frozen`` with ``tree_map`` to unfreeze a tree.**
- Example:
- >>> import sepes as sp
- >>> import jax
- >>> frozen_value = sp.freeze(1)
- >>> sp.unfreeze(frozen_value)
- 1
- >>> # usage with `jax.tree_map`
- >>> frozen_tree = jax.tree_map(sp.freeze, {"a": 1, "b": 2})
- >>> unfrozen_tree = jax.tree_map(sp.unfreeze, frozen_tree, is_leaf=sp.is_frozen)
- >>> unfrozen_tree
- {'a': 1, 'b': 2}
- """
- return unfreeze.type_dispatcher(value)
+def unmask(value: T) -> T:
+ return unmask.type_dispatcher(value)
-unfreeze.type_dispatcher = ft.singledispatch(lambda x: x)
-unfreeze.def_type = unfreeze.type_dispatcher.register
+unmask.type_dispatcher = ft.singledispatch(lambda x: x)
+unmask.def_type = unmask.type_dispatcher.register
-@unfreeze.def_type(_FrozenBase)
-def _(value: _FrozenBase[T]) -> T:
+@unmask.def_type(_MaskBase)
+def _(value: _MaskBase[T]) -> T:
return getattr(value, "__wrapped__")
def is_nondiff(value: Any) -> bool:
- """Returns True for non-inexact types, False otherwise.
-
- Args:
- value: A value to check.
-
- Note:
- - :func:`.is_nondiff` uses single dispatch to support custom types. To define
- a custom behavior for a certain type, use ``is_nondiff.def_type(type, func)``.
-
- Example:
- >>> import sepes as sp
- >>> import jax.numpy as jnp
- >>> sp.is_nondiff(jnp.array(1)) # int array is non-diff type
- True
- >>> sp.is_nondiff(jnp.array(1.)) # float array is diff type
- False
- >>> sp.is_nondiff(1) # int is non-diff type
- True
- >>> sp.is_nondiff(1.) # float is diff type
- False
-
- Note:
- This function is meant to be used with ``jax.tree_map`` to
- create a mask for non-differentiable nodes in a tree, that can be used
- to freeze the non-differentiable nodes before passing the tree to a
- ``jax`` transformation.
- """
return is_nondiff.type_dispatcher(value)
-is_nondiff.type_dispatcher = ft.singledispatch(lambda x: True)
+is_nondiff.type_dispatcher = ft.singledispatch(lambda _: True)
is_nondiff.def_type = is_nondiff.type_dispatcher.register
@@ -274,78 +197,53 @@ def _(_: float | complex) -> bool:
def _tree_mask_map(
tree: T,
- mask: MaskType,
+ cond: Callable[[Any], bool],
func: type | Callable[[Any], Any],
*,
is_leaf: Callable[[Any], None] | None = None,
):
- treelib = sepes._src.backend.treelib
- # apply func to leaves satisfying mask pytree/condtion
- _, lhsdef = treelib.tree_flatten(tree, is_leaf=is_leaf)
- _, rhsdef = treelib.tree_flatten(mask, is_leaf=is_leaf)
-
- if (lhsdef == rhsdef) and (type(mask) is type(tree)):
- # a tree with the same structure as tree with boolean values
- # and also a callable.
- def map_func(x, y):
- return func(x) if y else x
- return treelib.tree_map(map_func, tree, mask, is_leaf=is_leaf)
-
- if isinstance(mask, Callable):
+ if not isinstance(cond, Callable):
# a callable that accepts a leaf and returns a boolean
# but *not* a tree with the same structure as tree with boolean values.
- def map_func(x):
- return func(x) if mask(x) else x
+ raise TypeError(
+ f"`cond` must be a callable that accepts a leaf and returns a boolean "
+ f" Got {cond=} and {tree=}."
+ )
- return treelib.tree_map(map_func, tree, is_leaf=is_leaf)
+ treelib = sepes._src.backend.treelib
- raise ValueError(
- f"`mask` must be a callable that accepts a leaf and returns a boolean "
- f"or a tree with the same structure as tree with boolean values."
- f" Got {mask=} and {tree=}."
- )
+ def map_func(x):
+ return func(x) if cond(x) else x
+
+ return treelib.map(map_func, tree, is_leaf=is_leaf)
def tree_mask(
tree: T,
- mask: MaskType = is_nondiff,
+ cond: Callable[[Any], bool] = is_nondiff,
*,
is_leaf: Callable[[Any], None] | None = None,
):
"""Mask leaves of a pytree based on ``mask`` boolean pytree or callable.
+ Masked leaves are wrapped with a wrapper that yields no leaves when
+ ``tree_flatten`` is called on it.
+
Args:
tree: A pytree of values.
- mask: A pytree of boolean values or a callable that accepts a leaf and
- returns a boolean. If a leaf is ``True`` either in the mask or the
- callable, the leaf is wrapped by with a wrapper that yields no
- leaves when ``tree_flatten`` is called on it, otherwise
- it is unchanged. defaults to :func:`.is_nondiff` which returns true for
- non-differentiable nodes.
+ cond: A callable that accepts a leaf and returns a boolean to mark the leaf
+ for masking. Defaults to masking non-differentiable leaf nodes that
+ are not instances of of python float, python complex, or inexact
+ array types.
is_leaf: A callable that accepts a leaf and returns a boolean. If
provided, it is used to determine if a value is a leaf. for example,
``is_leaf=lambda x: isinstance(x, list)`` will treat lists as leaves
and will not recurse into them.
- Note:
- - Masked leaves are wrapped with a wrapper that yields no leaves when
- ``tree_flatten`` is called on it.
- - Masking is equivalent to applying :func:`.freeze` to the masked leaves.
-
- >>> import sepes as sp
- >>> import jax
- >>> tree = [1, 2, {"a": 3, "b": 4.}]
- >>> # mask all non-differentiable nodes by default
- >>> def mask_if_nondiff(x):
- ... return sp.freeze(x) if sp.is_nondiff(x) else x
- >>> masked_tree = jax.tree_map(mask_if_nondiff, tree)
-
- - Use masking on tree containing non-differentiable nodes before passing
- the tree to a ``jax`` transformation.
-
Example:
>>> import sepes as sp
+ >>> import jax
>>> tree = [1, 2, {"a": 3, "b": 4.}]
>>> # mask all non-differentiable nodes by default
>>> masked_tree = sp.tree_mask(tree)
@@ -357,32 +255,32 @@ def tree_mask(
[1, 2, {'a': 3, 'b': 4.0}]
Example:
- >>> # pass non-differentiable values to `jax.grad`
+ Pass non-differentiable values to ``jax.grad``
+
>>> import sepes as sp
>>> import jax
>>> @jax.grad
... def square(tree):
... tree = sp.tree_unmask(tree)
- ... return tree[0]**2
+ ... return tree[0] ** 2
>>> tree = (1., 2) # contains a non-differentiable node
>>> square(sp.tree_mask(tree))
(Array(2., dtype=float32, weak_type=True), #2)
"""
- return _tree_mask_map(tree, mask=mask, func=freeze, is_leaf=is_leaf)
+ return _tree_mask_map(tree, cond=cond, func=mask, is_leaf=is_leaf)
-def tree_unmask(tree: T, mask: MaskType = lambda _: True):
- """Undo the masking of tree leaves according to ``mask``. defaults to unmasking all leaves.
+def tree_unmask(tree: T, cond: Callable[[Any], bool] = lambda _: True):
+ """Undo the masking of tree leaves according to ``cond``. defaults to unmasking all leaves.
Args:
tree: A pytree of values.
- mask: A pytree of boolean values or a callable that accepts a leaf and
- returns a boolean. If a leaf is True either in the mask or the
- callable, the leaf is unfrozen, otherwise it is unchanged. defaults
- unmasking all nodes.
+ cond: A callable that accepts a leaf and returns a boolean to mark the
+ leaf to be unmasked. Defaults to always unmask.
Example:
>>> import sepes as sp
+ >>> import jax
>>> tree = [1, 2, {"a": 3, "b": 4.}]
>>> # mask all non-differentiable nodes by default
>>> masked_tree = sp.tree_mask(tree)
@@ -394,27 +292,19 @@ def tree_unmask(tree: T, mask: MaskType = lambda _: True):
[1, 2, {'a': 3, 'b': 4.0}]
Example:
- >>> # pass non-differentiable values to `jax.grad`
+ Pass non-differentiable values to ``jax.grad``
+
>>> import sepes as sp
>>> import jax
>>> @jax.grad
... def square(tree):
... tree = sp.tree_unmask(tree)
- ... return tree[0]**2
+ ... return tree[0] ** 2
>>> tree = (1., 2) # contains a non-differentiable node
>>> square(sp.tree_mask(tree))
(Array(2., dtype=float32, weak_type=True), #2)
-
- Note:
- - Unmasking is equivalent to applying :func:`.unfreeze` on the masked leaves.
-
- >>> import sepes as sp
- >>> import jax
- >>> tree = [1, 2, {"a": 3, "b": 4.}]
- >>> # unmask all nodes
- >>> tree = jax.tree_map(sp.unfreeze, tree, is_leaf=sp.is_frozen)
"""
- return _tree_mask_map(tree, mask=mask, func=unfreeze, is_leaf=is_frozen)
+ return _tree_mask_map(tree, cond=cond, func=unmask, is_leaf=is_masked)
if is_package_avaiable("jax"):
@@ -424,6 +314,6 @@ def tree_unmask(tree: T, mask: MaskType = lambda _: True):
# otherwise calling `freeze` inside a jax transformation on
# a tracer will hide the tracer from jax and will cause leaked tracer
# error.
- @freeze.def_type(jax.core.Tracer)
+ @mask.def_type(jax.core.Tracer)
def _(value: jax.core.Tracer) -> jax.core.Tracer:
return value
diff --git a/sepes/_src/tree_pprint.py b/sepes/_src/tree_pprint.py
index 27ea5f2..fb24e65 100644
--- a/sepes/_src/tree_pprint.py
+++ b/sepes/_src/tree_pprint.py
@@ -31,7 +31,6 @@
from sepes._src.backend import is_package_avaiable
from sepes._src.tree_util import (
Node,
- Partial,
construct_tree,
is_path_leaf_depth_factory,
tree_type_path_leaves,
@@ -178,13 +177,12 @@ def _(func: Callable, **spec: Unpack[PPSpec]) -> str:
return f"{name}({', '.join(header)})"
-@tree_str.def_type(Partial)
@tree_str.def_type(ft.partial)
def _(node: ft.partial, **spec: Unpack[PPSpec]) -> str:
func = tree_str.pp(node.func, **spec)
args = tree_str.pps(tree_str.pp, node.args, **spec)
keywords = tree_str.pps(tree_str.kv_pp, node.keywords, **spec)
- return f"Partial(" + ",".join([func, args, keywords]) + ")"
+ return "partial(" + ",".join([func, args, keywords]) + ")"
@tree_str.def_type(list)
@@ -242,7 +240,6 @@ def array_pp(node, **spec: Unpack[PPSpec]) -> str:
return f"{base}(μ={mean}, σ={std}, ∈{interval})"
-@tree_repr.def_type(Partial)
@tree_repr.def_type(ft.partial)
def _(node: ft.partial, **spec: Unpack[PPSpec]) -> str:
func = tree_repr.pp(node.func, **spec)
@@ -363,139 +360,6 @@ def step(
return (text if tabwidth is None else text.expandtabs(tabwidth)).rstrip()
-def tree_mermaid(
- tree: PyTree,
- depth: int | float = float("inf"),
- is_leaf: Callable[[Any], None] | None = None,
- tabwidth: int | None = 4,
-) -> str:
- """Generate a mermaid diagram syntax for arbitrary pytrees.
-
- Args:
- tree: PyTree
- depth: depth of the tree to print. default is max depth
- is_leaf: function to determine if a node is a leaf. default is None
- tabwidth: tab width of the repr string. default is 4.
-
- Example:
- >>> import sepes as sp
- >>> tree = [1, 2, dict(a=3)]
- >>> # as rendered by mermaid
- >>> print(sp.tree_mermaid(tree)) # doctest: +SKIP
-
- .. image:: ../_static/tree_mermaid.jpg
- :width: 300px
- :align: center
-
- Note:
- - Copy the output and paste it in the mermaid live editor to interact with
- the diagram. https://mermaid.live
- """
-
- def step(node: Node, depth: int = 0) -> str:
- if len(node.children) == 0:
- (key, _), value = node.data
- ppstr = f"{key}=" if key is not None else ""
- ppstr += tree_repr(value, depth=0)
- ppstr = "" + ppstr + ""
- return f'\tid{id(node.parent)} --- id{id(node)}("{ppstr}")\n'
-
- (key, type), _ = node.data
- ppstr = f"{key}:" if key is not None else ""
- ppstr += f"{type.__name__}"
- ppstr = "" + ppstr + ""
-
- if node.parent is None:
- text = f'\tid{id(node)}("{ppstr}")\n'
- else:
- text = f'\tid{id(node.parent)} --- id{id(node)}("{ppstr}")\n'
-
- for child in node.children.values():
- text += step(child, depth=depth + 1)
- return text
-
- is_path_leaf = is_path_leaf_depth_factory(depth)
- root = construct_tree(tree, is_leaf=is_leaf, is_path_leaf=is_path_leaf)
- text = "flowchart LR\n" + step(root)
- return (text.expandtabs(tabwidth) if tabwidth is not None else text).rstrip()
-
-
-# dispatcher for dot nodestyles
-dot_dispatcher = ft.singledispatch(lambda _: dict(shape="box"))
-
-
-def tree_graph(
- tree: PyTree,
- depth: int | float = float("inf"),
- is_leaf: Callable[[Any], None] | None = None,
- tabwidth: int | None = 4,
-) -> str:
- """Generate a dot diagram syntax for arbitrary pytrees.
-
- Args:
- tree: pytree
- depth: depth of the tree to print. default is max depth
- is_leaf: function to determine if a node is a leaf. default is None
- tabwidth: tab width of the repr string. default is 4.
-
- Returns:
- str: dot diagram syntax
-
- Example:
- >>> import sepes as sp
- >>> tree = [1, 2, dict(a=3)]
- >>> # as rendered by graphviz
-
- .. image:: ../_static/tree_graph.svg
-
- Example:
- >>> # define custom style for a node by dispatching on the value
- >>> # the defined function should return a dict of attributes
- >>> # that will be passed to graphviz.
- >>> import sepes as sp
- >>> tree = [1, 2, dict(a=3)]
- >>> @sp.tree_graph.def_nodestyle(list)
- ... def _(_) -> dict[str, str]:
- ... return dict(shape="circle", style="filled", fillcolor="lightblue")
-
- .. image:: ../_static/tree_graph_stylized.svg
- """
-
- def step(node: Node, depth: int = 0) -> str:
- (key, type), value = node.data
-
- # dispatch node style
- style = ", ".join(f"{k}={v}" for k, v in dot_dispatcher(value).items())
-
- if len(node.children) == 0:
- ppstr = f"{key}=" if key is not None else ""
- ppstr += tree_repr(value, depth=0)
- text = f'\t{id(node)} [label="{ppstr}", {style}];\n'
- text += f"\t{id(node.parent)} -> {id(node)};\n"
- return text
-
- ppstr = f"{key}:" if key is not None else ""
- ppstr += f"{type.__name__}"
-
- if node.parent is None:
- text = f'\t{id(node)} [label="{ppstr}", {style}];\n'
- else:
- text = f'\t{id(node)} [label="{ppstr}", {style}];\n'
- text += f"\t{id(node.parent)} -> {id(node)};\n"
-
- for child in node.children.values():
- text += step(child, depth=depth + 1)
- return text
-
- is_path_leaf = is_path_leaf_depth_factory(depth)
- root = construct_tree(tree, is_leaf=is_leaf, is_path_leaf=is_path_leaf)
- text = "digraph G {\n" + step(root) + "}"
- return (text.expandtabs(tabwidth) if tabwidth is not None else text).rstrip()
-
-
-tree_graph.def_nodestyle = dot_dispatcher.register
-
-
def format_width(string, width=60):
"""Strip newline/tab characters if less than max width."""
children_length = len(string) - string.count("\n") - string.count("\t")
@@ -570,39 +434,19 @@ def tree_summary(
>>> import sepes as sp
>>> import jax.numpy as jnp
>>> print(sp.tree_summary([1, [2, [3]], jnp.array([1, 2, 3])]))
- ┌─────────┬──────┬─────┬──────┐
- │Name │Type │Count│Size │
- ├─────────┼──────┼─────┼──────┤
- │[0] │int │1 │ │
- ├─────────┼──────┼─────┼──────┤
- │[1][0] │int │1 │ │
- ├─────────┼──────┼─────┼──────┤
- │[1][1][0]│int │1 │ │
- ├─────────┼──────┼─────┼──────┤
- │[2] │i32[3]│3 │12.00B│
- ├─────────┼──────┼─────┼──────┤
- │Σ │list │6 │12.00B│
- └─────────┴──────┴─────┴──────┘
-
- Example:
- Set custom type display for ``jax`` jaxprs
-
- >>> import jax
- >>> import sepes as sp
- >>> ClosedJaxprType = type(jax.make_jaxpr(lambda x: x)(1))
- >>> @sp.tree_summary.def_type(ClosedJaxprType)
- ... def _(expr: ClosedJaxprType) -> str:
- ... jaxpr = expr.jaxpr
- ... return f"Jaxpr({jaxpr.invars}, {jaxpr.outvars})"
- >>> def func(x, y):
- ... return x
- >>> jaxpr = jax.make_jaxpr(func)(1, 2)
- >>> print(sp.tree_summary(jaxpr))
- ┌────┬──────────────────┬─────┬────┐
- │Name│Type │Count│Size│
- ├────┼──────────────────┼─────┼────┤
- │Σ │Jaxpr([a, b], [a])│1 │ │
- └────┴──────────────────┴─────┴────┘
+ ┌─────────┬────────────────────────────────────┬─────┬──────┐
+ │Name │Type │Count│Size │
+ ├─────────┼────────────────────────────────────┼─────┼──────┤
+ │[0] │int │1 │ │
+ ├─────────┼────────────────────────────────────┼─────┼──────┤
+ │[1][0] │int │1 │ │
+ ├─────────┼────────────────────────────────────┼─────┼──────┤
+ │[1][1][0]│int │1 │ │
+ ├─────────┼────────────────────────────────────┼─────┼──────┤
+ │[2] │i32[3] │3 │12.00B│
+ ├─────────┼────────────────────────────────────┼─────┼──────┤
+ │Σ │list[int,list[int,list[int]],i32[3]]│6 │12.00B│
+ └─────────┴────────────────────────────────────┴─────┴──────┘
Example:
Display flops of a function in tree summary
@@ -662,14 +506,14 @@ def tree_size(tree: PyTree) -> int:
def reduce_func(acc, node):
return acc + tree_summary.size_dispatcher(node)
- leaves, _ = treelib.tree_flatten(tree)
+ leaves, _ = treelib.flatten(tree)
return ft.reduce(reduce_func, leaves, 0)
def tree_count(tree: PyTree) -> int:
def reduce_func(acc, node):
return acc + tree_summary.count_dispatcher(node)
- leaves, _ = treelib.tree_flatten(tree)
+ leaves, _ = treelib.flatten(tree)
return ft.reduce(reduce_func, leaves, 0)
traces_leaves = tree_type_path_leaves(
@@ -725,6 +569,21 @@ def _(node: Any) -> str:
dtype = arraylib.dtype(node)
return tree_repr(ShapeDTypePP(shape, dtype))
+@tree_summary.def_type(list)
+@tree_summary.def_type(tuple)
+def _(node: tuple) -> str:
+ # - output Container[types,...] instead of just container type in the type col.
+ # - usually this encounterd if the tree_summary depth is not inf
+ # so the tree leaves could contain non-atomic types.
+ treelib = sepes._src.backend.treelib
+
+ one_level_types = treelib.map(
+ tree_summary.type_dispatcher,
+ node,
+ is_leaf=lambda x: False if id(x) == id(node) else True,
+ )
+ return f"{type(node).__name__}[{','.join(one_level_types)}]"
+
if is_package_avaiable("jax"):
# jax pretty printing extra handlers
@@ -764,4 +623,19 @@ def _(node, **spec: Unpack[PPSpec]) -> str:
shape = node.aval.shape
dtype = node.aval.dtype
string = tree_repr.dispatch(ShapeDTypePP(shape, dtype), **spec)
- return f"Tracer({string})"
+ return f"{type(node).__name__}({string})"
+
+ # handle the sharding info if it is sharded
+ @tree_summary.def_type(jax.Array)
+ def _(node: Any) -> str:
+ """Return the type repr of the node."""
+ # global shape
+ global_shape = arraylib.shape(node)
+ shard_shape = node.sharding.shard_shape(global_shape)
+ dtype = arraylib.dtype(node)
+ global_info = tree_repr(ShapeDTypePP(global_shape, dtype))
+
+ if global_shape == shard_shape:
+ return global_info
+ shard_info = tree_repr(ShapeDTypePP(shard_shape, dtype))
+ return f"G:{global_info}\nS:{shard_info}"
diff --git a/sepes/_src/tree_util.py b/sepes/_src/tree_util.py
index 88fb24d..429b1b6 100644
--- a/sepes/_src/tree_util.py
+++ b/sepes/_src/tree_util.py
@@ -16,15 +16,17 @@
from __future__ import annotations
+import copy
import functools as ft
import operator as op
-import copy
from math import ceil, floor, trunc
-from typing import Any, Callable, Hashable, Iterator, Sequence, Tuple, TypeVar, Generic
-import sepes._src.backend.arraylib as arraylib
-from sepes._src.backend import is_package_avaiable
+from typing import Any, Callable, Generic, Hashable, Iterator, Sequence, Tuple, TypeVar
+
from typing_extensions import ParamSpec
+
import sepes
+import sepes._src.backend.arraylib as arraylib
+from sepes._src.backend import is_package_avaiable
T = TypeVar("T")
T1 = TypeVar("T1")
@@ -42,7 +44,7 @@
def tree_hash(*trees: PyTree) -> int:
treelib = sepes._src.backend.treelib
- leaves, treedef = treelib.tree_flatten(trees)
+ leaves, treedef = treelib.flatten(trees)
return hash((*leaves, treedef))
@@ -57,7 +59,7 @@ def tree_copy(tree: T) -> T:
def is_leaf(node) -> bool:
return isinstance(node, types)
- return treelib.tree_map(tree_copy.copy_dispatcher, tree, is_leaf=is_leaf)
+ return treelib.map(tree_copy.copy_dispatcher, tree, is_leaf=is_leaf)
# default behavior is to copy the tree elements except for registered types
@@ -66,18 +68,31 @@ def is_leaf(node) -> bool:
tree_copy.def_type = tree_copy.copy_dispatcher.register
+@tree_copy.def_type(int)
+@tree_copy.def_type(float)
+@tree_copy.def_type(complex)
+@tree_copy.def_type(str)
+@tree_copy.def_type(bytes)
+def _(x: T) -> T:
+ # skip applying `copy.copy` on immutable atom types
+ return x
+
+
def is_array_like(node) -> bool:
return hasattr(node, "shape") and hasattr(node, "dtype")
-def _is_leaf_rhs_equal(leaf, rhs) -> bool:
+def _is_leaf_rhs_equal(leaf, rhs):
if is_array_like(leaf):
if is_array_like(rhs):
if leaf.shape != rhs.shape:
return False
if leaf.dtype != rhs.dtype:
return False
- verdict = arraylib.all(leaf == rhs)
+ try:
+ verdict = arraylib.all(leaf == rhs)
+ except NotImplementedError:
+ verdict = leaf == rhs
try:
return bool(verdict)
except Exception:
@@ -94,11 +109,11 @@ def is_tree_equal(*trees: Any) -> bool:
"""
treelib = sepes._src.backend.treelib
tree0, *rest = trees
- leaves0, treedef0 = treelib.tree_flatten(tree0)
+ leaves0, treedef0 = treelib.flatten(tree0)
verdict = True
for tree in rest:
- leaves, treedef = treelib.tree_flatten(tree)
+ leaves, treedef = treelib.flatten(tree)
if (treedef != treedef0) or verdict is False:
return False
verdict = ft.reduce(op.and_, map(_is_leaf_rhs_equal, leaves0, leaves), verdict)
@@ -116,73 +131,28 @@ def __init_subclass__(klass, **k) -> None:
treelib.register_static(klass)
-class Partial(Static):
- """``Partial`` function with support for positional partial application.
-
- Args:
- func: The function to be partially applied.
- args: Positional arguments to be partially applied. use ``...`` as a
- placeholder for positional arguments.
- kwargs: Keyword arguments to be partially applied.
-
- Example:
- >>> import sepes as sp
- >>> def f(a, b, c):
- ... print(f"a: {a}, b: {b}, c: {c}")
- ... return a + b + c
-
- >>> # positional arguments using `...` placeholder
- >>> f_a = sp.Partial(f, ..., 2, 3)
- >>> f_a(1)
- a: 1, b: 2, c: 3
- 6
-
- >>> # keyword arguments
- >>> f_b = sp.Partial(f, b=2, c=3)
- >>> f_a(1)
- a: 1, b: 2, c: 3
- 6
-
- Note:
- - The ``...`` is used to indicate a placeholder for positional arguments.
- - https://stackoverflow.com/a/7811270
- """
-
- __slots__ = ["func", "args", "keywords"] # type: ignore
-
- def __init__(self, func: Callable[..., Any], *args: Any, **kwargs: Any):
- self.func = func
- self.args = args
- self.keywords = kwargs
-
+class partial(ft.partial):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
iargs = iter(args)
args = (next(iargs) if arg is ... else arg for arg in self.args) # type: ignore
return self.func(*args, *iargs, **{**self.keywords, **kwargs})
- def __repr__(self) -> str:
- return f"Partial({self.func}, {self.args}, {self.keywords})"
-
- def __hash__(self) -> int:
- return tree_hash(self)
-
- def __eq__(self, other: Any) -> bool:
- return is_tree_equal(self, other)
-
-
-# to match python
-partial = Partial
-
def bcmap(
func: Callable[P, T],
+ broadcast_to: int | str | None = None,
*,
is_leaf: Callable[[Any], bool] | None = None,
) -> Callable[P, T]:
"""Map a function over pytree leaves with automatic broadcasting for scalar arguments.
Args:
- func: the function to be mapped over the pytree
+ func: the function to be mapped over the pytree.
+ broadcast_to: Accepts integer for broadcasting to a specific argument
+ or string for broadcasting to a specific keyword argument.
+ If ``None``, then the function is broadcasted to the first argument
+ or the first keyword argument if no positional arguments are provided.
+ Defaults to ``None``.
is_leaf: a predicate function that returns True if the node is a leaf.
Example:
@@ -199,7 +169,6 @@ def bcmap(
>>> print(sp.tree_str(tree_add(tree_of_arrays, 1)))
dict(a=[2 3 4], b=[5 6 7])
"""
- # add broadcasting argnum/argname to the function later
treelib = sepes._src.backend.treelib
@ft.wraps(func)
@@ -209,23 +178,29 @@ def wrapper(*args, **kwargs):
leaves = []
kwargs_keys: list[str] = []
+ bdcst_to = (
+ (0 if len(args) else next(iter(kwargs)))
+ if broadcast_to is None
+ else broadcast_to
+ )
+
treedef0 = (
# reference treedef is the first positional argument
- treelib.tree_flatten(args[0], is_leaf=is_leaf)[1]
+ treelib.flatten(args[bdcst_to], is_leaf=is_leaf)[1]
if len(args)
# reference treedef is the first keyword argument
- else treelib.tree_flatten(kwargs[next(iter(kwargs))], is_leaf=is_leaf)[1]
+ else treelib.flatten(kwargs[bdcst_to], is_leaf=is_leaf)[1]
)
for arg in args:
- if treedef0 == treelib.tree_flatten(arg, is_leaf=is_leaf)[1]:
+ if treedef0 == treelib.flatten(arg, is_leaf=is_leaf)[1]:
cargs += [...]
leaves += [treedef0.flatten_up_to(arg)]
else:
cargs += [arg]
for key in kwargs:
- if treedef0 == treelib.tree_flatten(kwargs[key], is_leaf=is_leaf)[1]:
+ if treedef0 == treelib.flatten(kwargs[key], is_leaf=is_leaf)[1]:
ckwargs[key] = ...
leaves += [treedef0.flatten_up_to(kwargs[key])]
kwargs_keys += [key]
@@ -239,7 +214,7 @@ def wrapper(*args, **kwargs):
args = args_kwargs_values[:split_index]
kwargs = dict(zip(kwargs_keys, args_kwargs_values[split_index:]))
all_leaves += [bfunc(*args, **kwargs)]
- return treelib.tree_unflatten(treedef0, all_leaves)
+ return treelib.unflatten(treedef0, all_leaves)
return wrapper
@@ -266,7 +241,8 @@ def leafwise(klass: type[T]) -> type[T]:
The decorated class.
Example:
- >>> # use ``numpy`` functions on :class:`TreeClass`` classes decorated with ``leafwise``
+ Use ``numpy`` functions on :class:`TreeClass`` classes decorated with :func:`leafwise`
+
>>> import sepes as sp
>>> import jax.numpy as jnp
>>> @sp.leafwise
@@ -321,15 +297,15 @@ def leafwise(klass: type[T]) -> type[T]:
def uop(func):
def wrapper(self):
- return treelib.tree_map(func, self)
+ return treelib.map(func, self)
return ft.wraps(func)(wrapper)
def bop(func):
def wrapper(leaf, rhs=None):
if isinstance(rhs, type(leaf)):
- return treelib.tree_map(func, leaf, rhs)
- return treelib.tree_map(lambda x: func(x, rhs), leaf)
+ return treelib.map(func, leaf, rhs)
+ return treelib.map(lambda x: func(x, rhs), leaf)
return ft.wraps(func)(wrapper)
@@ -391,7 +367,7 @@ def tree_type_path_leaves(
is_path_leaf: Callable[[KeyTypePath], bool] | None = None,
) -> Sequence[tuple[KeyTypePath, Any]]:
treelib = sepes._src.backend.treelib
- _, atomicdef = treelib.tree_flatten(1)
+ _, atomicdef = treelib.flatten(1)
# mainly used for visualization
def flatten_one_level(type_path: KeyTypePath, tree: PyTree):
@@ -407,7 +383,7 @@ def one_level_is_leaf(node) -> bool:
return False
return True
- path_leaf, treedef = treelib.tree_path_flatten(tree, is_leaf=one_level_is_leaf)
+ path_leaf, treedef = treelib.path_flatten(tree, is_leaf=one_level_is_leaf)
if treedef == atomicdef:
yield type_path, tree
@@ -501,7 +477,7 @@ def construct_tree(
return root
-def value_and_tree(func, argnums: int | Sequence[int] = 0):
+def value_and_tree(func: Callable[..., T], argnums: int | Sequence[int] = 0):
"""Call a function on copied input argument and return the value and the tree.
Input arguments are copied before calling the function, and the argument
@@ -614,15 +590,15 @@ def immutate_is_leaf(node):
return False
@ft.wraps(func)
- def stateless_func(*args, **kwargs) -> tuple[Any, PyTree | tuple[PyTree, ...]]:
+ def stateless_func(*args, **kwargs) -> tuple[T, PyTree | tuple[PyTree, ...]]:
# copy the incoming inputs
(args, kwargs) = tree_copy((args, kwargs))
# and edit the node/record to make it mutable (if there is a rule for it)
- treelib.tree_map(lambda _: _, (args, kwargs), is_leaf=mutate_is_leaf)
+ treelib.map(lambda _: _, (args, kwargs), is_leaf=mutate_is_leaf)
output = func(*args, **kwargs)
# traverse each node in the tree depth-first manner
# to undo the mutation (if there is a rule for it)
- treelib.tree_map(lambda _: _, (args, kwargs), is_leaf=immutate_is_leaf)
+ treelib.map(lambda _: _, (args, kwargs), is_leaf=immutate_is_leaf)
out_args = tuple(a for i, a in enumerate(args) if i in argnums)
out_args = out_args[0] if is_int_argnum else out_args
return output, out_args
diff --git a/tests/test_index.py b/tests/test_index.py
index 097543e..0728380 100644
--- a/tests/test_index.py
+++ b/tests/test_index.py
@@ -23,7 +23,7 @@
from sepes._src.backend import arraylib, backend, treelib
from sepes._src.code_build import autoinit
from sepes._src.tree_base import TreeClass, _mutable_instance_registry
-from sepes._src.tree_index import AtIndexer, BaseKey
+from sepes._src.tree_index import at, BaseKey
from sepes._src.tree_util import is_tree_equal, leafwise, value_and_tree
test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy")
@@ -117,7 +117,7 @@ def __init__(self, c: int, d: int):
],
)
def test_indexer_get(tree, expected, where):
- indexer = AtIndexer(tree, where=where)
+ indexer = at(tree, where=where)
assert is_tree_equal(indexer.get(), expected)
assert is_tree_equal(indexer.get(is_parallel=True), expected)
@@ -150,11 +150,33 @@ def test_indexer_get(tree, expected, where):
],
)
def test_array_indexer_get(tree, expected, where):
- indexer = AtIndexer(tree, where=where)
+ indexer = at(tree, where=where)
assert is_tree_equal(indexer.get(), expected)
assert is_tree_equal(indexer.get(is_parallel=True), expected)
+@pytest.mark.skipif(backend != "jax", reason="test jax jit with get")
+def test_get_fill_value():
+ import jax
+ import jax.numpy as jnp
+
+ tree = dict(a=jnp.array([1, 2, 3]), b=jnp.array([4, 5, 6]))
+ mask = dict(
+ a=jnp.array([False, True, False]),
+ b=jnp.array([False, True, False]),
+ )
+
+ @jax.jit
+ def jit_func(tree):
+ return at(tree)[mask].get(fill_value=0)
+
+ out = jit_func(tree)
+ a = out["a"]
+ b = out["b"]
+ assert jnp.all(a == jnp.array([0, 2, 0]))
+ assert jnp.all(b == jnp.array([0, 5, 0]))
+
+
@pytest.mark.parametrize(
["tree", "expected", "where", "set_value"],
[
@@ -191,7 +213,7 @@ def test_array_indexer_get(tree, expected, where):
],
)
def test_indexer_set(tree, expected, where, set_value):
- indexer = AtIndexer(tree, where=where)
+ indexer = at(tree, where=where)
assert is_tree_equal(indexer.set(set_value), expected)
assert is_tree_equal(indexer.set(set_value, is_parallel=True), expected)
@@ -233,7 +255,7 @@ def test_indexer_set(tree, expected, where, set_value):
],
)
def test_array_indexer_set(tree, expected, where, set_value):
- indexer = AtIndexer(tree, where=where)
+ indexer = at(tree, where=where)
assert is_tree_equal(indexer.set(set_value), expected)
assert is_tree_equal(indexer.set(set_value, is_parallel=True), expected)
@@ -268,7 +290,7 @@ def test_array_indexer_set(tree, expected, where, set_value):
],
)
def test_indexer_apply(tree, expected, where):
- indexer = AtIndexer(tree, where=where)
+ indexer = at(tree, where=where)
assert is_tree_equal(indexer.apply(lambda _: _X), expected)
assert is_tree_equal(
indexer.apply(lambda _: _X, is_parallel=True),
@@ -307,7 +329,7 @@ def test_indexer_apply(tree, expected, where):
],
)
def test_array_indexer_apply(tree, expected, where):
- indexer = AtIndexer(tree, where=where)
+ indexer = at(tree, where=where)
assert is_tree_equal(indexer.apply(lambda _: _X), expected)
assert is_tree_equal(
indexer.apply(lambda _: _X, is_parallel=True),
@@ -343,7 +365,7 @@ def test_array_indexer_apply(tree, expected, where):
],
)
def test_indexer_reduce(tree, expected, where):
- indexer = AtIndexer(tree, where=where)
+ indexer = at(tree, where=where)
assert is_tree_equal(
indexer.reduce(lambda x, y: x + y, initializer=0),
expected,
@@ -378,7 +400,7 @@ def test_indexer_reduce(tree, expected, where):
],
)
def test_array_indexer_reduce(tree, expected, where):
- indexer = AtIndexer(tree, where=where)
+ indexer = at(tree, where=where)
assert is_tree_equal(
indexer.reduce(lambda x, y: x + y, initializer=0),
expected,
@@ -405,7 +427,7 @@ def test_array_indexer_reduce(tree, expected, where):
],
)
def test_indexer_scan(tree, expected, where):
- indexer = AtIndexer(tree, where=where)
+ indexer = at(tree, where=where)
assert is_tree_equal(
indexer.scan(lambda x, s: (x + s, x), state=0),
expected,
@@ -451,8 +473,8 @@ def __call__(self, x):
a = A(1)
_, b = value_and_tree(lambda A: A(2))(a)
- assert treelib.tree_flatten(a)[0] == [1]
- assert treelib.tree_flatten(b)[0] == [3]
+ assert treelib.flatten(a)[0] == [1]
+ assert treelib.flatten(b)[0] == [3]
with pytest.raises(TypeError):
a.at[0](1)
@@ -480,7 +502,7 @@ def delete(self, name):
def test_unsupported_where(where):
t = namedtuple("a", ["x", "y"])(1, 2)
with pytest.raises(NotImplementedError):
- AtIndexer(t, where=where).get()
+ at(t, where=where).get()
@pytest.mark.skipif(backend != "jax", reason="jax backend needed")
@@ -496,7 +518,7 @@ def __init__(self, a, b) -> None:
@property
def at(self):
- return AtIndexer(self)
+ return at(self)
if backend == "jax":
import jax.tree_util as jtu
@@ -533,7 +555,7 @@ def __init__(self, a, b) -> None:
@property
def at(self):
- return AtIndexer(self)
+ return at(self)
import optree as ot
@@ -575,26 +597,26 @@ class Tree(TreeClass):
t = Tree()
- assert repr(t.at["a"]) == "AtIndexer(tree=Tree(a=1, b=2), where=['a'])"
- assert str(t.at["a"]) == "AtIndexer(tree=Tree(a=1, b=2), where=['a'])"
- assert repr(t.at[...]) == "AtIndexer(tree=Tree(a=1, b=2), where=[Ellipsis])"
+ assert repr(t.at["a"]) == "at(Tree(a=1, b=2), where=['a'])"
+ assert str(t.at["a"]) == "at(Tree(a=1, b=2), where=['a'])"
+ assert repr(t.at[...]) == "at(Tree(a=1, b=2), where=[Ellipsis])"
def test_compat_mask():
tree = [1, 2, [3, 4]]
- tree_ = AtIndexer(tree)[[False, False, True]].set(10)
+ tree_ = at(tree)[[False, False, True]].set(10)
assert tree_ == [1, 2, 10]
def test_pluck():
tree = [1, 2, [3, 4]]
- subtrees = AtIndexer(tree)[2].pluck()
+ subtrees = at(tree)[2].pluck()
assert subtrees[0] == [3, 4]
- assert AtIndexer(tree)[0, 1].pluck(1) == [1]
- assert AtIndexer(tree)[0, 1].pluck(2) == [1, 2]
+ assert at(tree)[0, 1].pluck(1) == [1]
+ assert at(tree)[0, 1].pluck(2) == [1, 2]
tree = dict(a=1, b=2)
- assert AtIndexer(tree)[...].pluck() == [1, 2]
+ assert at(tree)[...].pluck() == [1, 2]
@pytest.mark.skipif(backend != "jax", reason="jax backend needed")
diff --git a/tests/test_mask.py b/tests/test_mask.py
index d03cf8b..9b8de81 100644
--- a/tests/test_mask.py
+++ b/tests/test_mask.py
@@ -13,6 +13,8 @@
# limitations under the License.
import copy
+import functools as ft
+import os
from typing import Any
import pytest
@@ -20,17 +22,12 @@
from sepes._src.backend import backend, treelib
from sepes._src.code_build import autoinit
from sepes._src.tree_base import TreeClass
-from sepes._src.tree_mask import (
- freeze,
- is_frozen,
- tree_mask,
- tree_unmask,
- unfreeze,
-)
-import os
+from sepes._src.tree_mask import is_masked, tree_mask, tree_unmask
from sepes._src.tree_util import is_tree_equal, leafwise, tree_hash
test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy")
+freeze = ft.partial(tree_mask, cond=lambda _: True)
+unfreeze = ft.partial(tree_unmask, cond=lambda _: True)
if test_arraylib == "jax":
import jax.numpy as arraylib
@@ -54,14 +51,14 @@ class A(TreeClass):
b = a.at[...].apply(freeze)
c = (
a.at["a"]
- .apply(unfreeze, is_leaf=is_frozen)
+ .apply(unfreeze, is_leaf=is_masked)
.at["b"]
- .apply(unfreeze, is_leaf=is_frozen)
+ .apply(unfreeze, is_leaf=is_masked)
)
- assert treelib.tree_flatten(a)[0] == [1, 2]
- assert treelib.tree_flatten(b)[0] == []
- assert treelib.tree_flatten(c)[0] == [1, 2]
+ assert treelib.flatten(a)[0] == [1, 2]
+ assert treelib.flatten(b)[0] == []
+ assert treelib.flatten(c)[0] == [1, 2]
assert unfreeze(freeze(1.0)) == 1.0
@autoinit
@@ -80,17 +77,17 @@ class A(TreeClass):
b: int
a = A(1, 2)
- b = treelib.tree_map(freeze, a)
+ b = treelib.map(freeze, a)
c = (
a.at["a"]
- .apply(unfreeze, is_leaf=is_frozen)
+ .apply(unfreeze, is_leaf=is_masked)
.at["b"]
- .apply(unfreeze, is_leaf=is_frozen)
+ .apply(unfreeze, is_leaf=is_masked)
)
- assert treelib.tree_flatten(a)[0] == [1, 2]
- assert treelib.tree_flatten(b)[0] == []
- assert treelib.tree_flatten(c)[0] == [1, 2]
+ assert treelib.flatten(a)[0] == [1, 2]
+ assert treelib.flatten(b)[0] == []
+ assert treelib.flatten(c)[0] == [1, 2]
@autoinit
class L0(TreeClass):
@@ -104,11 +101,11 @@ class L1(TreeClass):
class L2(TreeClass):
c: L1 = L1()
- t = treelib.tree_map(freeze, L2())
+ t = treelib.map(freeze, L2())
- assert treelib.tree_flatten(t)[0] == []
- assert treelib.tree_flatten(t.c)[0] == []
- assert treelib.tree_flatten(t.c.b)[0] == []
+ assert treelib.flatten(t)[0] == []
+ assert treelib.flatten(t.c)[0] == []
+ assert treelib.flatten(t.c.b)[0] == []
class L1(TreeClass):
def __init__(self):
@@ -118,9 +115,9 @@ class L2(TreeClass):
def __init__(self):
self.c = L1()
- t = treelib.tree_map(freeze, L2())
- assert treelib.tree_flatten(t.c)[0] == []
- assert treelib.tree_flatten(t.c.b)[0] == []
+ t = treelib.map(freeze, L2())
+ assert treelib.flatten(t.c)[0] == []
+ assert treelib.flatten(t.c.b)[0] == []
def test_freeze_errors():
@@ -160,25 +157,25 @@ class Test(TreeClass):
c: str = freeze("test")
t = Test()
- assert treelib.tree_flatten(t)[0] == [1]
+ assert treelib.flatten(t)[0] == [1]
with pytest.raises(AttributeError):
- treelib.tree_map(freeze, t).a = 1
+ treelib.map(freeze, t).a = 1
with pytest.raises(AttributeError):
- treelib.tree_map(unfreeze, t).a = 1
+ treelib.map(unfreeze, t).a = 1
hash(t)
t = Test()
- treelib.tree_map(unfreeze, t, is_leaf=is_frozen)
- treelib.tree_map(freeze, t)
+ treelib.map(unfreeze, t, is_leaf=is_masked)
+ treelib.map(freeze, t)
@autoinit
class Test(TreeClass):
a: int
- t = treelib.tree_map(freeze, (Test(100)))
+ t = treelib.map(freeze, (Test(100)))
class Test(TreeClass):
def __init__(self, x):
@@ -223,7 +220,7 @@ class Test(TreeClass):
t = Test()
- assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == []
+ assert treelib.flatten(treelib.map(freeze, t))[0] == []
def test_freeze_nondiff():
@@ -234,10 +231,10 @@ class Test(TreeClass):
t = Test()
- assert treelib.tree_flatten(t)[0] == ["a"]
- assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == []
- assert treelib.tree_flatten(
- (treelib.tree_map(freeze, t)).at["b"].apply(unfreeze, is_leaf=is_frozen)
+ assert treelib.flatten(t)[0] == ["a"]
+ assert treelib.flatten(treelib.map(freeze, t))[0] == []
+ assert treelib.flatten(
+ (treelib.map(freeze, t)).at["b"].apply(unfreeze, is_leaf=is_masked)
)[0] == ["a"]
@autoinit
@@ -246,11 +243,11 @@ class T0(TreeClass):
t = T0()
- assert treelib.tree_flatten(t)[0] == ["a"]
- assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == []
+ assert treelib.flatten(t)[0] == ["a"]
+ assert treelib.flatten(treelib.map(freeze, t))[0] == []
- assert treelib.tree_flatten(t)[0] == ["a"]
- assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == []
+ assert treelib.flatten(t)[0] == ["a"]
+ assert treelib.flatten(treelib.map(freeze, t))[0] == []
def test_freeze_nondiff_with_mask():
@@ -278,11 +275,11 @@ class L2(TreeClass):
t = t.at["d"]["d"]["a"].apply(freeze)
t = t.at["d"]["d"]["b"].apply(freeze)
- assert treelib.tree_flatten(t)[0] == [10, 20, 30, 1, 2, 3, 3]
+ assert treelib.flatten(t)[0] == [10, 20, 30, 1, 2, 3, 3]
def test_non_dataclass_input_to_freeze():
- assert treelib.tree_flatten(freeze(1))[0] == []
+ assert treelib.flatten(freeze(1))[0] == []
def test_tree_mask():
@@ -299,18 +296,18 @@ class L1(TreeClass):
tree = L1()
- assert treelib.tree_flatten(tree)[0] == [1, 2, 3]
- assert treelib.tree_flatten(treelib.tree_map(freeze, tree))[0] == []
- assert treelib.tree_flatten(treelib.tree_map(freeze, tree))[0] == []
- assert treelib.tree_flatten(tree.at[...].apply(freeze))[0] == []
- assert treelib.tree_flatten(tree.at[tree > 1].apply(freeze))[0] == [1]
- assert treelib.tree_flatten(tree.at[tree == 1].apply(freeze))[0] == [2, 3]
- assert treelib.tree_flatten(tree.at[tree < 1].apply(freeze))[0] == [1, 2, 3]
+ assert treelib.flatten(tree)[0] == [1, 2, 3]
+ assert treelib.flatten(treelib.map(freeze, tree))[0] == []
+ assert treelib.flatten(treelib.map(freeze, tree))[0] == []
+ assert treelib.flatten(tree.at[...].apply(freeze))[0] == []
+ assert treelib.flatten(tree.at[tree > 1].apply(freeze))[0] == [1]
+ assert treelib.flatten(tree.at[tree == 1].apply(freeze))[0] == [2, 3]
+ assert treelib.flatten(tree.at[tree < 1].apply(freeze))[0] == [1, 2, 3]
- assert treelib.tree_flatten(tree.at["a"].apply(freeze))[0] == [2, 3]
- assert treelib.tree_flatten(tree.at["b"].apply(freeze))[0] == [1]
- assert treelib.tree_flatten(tree.at["b"]["x"].apply(freeze))[0] == [1, 3]
- assert treelib.tree_flatten(tree.at["b"]["y"].apply(freeze))[0] == [1, 2]
+ assert treelib.flatten(tree.at["a"].apply(freeze))[0] == [2, 3]
+ assert treelib.flatten(tree.at["b"].apply(freeze))[0] == [1]
+ assert treelib.flatten(tree.at["b"]["x"].apply(freeze))[0] == [1, 3]
+ assert treelib.flatten(tree.at["b"]["y"].apply(freeze))[0] == [1, 2]
def test_tree_unmask():
@@ -328,21 +325,21 @@ class L1(TreeClass):
tree = L1()
frozen_tree = tree.at[...].apply(freeze)
- assert treelib.tree_flatten(frozen_tree)[0] == []
+ assert treelib.flatten(frozen_tree)[0] == []
mask = tree == tree
- unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_frozen)
- assert treelib.tree_flatten(unfrozen_tree)[0] == [1, 2, 3]
+ unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_masked)
+ assert treelib.flatten(unfrozen_tree)[0] == [1, 2, 3]
mask = tree > 1
- unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_frozen)
- assert treelib.tree_flatten(unfrozen_tree)[0] == [2, 3]
+ unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_masked)
+ assert treelib.flatten(unfrozen_tree)[0] == [2, 3]
- unfrozen_tree = frozen_tree.at["a"].apply(unfreeze, is_leaf=is_frozen)
- # assert treelib.tree_flatten(unfrozen_tree)[0] == [1]
+ unfrozen_tree = frozen_tree.at["a"].apply(unfreeze, is_leaf=is_masked)
+ # assert treelib.flatten(unfrozen_tree)[0] == [1]
- # unfrozen_tree = frozen_tree.at["b"].apply(unfreeze, is_leaf=is_frozen)
- # assert treelib.tree_flatten(unfrozen_tree)[0] == [2, 3]
+ # unfrozen_tree = frozen_tree.at["b"].apply(unfreeze, is_leaf=is_masked)
+ # assert treelib.flatten(unfrozen_tree)[0] == [2, 3]
def test_tree_mask_unfreeze():
@@ -361,12 +358,12 @@ class L1(TreeClass):
mask = tree == tree
frozen_tree = tree.at[...].apply(freeze)
- unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_frozen)
- assert treelib.tree_flatten(unfrozen_tree)[0] == [1, 2, 3]
+ unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_masked)
+ assert treelib.flatten(unfrozen_tree)[0] == [1, 2, 3]
# frozen_tree = tree.at["a"].apply(freeze)
- # unfrozen_tree = frozen_tree.at["a"].apply(unfreeze, is_leaf=is_frozen)
- # assert treelib.tree_flatten(unfrozen_tree)[0] == [1, 2, 3]
+ # unfrozen_tree = frozen_tree.at["a"].apply(unfreeze, is_leaf=is_masked)
+ # assert treelib.flatten(unfrozen_tree)[0] == [1, 2, 3]
def test_wrapper():
@@ -403,18 +400,16 @@ def test_wrapper():
@pytest.mark.skipif(backend == "default", reason="no array backend installed")
def test_tree_mask_tree_unmask():
tree = [1, 2, 3.0]
- assert treelib.tree_flatten(tree_mask(tree))[0] == [3.0]
- assert treelib.tree_flatten(tree_unmask(tree_mask(tree)))[0] == [1, 2, 3.0]
+ assert treelib.flatten(tree_mask(tree))[0] == [3.0]
+ assert treelib.flatten(tree_unmask(tree_mask(tree)))[0] == [1, 2, 3.0]
mask_func = lambda x: x < 2
- assert treelib.tree_flatten(tree_mask(tree, mask_func))[0] == [2, 3.0]
+ assert treelib.flatten(tree_mask(tree, mask_func))[0] == [2, 3.0]
assert freeze(freeze(1)) == freeze(1)
- assert tree_mask({"a": 1}, mask={"a": True}) == {"a": freeze(1)}
-
- with pytest.raises(ValueError):
- tree_mask({"a": 1}, mask=1.0)
+ with pytest.raises(TypeError):
+ tree_mask({"a": 1}, cond=1.0)
assert copy.copy(freeze(1)) == freeze(1)
@@ -424,7 +419,7 @@ def test_tree_mask_tree_unmask():
@pytest.mark.skipif(backend == "default", reason="no array backend installed")
def test_array_tree_mask_tree_unmask():
- frozen_array = tree_mask(arraylib.ones((5, 5)), mask=lambda _: True)
+ frozen_array = tree_mask(arraylib.ones((5, 5)), cond=lambda _: True)
assert frozen_array == frozen_array
assert not (frozen_array == freeze(arraylib.ones((5, 6))))
diff --git a/tests/test_operator.py b/tests/test_operator.py
index 8a89205..b8908c1 100644
--- a/tests/test_operator.py
+++ b/tests/test_operator.py
@@ -15,6 +15,7 @@
from __future__ import annotations
import math
+import os
from typing import Any
import pytest
@@ -22,9 +23,10 @@
from sepes._src.backend import backend
from sepes._src.code_build import autoinit, field
from sepes._src.tree_base import TreeClass
-from sepes._src.tree_mask import freeze
+from sepes._src.tree_mask import tree_mask
from sepes._src.tree_util import bcmap, is_tree_equal, leafwise
-import os
+
+freeze = lambda x: tree_mask(x, cond=lambda _: True)
test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy")
if test_arraylib == "jax":
@@ -171,3 +173,17 @@ def test_bcmap(tree, expected):
def test_math_operations_errors():
with pytest.raises(TypeError):
tree1 + "s"
+
+
+def test_bcmap_int_argnum_broadcast_to():
+ def func(x, y):
+ return x + y
+
+ assert bcmap(func, broadcast_to=1)(1, [2, 3, 4]) == [3, 4, 5]
+
+
+def test_bcmap_key_argnum_broadcast_to():
+ def func(x, y):
+ return x + y
+
+ assert bcmap(func, broadcast_to="y")(x=1, y=[2, 3, 4]) == [3, 4, 5]
diff --git a/tests/test_pprint.py b/tests/test_pprint.py
index 622f654..8eb8003 100644
--- a/tests/test_pprint.py
+++ b/tests/test_pprint.py
@@ -15,12 +15,11 @@
from __future__ import annotations
import dataclasses as dc
-import re
+import os
from collections import namedtuple
from typing import Any
import pytest
-import os
test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy")
backend = os.environ.get("SEPES_BACKEND", "jax")
@@ -29,8 +28,6 @@
from sepes._src.tree_pprint import (
_table,
tree_diagram,
- tree_graph,
- tree_mermaid,
tree_repr,
tree_str,
tree_summary,
@@ -144,7 +141,7 @@ def test_tree_summary():
assert (
tree_summary(r1, depth=1)
# trunk-ignore(flake8/E501)
- == "┌────┬────────────┬─────┬───────┐\n│Name│Type │Count│Size │\n├────┼────────────┼─────┼───────┤\n│.a │int │1 │ │\n├────┼────────────┼─────┼───────┤\n│.b │str │1 │ │\n├────┼────────────┼─────┼───────┤\n│.c │float │1 │ │\n├────┼────────────┼─────┼───────┤\n│.d │str │1 │ │\n├────┼────────────┼─────┼───────┤\n│.e │list │5 │ │\n├────┼────────────┼─────┼───────┤\n│.f │set │1 │ │\n├────┼────────────┼─────┼───────┤\n│.g │dict │27 │100.00B│\n├────┼────────────┼─────┼───────┤\n│.h │f32[5,1] │5 │20.00B │\n├────┼────────────┼─────┼───────┤\n│.i │f32[1,6] │6 │24.00B │\n├────┼────────────┼─────┼───────┤\n│.j │f32[1,1,4,5]│20 │80.00B │\n├────┼────────────┼─────┼───────┤\n│.k │tuple │3 │ │\n├────┼────────────┼─────┼───────┤\n│.l │a │2 │ │\n├────┼────────────┼─────┼───────┤\n│.m │f32[5,5] │25 │100.00B│\n├────┼────────────┼─────┼───────┤\n│.n │bool[] │1 │1.00B │\n├────┼────────────┼─────┼───────┤\n│.o │c64[2] │2 │16.00B │\n├────┼────────────┼─────┼───────┤\n│Σ │Repr1 │101 │341.00B│\n└────┴────────────┴─────┴───────┘"
+ == "┌────┬─────────────────────────┬─────┬───────┐\n│Name│Type │Count│Size │\n├────┼─────────────────────────┼─────┼───────┤\n│.a │int │1 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.b │str │1 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.c │float │1 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.d │str │1 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.e │list[int,int,int,int,int]│5 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.f │set │1 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.g │dict │27 │100.00B│\n├────┼─────────────────────────┼─────┼───────┤\n│.h │f32[5,1] │5 │20.00B │\n├────┼─────────────────────────┼─────┼───────┤\n│.i │f32[1,6] │6 │24.00B │\n├────┼─────────────────────────┼─────┼───────┤\n│.j │f32[1,1,4,5] │20 │80.00B │\n├────┼─────────────────────────┼─────┼───────┤\n│.k │tuple[int,int,int] │3 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.l │a[int,int] │2 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.m │f32[5,5] │25 │100.00B│\n├────┼─────────────────────────┼─────┼───────┤\n│.n │bool[] │1 │1.00B │\n├────┼─────────────────────────┼─────┼───────┤\n│.o │c64[2] │2 │16.00B │\n├────┼─────────────────────────┼─────┼───────┤\n│Σ │Repr1 │101 │341.00B│\n└────┴─────────────────────────┴─────┴───────┘"
)
assert (
@@ -165,20 +162,6 @@ def test_tree_diagram():
assert tree_diagram(r1, depth=1) == out
-@pytest.mark.skipif(backend != "jax", reason="jax is not installed")
-def test_tree_mermaid():
- assert (
- re.sub(r"id\d*", "***", tree_mermaid(r1, depth=1))
- # trunk-ignore(flake8/E501)
- == 'flowchart LR\n ***("Repr1")\n *** --- ***(".a=1")\n *** --- ***(".b=string")\n *** --- ***(".c=1.0")\n *** --- ***(".d=aaaaa")\n *** --- ***(".e=[...]")\n *** --- ***(".f={...}")\n *** --- ***(".g=dict(...)")\n *** --- ***(".h=f32[5,1](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".i=f32[1,6](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".j=f32[1,1,4,5](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".k=(...)")\n *** --- ***(".l=a(...)")\n *** --- ***(".m=f32[5,5](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".n=bool[]")\n *** --- ***(".o=c64[2]")'
- )
- assert (
- re.sub(r"id\d*", "***", tree_mermaid(r1, depth=2))
- # trunk-ignore(flake8/E501)
- == 'flowchart LR\n ***("Repr1")\n *** --- ***(".a=1")\n *** --- ***(".b=string")\n *** --- ***(".c=1.0")\n *** --- ***(".d=aaaaa")\n *** --- ***(".e:list")\n *** --- ***("[0]=10")\n *** --- ***("[1]=10")\n *** --- ***("[2]=10")\n *** --- ***("[3]=10")\n *** --- ***("[4]=10")\n *** --- ***(".f={...}")\n *** --- ***(".g:dict")\n *** --- ***("[\'a\']=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")\n *** --- ***("[\'b\']=bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")\n *** --- ***("[\'c\']=f32[5,5](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".h=f32[5,1](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".i=f32[1,6](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".j=f32[1,1,4,5](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".k:tuple")\n *** --- ***("[0]=1")\n *** --- ***("[1]=2")\n *** --- ***("[2]=3")\n *** --- ***(".l:a")\n *** --- ***(".b=1")\n *** --- ***(".c=2")\n *** --- ***(".m=f32[5,5](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".n=bool[]")\n *** --- ***(".o=c64[2]")'
- )
-
-
@pytest.mark.skipif(backend != "jax", reason="jax is not installed")
def test_misc():
x = (1, 2, 3)
@@ -251,16 +234,6 @@ def test_invalid_depth():
tree_diagram(1, depth="a")
with pytest.raises(TypeError):
tree_summary(1, depth="a")
- with pytest.raises(TypeError):
- tree_mermaid(1, depth="a")
-
-
-@pytest.mark.skipif(backend != "jax", reason="jax is not installed")
-def test_tree_graph():
- assert (
- re.sub(r"\b\d{10,}", "***", tree_graph(r1))
- == 'digraph G {\n *** [label="Repr1", shape=box];\n *** [label=".a=1", shape=box];\n *** -> ***;\n *** [label=".b=string", shape=box];\n *** -> ***;\n *** [label=".c=1.0", shape=box];\n *** -> ***;\n *** [label=".d=aaaaa", shape=box];\n *** -> ***;\n *** [label=".e:list", shape=box];\n *** -> ***;\n *** [label="[0]=10", shape=box];\n *** -> ***;\n *** [label="[1]=10", shape=box];\n *** -> ***;\n *** [label="[2]=10", shape=box];\n *** -> ***;\n *** [label="[3]=10", shape=box];\n *** -> ***;\n *** [label="[4]=10", shape=box];\n *** -> ***;\n *** [label=".f={...}", shape=box];\n *** -> ***;\n *** [label=".g:dict", shape=box];\n *** -> ***;\n *** [label="[\'a\']=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", shape=box];\n *** -> ***;\n *** [label="[\'b\']=bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", shape=box];\n *** -> ***;\n *** [label="[\'c\']=f32[5,5](μ=1.00, σ=0.00, ∈[1.00,1.00])", shape=box];\n *** -> ***;\n *** [label=".h=f32[5,1](μ=1.00, σ=0.00, ∈[1.00,1.00])", shape=box];\n *** -> ***;\n *** [label=".i=f32[1,6](μ=1.00, σ=0.00, ∈[1.00,1.00])", shape=box];\n *** -> ***;\n *** [label=".j=f32[1,1,4,5](μ=1.00, σ=0.00, ∈[1.00,1.00])", shape=box];\n *** -> ***;\n *** [label=".k:tuple", shape=box];\n *** -> ***;\n *** [label="[0]=1", shape=box];\n *** -> ***;\n *** [label="[1]=2", shape=box];\n *** -> ***;\n *** [label="[2]=3", shape=box];\n *** -> ***;\n *** [label=".l:a", shape=box];\n *** -> ***;\n *** [label=".b=1", shape=box];\n *** -> ***;\n *** [label=".c=2", shape=box];\n *** -> ***;\n *** [label=".m=f32[5,5](μ=1.00, σ=0.00, ∈[1.00,1.00])", shape=box];\n *** -> ***;\n *** [label=".n=bool[]", shape=box];\n *** -> ***;\n *** [label=".o=c64[2]", shape=box];\n *** -> ***;\n}'
- )
@pytest.mark.skipif(backend != "jax", reason="jax is not installed")
@@ -270,9 +243,25 @@ def test_tracer_repr():
@jax.jit
def f(x):
out = tree_repr(x)
- assert out == "Tracer(f32[10,10])"
+ assert out == "DynamicJaxprTracer(f32[10,10])"
out = tree_str(x)
- assert out == "Tracer(f32[10,10])"
+ assert out == "DynamicJaxprTracer(f32[10,10])"
return x
f(jax.numpy.ones((10, 10)))
+
+
+@pytest.mark.skipif(backend != "jax", reason="testing jax specific sharding info")
+def test_jax_sharding_tree_summary():
+ import jax
+ import numpy as np
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
+
+ x = jax.numpy.ones([4 * 4, 2 * 2])
+ mesh = Mesh(devices=np.array(jax.devices()).reshape(4, 2), axis_names=["i", "j"])
+ sharding = NamedSharding(mesh=mesh, spec=PartitionSpec("i", "j"))
+ x = jax.device_put(x, device=sharding)
+ assert (
+ tree_summary(x)
+ == "┌────┬───────────┬─────┬───────┐\n│Name│Type │Count│Size │\n├────┼───────────┼─────┼───────┤\n│Σ │G:f32[16,4]│64 │256.00B│\n│ │S:f32[4,2] │ │ │\n└────┴───────────┴─────┴───────┘"
+ )
diff --git a/tests/test_treeclass.py b/tests/test_treeclass.py
index 67da1f3..61887e3 100644
--- a/tests/test_treeclass.py
+++ b/tests/test_treeclass.py
@@ -15,11 +15,12 @@
import copy
import dataclasses as dc
import inspect
+import os
from typing import Any
import numpy.testing as npt
import pytest
-import os
+
from sepes._src.backend import backend, treelib
from sepes._src.code_build import (
autoinit,
@@ -29,8 +30,10 @@
fields,
)
from sepes._src.tree_base import TreeClass
-from sepes._src.tree_mask import freeze
-from sepes._src.tree_util import Partial, is_tree_equal, value_and_tree
+from sepes._src.tree_mask import tree_mask
+from sepes._src.tree_util import is_tree_equal, partial, value_and_tree
+
+freeze = lambda x: tree_mask(x, cond=lambda _: True)
test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy")
if test_arraylib == "jax":
@@ -147,7 +150,7 @@ def __init__(
test = Test()
- assert treelib.tree_flatten(test)[0] == []
+ assert treelib.flatten(test)[0] == []
class Test(TreeClass):
def __init__(self, a=arraylib.array([1, 2, 3]), b=arraylib.array([4, 5, 6])):
@@ -155,7 +158,7 @@ def __init__(self, a=arraylib.array([1, 2, 3]), b=arraylib.array([4, 5, 6])):
self.b = b
test = Test()
- npt.assert_allclose(treelib.tree_flatten(test)[0][0], arraylib.array([4, 5, 6]))
+ npt.assert_allclose(treelib.flatten(test)[0][0], arraylib.array([4, 5, 6]))
def test_post_init():
@@ -200,7 +203,7 @@ def inc(self, x):
l1 = L1()
- assert treelib.tree_flatten(l1)[0] == [2, 4, 5, 5]
+ assert treelib.flatten(l1)[0] == [2, 4, 5, 5]
assert l1.inc(10) == 20
assert l1.sub(10) == 0
assert l1.d == 5
@@ -212,7 +215,7 @@ class L1(L0):
l1 = L1()
- assert treelib.tree_flatten(l1)[0] == [2, 4, 5]
+ assert treelib.flatten(l1)[0] == [2, 4, 5]
def test_registering_state():
@@ -414,7 +417,7 @@ class Test(TreeClass):
t = Test(1)
assert t.a == freeze(1)
- assert treelib.tree_flatten(t)[0] == []
+ assert treelib.flatten(t)[0] == []
def test_super():
@@ -522,10 +525,10 @@ def test_partial():
def f(a, b, c):
return a + b + c
- f_a = Partial(f, ..., 2, 3)
+ f_a = partial(f, ..., 2, 3)
assert f_a(1) == 6
- f_b = Partial(f, 1, ..., 3)
+ f_b = partial(f, 1, ..., 3)
assert f_b(2) == 6
assert f_b == f_b