diff --git a/CHANGELOG.md b/CHANGELOG.md index c0230a2..36acd93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,42 @@ ### Additions +- Expose `at.def_rule` to define custom matchers. + + Example: Define a type matcher that matches based on the name, dtype, and shape + of the leaf and then apply a function to the matched leaf. + + ```python + import sepes as sp + import jax + import jax.numpy as jnp + import dataclasses as dc + @dc.dataclass + class NameDtypeShapeMatcher: + name: str + dtype: jnp.dtype + shape: tuple[int, ...] + def compare(matcher: NameDtypeShapeMatcher, key, leaf) -> bool: + if not isinstance(leaf, jax.Array): + return False + if isinstance(key, str): + key = key + elif isinstance(key, jax.tree_util.GetAttrKey): + key = key.name + elif isinstance(key, jax.tree_util.DictKey): + key = key.key + return matcher.name == key and matcher.dtype == leaf.dtype and matcher.shape == leaf.shape + tree = dict(weight=jnp.arange(9).reshape(3, 3), bias=jnp.zeros(3)) + sp.at.def_rule(NameDtypeShapeMatcher, compare) + matcher = NameDtypeShapeMatcher('weight', jnp.int32, (3, 3)) + to_symmetric = lambda x: (x + x.T) / 2 + sp.at(tree)[matcher].apply(to_symmetric) + # {'bias': Array([0., 0., 0.], dtype=float32), + # 'weight': Array([[0., 2., 4.], + # [2., 4., 6.], + # [4., 6., 8.]], dtype=float32)} + ``` + - Add ability to register custom types for masking wrappers. Example to define a custom masking wrapper for a specific type. diff --git a/docs/API/tree.rst b/docs/API/tree.rst index d408467..ae85aed 100644 --- a/docs/API/tree.rst +++ b/docs/API/tree.rst @@ -12,6 +12,7 @@ scan, reduce, pluck, + def_rule, .. autofunction:: value_and_tree .. autofunction:: bcmap \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index da44229..d9285b3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -10,10 +10,10 @@ Install from pip:: .. toctree:: - :caption: 📖 Guides + :caption: 📖 Recipes :maxdepth: 1 - notebooks/[guides]surgery + recipes .. toctree:: :caption: API Documentation diff --git a/docs/notebooks/[guides]surgery.ipynb b/docs/notebooks/[recipes]surgery.ipynb similarity index 95% rename from docs/notebooks/[guides]surgery.ipynb rename to docs/notebooks/[recipes]surgery.ipynb index ab71cc5..7f63fed 100644 --- a/docs/notebooks/[guides]surgery.ipynb +++ b/docs/notebooks/[recipes]surgery.ipynb @@ -53,6 +53,7 @@ ], "source": [ "import sepes as sp\n", + "\n", "pytree1 = [1, [2, 3], 4]\n", "pytree2 = sp.at(pytree1)[...].get() # get the whole pytree using ...\n", "print(f\"{pytree1=}, {pytree2=}\")\n", @@ -97,6 +98,7 @@ ], "source": [ "import sepes as sp\n", + "\n", "pytree1 = [1, [2, 3], 4]\n", "plus_one = lambda x: x + 1\n", "pytree2 = sp.at(pytree1)[...].apply(plus_one)\n", @@ -130,6 +132,7 @@ ], "source": [ "import sepes as sp\n", + "\n", "pytree1 = [1, [2, 3], 4]\n", "pytree2 = sp.at(pytree1)[1][0].set(100) # equivalent to pytree1[1][0] = 100\n", "pytree2" @@ -161,8 +164,9 @@ ], "source": [ "import sepes as sp\n", + "\n", "pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4, \"f\": {\"g\": 7, \"h\": 8}}\n", - "pytree2 = sp.at(pytree1)[\"b\"].set(100) # equivalent to pytree1[\"b\"] = 100\n", + "pytree2 = sp.at(pytree1)[\"b\"].set(100) # equivalent to pytree1[\"b\"] = 100\n", "pytree2" ] }, @@ -193,7 +197,13 @@ "source": [ "import sepes as sp\n", "import re\n", - "pytree1 = {\"key_1\": 1, \"key_2\": {\"key_3\": 3, \"key_4\": 4}, \"key_5\": 5, \"key_6\": {\"key_7\": 7, \"key_8\": 8}}\n", + "\n", + "pytree1 = {\n", + " \"key_1\": 1,\n", + " \"key_2\": {\"key_3\": 3, \"key_4\": 4},\n", + " \"key_5\": 5,\n", + " \"key_6\": {\"key_7\": 7, \"key_8\": 8},\n", + "}\n", "# from 1 - 5, set the value to 100\n", "pattern = re.compile(r\"key_[1-5]\")\n", "pytree2 = sp.at(pytree1)[pattern].set(100)\n", @@ -231,6 +241,7 @@ "source": [ "import sepes as sp\n", "import jax\n", + "\n", "pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4}\n", "# mask defines all desired entries to apply the function\n", "mask = jax.tree_util.tree_map(lambda x: x < 0, pytree1)\n", diff --git a/docs/recipes.rst b/docs/recipes.rst index f78baee..0d93c13 100644 --- a/docs/recipes.rst +++ b/docs/recipes.rst @@ -1,12 +1,13 @@ -🍳 Recipes ----------------------- +🍳 Tree recipes +---------------- .. toctree:: :caption: Recipes :maxdepth: 1 + notebooks/[recipes]surgery notebooks/[recipes]fields notebooks/[recipes]intermediates + notebooks/[recipes]misc notebooks/[recipes]sharing - notebooks/[recipes]transformations - notebooks/[recipes]misc \ No newline at end of file + notebooks/[recipes]transformations \ No newline at end of file diff --git a/sepes/_src/tree_index.py b/sepes/_src/tree_index.py index 4bb98f4..542dc2a 100644 --- a/sepes/_src/tree_index.py +++ b/sepes/_src/tree_index.py @@ -39,7 +39,7 @@ import abc import functools as ft import re -from typing import Any, Callable, Generic, Hashable, TypeVar, Sequence +from typing import Any, Callable, Generic, Hashable, Sequence, TypeVar from typing_extensions import Self @@ -52,6 +52,7 @@ T = TypeVar("T") S = TypeVar("S") PyTree = Any +Leaf = Any EllipsisType = TypeVar("EllipsisType") PathKeyEntry = TypeVar("PathKeyEntry", bound=Hashable) _no_initializer = object() @@ -62,12 +63,13 @@ class BaseKey(abc.ABC): """Parent class for all match classes.""" @abc.abstractmethod - def __eq__(self, entry: PathKeyEntry) -> bool: + def compare(self, entry: PathKeyEntry, leaf: Leaf) -> bool: pass @property @abc.abstractmethod - def broadcast(self): ... + def broadcast(self): + ... _INVALID_INDEXER = """\ @@ -78,6 +80,7 @@ def broadcast(self): ... - ``re.Pattern`` to match a leaf level path with a regex pattern. - Boolean mask of a compatible structure as the pytree. - `tuple` of the above types to match multiple leaves at the same level. + - Custom matchers defined with `at.def_rule`. """ _NO_LEAF_MATCH = """\ @@ -139,8 +142,8 @@ def path_map_func(path, leaf): # ensure that the path is not empty if len(path) == len(where): - for pi, ki in zip(path, where): - if pi != ki: + for wi, pi in zip(where, path): + if not wi.compare(pi, leaf): return false_tree(leaf) match = True return true_tree(leaf) @@ -150,7 +153,7 @@ def path_map_func(path, leaf): # path entry matches the current where entry, if not then return # a false tree to stop traversing deeper into the tree. (cur_where, *rest_where), (cur_path, *_) = where, path - if cur_where == cur_path: + if cur_where.compare(cur_path, leaf): # where is nonlocal to the function # so reduce the where path by one level and traverse deeper # then restore the where path to the original value before @@ -236,7 +239,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool: # valid resolution of `BaseKey` is a valid indexing leaf # makes it possible to dispatch on multi-leaf pytree level_paths += [resolved_key] - return False + return True if type(node) is tuple and seen_tuple is False: # e.g. `at[1,2,3]` but not `at[1,(2,3)]`` @@ -283,6 +286,7 @@ class at(Generic[T]): - ``...`` to select all leaves. - a boolean mask of the same structure as the tree - ``re.Pattern`` to match a leaf level path with a regex pattern. + - Custom matchers defined with ``at.def_rule``. - a tuple of the above to match multiple keys at the same level. Example: @@ -297,6 +301,7 @@ class at(Generic[T]): >>> sp.at(tree)[mask].set(100) {'a': 1, 'b': [1, 100, 100]} """ + def __init__(self, tree: T, where: list[Any] | None = None) -> None: self.tree = tree self.where = [] if where is None else where @@ -685,38 +690,81 @@ def aggregate_subtrees(node: Any) -> bool: treelib.flatten(tree, is_leaf=aggregate_subtrees) return subtrees + @staticmethod + def def_rule( + matcher_type: type[T], + compare: Callable[[T, PathKeyEntry, Leaf], bool], + *, + broadcastable: bool = False, + ) -> None: + """Define a rule to match user input to with the corresponding path and leaf entry. -# pass through for boolean pytrees masks and tuple of keys -at.dispatcher = ft.singledispatch(lambda x: x) + Args: + matcher_type: the user match object type to match with the path and leaf entry. + compare: a function to compare the user matcher object with the path + and leaf entry the function accepts the user input, the path entry, + and the leaf value and returns a boolean value to mark if the user + input matches the path and leaf entry. + broadcastable: if the user type match result should be broadcasted to the + full subtree. Default to ``False``. + Example: + Define a type matcher that matches based on the name, dtype, and shape + of the leaf and then apply a function to the matched leaf. + + >>> import sepes as sp + >>> import jax + >>> import jax.numpy as jnp + >>> import dataclasses as dc + >>> @dc.dataclass + ... class NameDtypeShapeMatcher: + ... name: str + ... dtype: jnp.dtype + ... shape: tuple[int, ...] + >>> def compare(matcher: NameDtypeShapeMatcher, key, leaf) -> bool: + ... if not isinstance(leaf, jax.Array): + ... return False + ... if isinstance(key, str): + ... key = key + ... elif isinstance(key, jax.tree_util.GetAttrKey): + ... key = key.name + ... elif isinstance(key, jax.tree_util.DictKey): + ... key = key.key + ... return matcher.name == key and matcher.dtype == leaf.dtype and matcher.shape == leaf.shape + >>> tree = dict(weight=jnp.arange(9).reshape(3, 3), bias=jnp.zeros(3)) + >>> sp.at.def_rule(NameDtypeShapeMatcher, compare) + >>> matcher = NameDtypeShapeMatcher('weight', jnp.int32, (3, 3)) + >>> to_symmetric = lambda x: (x + x.T) / 2 + >>> sp.at(tree)[matcher].apply(to_symmetric) + {'bias': Array([0., 0., 0.], dtype=float32), + 'weight': Array([[0., 2., 4.], + [2., 4., 6.], + [4., 6., 8.]], dtype=float32)} + """ -def def_rule( - user_type: type[T], - path_compare_func: Callable[[T, PathKeyEntry], bool], - *, - broadcastable: bool = False, -) -> None: - # remove the BaseKey abstraction from the user-facing function - class UserKey(BaseKey): - broadcast: bool = broadcastable + # remove the BaseKey abstraction from the user-facing function + class UserKey(BaseKey): + broadcast: bool = broadcastable - def __init__(self, input: T): - self.input = input + def __init__(self, input: T): + self.input = input - def __eq__(self, key: PathKeyEntry) -> bool: - return path_compare_func(self.input, key) + def compare(self, key: PathKeyEntry, leaf: Leaf) -> bool: + return compare(self.input, key, leaf) - at.dispatcher.register(user_type, UserKey) + at.dispatcher.register(matcher_type, UserKey) -at.def_rule = def_rule +# pass through for boolean pytrees masks and tuple of keys +at.dispatcher = ft.singledispatch(lambda x: x) -# key rules to match user input to with the path entry +# key rules to match user input to with the path and leaf entry -def str_compare(name: str, key: PathKeyEntry): +def str_compare(name: str, key: PathKeyEntry, leaf: Leaf) -> bool: """Match a leaf with a given name.""" + del leaf if isinstance(key, str): return name == key treelib = sepes._src.backend.treelib @@ -727,8 +775,9 @@ def str_compare(name: str, key: PathKeyEntry): return False -def int_compare(idx: int, key: PathKeyEntry) -> bool: +def int_compare(idx: int, key: PathKeyEntry, leaf: Leaf) -> bool: """Match a leaf with a given index.""" + del leaf if isinstance(key, int): return idx == key treelib = sepes._src.backend.treelib @@ -737,8 +786,9 @@ def int_compare(idx: int, key: PathKeyEntry) -> bool: return False -def regex_compare(pattern: re.Pattern, key: PathKeyEntry) -> bool: +def regex_compare(pattern: re.Pattern, key: PathKeyEntry, leaf: Leaf) -> bool: """Match a path with a regex pattern inside 'at' property.""" + del leaf if isinstance(key, str): return re.fullmatch(pattern, key) is not None treelib = sepes._src.backend.treelib @@ -749,7 +799,8 @@ def regex_compare(pattern: re.Pattern, key: PathKeyEntry) -> bool: return False -def ellipsis_compare(_, __): +def ellipsis_compare(_, key: PathKeyEntry, leaf: Leaf) -> bool: + del key, leaf return True @@ -762,11 +813,11 @@ def ellipsis_compare(_, __): class MultiKey(BaseKey): """Match a leaf with multiple keys at the same level.""" - def __init__(self, *keys): - self.keys = tuple(keys) + def __init__(self, *keys: BaseKey): + self.keys = keys - def __eq__(self, entry: PathKeyEntry) -> bool: - return any(entry == key for key in self.keys) + def compare(self, entry: PathKeyEntry, leaf: Leaf) -> bool: + return any(key.compare(entry, leaf) for key in self.keys) broadcast: bool = False @@ -774,8 +825,9 @@ def __eq__(self, entry: PathKeyEntry) -> bool: if is_package_avaiable("jax"): import jax.tree_util as jtu - def jax_key_compare(input, key: PathKeyEntry) -> bool: + def jax_key_compare(input, key: PathKeyEntry, leaf: Leaf) -> bool: """Enable indexing with jax keys directly in `at`.""" + del leaf return input == key at.def_rule(jtu.SequenceKey, jax_key_compare, broadcastable=False) diff --git a/tests/test_index.py b/tests/test_index.py index 0728380..e398c1f 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -23,7 +23,7 @@ from sepes._src.backend import arraylib, backend, treelib from sepes._src.code_build import autoinit from sepes._src.tree_base import TreeClass, _mutable_instance_registry -from sepes._src.tree_index import at, BaseKey +from sepes._src.tree_index import BaseKey, at from sepes._src.tree_util import is_tree_equal, leafwise, value_and_tree test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy") @@ -650,3 +650,36 @@ def __repr__(self) -> str: assert cur_count == 1 assert new_counter.count == 1 assert not (counter is new_counter) + + +@pytest.mark.skipif(backend != "jax", reason="jax backend needed") +def test_pytree_matcher(): + import jax + import jax.numpy as jnp + import jax.tree_util as jtu + + class NameDtypeShapeMatcher(NamedTuple): + name: str + dtype: str + shape: tuple[int, ...] + + def compare(matcher: NameDtypeShapeMatcher, key, leaf) -> bool: + if not isinstance(leaf, jax.Array): + return False + if isinstance(key, str): + key = key + elif isinstance(key, jtu.GetAttrKey): + key = key.name + elif isinstance(key, jtu.DictKey): + key = key.key + return ( + matcher.name == key + and matcher.dtype == leaf.dtype + and matcher.shape == leaf.shape + ) + + tree = dict(weight=jnp.arange(9).reshape(3, 3), bias=jnp.zeros(3)) + at.def_rule(NameDtypeShapeMatcher, compare) + matcher = NameDtypeShapeMatcher("weight", jnp.int32, (3, 3)) + to_symmetric = lambda x: (x + x.T) / 2 + at(tree)[matcher].apply(to_symmetric)