Skip to content

Commit

Permalink
add def_rule for at indexer
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Mar 29, 2024
1 parent 07af95f commit 07e66c0
Showing 1 changed file with 66 additions and 61 deletions.
127 changes: 66 additions & 61 deletions sepes/_src/tree_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import abc
import functools as ft
import re
from typing import Any, Callable, Generic, Hashable, TypeVar
from typing import Any, Callable, Generic, Hashable, TypeVar, Sequence

from typing_extensions import Self

Expand All @@ -65,6 +65,10 @@ class BaseKey(abc.ABC):
def __eq__(self, entry: PathKeyEntry) -> bool:
pass

@property
@abc.abstractmethod
def broadcast(self): ...


_INVALID_INDEXER = """\
Indexing with {indexer} is not implemented, supported indexing types are:
Expand Down Expand Up @@ -108,7 +112,7 @@ def is_leaf_func(node) -> bool:

return treelib.path_map(func, tree, is_leaf=is_leaf_func)

if any(isinstance(mask, EllipsisKey) for mask in where):
if any(where_i.broadcast for where_i in where):
# should the selected subtree be broadcasted to the full tree
# e.g. tree = [[1, 2], 3, 4] and where = [0], then
# broadcast with True will be [[True, True], False, False]
Expand Down Expand Up @@ -228,7 +232,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool:
bool_masks += [node]
return True

if isinstance(resolved_key := at.key_dispatcher(node), BaseKey):
if isinstance(resolved_key := at.dispatcher(node), BaseKey):
# valid resolution of `BaseKey` is a valid indexing leaf
# makes it possible to dispatch on multi-leaf pytree
level_paths += [resolved_key]
Expand Down Expand Up @@ -689,73 +693,76 @@ def aggregate_subtrees(node: Any) -> bool:


# pass through for boolean pytrees masks and tuple of keys
at.key_dispatcher = ft.singledispatch(lambda x: x)
at.def_key = at.key_dispatcher.register
at.dispatcher = ft.singledispatch(lambda x: x)

# key rules

def def_rule(
user_type: type[T],
path_compare_func: Callable[[T, PathKeyEntry], bool],
*,
broadcastable: bool = False,
) -> None:
# remove the BaseKey abstraction from the user-facing function
class UserKey(BaseKey):
broadcast: bool = broadcastable

@at.def_key(str)
class NameMatchKey(BaseKey):
"""Match a leaf with a given name."""
def __init__(self, input: T):
self.input = input

def __init__(self, name: str) -> None:
self.name = name
def __eq__(self, key: PathKeyEntry) -> bool:
return path_compare_func(self.input, key)

def __eq__(self, key: PathKeyEntry) -> bool:
if isinstance(key, str):
return self.name == key
treelib = sepes._src.backend.treelib
if isinstance(key, type(treelib.attribute_key(""))):
return self.name == key.name
if isinstance(key, type(treelib.dict_key(""))):
return self.name == key.key
return False
at.dispatcher.register(user_type, UserKey)


@at.def_key(int)
class IndexKey(BaseKey):
"""Match a leaf with a given index."""
at.def_rule = def_rule

def __init__(self, idx: int) -> None:
self.idx = idx

def __eq__(self, key: PathKeyEntry) -> bool:
if isinstance(key, int):
return self.idx == key
treelib = sepes._src.backend.treelib
if isinstance(key, type(treelib.sequence_key(0))):
return self.idx == key.idx
return False
# key rules to match user input to with the path entry


@at.def_key(type(...))
class EllipsisKey(BaseKey):
"""Match all leaves."""
def str_compare(name: str, key: PathKeyEntry):
"""Match a leaf with a given name."""
if isinstance(key, str):
return name == key
treelib = sepes._src.backend.treelib
if isinstance(key, type(treelib.attribute_key(""))):
return name == key.name
if isinstance(key, type(treelib.dict_key(""))):
return name == key.key
return False

def __init__(self, _):
del _

def __eq__(self, _: PathKeyEntry) -> bool:
return True
def int_compare(idx: int, key: PathKeyEntry) -> bool:
"""Match a leaf with a given index."""
if isinstance(key, int):
return idx == key
treelib = sepes._src.backend.treelib
if isinstance(key, type(treelib.sequence_key(0))):
return idx == key.idx
return False


@at.def_key(re.Pattern)
class RegexKey(BaseKey):
def regex_compare(pattern: re.Pattern, key: PathKeyEntry) -> bool:
"""Match a path with a regex pattern inside 'at' property."""
if isinstance(key, str):
return re.fullmatch(pattern, key) is not None
treelib = sepes._src.backend.treelib
if isinstance(key, type(treelib.attribute_key(""))):
return re.fullmatch(pattern, key.name) is not None
if isinstance(key, type(treelib.dict_key(""))):
return re.fullmatch(pattern, key.key) is not None
return False

def __init__(self, pattern: str) -> None:
self.pattern = pattern

def __eq__(self, key: PathKeyEntry) -> bool:
if isinstance(key, str):
return re.fullmatch(self.pattern, key) is not None
treelib = sepes._src.backend.treelib
if isinstance(key, type(treelib.attribute_key(""))):
return re.fullmatch(self.pattern, key.name) is not None
if isinstance(key, type(treelib.dict_key(""))):
return re.fullmatch(self.pattern, key.key) is not None
return False
def ellipsis_compare(_, __):
return True


at.def_rule(str, str_compare, broadcastable=False)
at.def_rule(int, int_compare, broadcastable=False)
at.def_rule(re.Pattern, regex_compare, broadcastable=False)
at.def_rule(type(...), ellipsis_compare, broadcastable=True)


class MultiKey(BaseKey):
Expand All @@ -767,18 +774,16 @@ def __init__(self, *keys):
def __eq__(self, entry: PathKeyEntry) -> bool:
return any(entry == key for key in self.keys)

broadcast: bool = False


if is_package_avaiable("jax"):
import jax.tree_util as jtu

@at.def_key(jtu.SequenceKey)
@at.def_key(jtu.GetAttrKey)
@at.def_key(jtu.DictKey)
class JaxKey(BaseKey):
def jax_key_compare(input, key: PathKeyEntry) -> bool:
"""Enable indexing with jax keys directly in `at`."""
return input == key

def __init__(self, key) -> None:
self.key = key

def __eq__(self, value) -> bool:
return self.key == value
at.def_rule(jtu.SequenceKey, jax_key_compare, broadcastable=False)
at.def_rule(jtu.GetAttrKey, jax_key_compare, broadcastable=False)
at.def_rule(jtu.DictKey, jax_key_compare, broadcastable=False)

0 comments on commit 07e66c0

Please sign in to comment.