diff --git a/sepes/_src/backend/treelib/__init__.py b/sepes/_src/backend/treelib/__init__.py index 90655b7..a6cb835 100644 --- a/sepes/_src/backend/treelib/__init__.py +++ b/sepes/_src/backend/treelib/__init__.py @@ -61,7 +61,7 @@ class AbstractTreeLib(abc.ABC): @staticmethod @abc.abstractmethod - def tree_map( + def map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -72,7 +72,7 @@ def tree_map( @staticmethod @abc.abstractmethod - def tree_path_map( + def path_map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -83,7 +83,7 @@ def tree_path_map( @staticmethod @abc.abstractmethod - def tree_flatten( + def flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -92,7 +92,7 @@ def tree_flatten( @staticmethod @abc.abstractmethod - def tree_path_flatten( + def path_flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -101,7 +101,7 @@ def tree_path_flatten( @staticmethod @abc.abstractmethod - def tree_unflatten(treedef: Any, leaves: Iterable[Any]) -> Any: + def unflatten(treedef: Any, leaves: Iterable[Any]) -> Any: ... @staticmethod diff --git a/sepes/_src/backend/treelib/jax.py b/sepes/_src/backend/treelib/jax.py index cf49f8a..43bc9d7 100644 --- a/sepes/_src/backend/treelib/jax.py +++ b/sepes/_src/backend/treelib/jax.py @@ -36,7 +36,7 @@ def __str__(self): class JaxTreeLib(AbstractTreeLib): @staticmethod - def tree_map( + def map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -51,7 +51,7 @@ def tree_map( return jtu.tree_unflatten(treedef, concurrent_map(func, flat, **config)) @staticmethod - def tree_path_map( + def path_map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -66,7 +66,7 @@ def tree_path_map( return jtu.tree_unflatten(treedef, concurrent_map(func, flat, **config)) @staticmethod - def tree_flatten( + def flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -74,7 +74,7 @@ def tree_flatten( return jtu.tree_flatten(tree, is_leaf=is_leaf) @staticmethod - def tree_path_flatten( + def path_flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -82,7 +82,7 @@ def tree_path_flatten( return jtu.tree_flatten_with_path(tree, is_leaf=is_leaf) @staticmethod - def tree_unflatten(treedef: jtu.PyTreeDef, leaves: Iterable[Any]) -> Any: + def unflatten(treedef: jtu.PyTreeDef, leaves: Iterable[Any]) -> Any: return jtu.tree_unflatten(treedef, leaves) @staticmethod diff --git a/sepes/_src/backend/treelib/optree.py b/sepes/_src/backend/treelib/optree.py index 78015ad..4747494 100644 --- a/sepes/_src/backend/treelib/optree.py +++ b/sepes/_src/backend/treelib/optree.py @@ -61,7 +61,7 @@ def __str__(self) -> str: class OpTreeTreeLib(AbstractTreeLib): @staticmethod - def tree_map( + def map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -76,7 +76,7 @@ def tree_map( return ot.tree_unflatten(treedef, concurrent_map(func, flat, **config)) @staticmethod - def tree_path_map( + def path_map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -92,7 +92,7 @@ def tree_path_map( return ot.tree_unflatten(treedef, concurrent_map(func, flat, **config)) @staticmethod - def tree_flatten( + def flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -101,7 +101,7 @@ def tree_flatten( return (leaves, treedef) @staticmethod - def tree_path_flatten( + def path_flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -110,7 +110,7 @@ def tree_path_flatten( return (list(zip(ot.treespec_paths(treedef), leaves)), treedef) @staticmethod - def tree_unflatten(treedef: ot.PyTreeDef, leaves: Iterable[Any]) -> Any: + def unflatten(treedef: ot.PyTreeDef, leaves: Iterable[Any]) -> Any: return ot.tree_unflatten(treedef, leaves) @staticmethod diff --git a/sepes/_src/tree_index.py b/sepes/_src/tree_index.py index f49b4d2..5d08b2e 100644 --- a/sepes/_src/tree_index.py +++ b/sepes/_src/tree_index.py @@ -147,7 +147,7 @@ def is_leaf_func(node) -> bool: return False return True - return treelib.tree_path_map(func, tree, is_leaf=is_leaf_func) + return treelib.path_map(func, tree, is_leaf=is_leaf_func) if any(isinstance(mask, EllipsisMatchKey) for mask in where): # should the selected subtree be broadcasted to the full tree @@ -160,8 +160,8 @@ def is_leaf_func(node) -> bool: # and without broadcast the result will be [100, 3, 4] def bool_tree(value: bool, tree: Any): - leaves, treedef = treelib.tree_flatten(tree, is_leaf=is_leaf) - return treelib.tree_unflatten(treedef, [value] * len(leaves)) + leaves, treedef = treelib.flatten(tree, is_leaf=is_leaf) + return treelib.unflatten(treedef, [value] * len(leaves)) true_tree = ft.partial(bool_tree, True) false_tree = ft.partial(bool_tree, False) @@ -249,7 +249,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool: nonlocal seen_tuple, level_paths, bool_masks # used to check if a pytree is a valid indexing pytree # used with `is_leaf` argument of any `tree_*` function - leaves, _ = treelib.tree_flatten(node) + leaves, _ = treelib.flatten(node) if all(map(is_bool_leaf, leaves)): # if all leaves are boolean then this is maybe a boolean mask. @@ -289,7 +289,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool: # each for loop iteration is a level in the where path # this means that if where = ("a", "b", "c") then this means # we are travering the tree at level "a" then level "b" then level "c" - treelib.tree_flatten(level_keys, is_leaf=verify_and_aggregate_is_leaf) + treelib.flatten(level_keys, is_leaf=verify_and_aggregate_is_leaf) # if len(level_paths) > 1 then this means that we have multiple keys # at the same level, for example where = ("a", ("b", "c")) then this # means that for a parent "a", select "b" and "c". @@ -304,7 +304,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool: if bool_masks: all_masks = [mask, *bool_masks] if mask else bool_masks - mask = treelib.tree_map(combine_bool_leaves, *all_masks) + mask = treelib.map(combine_bool_leaves, *all_masks) return mask @@ -390,7 +390,7 @@ def leaf_get(where: Any, leaf: Any): # and `None` otherwise return leaf if where else None - return treelib.tree_map( + return treelib.map( leaf_get, resolve_where(self.where, self.tree, is_leaf), self.tree, @@ -440,8 +440,8 @@ def leaf_set(where: Any, leaf: Any, set_value: Any): return arraylib.where(where, set_value, leaf) return set_value if where else leaf - _, lhsdef = treelib.tree_flatten(self.tree, is_leaf=is_leaf) - _, rhsdef = treelib.tree_flatten(set_value, is_leaf=is_leaf) + _, lhsdef = treelib.flatten(self.tree, is_leaf=is_leaf) + _, rhsdef = treelib.flatten(set_value, is_leaf=is_leaf) if lhsdef == rhsdef: # do not broadcast set_value if it is a pytree of same structure @@ -449,7 +449,7 @@ def leaf_set(where: Any, leaf: Any, set_value: Any): # to tree2 leaves if tree2 is a pytree of same structure as tree # instead of making each leaf of tree a copy of tree2 # is design is similar to ``numpy`` design `np.at[...].set(Array)` - return treelib.tree_map( + return treelib.map( leaf_set, resolve_where(self.where, self.tree, is_leaf), self.tree, @@ -458,7 +458,7 @@ def leaf_set(where: Any, leaf: Any, set_value: Any): is_parallel=is_parallel, ) - return treelib.tree_map( + return treelib.map( ft.partial(leaf_set, set_value=set_value), resolve_where(self.where, self.tree, is_leaf), self.tree, @@ -517,7 +517,7 @@ def leaf_apply(where: Any, leaf: Any): return arraylib.where(where, func(leaf), leaf) return func(leaf) if where else leaf - return treelib.tree_map( + return treelib.map( leaf_apply, resolve_where(self.where, self.tree, is_leaf), self.tree, @@ -578,7 +578,7 @@ def leaf_apply(where: Any, leaf: Any): return arraylib.where(where, stateless_func(leaf), leaf) return stateless_func(leaf) if where else leaf - out_tree = treelib.tree_map( + out_tree = treelib.map( leaf_apply, resolve_where(self.where, self.tree, is_leaf), self.tree, @@ -619,7 +619,7 @@ def reduce( """ treelib = sepes._src.backend.treelib tree = self.get(is_leaf=is_leaf) # type: ignore - leaves, _ = treelib.tree_flatten(tree, is_leaf=is_leaf) + leaves, _ = treelib.flatten(tree, is_leaf=is_leaf) if initializer is _no_initializer: return ft.reduce(func, leaves) return ft.reduce(func, leaves, initializer) @@ -700,7 +700,7 @@ def aggregate_subtrees(node: Any) -> bool: # for example if tree = dict(a=1) and mask is dict(a=True) # then returns [1] and not [dict(a=1)] return False - leaves, _ = treelib.tree_flatten(node, is_leaf=lambda x: x is None) + leaves, _ = treelib.flatten(node, is_leaf=lambda x: x is None) # in essence if the subtree does not contain any None leaves # then it is a valid subtree to be plucked # this because `get` sets the non-selected leaves to None @@ -710,7 +710,7 @@ def aggregate_subtrees(node: Any) -> bool: count -= 1 return True - treelib.tree_flatten(tree, is_leaf=aggregate_subtrees) + treelib.flatten(tree, is_leaf=aggregate_subtrees) return subtrees diff --git a/sepes/_src/tree_mask.py b/sepes/_src/tree_mask.py index b485eb9..37addb2 100644 --- a/sepes/_src/tree_mask.py +++ b/sepes/_src/tree_mask.py @@ -231,7 +231,7 @@ def _tree_mask_map( ): treelib = sepes._src.backend.treelib # apply func to leaves satisfying mask pytree/condtion - _, lhsdef = treelib.tree_flatten(tree, is_leaf=is_leaf) + _, lhsdef = treelib.flatten(tree, is_leaf=is_leaf) if not isinstance(cond, Callable): # a callable that accepts a leaf and returns a boolean @@ -244,7 +244,7 @@ def _tree_mask_map( def map_func(x): return func(x) if cond(x) else x - return treelib.tree_map(map_func, tree, is_leaf=is_leaf) + return treelib.map(map_func, tree, is_leaf=is_leaf) def tree_mask( diff --git a/sepes/_src/tree_pprint.py b/sepes/_src/tree_pprint.py index e0f5813..198a266 100644 --- a/sepes/_src/tree_pprint.py +++ b/sepes/_src/tree_pprint.py @@ -526,14 +526,14 @@ def tree_size(tree: PyTree) -> int: def reduce_func(acc, node): return acc + tree_summary.size_dispatcher(node) - leaves, _ = treelib.tree_flatten(tree) + leaves, _ = treelib.flatten(tree) return ft.reduce(reduce_func, leaves, 0) def tree_count(tree: PyTree) -> int: def reduce_func(acc, node): return acc + tree_summary.count_dispatcher(node) - leaves, _ = treelib.tree_flatten(tree) + leaves, _ = treelib.flatten(tree) return ft.reduce(reduce_func, leaves, 0) traces_leaves = tree_type_path_leaves( diff --git a/sepes/_src/tree_util.py b/sepes/_src/tree_util.py index b6d7094..bbf705b 100644 --- a/sepes/_src/tree_util.py +++ b/sepes/_src/tree_util.py @@ -42,7 +42,7 @@ def tree_hash(*trees: PyTree) -> int: treelib = sepes._src.backend.treelib - leaves, treedef = treelib.tree_flatten(trees) + leaves, treedef = treelib.flatten(trees) return hash((*leaves, treedef)) @@ -57,7 +57,7 @@ def tree_copy(tree: T) -> T: def is_leaf(node) -> bool: return isinstance(node, types) - return treelib.tree_map(tree_copy.copy_dispatcher, tree, is_leaf=is_leaf) + return treelib.map(tree_copy.copy_dispatcher, tree, is_leaf=is_leaf) # default behavior is to copy the tree elements except for registered types @@ -94,11 +94,11 @@ def is_tree_equal(*trees: Any) -> bool: """ treelib = sepes._src.backend.treelib tree0, *rest = trees - leaves0, treedef0 = treelib.tree_flatten(tree0) + leaves0, treedef0 = treelib.flatten(tree0) verdict = True for tree in rest: - leaves, treedef = treelib.tree_flatten(tree) + leaves, treedef = treelib.flatten(tree) if (treedef != treedef0) or verdict is False: return False verdict = ft.reduce(op.and_, map(_is_leaf_rhs_equal, leaves0, leaves), verdict) @@ -171,21 +171,21 @@ def wrapper(*args, **kwargs): treedef0 = ( # reference treedef is the first positional argument - treelib.tree_flatten(args[bdcst_to], is_leaf=is_leaf)[1] + treelib.flatten(args[bdcst_to], is_leaf=is_leaf)[1] if len(args) # reference treedef is the first keyword argument - else treelib.tree_flatten(kwargs[bdcst_to], is_leaf=is_leaf)[1] + else treelib.flatten(kwargs[bdcst_to], is_leaf=is_leaf)[1] ) for arg in args: - if treedef0 == treelib.tree_flatten(arg, is_leaf=is_leaf)[1]: + if treedef0 == treelib.flatten(arg, is_leaf=is_leaf)[1]: cargs += [...] leaves += [treedef0.flatten_up_to(arg)] else: cargs += [arg] for key in kwargs: - if treedef0 == treelib.tree_flatten(kwargs[key], is_leaf=is_leaf)[1]: + if treedef0 == treelib.flatten(kwargs[key], is_leaf=is_leaf)[1]: ckwargs[key] = ... leaves += [treedef0.flatten_up_to(kwargs[key])] kwargs_keys += [key] @@ -199,7 +199,7 @@ def wrapper(*args, **kwargs): args = args_kwargs_values[:split_index] kwargs = dict(zip(kwargs_keys, args_kwargs_values[split_index:])) all_leaves += [bfunc(*args, **kwargs)] - return treelib.tree_unflatten(treedef0, all_leaves) + return treelib.unflatten(treedef0, all_leaves) return wrapper @@ -281,15 +281,15 @@ def leafwise(klass: type[T]) -> type[T]: def uop(func): def wrapper(self): - return treelib.tree_map(func, self) + return treelib.map(func, self) return ft.wraps(func)(wrapper) def bop(func): def wrapper(leaf, rhs=None): if isinstance(rhs, type(leaf)): - return treelib.tree_map(func, leaf, rhs) - return treelib.tree_map(lambda x: func(x, rhs), leaf) + return treelib.map(func, leaf, rhs) + return treelib.map(lambda x: func(x, rhs), leaf) return ft.wraps(func)(wrapper) @@ -351,7 +351,7 @@ def tree_type_path_leaves( is_path_leaf: Callable[[KeyTypePath], bool] | None = None, ) -> Sequence[tuple[KeyTypePath, Any]]: treelib = sepes._src.backend.treelib - _, atomicdef = treelib.tree_flatten(1) + _, atomicdef = treelib.flatten(1) # mainly used for visualization def flatten_one_level(type_path: KeyTypePath, tree: PyTree): @@ -367,7 +367,7 @@ def one_level_is_leaf(node) -> bool: return False return True - path_leaf, treedef = treelib.tree_path_flatten(tree, is_leaf=one_level_is_leaf) + path_leaf, treedef = treelib.path_flatten(tree, is_leaf=one_level_is_leaf) if treedef == atomicdef: yield type_path, tree @@ -578,11 +578,11 @@ def stateless_func(*args, **kwargs) -> tuple[Any, PyTree | tuple[PyTree, ...]]: # copy the incoming inputs (args, kwargs) = tree_copy((args, kwargs)) # and edit the node/record to make it mutable (if there is a rule for it) - treelib.tree_map(lambda _: _, (args, kwargs), is_leaf=mutate_is_leaf) + treelib.map(lambda _: _, (args, kwargs), is_leaf=mutate_is_leaf) output = func(*args, **kwargs) # traverse each node in the tree depth-first manner # to undo the mutation (if there is a rule for it) - treelib.tree_map(lambda _: _, (args, kwargs), is_leaf=immutate_is_leaf) + treelib.map(lambda _: _, (args, kwargs), is_leaf=immutate_is_leaf) out_args = tuple(a for i, a in enumerate(args) if i in argnums) out_args = out_args[0] if is_int_argnum else out_args return output, out_args diff --git a/tests/test_index.py b/tests/test_index.py index 580533d..3b5eeae 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -450,8 +450,8 @@ def __call__(self, x): a = A(1) _, b = value_and_tree(lambda A: A(2))(a) - assert treelib.tree_flatten(a)[0] == [1] - assert treelib.tree_flatten(b)[0] == [3] + assert treelib.flatten(a)[0] == [1] + assert treelib.flatten(b)[0] == [3] with pytest.raises(TypeError): a.at[0](1) @@ -563,7 +563,7 @@ def __eq__(self, other): return other == (self.name, self.type) return False - assert treelib.tree_flatten(tree.at[MatchNameType("a", int)].get())[0] == [1] + assert treelib.flatten(tree.at[MatchNameType("a", int)].get())[0] == [1] def test_repr_str(): diff --git a/tests/test_mask.py b/tests/test_mask.py index 789460d..44d5ab5 100644 --- a/tests/test_mask.py +++ b/tests/test_mask.py @@ -60,9 +60,9 @@ class A(TreeClass): .apply(unfreeze, is_leaf=is_masked) ) - assert treelib.tree_flatten(a)[0] == [1, 2] - assert treelib.tree_flatten(b)[0] == [] - assert treelib.tree_flatten(c)[0] == [1, 2] + assert treelib.flatten(a)[0] == [1, 2] + assert treelib.flatten(b)[0] == [] + assert treelib.flatten(c)[0] == [1, 2] assert unfreeze(freeze(1.0)) == 1.0 @autoinit @@ -81,7 +81,7 @@ class A(TreeClass): b: int a = A(1, 2) - b = treelib.tree_map(freeze, a) + b = treelib.map(freeze, a) c = ( a.at["a"] .apply(unfreeze, is_leaf=is_masked) @@ -89,9 +89,9 @@ class A(TreeClass): .apply(unfreeze, is_leaf=is_masked) ) - assert treelib.tree_flatten(a)[0] == [1, 2] - assert treelib.tree_flatten(b)[0] == [] - assert treelib.tree_flatten(c)[0] == [1, 2] + assert treelib.flatten(a)[0] == [1, 2] + assert treelib.flatten(b)[0] == [] + assert treelib.flatten(c)[0] == [1, 2] @autoinit class L0(TreeClass): @@ -105,11 +105,11 @@ class L1(TreeClass): class L2(TreeClass): c: L1 = L1() - t = treelib.tree_map(freeze, L2()) + t = treelib.map(freeze, L2()) - assert treelib.tree_flatten(t)[0] == [] - assert treelib.tree_flatten(t.c)[0] == [] - assert treelib.tree_flatten(t.c.b)[0] == [] + assert treelib.flatten(t)[0] == [] + assert treelib.flatten(t.c)[0] == [] + assert treelib.flatten(t.c.b)[0] == [] class L1(TreeClass): def __init__(self): @@ -119,9 +119,9 @@ class L2(TreeClass): def __init__(self): self.c = L1() - t = treelib.tree_map(freeze, L2()) - assert treelib.tree_flatten(t.c)[0] == [] - assert treelib.tree_flatten(t.c.b)[0] == [] + t = treelib.map(freeze, L2()) + assert treelib.flatten(t.c)[0] == [] + assert treelib.flatten(t.c.b)[0] == [] def test_freeze_errors(): @@ -161,25 +161,25 @@ class Test(TreeClass): c: str = freeze("test") t = Test() - assert treelib.tree_flatten(t)[0] == [1] + assert treelib.flatten(t)[0] == [1] with pytest.raises(AttributeError): - treelib.tree_map(freeze, t).a = 1 + treelib.map(freeze, t).a = 1 with pytest.raises(AttributeError): - treelib.tree_map(unfreeze, t).a = 1 + treelib.map(unfreeze, t).a = 1 hash(t) t = Test() - treelib.tree_map(unfreeze, t, is_leaf=is_masked) - treelib.tree_map(freeze, t) + treelib.map(unfreeze, t, is_leaf=is_masked) + treelib.map(freeze, t) @autoinit class Test(TreeClass): a: int - t = treelib.tree_map(freeze, (Test(100))) + t = treelib.map(freeze, (Test(100))) class Test(TreeClass): def __init__(self, x): @@ -224,7 +224,7 @@ class Test(TreeClass): t = Test() - assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == [] + assert treelib.flatten(treelib.map(freeze, t))[0] == [] def test_freeze_nondiff(): @@ -235,10 +235,10 @@ class Test(TreeClass): t = Test() - assert treelib.tree_flatten(t)[0] == ["a"] - assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == [] - assert treelib.tree_flatten( - (treelib.tree_map(freeze, t)).at["b"].apply(unfreeze, is_leaf=is_masked) + assert treelib.flatten(t)[0] == ["a"] + assert treelib.flatten(treelib.map(freeze, t))[0] == [] + assert treelib.flatten( + (treelib.map(freeze, t)).at["b"].apply(unfreeze, is_leaf=is_masked) )[0] == ["a"] @autoinit @@ -247,11 +247,11 @@ class T0(TreeClass): t = T0() - assert treelib.tree_flatten(t)[0] == ["a"] - assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == [] + assert treelib.flatten(t)[0] == ["a"] + assert treelib.flatten(treelib.map(freeze, t))[0] == [] - assert treelib.tree_flatten(t)[0] == ["a"] - assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == [] + assert treelib.flatten(t)[0] == ["a"] + assert treelib.flatten(treelib.map(freeze, t))[0] == [] def test_freeze_nondiff_with_mask(): @@ -279,11 +279,11 @@ class L2(TreeClass): t = t.at["d"]["d"]["a"].apply(freeze) t = t.at["d"]["d"]["b"].apply(freeze) - assert treelib.tree_flatten(t)[0] == [10, 20, 30, 1, 2, 3, 3] + assert treelib.flatten(t)[0] == [10, 20, 30, 1, 2, 3, 3] def test_non_dataclass_input_to_freeze(): - assert treelib.tree_flatten(freeze(1))[0] == [] + assert treelib.flatten(freeze(1))[0] == [] def test_tree_mask(): @@ -300,18 +300,18 @@ class L1(TreeClass): tree = L1() - assert treelib.tree_flatten(tree)[0] == [1, 2, 3] - assert treelib.tree_flatten(treelib.tree_map(freeze, tree))[0] == [] - assert treelib.tree_flatten(treelib.tree_map(freeze, tree))[0] == [] - assert treelib.tree_flatten(tree.at[...].apply(freeze))[0] == [] - assert treelib.tree_flatten(tree.at[tree > 1].apply(freeze))[0] == [1] - assert treelib.tree_flatten(tree.at[tree == 1].apply(freeze))[0] == [2, 3] - assert treelib.tree_flatten(tree.at[tree < 1].apply(freeze))[0] == [1, 2, 3] + assert treelib.flatten(tree)[0] == [1, 2, 3] + assert treelib.flatten(treelib.map(freeze, tree))[0] == [] + assert treelib.flatten(treelib.map(freeze, tree))[0] == [] + assert treelib.flatten(tree.at[...].apply(freeze))[0] == [] + assert treelib.flatten(tree.at[tree > 1].apply(freeze))[0] == [1] + assert treelib.flatten(tree.at[tree == 1].apply(freeze))[0] == [2, 3] + assert treelib.flatten(tree.at[tree < 1].apply(freeze))[0] == [1, 2, 3] - assert treelib.tree_flatten(tree.at["a"].apply(freeze))[0] == [2, 3] - assert treelib.tree_flatten(tree.at["b"].apply(freeze))[0] == [1] - assert treelib.tree_flatten(tree.at["b"]["x"].apply(freeze))[0] == [1, 3] - assert treelib.tree_flatten(tree.at["b"]["y"].apply(freeze))[0] == [1, 2] + assert treelib.flatten(tree.at["a"].apply(freeze))[0] == [2, 3] + assert treelib.flatten(tree.at["b"].apply(freeze))[0] == [1] + assert treelib.flatten(tree.at["b"]["x"].apply(freeze))[0] == [1, 3] + assert treelib.flatten(tree.at["b"]["y"].apply(freeze))[0] == [1, 2] def test_tree_unmask(): @@ -329,21 +329,21 @@ class L1(TreeClass): tree = L1() frozen_tree = tree.at[...].apply(freeze) - assert treelib.tree_flatten(frozen_tree)[0] == [] + assert treelib.flatten(frozen_tree)[0] == [] mask = tree == tree unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_masked) - assert treelib.tree_flatten(unfrozen_tree)[0] == [1, 2, 3] + assert treelib.flatten(unfrozen_tree)[0] == [1, 2, 3] mask = tree > 1 unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_masked) - assert treelib.tree_flatten(unfrozen_tree)[0] == [2, 3] + assert treelib.flatten(unfrozen_tree)[0] == [2, 3] unfrozen_tree = frozen_tree.at["a"].apply(unfreeze, is_leaf=is_masked) - # assert treelib.tree_flatten(unfrozen_tree)[0] == [1] + # assert treelib.flatten(unfrozen_tree)[0] == [1] # unfrozen_tree = frozen_tree.at["b"].apply(unfreeze, is_leaf=is_masked) - # assert treelib.tree_flatten(unfrozen_tree)[0] == [2, 3] + # assert treelib.flatten(unfrozen_tree)[0] == [2, 3] def test_tree_mask_unfreeze(): @@ -363,11 +363,11 @@ class L1(TreeClass): mask = tree == tree frozen_tree = tree.at[...].apply(freeze) unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_masked) - assert treelib.tree_flatten(unfrozen_tree)[0] == [1, 2, 3] + assert treelib.flatten(unfrozen_tree)[0] == [1, 2, 3] # frozen_tree = tree.at["a"].apply(freeze) # unfrozen_tree = frozen_tree.at["a"].apply(unfreeze, is_leaf=is_masked) - # assert treelib.tree_flatten(unfrozen_tree)[0] == [1, 2, 3] + # assert treelib.flatten(unfrozen_tree)[0] == [1, 2, 3] def test_wrapper(): @@ -404,11 +404,11 @@ def test_wrapper(): @pytest.mark.skipif(backend == "default", reason="no array backend installed") def test_tree_mask_tree_unmask(): tree = [1, 2, 3.0] - assert treelib.tree_flatten(tree_mask(tree))[0] == [3.0] - assert treelib.tree_flatten(tree_unmask(tree_mask(tree)))[0] == [1, 2, 3.0] + assert treelib.flatten(tree_mask(tree))[0] == [3.0] + assert treelib.flatten(tree_unmask(tree_mask(tree)))[0] == [1, 2, 3.0] mask_func = lambda x: x < 2 - assert treelib.tree_flatten(tree_mask(tree, mask_func))[0] == [2, 3.0] + assert treelib.flatten(tree_mask(tree, mask_func))[0] == [2, 3.0] assert freeze(freeze(1)) == freeze(1) diff --git a/tests/test_treeclass.py b/tests/test_treeclass.py index 69130ed..154e2d2 100644 --- a/tests/test_treeclass.py +++ b/tests/test_treeclass.py @@ -149,7 +149,7 @@ def __init__( test = Test() - assert treelib.tree_flatten(test)[0] == [] + assert treelib.flatten(test)[0] == [] class Test(TreeClass): def __init__(self, a=arraylib.array([1, 2, 3]), b=arraylib.array([4, 5, 6])): @@ -157,7 +157,7 @@ def __init__(self, a=arraylib.array([1, 2, 3]), b=arraylib.array([4, 5, 6])): self.b = b test = Test() - npt.assert_allclose(treelib.tree_flatten(test)[0][0], arraylib.array([4, 5, 6])) + npt.assert_allclose(treelib.flatten(test)[0][0], arraylib.array([4, 5, 6])) def test_post_init(): @@ -202,7 +202,7 @@ def inc(self, x): l1 = L1() - assert treelib.tree_flatten(l1)[0] == [2, 4, 5, 5] + assert treelib.flatten(l1)[0] == [2, 4, 5, 5] assert l1.inc(10) == 20 assert l1.sub(10) == 0 assert l1.d == 5 @@ -214,7 +214,7 @@ class L1(L0): l1 = L1() - assert treelib.tree_flatten(l1)[0] == [2, 4, 5] + assert treelib.flatten(l1)[0] == [2, 4, 5] def test_registering_state(): @@ -416,7 +416,7 @@ class Test(TreeClass): t = Test(1) assert t.a == freeze(1) - assert treelib.tree_flatten(t)[0] == [] + assert treelib.flatten(t)[0] == [] def test_super():