- Reduce the core API size by removing:
tree_graph
(for graphviz)tree_mermaid
(mermaidjs)Partial/partial
-> Usejax.tree_util.Partial
instead.is_tree_equal
-> Usebcmap(numpy.testing.*)(pytree1, pytree2)
instead.freeze
-> Useft.partial(tree_mask, lambda _: True)
instead.unfreeze
-> Usetree_unmask
instead.is_nondiff
BaseKey
-
tree_{mask,unmask}
now accepts only callablecond
argument.For masking using pytree boolean mask use the following pattern:
import jax import sepes as sp import functools as ft tree = [[1, 2], 3] # the nested tree where = [[True, False], True] # mask tree[0][1] and tree[1] mask = ft.partial(sp.tree_mask, cond=lambda _: True) sp.at(tree)[where].apply(mask) # apply using `at` # [[#1, 2], #3] # or simply apply to the node directly tree = [[mask(1), 2], mask(3)] # [[#1, 2], #3]
-
Rename
is_frozen
tois_masked
- frozen could mean non-trainable array, however the masking is not only for arrays but also for other types that will be hidden across jax transformations.
-
Rename
AtIndexer
toat
for shorter syntax.
- Add
fill_value
inat[...].get(fill_value=...)
to add default value for non selected leaves. Useful for arrays underjax.jit
to avoid variable size related errors. - Add
jax.tree_util.{SequenceKey,GetAttrKey,DictKey}
as valid path keys inat[...]
.
-
Raise error if
autoinit
is used with__init__
method defined. -
Avoid applying
copy.copy
jax.Array
during flatten/unflatten orAtIndexer
operations. -
Add
at
as an alias forAtIndexer
for shorter syntax. -
Deprecate
AtIndexer.__call__
in favor ofvalue_and_tree
to apply function in a functional manner by copying the input argument.import sepes as sp class Counter(sp.TreeClass): def __init__(self, count: int): self.count = count def increment(self, value): self.count += value return self.count counter = Counter(0) # the function follow jax.value_and_grad semantics where the tree is the # copied mutated input argument, if the function mutates the input arguments sp.value_and_tree(lambda C: C.increment(1))(counter) # (1, Counter(count=1))
-
Add sharding info in
tree_summary
,G
for global,S
for sharded shape.import jax import sepes as sp from jax.sharding import Mesh, NamedSharding as N, PartitionSpec as P import numpy as np import os os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" x = jax.numpy.ones([4 * 4, 2 * 2]) mesh = Mesh(devices=np.array(jax.devices()).reshape(4, 2), axis_names=["i", "j"]) sharding = N(mesh=mesh, spec=P("i", "j")) x = jax.device_put(x, device=sharding) print(sp.tree_summary(x)) ┌────┬───────────┬─────┬───────┐ │Name│Type │Count│Size │ ├────┼───────────┼─────┼───────┤ │Σ │G:f32[16,4]│64 │256.00B│ │ │S:f32[4,2] │ │ │ └────┴───────────┴─────┴───────┘
-
Updated docstrings. e.g. How to construct flops counter in
tree_summary
usingjax.jit
- No freezing rule for
jax.Tracer
insp.freeze
- Add pprint rule
jax.Tracer
insp.tree_repr
/sp.tree_str
- Add no-op warning if user adds
autoinit
to class with__init__
method. - Add warning if user add fields in incorrect kind order.
- Add warning if any bases of autoinit has
__init__
method. - Add
CLASS_VAR
kind infield
to support class variables inautoinit
.
-
__call__
is added toAtIndexer
to enable methods that work on copied instance. to avoid mutating in-place. This is useful to write methods in stateful manner, and use theAtIndexer
to operate in a functional manner. This feature was previously enabled only forTreeClass
, but now it is enabled for any class.The following shows how to use
AtIndexer
to call a method that mutates the tree in-place, in an out-of-place manner (i.e. execute the method on a copy of the tree)import sepes as sp import jax.tree_util as jtu class Counter: def __init__(self, count: int): self.count = count def increment_count(self, count:int) -> int: # mutates the tree self.count += count return self.count def __repr__(self) -> str: return f"Tree(count={self.count})" counter = Counter(0) indexer = sp.AtIndexer(counter) cur_count, new_counter = indexer["increment_count"](count=1) assert counter.count == 0 # did not mutate in-place assert cur_count == 1 # the method returned the current count assert new_counter.count == 1 # the copied instance where the method was executed assert not (counter is new_counter) # the old and new instance are not the same
If the instance is frozen (e.g.
dataclasses.dataclass(frozen=True)
) ,or implements a custom__setattr__
/__delattr__
then the__call__
may not work. Register a custom mutator/immutator toAtIndexer
to enable__call__
to work. For more seeAtIndexer.__call__
docstring.
-
Mark full subtrees for replacement.
import sepes tree = [1, 2, [3,4]] tree_= sp.AtIndexer(tree)[[False,False,True]].set(10) assert tree_ == [1, 2, 10]
i.e. Inside a mask, marking a subtree mask with single bool leaf, will replace the whole subtree. In this example subtree
[3, 4]
marked withTrue
in the mask is an indicator for replacement.If the subtree is populated with
True
leaves, then the set value will be broadcasted to all subtree leaves.import sepes tree = [1, 2, [3, 4]] tree_ = sp.AtIndexer(tree)[[False, False, [True, True]]].set(10) assert tree_ == [1, 2, [10, 10]]
-
Do not broadcast path based mask
import sepes as sp tree = [1, 2, [3, 4]] tree_= sp.AtIndexer(tree)[2].set(10) assert tree_ == [1, 2, 10]
To broadcast to subtree use
...
import sepes as sp tree = [1, 2, [3, 4]] tree_= sp.AtIndexer(tree)[2][...].set(10) assert tree_ == [1, 2, [10, 10]]
-
Better lookup errors
import sepes as sp tree = {"a": {"b": 1, "c": 2}, "d": 3} sp.AtIndexer(tree)["a"]["d"].set(100)
LookupError: No leaf match is found for where=[a, d]. Available keys are ['a']['b'], ['a']['c'], ['d']. Check the following: - If where is `str` then check if the key exists as a key or attribute. - If where is `int` then check if the index is in range. - If where is `re.Pattern` then check if the pattern matches any key. - If where is a `tuple` of the above types then check if any of the tuple elements match.
-
Extract subtrees with
pluck
import sepes as sp tree = {"a": 1, "b": [1, 2, 3]} indexer = sp.AtIndexer(tree) # construct an indexer # `pluck` returns a list of selected subtrees indexer["b"].pluck() # [[1, 2, 3]] # in comparison, `get` returns same pytree indexer["b"].get() # {'a': None, 'b': [1, 2, 3]}
pluck
with maskimport sepes as sp tree = {"a": 1, "b": [2, 3, 4]} mask = {"a": True, "b": [False, True, False]} indexer = sp.AtIndexer(tree) indexer[mask].pluck() # [1, 3]
This is equivalent to the following:
[tree["a"], tree["b"][1]]
To get the first
n
matches, usepluck(n)
.A simple application of pluck is to share reference using a descriptor-based approach:
import sepes as sp marker = object() class Tie: def __set_name__(self, owner, name): self.name = name def __set__(self, instance, indexer): self.where = indexer.where where_str = "/".join(map(str, self.where)) vars(instance)[self.name] = f"Ref:{where_str}" def __get__(self, instance, owner): (subtree,) = sp.AtIndexer(instance, self.where).pluck(1) return subtree class Tree(sp.TreeClass): shared: Tie = Tie() def __init__(self): self.lookup = dict(a=marker, b=2) self.shared = self.at["lookup"]["a"] tree = Tree() assert tree.lookup["a"] is tree.shared
-
Revamp the backend mechanism:
- Rewrite array backend via dispatch to work with
numpy
,jax
, andtorch
simultaneously. for example the following recognize bothjax
andtorch
entries without backend changes.
import sepes as sp import jax.numpy as jnp import torch tree = [[1, 2], 2, [3, 4], jnp.ones((2, 2)), torch.ones((2, 2))] print(sp.tree_repr(tree)) # [ # [1, 2], # 2, # [3, 4], # f32[2,2](μ=1.00, σ=0.00, ∈[1.00,1.00]), # torch.f32[2,2](μ=1.00, σ=0.00, ∈[1.00,1.00]) # ]
- Introduce
backend_context
to switch betweenjax
/optree
backend registration and tree utilities. the following example shows how to register with different backends:
import sepes import jax import optree with sepes.backend_context("jax"): class JaxTree(sepes.TreeClass): def __init__(self): self.l1 = 1.0 self.l2 = 2.0 print(jax.tree_util.tree_leaves(JaxTree())) with sepes.backend_context("optree"): class OpTreeTree(sepes.TreeClass): def __init__(self): self.l1 = 1.0 self.l2 = 2.0 print(optree.tree_leaves(OpTreeTree(), namespace="sepes")) # [1.0, 2.0] # [1.0, 2.0]
- Rewrite array backend via dispatch to work with
-
successor of the
jax
-specificpytreeclass
-
Supports multibackend:
numpy
+optree
viaexport SEPES_BACKEND=numpy
(lightweight option)jax
viaexport SEPES_BACKEND=jax
- The default -torch
+optree
viaexport SEPES_BACKEND=torch
- no array +
optree
viaexport SEPES_BACKEND=default
-
drop
callback
option in parallel options inis_parallel
-
Add parallel processing via
is_parallel
to.{get,set}
-
register_excluded_type
toautoinit
to exclude certain types to be infield
defaults. -
add
doc
infield
to add extra documentation for the descriptor__doc__