Skip to content

Commit

Permalink
docs , add fft rnn back
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 25, 2023
1 parent dd5c522 commit 37d950f
Show file tree
Hide file tree
Showing 15 changed files with 405 additions and 303 deletions.
9 changes: 8 additions & 1 deletion docs/API/api.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
``Serket`` NN API
======================

.. currentmodule:: serket

.. autofunction:: tree_state
.. autofunction:: tree_evaluation

.. toctree::
:maxdepth: 2
:caption: API Documentation
Expand All @@ -16,4 +21,6 @@
misc
random_transforms
activations
recurrent
recurrent


3 changes: 3 additions & 0 deletions docs/API/pytreeclass.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
``PyTreeClass`` exported API
=============================

Serket relies on ``PyTreeClass`` :ref: for its module system (``TreeClass``).
The entire ``PyTreeClass`` is reexported for convinence and is defiend below.

.. toctree::
:maxdepth: 2
:caption: API Documentation
Expand Down
3 changes: 2 additions & 1 deletion docs/API/pytreeclass_core.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
Core API
=============================


.. currentmodule:: serket

.. autoclass:: TreeClass
.. autoclass:::members: at
.. autofunction:: autoinit
.. autofunction:: leafwise
.. autofunction:: is_tree_equal
.. autofunction:: field
.. autofunction:: fields
9 changes: 9 additions & 0 deletions docs/API/recurrent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,21 @@ Recurrent
.. autoclass:: GRUCell
.. autoclass:: SimpleRNNCell
.. autoclass:: DenseCell

.. autoclass:: ConvLSTM1DCell
.. autoclass:: ConvLSTM2DCell
.. autoclass:: ConvLSTM3DCell
.. autoclass:: ConvGRU1DCell
.. autoclass:: ConvGRU2DCell
.. autoclass:: ConvGRU3DCell

.. autoclass:: FFTConvLSTM1DCell
.. autoclass:: FFTConvLSTM2DCell
.. autoclass:: FFTConvLSTM3DCell
.. autoclass:: FFTConvGRU1DCell
.. autoclass:: FFTConvGRU2DCell
.. autoclass:: FFTConvGRU3DCell

.. autoclass:: ScanRNN
:members:
__call__
3 changes: 3 additions & 0 deletions docs/_static/kol.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
61 changes: 56 additions & 5 deletions docs/index.rst
Original file line number Diff line number Diff line change
@@ -1,18 +1,68 @@



Serket
|logo| Serket
==============

- ``serket`` aims to be the most intuitive and easy-to-use neural network library in ``jax``.
- ``serket`` is fully transparent to ``jax`` transformation (e.g. ``vmap``, ``grad``, ``jit``,...).

.. |logo| image:: _static/kol.svg
:height: 40px

Installation
------------

Install from pip::
🛠️ Installation
----------------

Install from github::

pip install git+https://github.com/ASEM000/serket


🏃 Quick example
------------------

.. code-block:: python
import jax, jax.numpy as jnp
import serket as sk
import optax
x_train, y_train = ..., ...
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)
nn = sk.nn.Sequential(
sk.nn.Linear(28 * 28, 64, key=k1), jax.nn.relu,
sk.nn.Linear(64, 64, key=k2), jax.nn.relu,
sk.nn.Linear(64, 10, key=k3),
)
nn = sk.tree_mask(nn) # pass non-jaxtype through jax-transforms
optim = optax.adam(LR)
optim_state = optim.init(nn)
@ft.partial(jax.grad, has_aux=True)
def loss_func(nn, x, y):
nn = sk.tree_unmask(nn)
logits = jax.vmap(nn)(x)
onehot = jax.nn.one_hot(y, 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits, onehot))
return loss, (loss, logits)
@jax.jit
def train_step(nn, optim_state, x, y):
grads, (loss, logits) = loss_func(nn, x, y)
updates, optim_state = optim.update(grads, optim_state)
nn = optax.apply_updates(nn, updates)
return nn, optim_state, (loss, logits)
for j, (xb, yb) in enumerate(zip(x_train, y_train)):
nn, optim_state, (loss, logits) = train_step(nn, optim_state, xb, yb)
accuracy = accuracy_func(logits, y_train)
nn = sk.tree_unmask(nn)
pip install serket
.. toctree::
:caption: Examples
Expand All @@ -25,6 +75,7 @@ Install from pip::
:caption: API Documentation
:maxdepth: 1


API/api
API/pytreeclass

