Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 2, 2023
1 parent 40c6f17 commit f501a40
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 4 deletions.
197 changes: 194 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import jax, jax.numpy as jnp
import serket as sk
import optax

x_train, y_train = ..., ... # samples, 1, 28, 28
x_train, y_train = ..., ...
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)

net = sk.nn.Sequential(
Expand Down Expand Up @@ -80,11 +80,202 @@ for j, (xb, yb) in enumerate(zip(x_train, y_train)):
net = sk.tree_unmask(net)
```

#### Notable features:
### 🧠 Neural network package: `serket.nn`

<table>

<tr>

<td>

[Linear](https://serket.readthedocs.io/en/latest/API/linear.html)

</td>

<td>

- `Linear`, `Multilinear`, `GeneralLinear`, `Identity`, `FNN`, `MLP`, `Embedding`

</td>
</tr>

<tr>

<td>

[Convolution](https://serket.readthedocs.io/en/latest/API/convolution.html)

</td>

<td>

- `{Conv,FFTConv}{1D,2D,3D}`
- `{Conv,FFTConv}{1D,2D,3D}Transpose`
- `{Depthwise,Separable}{Conv,FFTConv}{1D,2D,3D}`
- `Conv{1D,2D,3D}Local`

</td>

</tr>

<tr>

<td>

[Containers](https://serket.readthedocs.io/en/latest/API/containers.html)

</td>

<td>

- `Sequential`, `RandomApply`
</td>

</tr>

<tr>

<td>

[Pooling](https://serket.readthedocs.io/en/latest/API/pooling.html)

</td>

<td>

- `{Avg,Max,LP}Pool{1D,2D,3D}`
- `Global{Avg,Max}Pool{1D,2D,3D}`
- `Adaptive{Avg,Max}Pool{1D,2D,3D}`

</td>

</tr>

<tr>

<td>

[Reshaping](https://serket.readthedocs.io/en/latest/API/reshaping.html)

</td>

<td>

- `Flatten`, `Unflatten`
- `Repeat{1D,2D,3D}`
- `Resize{1D,2D,3D}`
- `Upsample{1D,2D,3D}`
- `Pad{1D,2D,3D}`
- `{Crop,RandomCrop}{1D,2D,3D}`
- `RandomZoom2D`

</td>

</tr>

<tr>

<td>

[Normalization](https://serket.readthedocs.io/en/latest/API/normalization.html)

</td>

<td>

- `{Layer,Instance,Group,Batch}Norm`

</td>

</tr>

<tr>

<td>

[Image](https://serket.readthedocs.io/en/latest/API/image.html#)

</td>

<td>

- `{Avg,Gaussian}Blur2D`
- `{Filter,FFTFilter}2D`
- `HistogramEqualization2D`
- `{Adjust,Random}Contrast2D`
- `PixelShuffle2D`

</td>

</tr>

<tr>

<td>

[Dropout](https://serket.readthedocs.io/en/latest/API/dropout.html)

</td>

<td>

- `Dropout`
- `Dropout{1D,2D,3D}`
- `RandomCutout{1D,2D}`

</td>

</tr>

<tr>

<td>

[Activations](https://serket.readthedocs.io/en/latest/API/activations.html)

</td>

<td>

- `Adaptive{LeakyReLU,ReLU,Sigmoid,Tanh}`
- `CeLU`,`ELU`,`GELU`,`GLU`
- `Hard{Shrink,Sigmoid,Swish,Tanh}`
- `Soft{Plus,Sign,Shrink}`
- `LeakyReLU`,`LogSigmoid`,`LogSoftmax`,`Mish`,`PReLU`
- `ReLU`,`ReLU6`,`SeLU`,`Sigmoid`
- `Swish`,`Tanh`,`TanhShrink`, `ThresholdedReLU`, `Snake`

</td>

</tr>

<tr>

<td>

[Recurrent](https://serket.readthedocs.io/en/latest/API/recurrent.html)

</td>

<td>

- `{Dense,SimpleRNN,LSTM,GRU}Cell`
- `{Conv,FFTConv}{LSTM,GRU}{1D,2D,3D}Cell`
- `ScanRNN`

</td>

</tr>

</table>

#### Other features:

<details><summary>🥱 Functional lazy initialization </summary>

Lazy initialization is particularly useful in scenarios where the dimensions of certain input features are not known in advance. For instance, consider a situation where the number of neurons required for a flattened image input is uncertain (**Example 1**), or the shape of the output from a flattened convolutional layer is not straightforward to calculate (**Example 2**). In such cases, lazy initialization allows the model to defer the allocation of memory for these uncertain dimensions until they are explicitly computed during the training process. This flexibility ensures that the model can handle varying input sizes and adapt its architecture accordingly, making it more versatile and efficient when dealing with different data samples or changing conditions.
Lazy initialization is particularly useful in scenarios where the dimensions of certain input features are not known in advance. For instance, consider a situation where the number of neurons required for a flattened image input is uncertain (**Example 1**), or the shape of the output from a flattened convolutional layer is not straightforward to calculate (**Example 2**).

In such cases, lazy initialization allows the model to defer the allocation of memory for these uncertain dimensions until they are explicitly computed during the training process. This flexibility ensures that the model can handle varying input sizes and adapt its architecture accordingly, making it more versatile and efficient when dealing with different data samples or changing conditions.

_Example 1_

Expand Down
2 changes: 1 addition & 1 deletion serket/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def maybe_lazy_call(
as ``func``.
"""

# @ft.wraps(func)
@ft.wraps(func)
def inner(instance, *a, **k):
if not is_lazy(instance, *a, **k):
return func(instance, *a, **k)
Expand Down

0 comments on commit f501a40

Please sign in to comment.