v0.11.3
V0.11.3
- Raise error if
autoinit
is used with__init__
method defined. - Avoid applying
copy.copy
jax.Array
during flatten/unflatten orAtIndexer
operations. - Add
at
as an alias forAtIndexer
for shorter syntax. - Deprecate
AtIndexer.__call__
in favor ofvalue_and_tree
to apply function in a functional manner by copying the input argument.
import sepes as sp
class Counter(sp.TreeClass):
def __init__(self, count: int):
self.count = count
def increment(self, value):
self.count += value
return self.count
counter = Counter(0)
# the function follow jax.value_and_grad semantics where the tree is the
# copied mutated input argument, if the function mutates the input arguments
sp.value_and_tree(lambda C: C.increment(1))(counter)
# (1, Counter(count=1))
- Updated docstrings. e.g. How to construct flops counter in
tree_summary
usingjax.jit
What's Changed
Full Changelog: v0.11.2...v0.11.3