Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 3, 2023
1 parent b886369 commit 9a82c09
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 38 deletions.
37 changes: 1 addition & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -335,39 +335,4 @@ y.shape
# (5, 10)
```

</details>

<!-- <details><summary>Evaluation behavior handling</summary>
`serket` uses `functools` dispatching to modifiy a tree of layers to disable any training-related behavior during evaluation. It replaces certain layers, such as `Dropout` and `BatchNorm`, with equivalent layers that don't affect the model's output during evaluation.
for example:
```python
# dropout is replaced by an identity layer in evaluation mode
# by registering `tree_eval.def_eval(sk.nn.Dropout, sk.nn.Identity)`
import jax.numpy as jnp
import serket as sk
layer = sk.nn.Dropout(0.5)
sk.tree_eval(layer)
# Identity()
```
Let's break down the code snippet and its purpose:
1. The function `tree_eval(tree)` takes a tree of layers as input.
2. The function replaces specific layers in the tree with evaluation-specific layers.
Here are the modifications it makes to the tree:
- If a `Dropout` layer is encountered in the tree, it is replaced with an `Identity` layer. The `Identity` layer is a simple layer that doesn't introduce any changes to the input, making it effectively a no-op during evaluation.
- If a `BatchNorm` layer is encountered in the tree, it is replaced with an `EvalNorm` layer. The `EvalNorm` layer is designed to have the same behavior as `BatchNorm` during evaluation but not during training.
The purpose of these replacements is to ensure that the evaluation behavior is part of the
of the tree remains the same as its structure during training, without any additional randomness introduced by dropout layers or changes caused by batch normalization
Th
</details> -->
</details>
3 changes: 2 additions & 1 deletion docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
:maxdepth: 1

notebooks/train_mnist
notebooks/train_bilstm
notebooks/train_bilstm
notebooks/train_eval
197 changes: 197 additions & 0 deletions docs/notebooks/train_eval.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion serket/nn/custom_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def tree_state(tree: T, array: jax.Array | None = None) -> T:
(e.g. :class:`nn.ConvGRU1DCell`). default: ``None``.
Returns:
A tree of state leaves if it has state, otherwise ``None``.
A tree of state leaves if it has state, otherwise ``NoState`` leaf.
Example:
>>> import jax.numpy as jnp
Expand Down

0 comments on commit 9a82c09

Please sign in to comment.