Expand Down
16 changes: 14 additions & 2 deletions serket/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@
ConvLSTM2DCell,
ConvLSTM3DCell,
DenseCell,
FFTConvGRU1DCell,
FFTConvGRU2DCell,
FFTConvGRU3DCell,
FFTConvLSTM1DCell,
FFTConvLSTM2DCell,
FFTConvLSTM3DCell,
GRUCell,
LSTMCell,
ScanRNN,
Expand Down Expand Up @@ -283,15 +289,21 @@
"GRUCell",
"SimpleRNNCell",
"DenseCell",
# spatial rnn
"ConvLSTM1DCell",
"ConvLSTM2DCell",
"ConvLSTM3DCell",
"ConvGRU1DCell",
"ConvGRU2DCell",
"ConvGRU3DCell",
# spatial fft rnn
"FFTConvGRU1DCell",
"FFTConvGRU2DCell",
"FFTConvGRU3DCell",
"FFTConvLSTM1DCell",
"FFTConvLSTM2DCell",
"FFTConvLSTM3DCell",
"ScanRNN",
# Polynomial
"Polynomial",
# Flatten
"Flatten",
"Unflatten",
Expand Down
7 changes: 3 additions & 4 deletions serket/nn/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@
from __future__ import annotations

import functools as ft
from typing import Any

import jax
import jax.random as jr

import serket as sk


@sk.autoinit
class Sequential(sk.TreeClass):
"""A sequential container for layers.
Expand All @@ -44,14 +42,15 @@ class Sequential(sk.TreeClass):
it might have a key argument for random number generation.
"""

# allow list then cast to tuple avoid mutability issues
layers: tuple[Any, ...] = sk.field(kind="VAR_POS")
def __init__(self, *layers):
self.layers = layers

def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array:
for key, layer in zip(jr.split(key, len(self.layers)), self.layers):
try:
x = layer(x, key=key)
except TypeError:
# a `layer` or a `function` without a key argument
x = layer(x)
return x

Expand Down
24 changes: 24 additions & 0 deletions serket/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,10 @@ class SeparableConv1D(sk.TreeClass):
in_features: Number of input feature maps, for 1D convolution this is the
length of the input, for 2D convolution this is the number of input
channels, for 3D convolution this is the number of input channels.
out_features: Number of output features maps, for 1D convolution this is
the length of the output, for 2D convolution this is the number of
output channels, for 3D convolution this is the number of output
channels.
kernel_size: Size of the convolutional kernel. accepts:
- single integer for same kernel size in all dimensions.
Expand Down Expand Up @@ -980,6 +984,10 @@ class SeparableConv2D(sk.TreeClass):
in_features: Number of input feature maps, for 1D convolution this is the
length of the input, for 2D convolution this is the number of input
channels, for 3D convolution this is the number of input channels.
out_features: Number of output features maps, for 1D convolution this is
the length of the output, for 2D convolution this is the number of
output channels, for 3D convolution this is the number of output
channels.
kernel_size: Size of the convolutional kernel. accepts:
- single integer for same kernel size in all dimensions.
Expand Down Expand Up @@ -1081,6 +1089,10 @@ class SeparableConv3D(sk.TreeClass):
in_features: Number of input feature maps, for 1D convolution this is the
length of the input, for 2D convolution this is the number of input
channels, for 3D convolution this is the number of input channels.
out_features: Number of output features maps, for 1D convolution this is
the length of the output, for 2D convolution this is the number of
output channels, for 3D convolution this is the number of output
channels.
kernel_size: Size of the convolutional kernel. accepts:
- single integer for same kernel size in all dimensions.
Expand Down Expand Up @@ -1267,6 +1279,10 @@ class Conv1DLocal(ConvNDLocal):
in_features: Number of input feature maps, for 1D convolution this is the
length of the input, for 2D convolution this is the number of input
channels, for 3D convolution this is the number of input channels.
out_features: Number of output features maps, for 1D convolution this is
the length of the output, for 2D convolution this is the number of
output channels, for 3D convolution this is the number of output
channels.
kernel_size: Size of the convolutional kernel. accepts:
- single integer for same kernel size in all dimensions.
Expand Down Expand Up @@ -1324,6 +1340,10 @@ class Conv2DLocal(ConvNDLocal):
in_features: Number of input feature maps, for 1D convolution this is the
length of the input, for 2D convolution this is the number of input
channels, for 3D convolution this is the number of input channels.
out_features: Number of output features maps, for 1D convolution this is
the length of the output, for 2D convolution this is the number of
output channels, for 3D convolution this is the number of output
channels.
kernel_size: Size of the convolutional kernel. accepts:
- single integer for same kernel size in all dimensions.
Expand Down Expand Up @@ -1381,6 +1401,10 @@ class Conv3DLocal(ConvNDLocal):
in_features: Number of input feature maps, for 1D convolution this is the
length of the input, for 2D convolution this is the number of input
channels, for 3D convolution this is the number of input channels.
out_features: Number of output features maps, for 1D convolution this is
the length of the output, for 2D convolution this is the number of
output channels, for 3D convolution this is the number of output
channels.
kernel_size: Size of the convolutional kernel. accepts:
- single integer for same kernel size in all dimensions.
Expand Down
13 changes: 6 additions & 7 deletions serket/nn/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,23 @@
from __future__ import annotations

