Changelog
V0.12
Deprecations
- Reduce the core API size by removing:
tree_graph
(for graphviz)tree_mermaid
(mermaidjs)Partial/partial
-> Usejax.tree_util.Partial
instead.is_tree_equal
-> Usebcmap(numpy.testing.*)(pytree1, pytree2)
instead.freeze
-> Useft.partial(tree_mask, lambda _: True)
instead.unfreeze
-> Usetree_unmask
instead.is_nondiff
BaseKey
Changes
-
tree_{mask,unmask}
now accepts only callablecond
argument.For masking using pytree boolean mask use the following pattern:
import jax import sepes as sp import functools as ft tree = [[1, 2], 3] # the nested tree where = [[True, False], True] # mask tree[0][1] and tree[1] mask = ft.partial(sp.tree_mask, cond=lambda _: True) sp.at(tree)[where].apply(mask) # apply using `at` # [[#1, 2], #3] # or simply apply to the node directly tree = [[mask(1), 2], mask(3)] # [[#1, 2], #3]
-
Rename
is_frozen
tois_masked
- frozen could mean non-trainable array, however the masking is not only for arrays but also for other types that will be hidden across jax transformations.
-
Rename
AtIndexer
toat
for shorter syntax.
Additions
- Add
fill_value
inat[...].get(fill_value=...)
to add default value for non
selected leaves. Useful for arrays underjax.jit
to avoid variable size related errors. - Add
jax.tree_util.{SequenceKey,GetAttrKey,DictKey}
as valid path keys inat[...]
.