Skip to content

Commit

Permalink
nits
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Mar 29, 2024
1 parent a98a98b commit 323058b
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions sepes/_src/tree_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,6 @@ def _tree_mask_map(
*,
is_leaf: Callable[[Any], None] | None = None,
):
treelib = sepes._src.backend.treelib
# apply func to leaves satisfying mask pytree/condtion
_, lhsdef = treelib.flatten(tree, is_leaf=is_leaf)

if not isinstance(cond, Callable):
# a callable that accepts a leaf and returns a boolean
Expand All @@ -214,6 +211,8 @@ def _tree_mask_map(
f" Got {cond=} and {tree=}."
)

treelib = sepes._src.backend.treelib

def map_func(x):
return func(x) if cond(x) else x

Expand Down

0 comments on commit 323058b

Please sign in to comment.