Skip to content

v0.11.3

Compare
Choose a tag to compare
@ASEM000 ASEM000 released this 16 Dec 23:00
· 2 commits to main since this release
ea1ad24

V0.11.3

  • Raise error if autoinit is used with __init__ method defined.
  • Avoid applying copy.copy jax.Array during flatten/unflatten or AtIndexer operations.
  • Add at as an alias for AtIndexer for shorter syntax.
  • Deprecate AtIndexer.__call__ in favor of value_and_tree to apply function in a functional manner by copying the input argument.
import sepes as sp
class Counter(sp.TreeClass):
    def __init__(self, count: int):
        self.count = count
    def increment(self, value):
        self.count += value
        return self.count
counter = Counter(0)
# the function follow jax.value_and_grad semantics where the tree is the
# copied mutated input argument, if the function mutates the input arguments
sp.value_and_tree(lambda C: C.increment(1))(counter)
# (1, Counter(count=1))
  • Updated docstrings. e.g. How to construct flops counter in tree_summary using jax.jit

What's Changed

Full Changelog: v0.11.2...v0.11.3