Skip to content

v0.12

Latest
Compare
Choose a tag to compare
@ASEM000 ASEM000 released this 30 Mar 17:21
1639900

Changelog

V0.12

Deprecations

  • Reduce the core API size by removing:
    1. tree_graph (for graphviz)
    2. tree_mermaid (mermaidjs)
    3. Partial/partial -> Use jax.tree_util.Partial instead.
    4. is_tree_equal -> Use bcmap(numpy.testing.*)(pytree1, pytree2) instead.
    5. freeze -> Use ft.partial(tree_mask, lambda _: True) instead.
    6. unfreeze -> Use tree_unmask instead.
    7. is_nondiff
    8. BaseKey

Changes

  • tree_{mask,unmask} now accepts only callable cond argument.

    For masking using pytree boolean mask use the following pattern:

    import jax
    import sepes as sp
    import functools as ft
    tree = [[1, 2], 3] # the nested tree
    where = [[True, False], True]  # mask tree[0][1] and tree[1]
    mask = ft.partial(sp.tree_mask, cond=lambda _: True)
    sp.at(tree)[where].apply(mask)  # apply using `at`
    # [[#1, 2], #3]
    # or simply apply to the node directly
    tree = [[mask(1), 2], mask(3)]
    # [[#1, 2], #3]
  • Rename is_frozen to is_masked

    • frozen could mean non-trainable array, however the masking is not only for arrays but also for other types that will be hidden across jax transformations.
  • Rename AtIndexer to at for shorter syntax.

Additions

  • Add fill_value in at[...].get(fill_value=...) to add default value for non
    selected leaves. Useful for arrays under jax.jit to avoid variable size related errors.
  • Add jax.tree_util.{SequenceKey,GetAttrKey,DictKey} as valid path keys in at[...].