Skip to content

Commit

Permalink
skip root node for pluck
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Oct 8, 2023
1 parent 7a0f4a6 commit 9d8e49d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
22 changes: 17 additions & 5 deletions sepes/_src/tree_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,11 @@ def pluck(
) -> list[Any]:
"""Extract subtrees at the specified location.
Note:
``pluck`` first applies ``get`` to the specified location and then
extracts the immediate subtrees of the selected leaves. ``is_leaf``
and ``is_parallel`` are passed to ``get``.
Args:
count: number of subtrees to extract, Default to ``None`` to
extract all subtrees.
Expand Down Expand Up @@ -963,16 +968,23 @@ def pluck(
def aggregate_subtrees(node: Any) -> bool:
nonlocal subtrees, count
if count < 1:
# stop traversing the tree
# if total number of subtrees is reached
return True
if id(node) == id(tree):
# skip the root node
# for example if tree = dict(a=1) and mask is dict(a=True)
# then returns [1] and not [dict(a=1)]
return False
leaves, _ = treelib.tree_flatten(node, is_leaf=lambda x: x is None)
# in essence if the subtree does not contain any None leaves
# then it is a valid subtree to be plucked
# this because `get` sets the non-selected leaves to None
if all(leaf is not None for leaf in leaves):
subtrees += [node]
count -= 1
return True
return False
if any(leaf is None for leaf in leaves):
return False
subtrees += [node]
count -= 1
return True

treelib.tree_flatten(tree, is_leaf=aggregate_subtrees)
return subtrees
3 changes: 3 additions & 0 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,3 +589,6 @@ def test_pluck():
assert subtrees[0] == [3, 4]
assert AtIndexer(tree)[0, 1].pluck(1) == [1]
assert AtIndexer(tree)[0, 1].pluck(2) == [1, 2]

tree = dict(a=1, b=2)
assert AtIndexer(tree)[...].pluck() == [1, 2]

0 comments on commit 9d8e49d

Please sign in to comment.