Skip to content

Commit

Permalink
revert #14 plus some mods
Browse files Browse the repository at this point in the history
- in favor of more explicit
- fails if dicts haves keys similar to re.Pattern
  • Loading branch information
ASEM000 committed Mar 26, 2024
1 parent b31b74e commit d5966f2
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 110 deletions.
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@
- `tree_{mask,unmask}` now accepts only callable `cond` argument.
- Rename `is_frozen` to `is_masked`
- frozen could mean non-trainable array, however the masking is not only for arrays but also for other types that will be hidden across jax transformations.
- Remove `re.compile(pattern)` to match `re.Pattern` in `where` argument in `AtIndexer`, instead use string `pattern` directly.


## V0.11.3

- Raise error if `autoinit` is used with `__init__` method defined.
Expand Down
171 changes: 94 additions & 77 deletions sepes/_src/tree_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@
# apply the *what* part to the *where* part.

from __future__ import annotations

import abc
import functools as ft
import re
from typing import Any, Callable, Hashable, Tuple, TypeVar, Generic
from typing import Any, Callable, Hashable, TypeVar, Generic

from typing_extensions import Self

Expand All @@ -44,76 +43,24 @@
S = TypeVar("S")
PyTree = Any
EllipsisType = TypeVar("EllipsisType")
KeyEntry = TypeVar("KeyEntry", bound=Hashable)
KeyPath = Tuple[KeyEntry, ...]
PathKeyEntry = TypeVar("PathKeyEntry", bound=Hashable)
_no_initializer = object()


class BaseMatchKey(abc.ABC):
class BaseKey(abc.ABC):
"""Parent class for all match classes."""

@abc.abstractmethod
def __eq__(self, entry: KeyEntry) -> bool:
def __eq__(self, entry: PathKeyEntry) -> bool:
pass


class IndexMatchKey(BaseMatchKey):
"""Match a leaf with a given index."""

def __init__(self, idx: int) -> None:
self.idx = idx

def __eq__(self, key: KeyEntry) -> bool:
if isinstance(key, int):
return self.idx == key
treelib = sepes._src.backend.treelib
if isinstance(key, type(treelib.sequence_key(0))):
return self.idx == key.idx
return False


class EllipsisMatchKey(BaseMatchKey):
"""Match all leaves."""

def __init__(self, _):
del _

def __eq__(self, _: KeyEntry) -> bool:
return True


class MultiMatchKey(BaseMatchKey):
"""Match a leaf with multiple keys at the same level."""

def __init__(self, *keys: tuple[BaseMatchKey, ...]):
self.keys = tuple(keys)

def __eq__(self, entry) -> bool:
return any(entry == key for key in self.keys)


class RegexMatchKey(BaseMatchKey):
"""Match a leaf with a regex pattern inside 'at' property."""

def __init__(self, pattern: str) -> None:
self.pattern = pattern

def __eq__(self, key: KeyEntry) -> bool:
if isinstance(key, str):
return re.fullmatch(self.pattern, key) is not None
treelib = sepes._src.backend.treelib
if isinstance(key, type(treelib.attribute_key(""))):
return re.fullmatch(self.pattern, key.name) is not None
if isinstance(key, type(treelib.dict_key(""))):
return re.fullmatch(self.pattern, key.key) is not None
return False


_INVALID_INDEXER = """\
Indexing with {indexer} is not implemented, supported indexing types are:
- `str` for mapping keys or class attributes.
- `int` for positional indexing for sequences.
- `...` to select all leaves.
- ``re.Pattern`` to match a leaf level path with a regex pattern.
- Boolean mask of a compatible structure as the pytree.
- `tuple` of the above types to match multiple leaves at the same level.
"""
Expand All @@ -124,11 +71,12 @@ def __eq__(self, key: KeyEntry) -> bool:
Check the following:
- If where is `str` then check if the key exists as a key or attribute.
- If where is `int` then check if the index is in range.
- If where is `re.Pattern` then check if the pattern matches any key.
- If where is a `tuple` of the above types then check if any of the tuple elements match.
"""


def generate_path_mask(tree, where: tuple[BaseMatchKey, ...], *, is_leaf=None):
def generate_path_mask(tree, where: tuple[BaseKey, ...], *, is_leaf=None):
# given a pytree `tree` and a `where` path, that is composed of keys
# generate a boolean mask that will be eventually used to with `tree_map`
# to mark the leaves at the specified location.
Expand All @@ -149,7 +97,7 @@ def is_leaf_func(node) -> bool:

return treelib.path_map(func, tree, is_leaf=is_leaf_func)

if any(isinstance(mask, EllipsisMatchKey) for mask in where):
if any(isinstance(mask, EllipsisKey) for mask in where):
# should the selected subtree be broadcasted to the full tree
# e.g. tree = [[1, 2], 3, 4] and where = [0], then
# broadcast with True will be [[True, True], False, False]
Expand Down Expand Up @@ -206,7 +154,7 @@ def path_map_func(path, leaf):
mask = one_level_tree_path_map(path_map_func, tree)

if not match:
path_leaf, _ = treelib.tree_path_flatten(tree, is_leaf=is_leaf)
path_leaf, _ = treelib.path_flatten(tree, is_leaf=is_leaf)
names = "".join("\n - " + treelib.keystr(path) for path, _ in path_leaf)
raise LookupError(_NO_LEAF_MATCH.format(where=where, names=names))

Expand Down Expand Up @@ -239,7 +187,7 @@ def is_bool_leaf(leaf: Any) -> bool:
# with `tree_map` to select the leaves at the specified location.
mask = None
bool_masks: list[T] = []
path_masks: list[BaseMatchKey] = []
path_masks: list[BaseKey] = []
seen_tuple = False # handle multiple keys at the same level
level_paths = []

Expand Down Expand Up @@ -268,8 +216,8 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool:
bool_masks += [node]
return True

if isinstance(resolved_key := at.alias_dispatcher(node), BaseMatchKey):
# valid resolution of `BaseMatchKey` is a valid indexing leaf
if isinstance(resolved_key := at.alias_dispatcher(node), BaseKey):
# valid resolution of `BaseKey` is a valid indexing leaf
# makes it possible to dispatch on multi-leaf pytree
level_paths += [resolved_key]
return False
Expand All @@ -293,9 +241,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool:
# if len(level_paths) > 1 then this means that we have multiple keys
# at the same level, for example where = ("a", ("b", "c")) then this
# means that for a parent "a", select "b" and "c".
path_masks += (
[MultiMatchKey(*level_paths)] if len(level_paths) > 1 else level_paths
)
path_masks += [MultiKey(*level_paths)] if len(level_paths) > 1 else level_paths
level_paths = []
seen_tuple = False

Expand All @@ -312,6 +258,8 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool:
class at(Generic[T]):
"""Operate on a pytree at a given path using a path or mask in out-of-place manner.
Alias for :func:`.at`
Args:
tree: pytree to operate on.
where: one of the following:
Expand All @@ -320,6 +268,7 @@ class at(Generic[T]):
- ``int`` for positional indexing for sequences.
- ``...`` to select all leaves.
- a boolean mask of the same structure as the tree
- ``re.Pattern`` to match a leaf level path with a regex pattern.
- a tuple of the above to match multiple keys at the same level.
Note:
Expand Down Expand Up @@ -347,7 +296,7 @@ def __getitem__(self, where: Any) -> Self:
return type(self)(self.tree, [*self.where, where])

