Skip to content

Commit

Permalink
add changelog entry
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Oct 11, 2023
1 parent 71b0633 commit 61ebeba
Show file tree
Hide file tree
Showing 12 changed files with 68 additions and 15 deletions.
43 changes: 43 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,49 @@
assert tree.lookup["a"] is tree.shared
```

- Revamp the backend mechanism:

- Rewrite array backend via dispatch to work with `numpy`,`jax`, and `torch` simultaneously. for example the following recognize both `jax` and `torch` entries without backend changes.

```python
import sepes as sp
import jax.numpy as jnp
import torch
tree = [[1, 2], 2, [3, 4], jnp.ones((2, 2)), torch.ones((2, 2))]
print(sp.tree_repr(tree))
# [
# [1, 2],
# 2,
# [3, 4],
# f32[2,2](μ=1.00, σ=0.00, ∈[1.00,1.00]),
# torch.f32[2,2](μ=1.00, σ=0.00, ∈[1.00,1.00])
# ]
```

- Introduce `backend_context` to switch between `jax`/`optree` backend registration and tree utilities. the following example shows how to register with different backends:

```python
import sepes
import jax
import optree

with sepes.backend_context("jax"):
class JaxTree(sepes.TreeClass):
def __init__(self):
self.l1 = 1.0
self.l2 = 2.0
print(jax.tree_util.tree_leaves(JaxTree()))

