Skip to content

Commit

Permalink
subtree replacement via relaxed boolean masking
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Oct 6, 2023
1 parent d442869 commit b478aee
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 28 deletions.
15 changes: 14 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# Changelog

## V0.Next

- Mark full subtrees for replacement.

```python
import sepes
tree = [1, 2, [3,4]]
tree_= sp.AtIndexer(tree)[[False,False,True]].set(10)
assert tree_ == [1, 2, 10]
```

i.e. Inside a mask, marking a _subtree_ mask with single bool leaf, will replace the whole subtree. In this example subtree `[3, 4]` marked with `True` in the mask is an indicator for replacement.

## v0.10.0

- successor of the `jax`-specific `pytreeclass`
Expand All @@ -14,4 +27,4 @@
- drop `callback` option in parallel options in `is_parallel`
- Add parallel processing via `is_parallel` to `.{get,set}`
- `register_excluded_type` to `autoinit` to exclude certain types to be in `field` defaults.
- add `doc` in `field` to add extra documentation for the descriptor `__doc__`
- add `doc` in `field` to add extra documentation for the descriptor `__doc__`
55 changes: 34 additions & 21 deletions sepes/_src/tree_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,12 @@ def __repr__(self) -> str:
BaseKey.def_alias = indexer_dispatcher.register


_NOT_IMPLEMENTED_INDEXING = """Indexing with {} is not implemented, supported indexing types are:
_INVALID_INDEXER = """\
Indexing with {indexer} is not implemented, supported indexing types are:
- `str` for mapping keys or class attributes.
- `int` for positional indexing for sequences.
- `...` to select all leaves.
- Boolean mask of the same structure as the tree
- Boolean mask of a compatible structure as the pytree.
- `re.Pattern` to index all keys matching a regex pattern.
- Instance of `BaseKey` with custom logic to index a pytree.
- `tuple` of the above types to match multiple leaves at the same level.
Expand Down Expand Up @@ -359,8 +360,8 @@ def _is_bool_leaf(leaf: Any) -> bool:


def _resolve_where(
tree: T,
where: tuple[Any, ...], # type: ignore
tree: T,
is_leaf: Callable[[Any], None] | None = None,
) -> T | None:
# given a pytree `tree` and a `where` path, that is composed of keys or
Expand All @@ -381,8 +382,20 @@ def verify_and_aggregate_is_leaf(x) -> bool:
# used with `is_leaf` argument of any `tree_*` function
leaves, treedef = treelib.tree_flatten(x)

if treedef == treedef0 and all(map(_is_bool_leaf, leaves)):
# boolean pytrees of same structure as `tree` is a valid indexing pytree
if all(map(_is_bool_leaf, leaves)):
# if all leaves are boolean then this is maybe a boolean mask.
# Maybe because the boolean mask can be a valid pytree of same structure
# as the pytree to be indexed or _compatible_ structure.
# that can be flattend up to inside tree_map.
# the following is an example showcase this:
# >>> tree = [1, 2, [3, 4]]
# >>> mask = [True, True, False]
# >>> AtIndexer(tree)[mask].get()
# in essence the user can mark full subtrees by `False` without
# needing to populate the subtree with `False` values. if treedef
# check is mandated then the user will need to populate the subtree
# with `False` values. i.e. mask = [True, True, [False, False]]
# Finally, invalid boolean mask will be caught by `jax.tree_util`
bool_masks += [x]
return True

Expand All @@ -398,7 +411,7 @@ def verify_and_aggregate_is_leaf(x) -> bool:
return False

# not a container of other keys or a pytree of same structure
raise NotImplementedError(_NOT_IMPLEMENTED_INDEXING.format(x))
raise NotImplementedError(_INVALID_INDEXER.format(indexer=x, treedef=treedef0))

