Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.12 #16

Merged
merged 33 commits into from
Mar 30, 2024
Merged

v0.12 #16

merged 33 commits into from
Mar 30, 2024

Conversation

ASEM000
Copy link
Owner

@ASEM000 ASEM000 commented Mar 29, 2024

Changelog

V0.12

Deprecations

  • Reduce the core API size by removing:
    1. tree_graph (for graphviz)
    2. tree_mermaid (mermaidjs)
    3. Partial/partial -> Use jax.tree_util.Partial instead.
    4. is_tree_equal -> Use bcmap(numpy.testing.*)(pytree1, pytree2) instead.
    5. freeze -> Use ft.partial(tree_mask, lambda _: True) instead.
    6. unfreeze -> Use tree_unmask instead.
    7. is_nondiff
    8. BaseKey

Changes

  • tree_{mask,unmask} now accepts only callable cond argument.

    For masking using pytree boolean mask use the following pattern:

    import jax
    import sepes as sp
    import functools as ft
    tree = [[1, 2], 3] # the nested tree
    where = [[True, False], True]  # mask tree[0][1] and tree[1]
    mask = ft.partial(sp.tree_mask, cond=lambda _: True)
    sp.at(tree)[where].apply(mask)  # apply using `at`
    # [[#1, 2], #3]
    # or simply apply to the node directly
    tree = [[mask(1), 2], mask(3)]
    # [[#1, 2], #3]
  • Rename is_frozen to is_masked

    • frozen could mean non-trainable array, however the masking is not only for arrays but also for other types that will be hidden across jax transformations.
  • Rename AtIndexer to at for shorter syntax.

ASEM000 and others added 29 commits January 26, 2024 00:21
* add shard info in `tree_summary` for jax arrays (if exists)

* add export xla flag in jax ci

* shorter syntax

* Update test_pprint.py
* Update CHANGELOG.md

* remove tree_mermaid tree_graph
* simplify the masking API.

- remove freeze/unfreeze
- rename is_frozen to is_masked.
- restrict the cond in tree_mask to callable only

* fix failing tests

* remove partial is_tree_equal from public API

* Update CHANGELOG.md
* changelog edit

* revert `__format__`

* [AtIndexer] make string key points to regex by default, remove BaseKey
- in favor of more explicit
- fails if dicts haves keys similar to re.Pattern
@codecov-commenter
Copy link

codecov-commenter commented Mar 29, 2024

Codecov Report

Attention: Patch coverage is 96.87500% with 11 lines in your changes are missing coverage. Please review.

Project coverage is 94.65%. Comparing base (ea1ad24) to head (8cd078d).

Files Patch % Lines
sepes/_src/tree_index.py 92.15% 8 Missing ⚠️
tests/test_index.py 94.59% 2 Missing ⚠️
sepes/_src/tree_pprint.py 95.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #16      +/-   ##
==========================================
- Coverage   94.68%   94.65%   -0.03%     
==========================================
  Files          13       13              
  Lines        2370     2358      -12     
==========================================
- Hits         2244     2232      -12     
  Misses        126      126              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ASEM000 ASEM000 merged commit 69a4d8e into main Mar 30, 2024
39 checks passed
@ASEM000 ASEM000 deleted the next branch March 30, 2024 17:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants