Skip to content

Commit

Permalink
v0.12 (#16)
Browse files Browse the repository at this point in the history
* Shard info (#10)

* add shard info in `tree_summary` for jax arrays (if exists)

* add export xla flag in jax ci

* shorter syntax

* Update test_pprint.py

* Simplify tree pprint (#11)

* Update CHANGELOG.md

* remove tree_mermaid tree_graph

* simplify the masking API. (#12)

* simplify the masking API.

- remove freeze/unfreeze
- rename is_frozen to is_masked.
- restrict the cond in tree_mask to callable only

* fix failing tests

* remove partial is_tree_equal from public API

* Update CHANGELOG.md

* add __format__

* Update CHANGELOG.md

* tree mask edits

* add broadcast_to to bcmap (#13)

* Remove distinction between regex and string match key (#14)

* changelog edit

* revert `__format__`

* [AtIndexer] make string key points to regex by default, remove BaseKey

* tree_*** -> ***

* print tracer type in tree repr/str

* fix `is_leaf` typing

* revert #14 plus some mods

- in favor of more explicit
- fails if dicts haves keys similar to re.Pattern

* remove __format__

* Add `fill_value` for `at[...].get(fill_value=...)`

* bump version

* changelog

* Add `jax.tree_util.{SequenceKey,GetAttrKey,DictKey}` as valid path keys in `at[...]`.

* define arraylib.array_equal

* fix numpy test failing

* AtIndexer -> at

* tuple -> tuple[type1, ... typen] in tree_summary type row

* add def_rule for at indexer

* remove is_nondiff

* docs organization

* list tree summary pp rule

* Update tree_mask.py

* nits

* fix at docstring

* fix no leaf match error

* fix doctest errors

* docs

* nits

* typing
  • Loading branch information
ASEM000 committed Mar 30, 2024
1 parent ea1ad24 commit 69a4d8e
Show file tree
Hide file tree
Showing 31 changed files with 769 additions and 1,221 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
run: |
export SEPES_TEST_ARRAYLIB=jax
export SEPES_BACKEND=jax
export XLA_FLAGS=--xla_force_host_platform_device_count=8
python -m pip install .
coverage run -m pytest tests
Expand Down
97 changes: 83 additions & 14 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,95 @@
# Changelog

## V0.12

### Deprecations

- Reduce the core API size by removing:
1) `tree_graph` (for graphviz)
2) `tree_mermaid` (mermaidjs)
3) `Partial/partial` -> Use `jax.tree_util.Partial` instead.
4) `is_tree_equal` -> Use `bcmap(numpy.testing.*)(pytree1, pytree2)` instead.
5) `freeze` -> Use `ft.partial(tree_mask, lambda _: True)` instead.
6) `unfreeze` -> Use `tree_unmask` instead.
7) `is_nondiff`
8) `BaseKey`


### Changes

- `tree_{mask,unmask}` now accepts only callable `cond` argument.

For masking using pytree boolean mask use the following pattern:

```python
import jax
import sepes as sp
import functools as ft
tree = [[1, 2], 3] # the nested tree
where = [[True, False], True] # mask tree[0][1] and tree[1]
mask = ft.partial(sp.tree_mask, cond=lambda _: True)
sp.at(tree)[where].apply(mask) # apply using `at`
# [[#1, 2], #3]
# or simply apply to the node directly
tree = [[mask(1), 2], mask(3)]
# [[#1, 2], #3]
```

- Rename `is_frozen` to `is_masked`
- frozen could mean non-trainable array, however the masking is not only for arrays but also for other types that will be hidden across jax transformations.

- Rename `AtIndexer` to `at` for shorter syntax.

### Additions

- Add `fill_value` in `at[...].get(fill_value=...)` to add default value for non
selected leaves. Useful for arrays under `jax.jit` to avoid variable size related errors.
- Add `jax.tree_util.{SequenceKey,GetAttrKey,DictKey}` as valid path keys in `at[...]`.

## V0.11.3

