Skip to content

Commit

Permalink
Update tree_mask.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Mar 29, 2024
1 parent 0735d06 commit a98a98b
Showing 1 changed file with 3 additions and 30 deletions.
33 changes: 3 additions & 30 deletions sepes/_src/tree_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,37 +172,10 @@ def _(value: _MaskBase[T]) -> T:


def is_nondiff(value: Any) -> bool:
"""Returns True for non-inexact types, False otherwise.
Args:
value: A value to check.
Note:
- :func:`.is_nondiff` uses single dispatch to support custom types. To define
a custom behavior for a certain type, use ``is_nondiff.def_type(type, func)``.
Example:
>>> import sepes as sp
>>> import jax.numpy as jnp
>>> sp.is_nondiff(jnp.array(1)) # int array is non-diff type
True
>>> sp.is_nondiff(jnp.array(1.)) # float array is diff type
False
>>> sp.is_nondiff(1) # int is non-diff type
True
>>> sp.is_nondiff(1.) # float is diff type
False
Note:
This function is meant to be used with ``jax.tree_map`` to
create a mask for non-differentiable nodes in a tree, that can be used
to freeze the non-differentiable nodes before passing the tree to a
``jax`` transformation.
"""
return is_nondiff.type_dispatcher(value)


is_nondiff.type_dispatcher = ft.singledispatch(lambda x: True)
is_nondiff.type_dispatcher = ft.singledispatch(lambda _: True)
is_nondiff.def_type = is_nondiff.type_dispatcher.register


Expand Down Expand Up @@ -261,8 +234,8 @@ def tree_mask(
Args:
tree: A pytree of values.
cond: A callable that accepts a leaf and returns a boolean to mark the leaf
for masking. Defaults to masking non-differentiable leaf nodes that
are not instances of of python float, python complex, or inexact
for masking. Defaults to masking non-differentiable leaf nodes that
are not instances of of python float, python complex, or inexact
array types.
is_leaf: A callable that accepts a leaf and returns a boolean. If
provided, it is used to determine if a value is a leaf. for example,
Expand Down

0 comments on commit a98a98b

Please sign in to comment.