for level_keys in where:
# each for loop iteration is a level in the where path
Expand Down Expand Up @@ -528,10 +541,10 @@ def get(
>>> tree.at['a'].get()
Tree(a=1, b=None)
"""
where = _resolve_where(self.tree, self.where, is_leaf)
where = _resolve_where(self.where, self.tree, is_leaf)
config = dict(is_leaf=is_leaf, is_parallel=is_parallel)

def leaf_get(leaf: Any, where: Any):
def leaf_get(where: Any, leaf: Any):
# support both array and non-array leaves
# for array boolean mask we select **parts** of the array that
# matches the mask, for example if the mask is Array([True, False, False])
Expand All @@ -542,7 +555,7 @@ def leaf_get(leaf: Any, where: Any):
# and `None` otherwise
return leaf if where else None

return treelib.tree_map(leaf_get, self.tree, where, **config)
return treelib.tree_map(leaf_get, where, self.tree, **config)

def set(
self,
Expand Down Expand Up @@ -586,10 +599,10 @@ def set(
>>> tree.at['a'].set(100)
Tree(a=100, b=2)
"""
where = _resolve_where(self.tree, self.where, is_leaf)
where = _resolve_where(self.where, self.tree, is_leaf)
config = dict(is_leaf=is_leaf, is_parallel=is_parallel)

def leaf_set(leaf: Any, where: Any, set_value: Any):
def leaf_set(where: Any, leaf: Any, set_value: Any):
# support both array and non-array leaves
# for array boolean mask we select **parts** of the array that
# matches the mask, for example if the mask is Array([True, False, False])
Expand All @@ -608,12 +621,12 @@ def leaf_set(leaf: Any, where: Any, set_value: Any):
# to tree2 leaves if tree2 is a pytree of same structure as tree
# instead of making each leaf of tree a copy of tree2
# is design is similar to ``numpy`` design `np.at[...].set(Array)`
return treelib.tree_map(leaf_set, self.tree, where, set_value, **config)
return treelib.tree_map(leaf_set, where, self.tree, set_value, **config)

# set_value is broadcasted to tree leaves
# for example tree.at[where].set(1) will set all tree leaves to 1
leaf_set_ = lambda leaf, where: leaf_set(leaf, where, set_value)
return treelib.tree_map(leaf_set_, self.tree, where, **config)
leaf_set_ = lambda where, leaf: leaf_set(where, leaf, set_value)
return treelib.tree_map(leaf_set_, where, self.tree, **config)

def apply(
self,
Expand Down Expand Up @@ -664,10 +677,10 @@ def apply(
>>> indexer = sp.AtIndexer({"lenna": "lenna.png", "baboon": "baboon.png"})
>>> images = indexer[...].apply(imread, parallel=dict(max_workers=2)) # doctest: +SKIP
"""
where = _resolve_where(self.tree, self.where, is_leaf)
where = _resolve_where(self.where, self.tree, is_leaf)
config = dict(is_leaf=is_leaf, is_parallel=is_parallel)

def leaf_apply(leaf: Any, where: bool):
def leaf_apply(where: Any, leaf: Any):
# same as `leaf_set` but with `func` applied to the leaf
# one thing to note is that, the where mask select an array
# then the function needs work properly when applied to the selected
Expand All @@ -676,7 +689,7 @@ def leaf_apply(leaf: Any, where: bool):
return arraylib.where(where, func(leaf), leaf)
return func(leaf) if where else leaf

return treelib.tree_map(leaf_apply, self.tree, where, **config)
return treelib.tree_map(leaf_apply, where, self.tree, **config)

def scan(
self,
Expand Down Expand Up @@ -737,7 +750,7 @@ def scan(
them with final state. While ``reduce`` applies a binary ``func`` to the
leaf values while carrying a state and returning a single value.
"""
where = _resolve_where(self.tree, self.where, is_leaf)
where = _resolve_where(self.where, self.tree, is_leaf)

running_state = state

Expand All @@ -746,12 +759,12 @@ def stateless_func(leaf):
leaf, running_state = func(leaf, running_state)
return leaf

def leaf_apply(leaf: Any, where: bool):
def leaf_apply(where: Any, leaf: Any):
if isinstance(where, arraylib.ndarray):
return arraylib.where(where, stateless_func(leaf), leaf)
return stateless_func(leaf) if where else leaf

out = treelib.tree_map(leaf_apply, self.tree, where, is_leaf=is_leaf)
out = treelib.tree_map(leaf_apply, where, self.tree, is_leaf=is_leaf)
return out, running_state

def reduce(
Expand Down Expand Up @@ -789,7 +802,7 @@ def reduce(
>>> tree.at[...].reduce(lambda a, b: a + b, initializer=0)
3
"""
where = _resolve_where(self.tree, self.where, is_leaf)
where = _resolve_where(self.where, self.tree, is_leaf)
tree = self[where].get(is_leaf=is_leaf) # type: ignore
leaves, _ = treelib.tree_flatten(tree, is_leaf=is_leaf)
if initializer is _no_initializer:
Expand Down
14 changes: 8 additions & 6 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@

from sepes._src.backend import arraylib, backend, treelib
from sepes._src.code_build import autoinit
from sepes._src.tree_base import (
TreeClass,
add_mutable_entry,
discard_mutable_entry,
)
from sepes._src.tree_base import TreeClass, add_mutable_entry, discard_mutable_entry
from sepes._src.tree_index import AtIndexer, BaseKey
from sepes._src.tree_util import is_tree_equal, leafwise

Expand Down Expand Up @@ -477,7 +473,7 @@ def delete(self, name):
t.delete("a")


@pytest.mark.parametrize("where", [(None,), ("a", [1]), (0, [1])])
@pytest.mark.parametrize("where", [("a", [1]), (0, [1])])
def test_unsupported_where(where):
t = namedtuple("a", ["x", "y"])(1, 2)
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -579,3 +575,9 @@ class Tree(TreeClass):
assert repr(t.at["a"]) == "TreeClassIndexer(tree=Tree(a=1, b=2), where=('a',))"
assert str(t.at["a"]) == "TreeClassIndexer(tree=Tree(a=1, b=2), where=('a',))"
assert repr(t.at[...]) == "TreeClassIndexer(tree=Tree(a=1, b=2), where=(Ellipsis,))"


def test_compat_mask():
tree = [1, 2, [3, 4]]
tree_ = AtIndexer(tree)[[False, False, True]].set(10)
assert tree_ == [1, 2, 10]

0 comments on commit b478aee

Please sign in to comment.