- Raise error if `autoinit` is used with `__init__` method defined.
- Avoid applying `copy.copy` `jax.Array` during flatten/unflatten or `AtIndexer` operations.
- Add `at` as an alias for `AtIndexer` for shorter syntax.
- Deprecate `AtIndexer.__call__` in favor of `value_and_tree` to apply function in a functional manner by copying the input argument.

```python
import sepes as sp
class Counter(sp.TreeClass):
def __init__(self, count: int):
self.count = count
def increment(self, value):
self.count += value
return self.count
counter = Counter(0)
# the function follow jax.value_and_grad semantics where the tree is the
# copied mutated input argument, if the function mutates the input arguments
sp.value_and_tree(lambda C: C.increment(1))(counter)
# (1, Counter(count=1))
```
```python
import sepes as sp
class Counter(sp.TreeClass):
def __init__(self, count: int):
self.count = count
def increment(self, value):
self.count += value
return self.count
counter = Counter(0)
# the function follow jax.value_and_grad semantics where the tree is the
# copied mutated input argument, if the function mutates the input arguments
sp.value_and_tree(lambda C: C.increment(1))(counter)
# (1, Counter(count=1))
```

- Add sharding info in `tree_summary`, `G` for global, `S` for sharded shape.

```python
import jax
import sepes as sp
from jax.sharding import Mesh, NamedSharding as N, PartitionSpec as P
import numpy as np
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
x = jax.numpy.ones([4 * 4, 2 * 2])
mesh = Mesh(devices=np.array(jax.devices()).reshape(4, 2), axis_names=["i", "j"])
sharding = N(mesh=mesh, spec=P("i", "j"))
x = jax.device_put(x, device=sharding)

print(sp.tree_summary(x))
┌────┬───────────┬─────┬───────┐
│Name│Type │Count│Size │
├────┼───────────┼─────┼───────┤
│Σ │G:f32[16,4]│64256.00B
│ │S:f32[4,2] │ │ │
└────┴───────────┴─────┴───────┘
```

- Updated docstrings. e.g. How to construct flops counter in `tree_summary` using `jax.jit`

Expand Down
10 changes: 10 additions & 0 deletions docs/API/constructor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
🏗️ Constructor utils API
=============================


.. currentmodule:: sepes

.. autofunction:: field
.. autofunction:: fields
.. autofunction:: autoinit
.. autofunction:: leafwise
31 changes: 0 additions & 31 deletions docs/API/core.rst

This file was deleted.

5 changes: 1 addition & 4 deletions docs/API/masking.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

.. currentmodule:: sepes

.. autofunction:: is_nondiff
.. autofunction:: freeze
.. autofunction:: unfreeze
.. autofunction:: is_frozen
.. autofunction:: is_masked
.. autofunction:: tree_mask
.. autofunction:: tree_unmask
10 changes: 10 additions & 0 deletions docs/API/module.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
📍 Module API
=============================


.. currentmodule:: sepes

.. autoclass:: TreeClass
:members:
at

2 changes: 0 additions & 2 deletions docs/API/pretty_print.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
.. currentmodule:: sepes

.. autofunction:: tree_diagram
.. autofunction:: tree_graph
.. autofunction:: tree_mermaid
.. autofunction:: tree_repr
.. autofunction:: tree_str
.. autofunction:: tree_summary
4 changes: 3 additions & 1 deletion docs/API/sepes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
:maxdepth: 2
:caption: API Documentation

core
module
masking
tree
constructor
pretty_print
backend
17 changes: 17 additions & 0 deletions docs/API/tree.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
🌲 Tree utils API
=============================


.. currentmodule:: sepes

.. autoclass:: at
:members:
get,
set,
apply,
scan,
reduce,
pluck,

.. autofunction:: value_and_tree
.. autofunction:: bcmap
67 changes: 0 additions & 67 deletions docs/_static/tree_graph.svg

This file was deleted.

67 changes: 0 additions & 67 deletions docs/_static/tree_graph_stylized.svg

This file was deleted.

Binary file removed docs/_static/tree_mermaid.jpg
Binary file not shown.
Loading

0 comments on commit 69a4d8e

Please sign in to comment.