Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 2, 2023
1 parent 632c72b commit 9897deb
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 5 deletions.
102 changes: 100 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
|[**Quick Example**](#QuickExample)

![Tests](https://github.com/ASEM000/serket/actions/workflows/tests.yml/badge.svg)
![pyver](https://img.shields.io/badge/python-3.7%203.8%203.9%203.10-red)
![pyver](https://img.shields.io/badge/python-3.8%203.9%203.10%203.11-red)
![codestyle](https://img.shields.io/badge/codestyle-black-black)
[![Downloads](https://pepy.tech/badge/serket)](https://pepy.tech/project/serket)
[![codecov](https://codecov.io/gh/ASEM000/serket/branch/main/graph/badge.svg?token=C6NXOK9EVS)](https://codecov.io/gh/ASEM000/serket)
Expand Down Expand Up @@ -75,4 +75,102 @@ for j, (xb, yb) in enumerate(zip(x_train, y_train)):
accuracy = accuracy_func(logits, y_train)

nn = sk.tree_unmask(nn)
```
```

#### Notable features:

<details><summary>🥱 Functional lazy initialization </summary>

Lazy initialization is particularly useful in scenarios where the dimensions of certain input features are not known in advance. For instance, consider a situation where the number of neurons required for a flattened image input is uncertain (**Example 1**), or the shape of the output from a flattened convolutional layer is not straightforward to calculate (**Example 2**). In such cases, lazy initialization allows the model to defer the allocation of memory for these uncertain dimensions until they are explicitly computed during the training process. This flexibility ensures that the model can handle varying input sizes and adapt its architecture accordingly, making it more versatile and efficient when dealing with different data samples or changing conditions.

_Example 1_

```python
import jax
import serket as sk

# 10 images from MNIST
x = jax.numpy.ones([5, 1, 28, 28])

layer = sk.nn.Sequential(
jax.numpy.ravel,
# lazy in_features inference pass `None`
sk.nn.Linear(None, 10),
jax.nn.relu,
sk.nn.Linear(10, 10),
jax.nn.softmax,
)
# materialize the layer with single image
_, layer = layer.at["__call__"](x[0])
# apply on batch
y = jax.vmap(layer)(x)
y.shape
(5, 10)
```

_Example 2_

```python
import jax
import serket as sk

# 10 images from MNIST
x = jax.numpy.ones([5, 1, 28, 28])

layer = sk.nn.Sequential(
sk.nn.Conv2D(1, 10, 3),
jax.nn.relu,
sk.nn.MaxPool2D(2),
jax.numpy.ravel,
# linear input size is inferred from
# previous layer output
sk.nn.Linear(None, 10),
jax.nn.softmax,
)

# materialize the layer with single image
_, layer = layer.at["__call__"](x[0])

# apply on batch
y = jax.vmap(layer)(x)

y.shape
# (5, 10)
```

</details>

<!-- <details><summary>Evaluation behavior handling</summary>
`serket` uses `functools` dispatching to modifiy a tree of layers to disable any training-related behavior during evaluation. It replaces certain layers, such as `Dropout` and `BatchNorm`, with equivalent layers that don't affect the model's output during evaluation.
for example:
```python
# dropout is replaced by an identity layer in evaluation mode
# by registering `tree_eval.def_eval(sk.nn.Dropout, sk.nn.Identity)`
import jax.numpy as jnp
import serket as sk
layer = sk.nn.Dropout(0.5)
sk.tree_eval(layer)
# Identity()
```
Let's break down the code snippet and its purpose:
1. The function `tree_eval(tree)` takes a tree of layers as input.
2. The function replaces specific layers in the tree with evaluation-specific layers.
Here are the modifications it makes to the tree:
- If a `Dropout` layer is encountered in the tree, it is replaced with an `Identity` layer. The `Identity` layer is a simple layer that doesn't introduce any changes to the input, making it effectively a no-op during evaluation.
- If a `BatchNorm` layer is encountered in the tree, it is replaced with an `EvalNorm` layer. The `EvalNorm` layer is designed to have the same behavior as `BatchNorm` during evaluation but not during training.
The purpose of these replacements is to ensure that the evaluation behavior is part of the
of the tree remains the same as its structure during training, without any additional randomness introduced by dropout layers or changes caused by batch normalization
Th
</details> -->
8 changes: 5 additions & 3 deletions serket/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,15 @@ def validate_spatial_ndim(func: Callable[P, T], attribute_name: str) -> Callable

def check_spatial_in_shape(x, spatial_ndim: int) -> None:
if x.ndim != spatial_ndim + 1:
spatial = {", ".join(("rows", "cols", "depths")[:spatial_ndim])}
spatial = ", ".join(("rows", "cols", "depths")[:spatial_ndim])
raise ValueError(
f"Dimesion mismatch error.\n"
f"Input should satisfy:\n"
f"- {(spatial_ndim + 1)=} dimension, got {x.ndim=}.\n"
f"- shape of (in_features, {spatial}), got {x.shape=}.\n"
f" - {(spatial_ndim + 1)=} dimension, but got {x.ndim=}.\n"
f" - shape of (in_features, {spatial}), but got {x.shape=}.\n"
+ (
# maybe the user apply the layer on a batched input
"\nThe input should be unbatched (no batch dimension).\n"
"To apply on batched input, use `jax.vmap(...)(input)`."
if x.ndim == spatial_ndim + 2
else ""
Expand Down

0 comments on commit 9897deb

Please sign in to comment.