Skip to content

Commit

Permalink
expose rule and extend matcher from key to key and leaf pair
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 20, 2024
1 parent bfb8031 commit f3cb70f
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 43 deletions.
36 changes: 36 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,42 @@

### Additions

- Expose `at.def_rule` to define custom matchers.

Example: Define a type matcher that matches based on the name, dtype, and shape
of the leaf and then apply a function to the matched leaf.

```python
import sepes as sp
import jax
import jax.numpy as jnp
import dataclasses as dc
@dc.dataclass
class NameDtypeShapeMatcher:
name: str
dtype: jnp.dtype
shape: tuple[int, ...]
def compare(matcher: NameDtypeShapeMatcher, key, leaf) -> bool:
if not isinstance(leaf, jax.Array):
return False
if isinstance(key, str):
key = key
elif isinstance(key, jax.tree_util.GetAttrKey):
key = key.name
elif isinstance(key, jax.tree_util.DictKey):
key = key.key
return matcher.name == key and matcher.dtype == leaf.dtype and matcher.shape == leaf.shape
tree = dict(weight=jnp.arange(9).reshape(3, 3), bias=jnp.zeros(3))
sp.at.def_rule(NameDtypeShapeMatcher, compare)
matcher = NameDtypeShapeMatcher('weight', jnp.int32, (3, 3))
to_symmetric = lambda x: (x + x.T) / 2
sp.at(tree)[matcher].apply(to_symmetric)
# {'bias': Array([0., 0., 0.], dtype=float32),
# 'weight': Array([[0., 2., 4.],
# [2., 4., 6.],
# [4., 6., 8.]], dtype=float32)}
```

- Add ability to register custom types for masking wrappers.

Example to define a custom masking wrapper for a specific type.
Expand Down
1 change: 1 addition & 0 deletions docs/API/tree.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
scan,
reduce,
pluck,
def_rule,

.. autofunction:: value_and_tree
.. autofunction:: bcmap
4 changes: 2 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ Install from pip::


.. toctree::
:caption: 📖 Guides
:caption: 📖 Recipes
:maxdepth: 1

notebooks/[guides]surgery
recipes

