From 13d26b8fad2f695d4dff188422136b5e006b939f Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Tue, 2 Apr 2024 19:14:46 +0900 Subject: [PATCH] custom wrapper rules --- CHANGELOG.md | 41 ++++++++ sepes/__init__.py | 2 +- sepes/_src/tree_mask.py | 200 ++++++++++++++++++++++++---------------- tests/test_mask.py | 24 +++++ 4 files changed, 189 insertions(+), 78 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 82fa907..011b25f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,46 @@ # Changelog +## v0.12.1 + + +### Additions + +- Add ability to register custom types for masking wrappers. + + Example to define a custom masking wrapper for a specific type. + + ```python + import sepes as sp + import jax + import dataclasses as dc + @dc.dataclass + class MyInt: + value: int + @dc.dataclass + class MaskedInt: + value: MyInt + # define a rule of how to mask instances of MyInt + @sp.tree_mask.def_type(MyInt) + def mask_int(value): + return MaskedInt(value) + # define a rule how to unmask the MaskedInt wrapper + @sp.tree_unmask.def_type(MaskedInt) + def unmask_int(value): + return value.value + tree = [MyInt(1), MyInt(2), {"a": MyInt(3)}] + masked_tree = sp.tree_mask(tree, cond=lambda _: True) + + masked_tree + #[MaskedInt(value=MyInt(value=1)), MaskedInt(value=MyInt(value=2)), {'a': MaskedInt(value=MyInt(value=3))}] + + sp.tree_unmask(masked_tree) + #[MyInt(value=1), MyInt(value=2), {'a': MyInt(value=3)}] + + # `is_masked` recognizes the new masked type + assert is_masked(masked_tree[0]) is True + ``` + + ## V0.12 ### Deprecations diff --git a/sepes/__init__.py b/sepes/__init__.py index d4ff6a3..fc69484 100644 --- a/sepes/__init__.py +++ b/sepes/__init__.py @@ -45,7 +45,7 @@ "backend_context", ] -__version__ = "0.12.0" +__version__ = "0.12.1" at.__module__ = "sepes" TreeClass.__module__ = "sepes" diff --git a/sepes/_src/tree_mask.py b/sepes/_src/tree_mask.py index 5f14370..7f57ab6 100644 --- a/sepes/_src/tree_mask.py +++ b/sepes/_src/tree_mask.py @@ -30,6 +30,30 @@ MaskType = Union[T, Callable[[Any], bool]] +def is_nondiff(value: Any) -> bool: + return is_nondiff.type_dispatcher(value) + + +is_nondiff.type_dispatcher = ft.singledispatch(lambda _: True) +is_nondiff.def_type = is_nondiff.type_dispatcher.register + + +for ndarray in arraylib.ndarrays: + + @is_nondiff.def_type(ndarray) + def is_nondiff_array(value) -> bool: + # return True if the node is non-inexact type, otherwise False + if arraylib.is_inexact(value): + return False + return True + + +@is_nondiff.def_type(float) +@is_nondiff.def_type(complex) +def _(_: float | complex) -> bool: + return False + + class _MaskedError(NamedTuple): opname: str @@ -123,86 +147,13 @@ def __eq__(self, other) -> bool: return arraylib.array_equal(lhs, rhs) -def mask(value: T) -> _MaskBase[T]: - # dispatching is used to customize the type of the wrapper based on the type - # of the value. For instance, hashable values dont need custom hash and - # equality implementations, so they are wrapped with a simpler wrapper. - # this approach avoids type logic in the wrapper equality and hash methods, - # thus effectively improving performance of the wrapper. - return mask.type_dispatcher(value) - - -mask.type_dispatcher = ft.singledispatch(_MaskedHashable) -mask.def_type = mask.type_dispatcher.register - - -for ndarray in arraylib.ndarrays: - - @mask.def_type(ndarray) - def mask_array(value: T) -> _MaskedArray[T]: - # wrap arrays with a custom wrapper that implements hash and equality - # arrays can be hashed by converting them to bytes and hashing the bytes - return _MaskedArray(value) - - -@mask.def_type(_MaskBase) -def _(value: _MaskBase[T]) -> _MaskBase[T]: - # idempotent mask operation, meaning that mask(mask(x)) == mask(x) - # this is useful to avoid recursive unwrapping of frozen values, plus its - # meaningless to mask a frozen value. - return value - - -def is_masked(value: Any) -> bool: - """Returns True if the value is a frozen wrapper.""" - return isinstance(value, _MaskBase) - - -def unmask(value: T) -> T: - return unmask.type_dispatcher(value) - - -unmask.type_dispatcher = ft.singledispatch(lambda x: x) -unmask.def_type = unmask.type_dispatcher.register - - -@unmask.def_type(_MaskBase) -def _(value: _MaskBase[T]) -> T: - return getattr(value, "__wrapped__") - - -def is_nondiff(value: Any) -> bool: - return is_nondiff.type_dispatcher(value) - - -is_nondiff.type_dispatcher = ft.singledispatch(lambda _: True) -is_nondiff.def_type = is_nondiff.type_dispatcher.register - - -for ndarray in arraylib.ndarrays: - - @is_nondiff.def_type(ndarray) - def is_nondiff_array(value) -> bool: - # return True if the node is non-inexact type, otherwise False - if arraylib.is_inexact(value): - return False - return True - - -@is_nondiff.def_type(float) -@is_nondiff.def_type(complex) -def _(_: float | complex) -> bool: - return False - - def _tree_mask_map( tree: T, cond: Callable[[Any], bool], - func: type | Callable[[Any], Any], + func: Callable[[Any], Any], *, is_leaf: Callable[[Any], None] | None = None, ): - if not isinstance(cond, Callable): # a callable that accepts a leaf and returns a boolean # but *not* a tree with the same structure as tree with boolean values. @@ -266,8 +217,39 @@ def tree_mask( >>> tree = (1., 2) # contains a non-differentiable node >>> square(sp.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2) + + Example: + Define a custom masking wrapper for a specific type. + + >>> import sepes as sp + >>> import jax + >>> import dataclasses as dc + >>> @dc.dataclass + ... class MyInt: + ... value: int + >>> @dc.dataclass + ... class MaskedInt: + ... value: MyInt + >>> # define a rule of how to mask an integer + >>> @sp.tree_mask.def_type(MyInt) + ... def mask_int(value): + ... return MaskedInt(value) + >>> # define a rule how to unmask the wrapper + >>> @sp.tree_unmask.def_type(MaskedInt) + ... def unmask_int(value): + ... return value.value + >>> tree = [MyInt(1), MyInt(2), {"a": MyInt(3)}] + >>> masked_tree = sp.tree_mask(tree, cond=lambda _: True) + >>> masked_tree + [MaskedInt(value=MyInt(value=1)), MaskedInt(value=MyInt(value=2)), {'a': MaskedInt(value=MyInt(value=3))}] + >>> sp.tree_unmask(masked_tree) + [MyInt(value=1), MyInt(value=2), {'a': MyInt(value=3)}] """ - return _tree_mask_map(tree, cond=cond, func=mask, is_leaf=is_leaf) + return _tree_mask_map(tree, cond=cond, func=tree_mask.dispatcher, is_leaf=is_leaf) + + +tree_mask.dispatcher = ft.singledispatch(_MaskedHashable) +tree_mask.def_type = tree_mask.dispatcher.register def tree_unmask(tree: T, cond: Callable[[Any], bool] = lambda _: True): @@ -303,8 +285,72 @@ def tree_unmask(tree: T, cond: Callable[[Any], bool] = lambda _: True): >>> tree = (1., 2) # contains a non-differentiable node >>> square(sp.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2) + + Example: + Define a custom masking wrapper for a specific type. + + >>> import sepes as sp + >>> import jax + >>> import dataclasses as dc + >>> @dc.dataclass + ... class MyInt: + ... value: int + >>> @dc.dataclass + ... class MaskedInt: + ... value: MyInt + >>> # define a rule of how to mask an integer + >>> @sp.tree_mask.def_type(MyInt) + ... def mask_int(value): + ... return MaskedInt(value) + >>> # define a rule how to unmask the wrapper + >>> @sp.tree_unmask.def_type(MaskedInt) + ... def unmask_int(value): + ... return value.value + >>> tree = [MyInt(1), MyInt(2), {"a": MyInt(3)}] + >>> masked_tree = sp.tree_mask(tree, cond=lambda _: True) + >>> masked_tree + [MaskedInt(value=MyInt(value=1)), MaskedInt(value=MyInt(value=2)), {'a': MaskedInt(value=MyInt(value=3))}] + >>> sp.tree_unmask(masked_tree) + [MyInt(value=1), MyInt(value=2), {'a': MyInt(value=3)}] """ - return _tree_mask_map(tree, cond=cond, func=unmask, is_leaf=is_masked) + return _tree_mask_map( + tree, + cond=cond, + func=tree_unmask.dispatcher, + is_leaf=is_masked, + ) + + +tree_unmask.dispatcher = ft.singledispatch(lambda x: x) +tree_unmask.def_type = tree_unmask.dispatcher.register + + +for ndarray in arraylib.ndarrays: + + @tree_mask.def_type(ndarray) + def mask_array(value: T) -> _MaskedArray[T]: + # wrap arrays with a custom wrapper that implements hash and equality + # arrays can be hashed by converting them to bytes and hashing the bytes + return _MaskedArray(value) + + +@tree_mask.def_type(_MaskBase) +def _(value: _MaskBase[T]) -> _MaskBase[T]: + # idempotent mask operation, meaning that mask(mask(x)) == mask(x) + # this is useful to avoid recursive unwrapping of frozen values, plus its + # meaningless to mask a frozen value. + return value + + +def is_masked(value: Any) -> bool: + """Returns True if the value is a frozen wrapper.""" + types = tuple(set(tree_unmask.dispatcher.registry) - {object}) + return isinstance(value, types) + + +@tree_unmask.def_type(_MaskBase) +def _(value: _MaskBase[T]) -> T: + return getattr(value, "__wrapped__") if is_package_avaiable("jax"): @@ -314,6 +360,6 @@ def tree_unmask(tree: T, cond: Callable[[Any], bool] = lambda _: True): # otherwise calling `freeze` inside a jax transformation on # a tracer will hide the tracer from jax and will cause leaked tracer # error. - @mask.def_type(jax.core.Tracer) + @tree_mask.def_type(jax.core.Tracer) def _(value: jax.core.Tracer) -> jax.core.Tracer: return value diff --git a/tests/test_mask.py b/tests/test_mask.py index 9b8de81..8d53d1b 100644 --- a/tests/test_mask.py +++ b/tests/test_mask.py @@ -425,3 +425,27 @@ def test_array_tree_mask_tree_unmask(): assert not (frozen_array == freeze(arraylib.ones((5, 6)))) # assert not (frozen_array == freeze(arraylib.ones((5, 5)).astype(arraylib.uint8))) assert hash(frozen_array) == hash(frozen_array) + + +def test_custom_mask_unmask_wrappers(): + + @dc.dataclass + class MyInt: + value: int + @dc.dataclass + class MaskedInt: + value: MyInt + # define a rule of how to mask instances of MyInt + @tree_mask.def_type(MyInt) + def mask_int(value): + return MaskedInt(value) + # define a rule how to unmask the MaskedInt wrapper + @tree_unmask.def_type(MaskedInt) + def unmask_int(value): + return value.value + tree = [MyInt(1), MyInt(2), {"a": MyInt(3)}] + masked_tree = tree_mask(tree, cond=lambda _: True) + masked_tree + assert tree_unmask(masked_tree) == [MyInt(value=1), MyInt(value=2), {'a': MyInt(value=3)}] + # check the mask type is recognized by is_masked + assert is_masked(masked_tree[0]) is True