with sepes.backend_context("optree"):
class OpTreeTree(sepes.TreeClass):
def __init__(self):
self.l1 = 1.0
self.l2 = 2.0
print(optree.tree_leaves(OpTreeTree(), namespace="sepes"))
# [1.0, 2.0]
# [1.0, 2.0]
```

## v0.10.0

- successor of the `jax`-specific `pytreeclass`
Expand Down
6 changes: 6 additions & 0 deletions docs/API/backend.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
🏰 Backend API
----------------------------------------------

.. currentmodule:: sepes

.. autofunction:: backend_context
1 change: 1 addition & 0 deletions docs/API/sepes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
masking
advanced_api
pretty_print
backend
3 changes: 3 additions & 0 deletions sepes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
tree_summary,
)
from sepes._src.tree_util import Partial, bcmap, is_tree_equal, leafwise
from sepes._src.backend import backend_context

__all__ = (
# general utils
Expand Down Expand Up @@ -61,6 +62,8 @@
"bcmap",
"Partial",
"leafwise",
# backend utils
"backend_context",
)

__version__ = "0.11.0"
Expand Down
4 changes: 2 additions & 2 deletions sepes/_src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def jax_backend():

@contextmanager
def backend_context(backend_name: BackendLiteral):
"""Context manager for switching the tree backend temporarily.
"""Context manager for switching the tree backend within a context.
Args:
backend_name: The name of the backend to switch to. available backends are
Expand All @@ -105,7 +105,7 @@ def backend_context(backend_name: BackendLiteral):
>>> optree.tree_flatten(tree, namespace="sepes")
[1, 2]
"""
global treelib
global treelib, backend
old_treelib = treelib
old_backend = backend
try:
Expand Down
2 changes: 1 addition & 1 deletion sepes/_src/backend/arraylib/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ class ArrayLib:
is_integer = staticmethod(is_integer)
is_inexact = staticmethod(is_inexact)
is_bool = staticmethod(is_bool)
types: tuple[type, ...] = ()
ndarrays: tuple[type, ...] = ()
2 changes: 1 addition & 1 deletion sepes/_src/backend/arraylib/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@
ArrayLib.is_integer.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.integer))
ArrayLib.is_inexact.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.inexact))
ArrayLib.is_bool.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.bool_))
ArrayLib.types += (Array,)
ArrayLib.ndarrays += (Array,)
2 changes: 1 addition & 1 deletion sepes/_src/backend/arraylib/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@
ArrayLib.is_integer.register(ndarray, lambda x: np.issubdtype(x.dtype, np.integer))
ArrayLib.is_inexact.register(ndarray, lambda x: np.issubdtype(x.dtype, np.inexact))
ArrayLib.is_bool.register(ndarray, lambda x: np.issubdtype(x.dtype, np.bool_))
ArrayLib.types += (ndarray,)
ArrayLib.ndarrays += (ndarray,)
2 changes: 1 addition & 1 deletion sepes/_src/backend/arraylib/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@
ArrayLib.is_integer.register(Tensor, lambda x: x.dtype in integers)
ArrayLib.is_inexact.register(Tensor, lambda x: x.dtype in floatings + complexes)
ArrayLib.is_bool.register(Tensor, lambda x: x.dtype == torch.bool)
ArrayLib.types += (Tensor,)
ArrayLib.ndarrays += (Tensor,)
10 changes: 5 additions & 5 deletions sepes/_src/tree_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def combine_bool_leaves(*leaves):
return verdict

def is_bool_leaf(leaf: Any) -> bool:
if isinstance(leaf, arraylib.types):
if isinstance(leaf, arraylib.ndarrays):
return arraylib.is_bool(leaf)
return isinstance(leaf, bool)

Expand Down Expand Up @@ -598,7 +598,7 @@ def leaf_get(where: Any, leaf: Any):
# for array boolean mask we select **parts** of the array that
# matches the mask, for example if the mask is Array([True, False, False])
# and the leaf is Array([1, 2, 3]) then the result is Array([1])
if isinstance(where, arraylib.types) and len(arraylib.shape(where)):
if isinstance(where, arraylib.ndarrays) and len(arraylib.shape(where)):
return leaf[where]
# non-array boolean mask we select the leaf if the mask is True
# and `None` otherwise
Expand Down Expand Up @@ -663,7 +663,7 @@ def leaf_set(where: Any, leaf: Any, set_value: Any):
# matches the mask, for example if the mask is Array([True, False, False])
# and the leaf is Array([1, 2, 3]) then the result is Array([1, 100, 100])
# with set_value = 100
if isinstance(where, arraylib.types):
if isinstance(where, arraylib.ndarrays):
return arraylib.where(where, set_value, leaf)
return set_value if where else leaf

Expand Down Expand Up @@ -751,7 +751,7 @@ def leaf_apply(where: Any, leaf: Any):
# one thing to note is that, the where mask select an array
# then the function needs work properly when applied to the selected
# array elements
if isinstance(where, arraylib.types):
if isinstance(where, arraylib.ndarrays):
return arraylib.where(where, func(leaf), leaf)
return func(leaf) if where else leaf

Expand Down Expand Up @@ -831,7 +831,7 @@ def stateless_func(leaf):
return leaf

def leaf_apply(where: Any, leaf: Any):
if isinstance(where, arraylib.types):
if isinstance(where, arraylib.ndarrays):
return arraylib.where(where, stateless_func(leaf), leaf)
return stateless_func(leaf) if where else leaf

Expand Down
4 changes: 2 additions & 2 deletions sepes/_src/tree_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def freeze(value: T) -> _FrozenBase[T]:
freeze.def_type = freeze.type_dispatcher.register


for ndarray in arraylib.types:
for ndarray in arraylib.ndarrays:

@freeze.def_type(ndarray)
def freeze_array(value: T) -> _FrozenArray[T]:
Expand Down Expand Up @@ -254,7 +254,7 @@ def is_nondiff(value: Any) -> bool:
is_nondiff.def_type = is_nondiff.type_dispatcher.register


for ndarray in arraylib.types:
for ndarray in arraylib.ndarrays:

@is_nondiff.def_type(ndarray)
def is_nondiff_array(value) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions sepes/_src/tree_pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _(func: Callable, **spec: Unpack[PPSpec]) -> str:
return text


for ndarray in arraylib.types:
for ndarray in arraylib.ndarrays:

@tree_repr.def_type(ndarray)
def array_pp(node, **spec: Unpack[PPSpec]) -> str:
Expand Down Expand Up @@ -672,7 +672,7 @@ def tree_summary(
tree_summary.def_type = tree_summary.type_dispatcher.register


for ndarray in arraylib.types:
for ndarray in arraylib.ndarrays:

@tree_summary.def_size(ndarray)
def _(node) -> int:
Expand Down

0 comments on commit 61ebeba

Please sign in to comment.