Skip to content

Commit

Permalink
custom wrapper rules
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 2, 2024
1 parent 1639900 commit 13d26b8
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 78 deletions.
41 changes: 41 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,46 @@
# Changelog

## v0.12.1


### Additions

- Add ability to register custom types for masking wrappers.

Example to define a custom masking wrapper for a specific type.

```python
import sepes as sp
import jax
import dataclasses as dc
@dc.dataclass
class MyInt:
value: int
@dc.dataclass
class MaskedInt:
value: MyInt
# define a rule of how to mask instances of MyInt
@sp.tree_mask.def_type(MyInt)
def mask_int(value):
return MaskedInt(value)
# define a rule how to unmask the MaskedInt wrapper
@sp.tree_unmask.def_type(MaskedInt)
def unmask_int(value):
return value.value
tree = [MyInt(1), MyInt(2), {"a": MyInt(3)}]
masked_tree = sp.tree_mask(tree, cond=lambda _: True)

masked_tree
#[MaskedInt(value=MyInt(value=1)), MaskedInt(value=MyInt(value=2)), {'a': MaskedInt(value=MyInt(value=3))}]

sp.tree_unmask(masked_tree)
#[MyInt(value=1), MyInt(value=2), {'a': MyInt(value=3)}]

# `is_masked` recognizes the new masked type
assert is_masked(masked_tree[0]) is True
```


## V0.12

### Deprecations
Expand Down
2 changes: 1 addition & 1 deletion sepes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"backend_context",
]

__version__ = "0.12.0"
__version__ = "0.12.1"

at.__module__ = "sepes"
TreeClass.__module__ = "sepes"
200 changes: 123 additions & 77 deletions sepes/_src/tree_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,30 @@
MaskType = Union[T, Callable[[Any], bool]]


def is_nondiff(value: Any) -> bool:
return is_nondiff.type_dispatcher(value)


is_nondiff.type_dispatcher = ft.singledispatch(lambda _: True)
is_nondiff.def_type = is_nondiff.type_dispatcher.register


for ndarray in arraylib.ndarrays:

@is_nondiff.def_type(ndarray)
def is_nondiff_array(value) -> bool:
# return True if the node is non-inexact type, otherwise False
if arraylib.is_inexact(value):
return False
return True


@is_nondiff.def_type(float)
@is_nondiff.def_type(complex)
def _(_: float | complex) -> bool:
return False


class _MaskedError(NamedTuple):
opname: str

Expand Down Expand Up @@ -123,86 +147,13 @@ def __eq__(self, other) -> bool:
return arraylib.array_equal(lhs, rhs)


def mask(value: T) -> _MaskBase[T]:
# dispatching is used to customize the type of the wrapper based on the type
# of the value. For instance, hashable values dont need custom hash and
# equality implementations, so they are wrapped with a simpler wrapper.
# this approach avoids type logic in the wrapper equality and hash methods,
# thus effectively improving performance of the wrapper.
return mask.type_dispatcher(value)


mask.type_dispatcher = ft.singledispatch(_MaskedHashable)
mask.def_type = mask.type_dispatcher.register


for ndarray in arraylib.ndarrays:

@mask.def_type(ndarray)
def mask_array(value: T) -> _MaskedArray[T]:
# wrap arrays with a custom wrapper that implements hash and equality
# arrays can be hashed by converting them to bytes and hashing the bytes
return _MaskedArray(value)


@mask.def_type(_MaskBase)
def _(value: _MaskBase[T]) -> _MaskBase[T]:
# idempotent mask operation, meaning that mask(mask(x)) == mask(x)
# this is useful to avoid recursive unwrapping of frozen values, plus its
# meaningless to mask a frozen value.
return value


def is_masked(value: Any) -> bool:
"""Returns True if the value is a frozen wrapper."""
return isinstance(value, _MaskBase)


def unmask(value: T) -> T:
return unmask.type_dispatcher(value)


unmask.type_dispatcher = ft.singledispatch(lambda x: x)
unmask.def_type = unmask.type_dispatcher.register


@unmask.def_type(_MaskBase)
def _(value: _MaskBase[T]) -> T:
return getattr(value, "__wrapped__")


def is_nondiff(value: Any) -> bool:
return is_nondiff.type_dispatcher(value)


is_nondiff.type_dispatcher = ft.singledispatch(lambda _: True)
is_nondiff.def_type = is_nondiff.type_dispatcher.register


for ndarray in arraylib.ndarrays:

@is_nondiff.def_type(ndarray)
def is_nondiff_array(value) -> bool:
# return True if the node is non-inexact type, otherwise False
if arraylib.is_inexact(value):
return False
return True


@is_nondiff.def_type(float)
@is_nondiff.def_type(complex)
def _(_: float | complex) -> bool:
return False


def _tree_mask_map(
tree: T,
cond: Callable[[Any], bool],
func: type | Callable[[Any], Any],
func: Callable[[Any], Any],
*,
is_leaf: Callable[[Any], None] | None = None,
):

