diff --git a/sepes/_src/tree_mask.py b/sepes/_src/tree_mask.py index d6a582c..482975f 100644 --- a/sepes/_src/tree_mask.py +++ b/sepes/_src/tree_mask.py @@ -172,37 +172,10 @@ def _(value: _MaskBase[T]) -> T: def is_nondiff(value: Any) -> bool: - """Returns True for non-inexact types, False otherwise. - - Args: - value: A value to check. - - Note: - - :func:`.is_nondiff` uses single dispatch to support custom types. To define - a custom behavior for a certain type, use ``is_nondiff.def_type(type, func)``. - - Example: - >>> import sepes as sp - >>> import jax.numpy as jnp - >>> sp.is_nondiff(jnp.array(1)) # int array is non-diff type - True - >>> sp.is_nondiff(jnp.array(1.)) # float array is diff type - False - >>> sp.is_nondiff(1) # int is non-diff type - True - >>> sp.is_nondiff(1.) # float is diff type - False - - Note: - This function is meant to be used with ``jax.tree_map`` to - create a mask for non-differentiable nodes in a tree, that can be used - to freeze the non-differentiable nodes before passing the tree to a - ``jax`` transformation. - """ return is_nondiff.type_dispatcher(value) -is_nondiff.type_dispatcher = ft.singledispatch(lambda x: True) +is_nondiff.type_dispatcher = ft.singledispatch(lambda _: True) is_nondiff.def_type = is_nondiff.type_dispatcher.register @@ -261,8 +234,8 @@ def tree_mask( Args: tree: A pytree of values. cond: A callable that accepts a leaf and returns a boolean to mark the leaf - for masking. Defaults to masking non-differentiable leaf nodes that - are not instances of of python float, python complex, or inexact + for masking. Defaults to masking non-differentiable leaf nodes that + are not instances of of python float, python complex, or inexact array types. is_leaf: A callable that accepts a leaf and returns a boolean. If provided, it is used to determine if a value is a leaf. for example,