Skip to content

Commit

Permalink
doc
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 25, 2023
1 parent 11f2f04 commit bd78a5f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
4 changes: 2 additions & 2 deletions serket/nn/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def tree_evaluation(tree):
Identity()
"""

types: set[type] = set(tree_evaluation.evaluation_dispatcher.registry) - {object}

def is_leaf(x: Callable[[Any], bool]) -> bool:
types = set(tree_evaluation.evaluation_dispatcher.registry.keys())
types.discard(object)
return isinstance(x, tuple(types))

return jax.tree_map(tree_evaluation.evaluation_dispatcher, tree, is_leaf=is_leaf)
Expand Down
28 changes: 28 additions & 0 deletions serket/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,34 @@ class BatchNorm(sk.TreeClass):
evaluation mode. In this case, this module will always use the running
estimates of the batch statistics during training.
Example:
>>> import jax
>>> import serket as sk
>>> bn = sk.nn.BatchNorm(10)
>>> state = sk.tree_state(bn)
>>> x = jax.random.uniform(jax.random.PRNGKey(0), shape=(5, 10))
>>> x, state = jax.vmap(bn, in_axes=(0, None))(x, state)
Example:
>>> # working with multiple states
>>> import jax
>>> import serket as sk
>>> @sk.autoinit
... class Tree(sk.TreeClass):
... bn1: sk.nn.BatchNorm = sk.nn.BatchNorm(10)
... bn2: sk.nn.BatchNorm = sk.nn.BatchNorm(10)
... def __call__(self, x, state):
... x, bn1 = self.bn1(x, state.bn1)
... x, bn2 = self.bn2(x, state.bn2)
... # update the output state
... state = state.at["bn1"].set(bn1).at["bn2"].set(bn2)
... return x, state
>>> tree = Tree()
>>> # initialize state as the same structure as tree
>>> state = sk.tree_state(tree)
>>> x = jax.random.uniform(jax.random.PRNGKey(0), shape=(5, 10))
>>> x, state = jax.vmap(tree, in_axes=(0, None))(x, state)
Note:
https://keras.io/api/layers/normalization_layers/batch_normalization/
"""
Expand Down
4 changes: 2 additions & 2 deletions serket/nn/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def tree_state(tree: T, array: jax.Array | None = None) -> T:
'some state'
"""

types: set[type] = set(tree_state.state_dispatcher.registry) - {object}

def is_leaf(x: Callable[[Any], bool]) -> bool:
types = set(tree_state.state_dispatcher.registry.keys())
types.discard(object)
return isinstance(x, tuple(types))

def dispatch_func(leaf):
Expand Down

0 comments on commit bd78a5f

Please sign in to comment.