Skip to content

Commit

Permalink
fix test_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 2, 2024
1 parent 13d26b8 commit 249ae02
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
12 changes: 7 additions & 5 deletions sepes/_src/tree_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,12 @@ def tree_mask(
>>> sp.tree_unmask(masked_tree)
[MyInt(value=1), MyInt(value=2), {'a': MyInt(value=3)}]
"""
return _tree_mask_map(tree, cond=cond, func=tree_mask.dispatcher, is_leaf=is_leaf)
return _tree_mask_map(
tree,
cond=cond,
func=tree_mask.dispatcher,
is_leaf=is_leaf,
)


tree_mask.dispatcher = ft.singledispatch(_MaskedHashable)
Expand Down Expand Up @@ -314,10 +319,7 @@ def tree_unmask(tree: T, cond: Callable[[Any], bool] = lambda _: True):
[MyInt(value=1), MyInt(value=2), {'a': MyInt(value=3)}]
"""
return _tree_mask_map(
tree,
cond=cond,
func=tree_unmask.dispatcher,
is_leaf=is_masked,
tree, cond=cond, func=tree_unmask.dispatcher, is_leaf=is_masked
)


Expand Down
1 change: 1 addition & 0 deletions tests/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sepes._src.tree_base import TreeClass
from sepes._src.tree_mask import is_masked, tree_mask, tree_unmask
from sepes._src.tree_util import is_tree_equal, leafwise, tree_hash
import dataclasses as dc

test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy")
freeze = ft.partial(tree_mask, cond=lambda _: True)
Expand Down

0 comments on commit 249ae02

Please sign in to comment.