def __repr__(self) -> str:
return f"{type(self).__name__}(tree={tree_repr(self.tree)}, where={self.where})"
return f"{type(self).__name__}({tree_repr(self.tree)}, where={self.where})"

def get(
self,
Expand Down Expand Up @@ -505,7 +454,6 @@ def apply(
>>> is_parallel = dict(max_workers=2)
>>> images = sp.at(path)[...].apply(imread, is_parallel=is_parallel) # doctest: +SKIP
"""

treelib = sepes._src.backend.treelib

def leaf_apply(where: Any, leaf: Any):
Expand Down Expand Up @@ -713,15 +661,84 @@ def aggregate_subtrees(node: Any) -> bool:
treelib.flatten(tree, is_leaf=aggregate_subtrees)
return subtrees

# dispatch on type of indexer to convert input item to at indexer
# `__getitem__` to the appropriate key
# avoid using container pytree types to avoid conflict between
# matching as a mask or as an instance of `BaseMatchKey`

at.alias_dispatcher = ft.singledispatch(lambda x: x)
at.def_alias = at.alias_dispatcher.register
at.def_alias(type(...), EllipsisMatchKey)
at.def_alias(int, IndexMatchKey)
at.def_alias(str, RegexMatchKey)
# backward compatibility
# backwards compatibility
AtIndexer = at


# rules


@at.def_alias(str)
class NameMatchKey(BaseKey):
"""Match a leaf with a given name."""

def __init__(self, name: str) -> None:
self.name = name

def __eq__(self, key: PathKeyEntry) -> bool:
if isinstance(key, str):
return self.name == key
treelib = sepes._src.backend.treelib
if isinstance(key, type(treelib.attribute_key(""))):
return self.name == key.name
if isinstance(key, type(treelib.dict_key(""))):
return self.name == key.key
return False


@at.def_alias(int)
class IndexKey(BaseKey):
"""Match a leaf with a given index."""

def __init__(self, idx: int) -> None:
self.idx = idx

def __eq__(self, key: PathKeyEntry) -> bool:
if isinstance(key, int):
return self.idx == key
treelib = sepes._src.backend.treelib
if isinstance(key, type(treelib.sequence_key(0))):
return self.idx == key.idx
return False


@at.def_alias(type(...))
class EllipsisKey(BaseKey):
"""Match all leaves."""

def __init__(self, _):
del _

def __eq__(self, _: PathKeyEntry) -> bool:
return True


@at.def_alias(re.Pattern)
class RegexKey(BaseKey):
"""Match a leaf with a regex pattern inside 'at' property."""

def __init__(self, pattern: str) -> None:
self.pattern = pattern

def __eq__(self, key: PathKeyEntry) -> bool:
if isinstance(key, str):
return re.fullmatch(self.pattern, key) is not None
treelib = sepes._src.backend.treelib
if isinstance(key, type(treelib.attribute_key(""))):
return re.fullmatch(self.pattern, key.name) is not None
if isinstance(key, type(treelib.dict_key(""))):
return re.fullmatch(self.pattern, key.key) is not None
return False


class MultiKey(BaseKey):
"""Match a leaf with multiple keys at the same level."""

def __init__(self, *keys):
self.keys = tuple(keys)

def __eq__(self, entry: PathKeyEntry) -> bool:
return any(entry == key for key in self.keys)
2 changes: 1 addition & 1 deletion sepes/_src/tree_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class _MaskedError(NamedTuple):

def __call__(self, *a, **k):
raise NotImplementedError(
f"Cannot apply `{self.opname}` operation to a frozen object "
f"Cannot apply `{self.opname}` operation on a masked object "
f"{', '.join(map(str, a))} "
f"{', '.join(k + '=' + str(v) for k, v in k.items())}.\n"
"Unmask the object first using `tree_unmask`"
Expand Down
Loading

0 comments on commit d5966f2

Please sign in to comment.