Skip to content

Commit

Permalink
tree_*** -> ***
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Mar 20, 2024
1 parent 9ac7f27 commit 6e09014
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 111 deletions.
10 changes: 5 additions & 5 deletions sepes/_src/backend/treelib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class AbstractTreeLib(abc.ABC):

@staticmethod
@abc.abstractmethod
def tree_map(
def map(
func: Callable[..., Any],
tree: Any,
*rest: Any,
Expand All @@ -72,7 +72,7 @@ def tree_map(

@staticmethod
@abc.abstractmethod
def tree_path_map(
def path_map(
func: Callable[..., Any],
tree: Any,
*rest: Any,
Expand All @@ -83,7 +83,7 @@ def tree_path_map(

@staticmethod
@abc.abstractmethod
def tree_flatten(
def flatten(
tree: Any,
*,
is_leaf: Callable[[Any], bool] | None = None,
Expand All @@ -92,7 +92,7 @@ def tree_flatten(

@staticmethod
@abc.abstractmethod
def tree_path_flatten(
def path_flatten(
tree: Any,
*,
is_leaf: Callable[[Any], bool] | None = None,
Expand All @@ -101,7 +101,7 @@ def tree_path_flatten(

@staticmethod
@abc.abstractmethod
def tree_unflatten(treedef: Any, leaves: Iterable[Any]) -> Any:
def unflatten(treedef: Any, leaves: Iterable[Any]) -> Any:
...

@staticmethod
Expand Down
10 changes: 5 additions & 5 deletions sepes/_src/backend/treelib/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __str__(self):

class JaxTreeLib(AbstractTreeLib):
@staticmethod
def tree_map(
def map(
func: Callable[..., Any],
tree: Any,
*rest: Any,
Expand All @@ -51,7 +51,7 @@ def tree_map(
return jtu.tree_unflatten(treedef, concurrent_map(func, flat, **config))

@staticmethod
def tree_path_map(
def path_map(
func: Callable[..., Any],
tree: Any,
*rest: Any,
Expand All @@ -66,23 +66,23 @@ def tree_path_map(
return jtu.tree_unflatten(treedef, concurrent_map(func, flat, **config))

@staticmethod
def tree_flatten(
def flatten(
tree: Any,
*,
is_leaf: Callable[[Any], bool] | None = None,
) -> tuple[Iterable[Any], jtu.PyTreeDef]:
return jtu.tree_flatten(tree, is_leaf=is_leaf)

@staticmethod
def tree_path_flatten(
def path_flatten(
tree: Any,
*,
is_leaf: Callable[[Any], bool] | None = None,
) -> tuple[Iterable[KeyPathLeaf], jtu.PyTreeDef]:
return jtu.tree_flatten_with_path(tree, is_leaf=is_leaf)

@staticmethod
def tree_unflatten(treedef: jtu.PyTreeDef, leaves: Iterable[Any]) -> Any:
def unflatten(treedef: jtu.PyTreeDef, leaves: Iterable[Any]) -> Any:
return jtu.tree_unflatten(treedef, leaves)

@staticmethod
Expand Down
10 changes: 5 additions & 5 deletions sepes/_src/backend/treelib/optree.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __str__(self) -> str:

class OpTreeTreeLib(AbstractTreeLib):
@staticmethod
def tree_map(
def map(
func: Callable[..., Any],
tree: Any,
*rest: Any,
Expand All @@ -76,7 +76,7 @@ def tree_map(
return ot.tree_unflatten(treedef, concurrent_map(func, flat, **config))

@staticmethod
def tree_path_map(
def path_map(
func: Callable[..., Any],
tree: Any,
*rest: Any,
Expand All @@ -92,7 +92,7 @@ def tree_path_map(
return ot.tree_unflatten(treedef, concurrent_map(func, flat, **config))

@staticmethod
def tree_flatten(
def flatten(
tree: Any,
*,
is_leaf: Callable[[Any], bool] | None = None,
Expand All @@ -101,7 +101,7 @@ def tree_flatten(
return (leaves, treedef)

@staticmethod
def tree_path_flatten(
def path_flatten(
tree: Any,
*,
is_leaf: Callable[[Any], bool] | None = None,
Expand All @@ -110,7 +110,7 @@ def tree_path_flatten(
return (list(zip(ot.treespec_paths(treedef), leaves)), treedef)

@staticmethod
def tree_unflatten(treedef: ot.PyTreeDef, leaves: Iterable[Any]) -> Any:
def unflatten(treedef: ot.PyTreeDef, leaves: Iterable[Any]) -> Any:
return ot.tree_unflatten(treedef, leaves)

@staticmethod
Expand Down
32 changes: 16 additions & 16 deletions sepes/_src/tree_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def is_leaf_func(node) -> bool:
return False
return True

return treelib.tree_path_map(func, tree, is_leaf=is_leaf_func)
return treelib.path_map(func, tree, is_leaf=is_leaf_func)

if any(isinstance(mask, EllipsisMatchKey) for mask in where):
# should the selected subtree be broadcasted to the full tree
Expand All @@ -160,8 +160,8 @@ def is_leaf_func(node) -> bool:
# and without broadcast the result will be [100, 3, 4]

def bool_tree(value: bool, tree: Any):
leaves, treedef = treelib.tree_flatten(tree, is_leaf=is_leaf)
return treelib.tree_unflatten(treedef, [value] * len(leaves))
leaves, treedef = treelib.flatten(tree, is_leaf=is_leaf)
return treelib.unflatten(treedef, [value] * len(leaves))

true_tree = ft.partial(bool_tree, True)
false_tree = ft.partial(bool_tree, False)
Expand Down Expand Up @@ -249,7 +249,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool:
nonlocal seen_tuple, level_paths, bool_masks
# used to check if a pytree is a valid indexing pytree
# used with `is_leaf` argument of any `tree_*` function
leaves, _ = treelib.tree_flatten(node)
leaves, _ = treelib.flatten(node)

if all(map(is_bool_leaf, leaves)):
# if all leaves are boolean then this is maybe a boolean mask.
Expand Down Expand Up @@ -289,7 +289,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool:
# each for loop iteration is a level in the where path
# this means that if where = ("a", "b", "c") then this means
# we are travering the tree at level "a" then level "b" then level "c"
treelib.tree_flatten(level_keys, is_leaf=verify_and_aggregate_is_leaf)
treelib.flatten(level_keys, is_leaf=verify_and_aggregate_is_leaf)
# if len(level_paths) > 1 then this means that we have multiple keys
# at the same level, for example where = ("a", ("b", "c")) then this
# means that for a parent "a", select "b" and "c".
Expand All @@ -304,7 +304,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool:

if bool_masks:
all_masks = [mask, *bool_masks] if mask else bool_masks
mask = treelib.tree_map(combine_bool_leaves, *all_masks)
mask = treelib.map(combine_bool_leaves, *all_masks)

return mask

Expand Down Expand Up @@ -390,7 +390,7 @@ def leaf_get(where: Any, leaf: Any):
# and `None` otherwise
return leaf if where else None

return treelib.tree_map(
return treelib.map(
leaf_get,
resolve_where(self.where, self.tree, is_leaf),
self.tree,
Expand Down Expand Up @@ -440,16 +440,16 @@ def leaf_set(where: Any, leaf: Any, set_value: Any):
return arraylib.where(where, set_value, leaf)
return set_value if where else leaf

_, lhsdef = treelib.tree_flatten(self.tree, is_leaf=is_leaf)
_, rhsdef = treelib.tree_flatten(set_value, is_leaf=is_leaf)
_, lhsdef = treelib.flatten(self.tree, is_leaf=is_leaf)
_, rhsdef = treelib.flatten(set_value, is_leaf=is_leaf)

if lhsdef == rhsdef:
# do not broadcast set_value if it is a pytree of same structure
# for example tree.at[where].set(tree2) will set all tree leaves
# to tree2 leaves if tree2 is a pytree of same structure as tree
# instead of making each leaf of tree a copy of tree2
# is design is similar to ``numpy`` design `np.at[...].set(Array)`
return treelib.tree_map(
return treelib.map(
leaf_set,
resolve_where(self.where, self.tree, is_leaf),
self.tree,
Expand All @@ -458,7 +458,7 @@ def leaf_set(where: Any, leaf: Any, set_value: Any):
is_parallel=is_parallel,
)

return treelib.tree_map(
return treelib.map(
ft.partial(leaf_set, set_value=set_value),
resolve_where(self.where, self.tree, is_leaf),
self.tree,
Expand Down Expand Up @@ -517,7 +517,7 @@ def leaf_apply(where: Any, leaf: Any):
return arraylib.where(where, func(leaf), leaf)
return func(leaf) if where else leaf

return treelib.tree_map(
return treelib.map(
leaf_apply,
resolve_where(self.where, self.tree, is_leaf),
self.tree,
Expand Down Expand Up @@ -578,7 +578,7 @@ def leaf_apply(where: Any, leaf: Any):
return arraylib.where(where, stateless_func(leaf), leaf)
return stateless_func(leaf) if where else leaf

out_tree = treelib.tree_map(
out_tree = treelib.map(
leaf_apply,
resolve_where(self.where, self.tree, is_leaf),
self.tree,
Expand Down Expand Up @@ -619,7 +619,7 @@ def reduce(
"""
treelib = sepes._src.backend.treelib
tree = self.get(is_leaf=is_leaf) # type: ignore
leaves, _ = treelib.tree_flatten(tree, is_leaf=is_leaf)
leaves, _ = treelib.flatten(tree, is_leaf=is_leaf)
if initializer is _no_initializer:
return ft.reduce(func, leaves)
return ft.reduce(func, leaves, initializer)
Expand Down Expand Up @@ -700,7 +700,7 @@ def aggregate_subtrees(node: Any) -> bool:
# for example if tree = dict(a=1) and mask is dict(a=True)
# then returns [1] and not [dict(a=1)]
return False
leaves, _ = treelib.tree_flatten(node, is_leaf=lambda x: x is None)
leaves, _ = treelib.flatten(node, is_leaf=lambda x: x is None)
# in essence if the subtree does not contain any None leaves
# then it is a valid subtree to be plucked
# this because `get` sets the non-selected leaves to None
Expand All @@ -710,7 +710,7 @@ def aggregate_subtrees(node: Any) -> bool:
count -= 1
return True

treelib.tree_flatten(tree, is_leaf=aggregate_subtrees)
treelib.flatten(tree, is_leaf=aggregate_subtrees)
return subtrees


Expand Down
4 changes: 2 additions & 2 deletions sepes/_src/tree_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _tree_mask_map(
):
treelib = sepes._src.backend.treelib
# apply func to leaves satisfying mask pytree/condtion
_, lhsdef = treelib.tree_flatten(tree, is_leaf=is_leaf)
_, lhsdef = treelib.flatten(tree, is_leaf=is_leaf)

if not isinstance(cond, Callable):
# a callable that accepts a leaf and returns a boolean
Expand All @@ -244,7 +244,7 @@ def _tree_mask_map(
def map_func(x):
return func(x) if cond(x) else x

return treelib.tree_map(map_func, tree, is_leaf=is_leaf)
return treelib.map(map_func, tree, is_leaf=is_leaf)


def tree_mask(
Expand Down
4 changes: 2 additions & 2 deletions sepes/_src/tree_pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,14 +526,14 @@ def tree_size(tree: PyTree) -> int:
def reduce_func(acc, node):
return acc + tree_summary.size_dispatcher(node)

leaves, _ = treelib.tree_flatten(tree)
leaves, _ = treelib.flatten(tree)
return ft.reduce(reduce_func, leaves, 0)

def tree_count(tree: PyTree) -> int:
def reduce_func(acc, node):
return acc + tree_summary.count_dispatcher(node)

leaves, _ = treelib.tree_flatten(tree)
leaves, _ = treelib.flatten(tree)
return ft.reduce(reduce_func, leaves, 0)

traces_leaves = tree_type_path_leaves(
Expand Down
Loading

0 comments on commit 6e09014

Please sign in to comment.