.. toctree::
:caption: API Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
],
"source": [
"import sepes as sp\n",
"\n",
"pytree1 = [1, [2, 3], 4]\n",
"pytree2 = sp.at(pytree1)[...].get() # get the whole pytree using ...\n",
"print(f\"{pytree1=}, {pytree2=}\")\n",
Expand Down Expand Up @@ -97,6 +98,7 @@
],
"source": [
"import sepes as sp\n",
"\n",
"pytree1 = [1, [2, 3], 4]\n",
"plus_one = lambda x: x + 1\n",
"pytree2 = sp.at(pytree1)[...].apply(plus_one)\n",
Expand Down Expand Up @@ -130,6 +132,7 @@
],
"source": [
"import sepes as sp\n",
"\n",
"pytree1 = [1, [2, 3], 4]\n",
"pytree2 = sp.at(pytree1)[1][0].set(100) # equivalent to pytree1[1][0] = 100\n",
"pytree2"
Expand Down Expand Up @@ -161,8 +164,9 @@
],
"source": [
"import sepes as sp\n",
"\n",
"pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4, \"f\": {\"g\": 7, \"h\": 8}}\n",
"pytree2 = sp.at(pytree1)[\"b\"].set(100) # equivalent to pytree1[\"b\"] = 100\n",
"pytree2 = sp.at(pytree1)[\"b\"].set(100) # equivalent to pytree1[\"b\"] = 100\n",
"pytree2"
]
},
Expand Down Expand Up @@ -193,7 +197,13 @@
"source": [
"import sepes as sp\n",
"import re\n",
"pytree1 = {\"key_1\": 1, \"key_2\": {\"key_3\": 3, \"key_4\": 4}, \"key_5\": 5, \"key_6\": {\"key_7\": 7, \"key_8\": 8}}\n",
"\n",
"pytree1 = {\n",
" \"key_1\": 1,\n",
" \"key_2\": {\"key_3\": 3, \"key_4\": 4},\n",
" \"key_5\": 5,\n",
" \"key_6\": {\"key_7\": 7, \"key_8\": 8},\n",
"}\n",
"# from 1 - 5, set the value to 100\n",
"pattern = re.compile(r\"key_[1-5]\")\n",
"pytree2 = sp.at(pytree1)[pattern].set(100)\n",
Expand Down Expand Up @@ -231,6 +241,7 @@
"source": [
"import sepes as sp\n",
"import jax\n",
"\n",
"pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4}\n",
"# mask defines all desired entries to apply the function\n",
"mask = jax.tree_util.tree_map(lambda x: x < 0, pytree1)\n",
Expand Down
9 changes: 5 additions & 4 deletions docs/recipes.rst
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
🍳 Recipes
----------------------
🍳 Tree recipes
----------------

.. toctree::
:caption: Recipes
:maxdepth: 1

notebooks/[recipes]surgery
notebooks/[recipes]fields
notebooks/[recipes]intermediates
notebooks/[recipes]misc
notebooks/[recipes]sharing
notebooks/[recipes]transformations
notebooks/[recipes]misc
notebooks/[recipes]transformations
120 changes: 86 additions & 34 deletions sepes/_src/tree_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import abc
import functools as ft
import re
from typing import Any, Callable, Generic, Hashable, TypeVar, Sequence
from typing import Any, Callable, Generic, Hashable, Sequence, TypeVar

from typing_extensions import Self

Expand All @@ -52,6 +52,7 @@
T = TypeVar("T")
S = TypeVar("S")
PyTree = Any
Leaf = Any
EllipsisType = TypeVar("EllipsisType")
PathKeyEntry = TypeVar("PathKeyEntry", bound=Hashable)
_no_initializer = object()
Expand All @@ -62,12 +63,13 @@ class BaseKey(abc.ABC):
"""Parent class for all match classes."""

@abc.abstractmethod
def __eq__(self, entry: PathKeyEntry) -> bool:
def compare(self, entry: PathKeyEntry, leaf: Leaf) -> bool:
pass

@property
@abc.abstractmethod
def broadcast(self): ...
def broadcast(self):
...


_INVALID_INDEXER = """\
Expand All @@ -78,6 +80,7 @@ def broadcast(self): ...
- ``re.Pattern`` to match a leaf level path with a regex pattern.
- Boolean mask of a compatible structure as the pytree.
- `tuple` of the above types to match multiple leaves at the same level.
- Custom matchers defined with `at.def_rule`.
"""

_NO_LEAF_MATCH = """\
Expand Down Expand Up @@ -139,8 +142,8 @@ def path_map_func(path, leaf):

# ensure that the path is not empty
if len(path) == len(where):
for pi, ki in zip(path, where):
if pi != ki:
for wi, pi in zip(where, path):
if not wi.compare(pi, leaf):
return false_tree(leaf)
match = True
return true_tree(leaf)
Expand All @@ -150,7 +153,7 @@ def path_map_func(path, leaf):
# path entry matches the current where entry, if not then return
# a false tree to stop traversing deeper into the tree.
(cur_where, *rest_where), (cur_path, *_) = where, path
if cur_where == cur_path:
if cur_where.compare(cur_path, leaf):
# where is nonlocal to the function
# so reduce the where path by one level and traverse deeper
# then restore the where path to the original value before
Expand Down Expand Up @@ -236,7 +239,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool:
# valid resolution of `BaseKey` is a valid indexing leaf
# makes it possible to dispatch on multi-leaf pytree
level_paths += [resolved_key]
return False
return True

if type(node) is tuple and seen_tuple is False:
# e.g. `at[1,2,3]` but not `at[1,(2,3)]``
Expand Down Expand Up @@ -283,6 +286,7 @@ class at(Generic[T]):
- ``...`` to select all leaves.
- a boolean mask of the same structure as the tree
- ``re.Pattern`` to match a leaf level path with a regex pattern.
- Custom matchers defined with ``at.def_rule``.
- a tuple of the above to match multiple keys at the same level.
Example:
Expand All @@ -297,6 +301,7 @@ class at(Generic[T]):
>>> sp.at(tree)[mask].set(100)
{'a': 1, 'b': [1, 100, 100]}
"""

def __init__(self, tree: T, where: list[Any] | None = None) -> None:
self.tree = tree
self.where = [] if where is None else where
Expand Down Expand Up @@ -685,38 +690,81 @@ def aggregate_subtrees(node: Any) -> bool:
treelib.flatten(tree, is_leaf=aggregate_subtrees)
return subtrees

@staticmethod
def def_rule(
matcher_type: type[T],
compare: Callable[[T, PathKeyEntry, Leaf], bool],
*,
broadcastable: bool = False,
) -> None:
"""Define a rule to match user input to with the corresponding path and leaf entry.
# pass through for boolean pytrees masks and tuple of keys
at.dispatcher = ft.singledispatch(lambda x: x)
Args:
matcher_type: the user match object type to match with the path and leaf entry.
compare: a function to compare the user matcher object with the path
and leaf entry the function accepts the user input, the path entry,
and the leaf value and returns a boolean value to mark if the user
input matches the path and leaf entry.
broadcastable: if the user type match result should be broadcasted to the
full subtree. Default to ``False``.
Example:
Define a type matcher that matches based on the name, dtype, and shape
of the leaf and then apply a function to the matched leaf.
>>> import sepes as sp
>>> import jax
>>> import jax.numpy as jnp
>>> import dataclasses as dc
>>> @dc.dataclass
... class NameDtypeShapeMatcher:
... name: str
... dtype: jnp.dtype
... shape: tuple[int, ...]
>>> def compare(matcher: NameDtypeShapeMatcher, key, leaf) -> bool:
... if not isinstance(leaf, jax.Array):
... return False
... if isinstance(key, str):
... key = key
... elif isinstance(key, jax.tree_util.GetAttrKey):
... key = key.name
... elif isinstance(key, jax.tree_util.DictKey):
... key = key.key
... return matcher.name == key and matcher.dtype == leaf.dtype and matcher.shape == leaf.shape
>>> tree = dict(weight=jnp.arange(9).reshape(3, 3), bias=jnp.zeros(3))
>>> sp.at.def_rule(NameDtypeShapeMatcher, compare)
>>> matcher = NameDtypeShapeMatcher('weight', jnp.int32, (3, 3))
>>> to_symmetric = lambda x: (x + x.T) / 2
>>> sp.at(tree)[matcher].apply(to_symmetric)
{'bias': Array([0., 0., 0.], dtype=float32),
'weight': Array([[0., 2., 4.],
[2., 4., 6.],
[4., 6., 8.]], dtype=float32)}
"""

def def_rule(
user_type: type[T],
path_compare_func: Callable[[T, PathKeyEntry], bool],
*,
broadcastable: bool = False,
) -> None:
# remove the BaseKey abstraction from the user-facing function
class UserKey(BaseKey):
broadcast: bool = broadcastable
# remove the BaseKey abstraction from the user-facing function
class UserKey(BaseKey):
broadcast: bool = broadcastable

def __init__(self, input: T):
self.input = input
def __init__(self, input: T):
self.input = input

def __eq__(self, key: PathKeyEntry) -> bool:
return path_compare_func(self.input, key)
def compare(self, key: PathKeyEntry, leaf: Leaf) -> bool:
return compare(self.input, key, leaf)

at.dispatcher.register(user_type, UserKey)
at.dispatcher.register(matcher_type, UserKey)


at.def_rule = def_rule
# pass through for boolean pytrees masks and tuple of keys
at.dispatcher = ft.singledispatch(lambda x: x)


# key rules to match user input to with the path entry
# key rules to match user input to with the path and leaf entry


def str_compare(name: str, key: PathKeyEntry):
def str_compare(name: str, key: PathKeyEntry, leaf: Leaf) -> bool:
"""Match a leaf with a given name."""
del leaf
if isinstance(key, str):
return name == key
treelib = sepes._src.backend.treelib
Expand All @@ -727,8 +775,9 @@ def str_compare(name: str, key: PathKeyEntry):
return False


def int_compare(idx: int, key: PathKeyEntry) -> bool:
def int_compare(idx: int, key: PathKeyEntry, leaf: Leaf) -> bool:
"""Match a leaf with a given index."""
del leaf
if isinstance(key, int):
return idx == key
treelib = sepes._src.backend.treelib
Expand All @@ -737,8 +786,9 @@ def int_compare(idx: int, key: PathKeyEntry) -> bool:
return False


def regex_compare(pattern: re.Pattern, key: PathKeyEntry) -> bool:
def regex_compare(pattern: re.Pattern, key: PathKeyEntry, leaf: Leaf) -> bool:
"""Match a path with a regex pattern inside 'at' property."""
del leaf
if isinstance(key, str):
return re.fullmatch(pattern, key) is not None
treelib = sepes._src.backend.treelib
Expand All @@ -749,7 +799,8 @@ def regex_compare(pattern: re.Pattern, key: PathKeyEntry) -> bool:
return False


def ellipsis_compare(_, __):
def ellipsis_compare(_, key: PathKeyEntry, leaf: Leaf) -> bool:
del key, leaf
return True


Expand All @@ -762,20 +813,21 @@ def ellipsis_compare(_, __):
class MultiKey(BaseKey):
"""Match a leaf with multiple keys at the same level."""

def __init__(self, *keys):
self.keys = tuple(keys)
def __init__(self, *keys: BaseKey):
self.keys = keys

def __eq__(self, entry: PathKeyEntry) -> bool:
return any(entry == key for key in self.keys)
def compare(self, entry: PathKeyEntry, leaf: Leaf) -> bool:
return any(key.compare(entry, leaf) for key in self.keys)

broadcast: bool = False


if is_package_avaiable("jax"):
import jax.tree_util as jtu

def jax_key_compare(input, key: PathKeyEntry) -> bool:
def jax_key_compare(input, key: PathKeyEntry, leaf: Leaf) -> bool:
"""Enable indexing with jax keys directly in `at`."""
del leaf
return input == key

at.def_rule(jtu.SequenceKey, jax_key_compare, broadcastable=False)
Expand Down
Loading

0 comments on commit f3cb70f

Please sign in to comment.