import functools as ft
from typing import Any, Callable, TypeVar
from typing import Any, Callable

import jax

T = TypeVar("T")


def tree_evaluation(tree: T) -> T:
def tree_evaluation(tree):
"""Modify tree layers to disable any trainning related behavior.
For example, `Dropout` layers drop probability is set to 0.0. and `BatchNorm`
layer `track_running_stats` is set to False when evaluating the tree.
For example, :class:`nn.Dropout` layer is replaced by an :class:`nn.Identity` layer
and :class:`nn.BatchNorm` layer ``evaluation`` is set to ``True`` when
evaluating the tree.
Args:
tree: A tree of layers.
Returns:
A tree of layers with evaluation behavior.
A tree of layers with evaluation behavior of same structure as ``tree``.
Example:
>>> # dropout is replaced by an identity layer in evaluation mode
Expand Down
12 changes: 12 additions & 0 deletions serket/nn/fft_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,10 @@ class SeparableFFTConv1D(sk.TreeClass):
in_features: Number of input feature maps, for 1D convolution this is the
length of the input, for 2D convolution this is the number of input
channels, for 3D convolution this is the number of input channels.
out_features: Number of output features maps, for 1D convolution this is
the length of the output, for 2D convolution this is the number of
output channels, for 3D convolution this is the number of output
channels.
kernel_size: Size of the convolutional kernel. accepts:
- single integer for same kernel size in all dimnsions.
Expand Down Expand Up @@ -1076,6 +1080,10 @@ class SeparableFFTConv2D(sk.TreeClass):
in_features: Number of input feature maps, for 1D convolution this is the
length of the input, for 2D convolution this is the number of input
channels, for 3D convolution this is the number of input channels.
out_features: Number of output features maps, for 1D convolution this is
the length of the output, for 2D convolution this is the number of
output channels, for 3D convolution this is the number of output
channels.
kernel_size: Size of the convolutional kernel. accepts:
- single integer for same kernel size in all dimnsions.
Expand Down Expand Up @@ -1186,6 +1194,10 @@ class SeparableFFTConv3D(sk.TreeClass):
in_features: Number of input feature maps, for 1D convolution this is the
length of the input, for 2D convolution this is the number of input
channels, for 3D convolution this is the number of input channels.
out_features: Number of output features maps, for 1D convolution this is
the length of the output, for 2D convolution this is the number of
output channels, for 3D convolution this is the number of output
channels.
kernel_size: Size of the convolutional kernel. accepts:
- single integer for same kernel size in all dimnsions.
Expand Down
17 changes: 12 additions & 5 deletions serket/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ def bn_train_step(x, state):

state = jax.lax.stop_gradient(state)

x = jax.lax.cond(evalution, jax.lax.stop_gradient, lambda x: x, x)

if gamma is not None:
output *= jnp.reshape(gamma, broadcast_shape)

Expand Down Expand Up @@ -332,15 +334,20 @@ def _(
class BatchNorm(sk.TreeClass):
"""Applies normalization over batched inputs`
Works under ``jax.vmap(BatchNorm(...), in_axes=(0, None))``, otherwise will be a no-op.
.. warning::
Works under
- ``jax.vmap(BatchNorm(...), in_axes=(0, None))(x, state)``
- ``jax.vmap(BatchNorm(...))(x)``
otherwise will be a no-op.
Evaluation behavior:
``output = (x - running_mean) / sqrt(running_var + eps)``
- ``output = (x - running_mean) / sqrt(running_var + eps)``
Training behavior:
``output = (x - batch_mean) / sqrt(batch_var + eps)``
``running_mean = momentum * running_mean + (1 - momentum) * batch_mean``
``running_var = momentum * running_var + (1 - momentum) * batch_var``
- ``output = (x - batch_mean) / sqrt(batch_var + eps)``
- ``running_mean = momentum * running_mean + (1 - momentum) * batch_mean``
- ``running_var = momentum * running_var + (1 - momentum) * batch_var``
Args:
in_features: the shape of the input to be normalized.
Expand Down
Loading

0 comments on commit 37d950f

Please sign in to comment.