if not isinstance(cond, Callable):
# a callable that accepts a leaf and returns a boolean
# but *not* a tree with the same structure as tree with boolean values.
Expand Down Expand Up @@ -266,8 +217,39 @@ def tree_mask(
>>> tree = (1., 2) # contains a non-differentiable node
>>> square(sp.tree_mask(tree))
(Array(2., dtype=float32, weak_type=True), #2)
Example:
Define a custom masking wrapper for a specific type.
>>> import sepes as sp
>>> import jax
>>> import dataclasses as dc
>>> @dc.dataclass
... class MyInt:
... value: int
>>> @dc.dataclass
... class MaskedInt:
... value: MyInt
>>> # define a rule of how to mask an integer
>>> @sp.tree_mask.def_type(MyInt)
... def mask_int(value):
... return MaskedInt(value)
>>> # define a rule how to unmask the wrapper
>>> @sp.tree_unmask.def_type(MaskedInt)
... def unmask_int(value):
... return value.value
>>> tree = [MyInt(1), MyInt(2), {"a": MyInt(3)}]
>>> masked_tree = sp.tree_mask(tree, cond=lambda _: True)
>>> masked_tree
[MaskedInt(value=MyInt(value=1)), MaskedInt(value=MyInt(value=2)), {'a': MaskedInt(value=MyInt(value=3))}]
>>> sp.tree_unmask(masked_tree)
[MyInt(value=1), MyInt(value=2), {'a': MyInt(value=3)}]
"""
return _tree_mask_map(tree, cond=cond, func=mask, is_leaf=is_leaf)
return _tree_mask_map(tree, cond=cond, func=tree_mask.dispatcher, is_leaf=is_leaf)


tree_mask.dispatcher = ft.singledispatch(_MaskedHashable)
tree_mask.def_type = tree_mask.dispatcher.register


def tree_unmask(tree: T, cond: Callable[[Any], bool] = lambda _: True):
Expand Down Expand Up @@ -303,8 +285,72 @@ def tree_unmask(tree: T, cond: Callable[[Any], bool] = lambda _: True):
>>> tree = (1., 2) # contains a non-differentiable node
>>> square(sp.tree_mask(tree))
(Array(2., dtype=float32, weak_type=True), #2)
Example:
Define a custom masking wrapper for a specific type.
>>> import sepes as sp
>>> import jax
>>> import dataclasses as dc
>>> @dc.dataclass
... class MyInt:
... value: int
>>> @dc.dataclass
... class MaskedInt:
... value: MyInt
>>> # define a rule of how to mask an integer
>>> @sp.tree_mask.def_type(MyInt)
... def mask_int(value):
... return MaskedInt(value)
>>> # define a rule how to unmask the wrapper
>>> @sp.tree_unmask.def_type(MaskedInt)
... def unmask_int(value):
... return value.value
>>> tree = [MyInt(1), MyInt(2), {"a": MyInt(3)}]
>>> masked_tree = sp.tree_mask(tree, cond=lambda _: True)
>>> masked_tree
[MaskedInt(value=MyInt(value=1)), MaskedInt(value=MyInt(value=2)), {'a': MaskedInt(value=MyInt(value=3))}]
>>> sp.tree_unmask(masked_tree)
[MyInt(value=1), MyInt(value=2), {'a': MyInt(value=3)}]
"""
return _tree_mask_map(tree, cond=cond, func=unmask, is_leaf=is_masked)
return _tree_mask_map(
tree,
cond=cond,
func=tree_unmask.dispatcher,
is_leaf=is_masked,
)


tree_unmask.dispatcher = ft.singledispatch(lambda x: x)
tree_unmask.def_type = tree_unmask.dispatcher.register


for ndarray in arraylib.ndarrays:

@tree_mask.def_type(ndarray)
def mask_array(value: T) -> _MaskedArray[T]:
# wrap arrays with a custom wrapper that implements hash and equality
# arrays can be hashed by converting them to bytes and hashing the bytes
return _MaskedArray(value)


@tree_mask.def_type(_MaskBase)
def _(value: _MaskBase[T]) -> _MaskBase[T]:
# idempotent mask operation, meaning that mask(mask(x)) == mask(x)
# this is useful to avoid recursive unwrapping of frozen values, plus its
# meaningless to mask a frozen value.
return value


def is_masked(value: Any) -> bool:
"""Returns True if the value is a frozen wrapper."""
types = tuple(set(tree_unmask.dispatcher.registry) - {object})
return isinstance(value, types)


@tree_unmask.def_type(_MaskBase)
def _(value: _MaskBase[T]) -> T:
return getattr(value, "__wrapped__")


if is_package_avaiable("jax"):
Expand All @@ -314,6 +360,6 @@ def tree_unmask(tree: T, cond: Callable[[Any], bool] = lambda _: True):
# otherwise calling `freeze` inside a jax transformation on
# a tracer will hide the tracer from jax and will cause leaked tracer
# error.
@mask.def_type(jax.core.Tracer)
@tree_mask.def_type(jax.core.Tracer)
def _(value: jax.core.Tracer) -> jax.core.Tracer:
return value
24 changes: 24 additions & 0 deletions tests/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,27 @@ def test_array_tree_mask_tree_unmask():
assert not (frozen_array == freeze(arraylib.ones((5, 6))))
# assert not (frozen_array == freeze(arraylib.ones((5, 5)).astype(arraylib.uint8)))
assert hash(frozen_array) == hash(frozen_array)


def test_custom_mask_unmask_wrappers():

@dc.dataclass
class MyInt:
value: int
@dc.dataclass
class MaskedInt:
value: MyInt
# define a rule of how to mask instances of MyInt
@tree_mask.def_type(MyInt)
def mask_int(value):
return MaskedInt(value)
# define a rule how to unmask the MaskedInt wrapper
@tree_unmask.def_type(MaskedInt)
def unmask_int(value):
return value.value
tree = [MyInt(1), MyInt(2), {"a": MyInt(3)}]
masked_tree = tree_mask(tree, cond=lambda _: True)
masked_tree
assert tree_unmask(masked_tree) == [MyInt(value=1), MyInt(value=2), {'a': MyInt(value=3)}]
# check the mask type is recognized by is_masked
assert is_masked(masked_tree[0]) is True

0 comments on commit 13d26b8

Please sign in to comment.