diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ea9268..18c7c8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,7 +35,8 @@ - `tree_{mask,unmask}` now accepts only callable `cond` argument. - Rename `is_frozen` to `is_masked` - - frozen could mean non-trainable array, however the masking is not only for arrays but also for other types that will be hidden from jax transformations. + - frozen could mean non-trainable array, however the masking is not only for arrays but also for other types that will be hidden across jax transformations. +- Remove `re.compile(pattern)` to match `re.Pattern` in `where` argument in `AtIndexer`, instead use string `pattern` directly. ## V0.11.3 diff --git a/docs/API/core.rst b/docs/API/core.rst index f976053..3021e35 100644 --- a/docs/API/core.rst +++ b/docs/API/core.rst @@ -7,7 +7,7 @@ .. autoclass:: TreeClass :members: at -.. autoclass:: AtIndexer +.. autoclass:: at :members: get, set, @@ -16,10 +16,7 @@ reduce, pluck, -.. autoclass:: at -.. autoclass:: BaseKey - :members: - __eq__ +.. autoclass:: AtIndexer .. autofunction:: autoinit .. autofunction:: leafwise .. autofunction:: field diff --git a/sepes/__init__.py b/sepes/__init__.py index 607860a..c27b625 100644 --- a/sepes/__init__.py +++ b/sepes/__init__.py @@ -15,7 +15,7 @@ from sepes._src.backend import backend_context from sepes._src.code_build import autoinit, field, fields from sepes._src.tree_base import TreeClass -from sepes._src.tree_index import AtIndexer, BaseKey, at +from sepes._src.tree_index import AtIndexer, at from sepes._src.tree_mask import ( is_masked, is_nondiff, @@ -53,7 +53,6 @@ # indexing utils "AtIndexer", "at", - "BaseKey", # tree utils "bcmap", "leafwise", diff --git a/sepes/_src/tree_base.py b/sepes/_src/tree_base.py index 9207aa0..d53dcb4 100644 --- a/sepes/_src/tree_base.py +++ b/sepes/_src/tree_base.py @@ -270,8 +270,6 @@ def at(self) -> AtIndexer[Self]: - ``int`` for positional indexing for sequences. - ``...`` to select all leaves. - a boolean mask of the same structure as the tree - - ``re.Pattern`` to index all keys matching a regex pattern. - - an instance of ``BaseKey`` with custom logic to index a pytree. - a tuple of the above types to index multiple keys at same level. Example: diff --git a/sepes/_src/tree_index.py b/sepes/_src/tree_index.py index 721c388..f49b4d2 100644 --- a/sepes/_src/tree_index.py +++ b/sepes/_src/tree_index.py @@ -49,123 +49,15 @@ _no_initializer = object() -class BaseKey(abc.ABC): - """Parent class for all match classes. - - - Subclass this class to create custom match keys by implementing - the `__eq__` method. The ``__eq__`` method should return True if the - key matches the given path entry and False otherwise. The path entry - refers to the entry defined in the ``tree_flatten_with_keys`` method of - the pytree class. - - - Typical path entries in ``jax`` are: - - - ``jax.tree_util.GetAttrKey`` for attributes - - ``jax.tree_util.DictKey`` for mapping keys - - ``jax.tree_util.SequenceKey`` for sequence indices - - - When implementing the ``__eq__`` method you can use the ``singledispatchmethod`` - to unpack the path entry for example: - - - ``jax.tree_util.GetAttrKey`` -> `key.name` - - ``jax.tree_util.DictKey`` -> `key.key` - - ``jax.tree_util.SequenceKey`` -> `key.index` - - - See Examples for more details. - - Example: - >>> # define an match strategy to match a leaf with a given name and type - >>> import sepes as sp - >>> from typing import NamedTuple - >>> import jax - >>> class NameTypeContainer(NamedTuple): - ... name: str - ... type: type - >>> @jax.tree_util.register_pytree_with_keys_class - ... class Tree: - ... def __init__(self, a, b) -> None: - ... self.a = a - ... self.b = b - ... def tree_flatten_with_keys(self): - ... ak = (NameTypeContainer("a", type(self.a)), self.a) - ... bk = (NameTypeContainer("b", type(self.b)), self.b) - ... return (ak, bk), None - ... @classmethod - ... def tree_unflatten(cls, aux_data, children): - ... return cls(*children) - ... @property - ... def at(self): - ... return sp.at(self) - >>> tree = Tree(1, 2) - >>> class MatchNameType(sp.BaseKey): - ... def __init__(self, name, type): - ... self.name = name - ... self.type = type - ... def __eq__(self, other): - ... if isinstance(other, NameTypeContainer): - ... return other == (self.name, self.type) - ... return False - >>> tree = tree.at[MatchNameType("a", int)].get() - >>> assert jax.tree_util.tree_leaves(tree) == [1] - - Note: - - use ``BaseKey.def_alias(type, func)`` to define an index type alias - for `BaseKey` subclasses. This is useful for convience when - creating new match strategies. - - >>> import sepes as sp - >>> import functools as ft - >>> from types import FunctionType - >>> import jax.tree_util as jtu - >>> # lets define a new match strategy called `FuncKey` that applies - >>> # a function to the path entry and returns True if the function - >>> # returns True and False otherwise. - >>> # for example `FuncKey(lambda x: x.startswith("a"))` will match - >>> # all leaves that start with "a". - >>> class FuncKey(sp.BaseKey): - ... def __init__(self, func): - ... self.func = func - ... @ft.singledispatchmethod - ... def __eq__(self, key): - ... return self.func(key) - ... @__eq__.register(jtu.GetAttrKey) - ... def _(self, key: jtu.GetAttrKey): - ... # unpack the GetAttrKey - ... return self.func(key.name) - ... @__eq__.register(jtu.DictKey) - ... def _(self, key: jtu.DictKey): - ... # unpack the DictKey - ... return self.func(key.key) - ... @__eq__.register(jtu.SequenceKey) - ... def _(self, key: jtu.SequenceKey): - ... return self.func(key.index) - >>> # instead of using ``FuncKey(function)`` we can define an alias - >>> # for `FuncKey`, for this example we will define any FunctionType - >>> # as a `FuncKey` by default. - >>> @sp.BaseKey.def_alias(FunctionType) - ... def _(func): - ... return FuncKey(func) - >>> # create a simple pytree - >>> @sp.autoinit - ... class Tree(sp.TreeClass): - ... a: int - ... b: str - >>> tree = Tree(1, "string") - >>> # now we can use the `FuncKey` alias to match all leaves that - >>> # are strings and start with "a" - >>> tree.at[lambda x: isinstance(x, str) and x.startswith("a")].get() - Tree(a=1, b=None) - """ +class BaseMatchKey(abc.ABC): + """Parent class for all match classes.""" @abc.abstractmethod def __eq__(self, entry: KeyEntry) -> bool: pass - broadcastable: bool = False - -class IndexKey(BaseKey): +class IndexMatchKey(BaseMatchKey): """Match a leaf with a given index.""" def __init__(self, idx: int) -> None: @@ -179,77 +71,29 @@ def __eq__(self, key: KeyEntry) -> bool: return self.idx == key.idx return False - def __repr__(self) -> str: - return f"{self.idx}" - - -class NameKey(BaseKey): - """Match a leaf with a given key.""" - - def __init__(self, name: str) -> None: - self.name = name - - def __eq__(self, key: KeyEntry) -> bool: - if isinstance(key, str): - return self.name == key - treelib = sepes._src.backend.treelib - if isinstance(key, type(treelib.attribute_key(""))): - return self.name == key.name - if isinstance(key, type(treelib.dict_key(""))): - return self.name == key.key - return False - - def __repr__(self) -> str: - return f"{self.name}" - -class EllipsisKey(BaseKey): +class EllipsisMatchKey(BaseMatchKey): """Match all leaves.""" - broadcastable = True - def __init__(self, _): del _ def __eq__(self, _: KeyEntry) -> bool: return True - def __repr__(self) -> str: - return "..." - -class MultiKey(BaseKey): +class MultiMatchKey(BaseMatchKey): """Match a leaf with multiple keys at the same level.""" - def __init__(self, *keys: tuple[BaseKey, ...]): + def __init__(self, *keys: tuple[BaseMatchKey, ...]): self.keys = tuple(keys) def __eq__(self, entry) -> bool: return any(entry == key for key in self.keys) - def __repr__(self) -> str: - return f"({', '.join(map(repr, self.keys))})" - -class RegexKey(BaseKey): - """Match a leaf with a regex pattern inside 'at' property. - - Args: - pattern: regex pattern to match. - - Example: - >>> import sepes as sp - >>> import re - >>> @sp.autoinit - ... class Tree(sp.TreeClass): - ... weight_1: float = 1.0 - ... weight_2: float = 2.0 - ... weight_3: float = 3.0 - ... bias: float = 0.0 - >>> tree = Tree() - >>> tree.at[re.compile(r"weight_.*")].set(100.0) # set all weights to 100.0 - Tree(weight_1=100.0, weight_2=100.0, weight_3=100.0, bias=0.0) - """ +class RegexMatchKey(BaseMatchKey): + """Match a leaf with a regex pattern inside 'at' property.""" def __init__(self, pattern: str) -> None: self.pattern = pattern @@ -264,22 +108,6 @@ def __eq__(self, key: KeyEntry) -> bool: return re.fullmatch(self.pattern, key.key) is not None return False - def __repr__(self) -> str: - return f"{self.pattern}" - - -# dispatch on type of indexer to convert input item to at indexer -# `__getitem__` to the appropriate key -# avoid using container pytree types to avoid conflict between -# matching as a mask or as an instance of `BaseKey` -indexer_dispatcher = ft.singledispatch(lambda x: x) -indexer_dispatcher.register(type(...), EllipsisKey) -indexer_dispatcher.register(int, IndexKey) -indexer_dispatcher.register(str, NameKey) -indexer_dispatcher.register(re.Pattern, RegexKey) - -BaseKey.def_alias = indexer_dispatcher.register - _INVALID_INDEXER = """\ Indexing with {indexer} is not implemented, supported indexing types are: @@ -287,22 +115,20 @@ def __repr__(self) -> str: - `int` for positional indexing for sequences. - `...` to select all leaves. - Boolean mask of a compatible structure as the pytree. - - `re.Pattern` to index all keys matching a regex pattern. - - Instance of `BaseKey` with custom logic to index a pytree. - `tuple` of the above types to match multiple leaves at the same level. """ _NO_LEAF_MATCH = """\ No leaf match is found for where={where}. Available keys are {names}. +No leaf match is found for where={where}. Available keys are {names}. Check the following: - If where is `str` then check if the key exists as a key or attribute. - If where is `int` then check if the index is in range. - - If where is `re.Pattern` then check if the pattern matches any key. - If where is a `tuple` of the above types then check if any of the tuple elements match. """ -def generate_path_mask(tree, where: tuple[BaseKey, ...], *, is_leaf=None): +def generate_path_mask(tree, where: tuple[BaseMatchKey, ...], *, is_leaf=None): # given a pytree `tree` and a `where` path, that is composed of keys # generate a boolean mask that will be eventually used to with `tree_map` # to mark the leaves at the specified location. @@ -323,7 +149,7 @@ def is_leaf_func(node) -> bool: return treelib.tree_path_map(func, tree, is_leaf=is_leaf_func) - if any(mask.broadcastable for mask in where): + if any(isinstance(mask, EllipsisMatchKey) for mask in where): # should the selected subtree be broadcasted to the full tree # e.g. tree = [[1, 2], 3, 4] and where = [0], then # broadcast with True will be [[True, True], False, False] @@ -413,7 +239,7 @@ def is_bool_leaf(leaf: Any) -> bool: # with `tree_map` to select the leaves at the specified location. mask = None bool_masks: list[T] = [] - path_masks: list[BaseKey] = [] + path_masks: list[BaseMatchKey] = [] seen_tuple = False # handle multiple keys at the same level level_paths = [] @@ -442,8 +268,8 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool: bool_masks += [node] return True - if isinstance(resolved_key := indexer_dispatcher(node), BaseKey): - # valid resolution of `BaseKey` is a valid indexing leaf + if isinstance(resolved_key := at.alias_dispatcher(node), BaseMatchKey): + # valid resolution of `BaseMatchKey` is a valid indexing leaf # makes it possible to dispatch on multi-leaf pytree level_paths += [resolved_key] return False @@ -467,7 +293,9 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool: # if len(level_paths) > 1 then this means that we have multiple keys # at the same level, for example where = ("a", ("b", "c")) then this # means that for a parent "a", select "b" and "c". - path_masks += [MultiKey(*level_paths)] if len(level_paths) > 1 else level_paths + path_masks += ( + [MultiMatchKey(*level_paths)] if len(level_paths) > 1 else level_paths + ) level_paths = [] seen_tuple = False @@ -481,12 +309,9 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool: return mask -class AtIndexer(Generic[T]): +class at(Generic[T]): """Operate on a pytree at a given path using a path or mask in out-of-place manner. - Note: - Use :class:`.at` as a shorter alias for this class. - Args: tree: pytree to operate on. where: one of the following: @@ -495,8 +320,6 @@ class AtIndexer(Generic[T]): - ``int`` for positional indexing for sequences. - ``...`` to select all leaves. - a boolean mask of the same structure as the tree - - ``re.Pattern`` to index all keys matching a regex pattern. - - an instance of ``BaseKey`` with custom logic to index a pytree. - a tuple of the above to match multiple keys at the same level. Note: @@ -516,11 +339,8 @@ class AtIndexer(Generic[T]): """ def __init__(self, tree: T, where: list[Any] | None = None) -> None: - vars(self)["tree"] = tree - vars(self)["where"] = [] if where is None else where - - def __setattr__(self, key: str, _: Any) -> None: - raise AttributeError(f"Cannot set {key=} on {type(self).__name__} instance") + self.tree = tree + self.where = [] if where is None else where def __getitem__(self, where: Any) -> Self: """Index a pytree at a given path using a path or mask.""" @@ -894,5 +714,15 @@ def aggregate_subtrees(node: Any) -> bool: return subtrees -# shorter alias -at = AtIndexer +# dispatch on type of indexer to convert input item to at indexer +# `__getitem__` to the appropriate key +# avoid using container pytree types to avoid conflict between +# matching as a mask or as an instance of `BaseMatchKey` + +at.alias_dispatcher = ft.singledispatch(lambda x: x) +at.def_alias = at.alias_dispatcher.register +at.def_alias(type(...), EllipsisMatchKey) +at.def_alias(int, IndexMatchKey) +at.def_alias(str, RegexMatchKey) +# backward compatibility +AtIndexer = at diff --git a/tests/test_index.py b/tests/test_index.py index 097543e..580533d 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -14,7 +14,6 @@ import os -import re from collections import namedtuple from typing import NamedTuple @@ -23,7 +22,7 @@ from sepes._src.backend import arraylib, backend, treelib from sepes._src.code_build import autoinit from sepes._src.tree_base import TreeClass, _mutable_instance_registry -from sepes._src.tree_index import AtIndexer, BaseKey +from sepes._src.tree_index import AtIndexer, BaseMatchKey from sepes._src.tree_util import is_tree_equal, leafwise, value_and_tree test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy") @@ -102,9 +101,9 @@ def __init__(self, c: int, d: int): [tree7, dict(a=None, b=[2, None], c=None), ("b", 0)], [tree8, dict(a=None, b=ClassSubTree(c=2, d=None), e=None), ("b", 0)], # by regex - [tree1, dict(a=None, b=dict(c=2, d=None), e=None), ("b", re.compile("c"))], - [tree2, ClassTree(None, dict(c=2, d=None), None), ("b", re.compile("c"))], - [tree3, ClassTree(None, ClassSubTree(2, None), None), ("b", re.compile("c"))], + [tree1, dict(a=None, b=dict(c=2, d=None), e=None), ("b", ("c"))], + [tree2, ClassTree(None, dict(c=2, d=None), None), ("b", ("c"))], + [tree3, ClassTree(None, ClassSubTree(2, None), None), ("b", ("c"))], # by ellipsis [tree1, tree1, (...,)], [tree2, tree2, (...,)], @@ -171,9 +170,9 @@ def test_array_indexer_get(tree, expected, where): [tree7, dict(a=1, b=[2, _X], c=4), ("b", 1), _X], [tree8, dict(a=1, b=ClassSubTree(c=2, d=_X), e=4), ("b", 1), _X], # by regex - [tree1, dict(a=1, b=dict(c=_X, d=3), e=4), ("b", re.compile("c")), _X], - [tree2, ClassTree(1, dict(c=_X, d=3), 4), ("b", re.compile("c")), _X], - [tree3, ClassTree(1, ClassSubTree(_X, 3), 4), ("b", re.compile("c")), _X], + [tree1, dict(a=1, b=dict(c=_X, d=3), e=4), ("b", ("c")), _X], + [tree2, ClassTree(1, dict(c=_X, d=3), 4), ("b", ("c")), _X], + [tree3, ClassTree(1, ClassSubTree(_X, 3), 4), ("b", ("c")), _X], # by ellipsis [ tree1, @@ -213,9 +212,9 @@ def test_indexer_set(tree, expected, where, set_value): [tree7, dict(a=1, b=[2, _X], c=4), ("b", 1), _X], [tree8, dict(a=1, b=ClassSubTree(c=2, d=_X), e=4), ("b", 1), _X], # by regex - [tree1, dict(a=1, b=dict(c=_X, d=3), e=4), ("b", re.compile("c")), _X], - [tree2, ClassTree(1, dict(c=_X, d=3), 4), ("b", re.compile("c")), _X], - [tree3, ClassTree(1, ClassSubTree(_X, 3), 4), ("b", re.compile("c")), _X], + [tree1, dict(a=1, b=dict(c=_X, d=3), e=4), ("b", ("c")), _X], + [tree2, ClassTree(1, dict(c=_X, d=3), 4), ("b", ("c")), _X], + [tree3, ClassTree(1, ClassSubTree(_X, 3), 4), ("b", ("c")), _X], # by ellipsis [ tree1, @@ -253,9 +252,9 @@ def test_array_indexer_set(tree, expected, where, set_value): [tree7, dict(a=1, b=[2, _X], c=4), ("b", 1)], [tree8, dict(a=1, b=ClassSubTree(c=2, d=_X), e=4), ("b", 1)], # by regex - [tree1, dict(a=1, b=dict(c=_X, d=3), e=4), ("b", re.compile("c"))], - [tree2, ClassTree(1, dict(c=_X, d=3), 4), ("b", re.compile("c"))], - [tree3, ClassTree(1, ClassSubTree(_X, 3), 4), ("b", re.compile("c"))], + [tree1, dict(a=1, b=dict(c=_X, d=3), e=4), ("b", ("c"))], + [tree2, ClassTree(1, dict(c=_X, d=3), 4), ("b", ("c"))], + [tree3, ClassTree(1, ClassSubTree(_X, 3), 4), ("b", ("c"))], # by ellipsis [tree1, dict(a=_X, b=dict(c=_X, d=_X), e=_X), (...,)], [tree2, ClassTree(_X, dict(c=_X, d=_X), _X), (...,)], @@ -292,9 +291,9 @@ def test_indexer_apply(tree, expected, where): [tree7, dict(a=1, b=[2, _X], c=4), ("b", 1)], [tree8, dict(a=1, b=ClassSubTree(c=2, d=_X), e=4), ("b", 1)], # by regex - [tree1, dict(a=1, b=dict(c=_X, d=3), e=4), ("b", re.compile("c"))], - [tree2, ClassTree(1, dict(c=_X, d=3), 4), ("b", re.compile("c"))], - [tree3, ClassTree(1, ClassSubTree(_X, 3), 4), ("b", re.compile("c"))], + [tree1, dict(a=1, b=dict(c=_X, d=3), e=4), ("b", ("c"))], + [tree2, ClassTree(1, dict(c=_X, d=3), 4), ("b", ("c"))], + [tree3, ClassTree(1, ClassSubTree(_X, 3), 4), ("b", ("c"))], # by ellipsis [tree1, dict(a=_X, b=dict(c=_X, d=_X), e=_X), (...,)], [tree2, ClassTree(_X, dict(c=_X, d=_X), _X), (...,)], @@ -328,9 +327,9 @@ def test_array_indexer_apply(tree, expected, where): # mixed [tree7, 5, ("b", (0, 1))], # by regex - [tree1, 5, ("b", re.compile("c|d"))], - [tree2, 5, ("b", re.compile("c|d"))], - [tree3, 5, ("b", re.compile("c|d"))], + [tree1, 5, ("b", ("c|d"))], + [tree2, 5, ("b", ("c|d"))], + [tree3, 5, ("b", ("c|d"))], # by ellipsis [tree1, 1 + 2 + 3 + 4, (...,)], [tree2, 1 + 2 + 3 + 4, (...,)], @@ -363,9 +362,9 @@ def test_indexer_reduce(tree, expected, where): # mixed [tree7, 5, ("b", (0, 1))], # by regex - [tree1, 5, ("b", re.compile("c|d"))], - [tree2, 5, ("b", re.compile("c|d"))], - [tree3, 5, ("b", re.compile("c|d"))], + [tree1, 5, ("b", ("c|d"))], + [tree2, 5, ("b", ("c|d"))], + [tree3, 5, ("b", ("c|d"))], # by ellipsis [tree1, 1 + 2 + 3 + 4, (...,)], [tree2, 1 + 2 + 3 + 4, (...,)], @@ -399,9 +398,9 @@ def test_array_indexer_reduce(tree, expected, where): [tree7, (dict(a=1, b=[2, 5], c=4), 3), ("b", (0, 1))], # [tree8, (dict(a=1, b=ClassSubTree(c=2, d=5), e=4), 3), ("b", (0, 1))], # by regex - [tree1, (dict(a=1, b=dict(c=2, d=5), e=4), 3), ("b", re.compile("c|d"))], - [tree2, (ClassTree(1, dict(c=2, d=5), 4), 3), ("b", re.compile("c|d"))], - [tree3, (ClassTree(1, ClassSubTree(2, 5), 4), 3), ("b", re.compile("c|d"))], + [tree1, (dict(a=1, b=dict(c=2, d=5), e=4), 3), ("b", ("c|d"))], + [tree2, (ClassTree(1, dict(c=2, d=5), 4), 3), ("b", ("c|d"))], + [tree3, (ClassTree(1, ClassSubTree(2, 5), 4), 3), ("b", ("c|d"))], ], ) def test_indexer_scan(tree, expected, where): @@ -554,7 +553,7 @@ def tree_unflatten(aux_data, children): tree = Tree(1, 2) - class MatchNameType(BaseKey): + class MatchNameType(BaseMatchKey): def __init__(self, name, type): self.name = name self.type = type @@ -575,9 +574,9 @@ class Tree(TreeClass): t = Tree() - assert repr(t.at["a"]) == "AtIndexer(tree=Tree(a=1, b=2), where=['a'])" - assert str(t.at["a"]) == "AtIndexer(tree=Tree(a=1, b=2), where=['a'])" - assert repr(t.at[...]) == "AtIndexer(tree=Tree(a=1, b=2), where=[Ellipsis])" + assert repr(t.at["a"]) == "at(tree=Tree(a=1, b=2), where=['a'])" + assert str(t.at["a"]) == "at(tree=Tree(a=1, b=2), where=['a'])" + assert repr(t.at[...]) == "at(tree=Tree(a=1, b=2), where=[Ellipsis])" def test_compat_mask():