Skip to content

Commit

Permalink
make is_tree_equal work with array_like
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 29, 2023
1 parent f17a010 commit c7f0f59
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions sepes/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,13 @@ def tree_copy(tree: T) -> T:
return treelib.tree_map(lambda x: copy(x), tree)


def is_array_like(node) -> bool:
return hasattr(node, "shape") and hasattr(node, "dtype")


def _is_leaf_rhs_equal(leaf, rhs) -> bool | arraylib.ndarray:
if isinstance(leaf, arraylib.ndarray):
if isinstance(rhs, arraylib.ndarray):
if is_array_like(leaf):
if is_array_like(rhs):
if leaf.shape != rhs.shape:
return False
if leaf.dtype != rhs.dtype:
Expand All @@ -70,9 +74,6 @@ def is_tree_equal(*trees: Any) -> bool | arraylib.ndarray:
Note:
trees are compared using their leaves and treedefs.
Note:
Under boolean ``Array`` if compiled otherwise ``bool``.
"""
tree0, *rest = trees
leaves0, treedef0 = treelib.tree_flatten(tree0)
Expand Down

0 comments on commit c7f0f59

Please sign in to comment.