diff --git a/docs/API/api.rst b/docs/API/api.rst index 2a5f9f6..89491a7 100644 --- a/docs/API/api.rst +++ b/docs/API/api.rst @@ -1,6 +1,11 @@ ``Serket`` NN API ====================== +.. currentmodule:: serket + +.. autofunction:: tree_state +.. autofunction:: tree_evaluation + .. toctree:: :maxdepth: 2 :caption: API Documentation @@ -16,4 +21,6 @@ misc random_transforms activations - recurrent \ No newline at end of file + recurrent + + diff --git a/docs/API/pytreeclass.rst b/docs/API/pytreeclass.rst index edda822..2c79789 100644 --- a/docs/API/pytreeclass.rst +++ b/docs/API/pytreeclass.rst @@ -1,6 +1,9 @@ ``PyTreeClass`` exported API ============================= +Serket relies on ``PyTreeClass`` :ref: for its module system (``TreeClass``). +The entire ``PyTreeClass`` is reexported for convinence and is defiend below. + .. toctree:: :maxdepth: 2 :caption: API Documentation diff --git a/docs/API/pytreeclass_core.rst b/docs/API/pytreeclass_core.rst index b3aa901..c97f569 100644 --- a/docs/API/pytreeclass_core.rst +++ b/docs/API/pytreeclass_core.rst @@ -1,11 +1,12 @@ Core API ============================= - .. currentmodule:: serket .. autoclass:: TreeClass .. autoclass:::members: at +.. autofunction:: autoinit +.. autofunction:: leafwise .. autofunction:: is_tree_equal .. autofunction:: field .. autofunction:: fields \ No newline at end of file diff --git a/docs/API/recurrent.rst b/docs/API/recurrent.rst index 7069e5a..6da99de 100644 --- a/docs/API/recurrent.rst +++ b/docs/API/recurrent.rst @@ -7,12 +7,21 @@ Recurrent .. autoclass:: GRUCell .. autoclass:: SimpleRNNCell .. autoclass:: DenseCell + .. autoclass:: ConvLSTM1DCell .. autoclass:: ConvLSTM2DCell .. autoclass:: ConvLSTM3DCell .. autoclass:: ConvGRU1DCell .. autoclass:: ConvGRU2DCell .. autoclass:: ConvGRU3DCell + +.. autoclass:: FFTConvLSTM1DCell +.. autoclass:: FFTConvLSTM2DCell +.. autoclass:: FFTConvLSTM3DCell +.. autoclass:: FFTConvGRU1DCell +.. autoclass:: FFTConvGRU2DCell +.. autoclass:: FFTConvGRU3DCell + .. autoclass:: ScanRNN :members: __call__ \ No newline at end of file diff --git a/docs/_static/kol.svg b/docs/_static/kol.svg new file mode 100644 index 0000000..4a872af --- /dev/null +++ b/docs/_static/kol.svg @@ -0,0 +1,3 @@ + + + diff --git a/docs/index.rst b/docs/index.rst index 8e28971..ef1f181 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,18 +1,68 @@ -Serket +|logo| Serket +============== - ``serket`` aims to be the most intuitive and easy-to-use neural network library in ``jax``. - ``serket`` is fully transparent to ``jax`` transformation (e.g. ``vmap``, ``grad``, ``jit``,...). +.. |logo| image:: _static/kol.svg + :height: 40px -Installation ------------- -Install from pip:: +🛠️ Installation +---------------- + +Install from github:: + + pip install git+https://github.com/ASEM000/serket + + +🏃 Quick example +------------------ + +.. code-block:: python + + import jax, jax.numpy as jnp + import serket as sk + import optax + + x_train, y_train = ..., ... + k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3) + + nn = sk.nn.Sequential( + sk.nn.Linear(28 * 28, 64, key=k1), jax.nn.relu, + sk.nn.Linear(64, 64, key=k2), jax.nn.relu, + sk.nn.Linear(64, 10, key=k3), + ) + + nn = sk.tree_mask(nn) # pass non-jaxtype through jax-transforms + optim = optax.adam(LR) + optim_state = optim.init(nn) + + @ft.partial(jax.grad, has_aux=True) + def loss_func(nn, x, y): + nn = sk.tree_unmask(nn) + logits = jax.vmap(nn)(x) + onehot = jax.nn.one_hot(y, 10) + loss = jnp.mean(optax.softmax_cross_entropy(logits, onehot)) + return loss, (loss, logits) + + @jax.jit + def train_step(nn, optim_state, x, y): + grads, (loss, logits) = loss_func(nn, x, y) + updates, optim_state = optim.update(grads, optim_state) + nn = optax.apply_updates(nn, updates) + return nn, optim_state, (loss, logits) + + for j, (xb, yb) in enumerate(zip(x_train, y_train)): + nn, optim_state, (loss, logits) = train_step(nn, optim_state, xb, yb) + accuracy = accuracy_func(logits, y_train) + + nn = sk.tree_unmask(nn) + - pip install serket .. toctree:: :caption: Examples @@ -25,6 +75,7 @@ Install from pip:: :caption: API Documentation :maxdepth: 1 + API/api API/pytreeclass diff --git a/serket/nn/__init__.py b/serket/nn/__init__.py index bd2ea8d..7f29c64 100644 --- a/serket/nn/__init__.py +++ b/serket/nn/__init__.py @@ -122,6 +122,12 @@ ConvLSTM2DCell, ConvLSTM3DCell, DenseCell, + FFTConvGRU1DCell, + FFTConvGRU2DCell, + FFTConvGRU3DCell, + FFTConvLSTM1DCell, + FFTConvLSTM2DCell, + FFTConvLSTM3DCell, GRUCell, LSTMCell, ScanRNN, @@ -283,15 +289,21 @@ "GRUCell", "SimpleRNNCell", "DenseCell", + # spatial rnn "ConvLSTM1DCell", "ConvLSTM2DCell", "ConvLSTM3DCell", "ConvGRU1DCell", "ConvGRU2DCell", "ConvGRU3DCell", + # spatial fft rnn + "FFTConvGRU1DCell", + "FFTConvGRU2DCell", + "FFTConvGRU3DCell", + "FFTConvLSTM1DCell", + "FFTConvLSTM2DCell", + "FFTConvLSTM3DCell", "ScanRNN", - # Polynomial - "Polynomial", # Flatten "Flatten", "Unflatten", diff --git a/serket/nn/containers.py b/serket/nn/containers.py index 0e8c77d..bcdf717 100644 --- a/serket/nn/containers.py +++ b/serket/nn/containers.py @@ -15,7 +15,6 @@ from __future__ import annotations import functools as ft -from typing import Any import jax import jax.random as jr @@ -23,7 +22,6 @@ import serket as sk -@sk.autoinit class Sequential(sk.TreeClass): """A sequential container for layers. @@ -44,14 +42,15 @@ class Sequential(sk.TreeClass): it might have a key argument for random number generation. """ - # allow list then cast to tuple avoid mutability issues - layers: tuple[Any, ...] = sk.field(kind="VAR_POS") + def __init__(self, *layers): + self.layers = layers def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array: for key, layer in zip(jr.split(key, len(self.layers)), self.layers): try: x = layer(x, key=key) except TypeError: + # a `layer` or a `function` without a key argument x = layer(x) return x diff --git a/serket/nn/convolution.py b/serket/nn/convolution.py index 937e6ae..c20eaa4 100644 --- a/serket/nn/convolution.py +++ b/serket/nn/convolution.py @@ -878,6 +878,10 @@ class SeparableConv1D(sk.TreeClass): in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. + out_features: Number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimensions. @@ -980,6 +984,10 @@ class SeparableConv2D(sk.TreeClass): in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. + out_features: Number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimensions. @@ -1081,6 +1089,10 @@ class SeparableConv3D(sk.TreeClass): in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. + out_features: Number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimensions. @@ -1267,6 +1279,10 @@ class Conv1DLocal(ConvNDLocal): in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. + out_features: Number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimensions. @@ -1324,6 +1340,10 @@ class Conv2DLocal(ConvNDLocal): in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. + out_features: Number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimensions. @@ -1381,6 +1401,10 @@ class Conv3DLocal(ConvNDLocal): in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. + out_features: Number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimensions. diff --git a/serket/nn/evaluation.py b/serket/nn/evaluation.py index a3206c4..8e795e0 100644 --- a/serket/nn/evaluation.py +++ b/serket/nn/evaluation.py @@ -17,24 +17,23 @@ from __future__ import annotations import functools as ft -from typing import Any, Callable, TypeVar +from typing import Any, Callable import jax -T = TypeVar("T") - -def tree_evaluation(tree: T) -> T: +def tree_evaluation(tree): """Modify tree layers to disable any trainning related behavior. - For example, `Dropout` layers drop probability is set to 0.0. and `BatchNorm` - layer `track_running_stats` is set to False when evaluating the tree. + For example, :class:`nn.Dropout` layer is replaced by an :class:`nn.Identity` layer + and :class:`nn.BatchNorm` layer ``evaluation`` is set to ``True`` when + evaluating the tree. Args: tree: A tree of layers. Returns: - A tree of layers with evaluation behavior. + A tree of layers with evaluation behavior of same structure as ``tree``. Example: >>> # dropout is replaced by an identity layer in evaluation mode diff --git a/serket/nn/fft_convolution.py b/serket/nn/fft_convolution.py index d19e8e5..8d02ad1 100644 --- a/serket/nn/fft_convolution.py +++ b/serket/nn/fft_convolution.py @@ -965,6 +965,10 @@ class SeparableFFTConv1D(sk.TreeClass): in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. + out_features: Number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimnsions. @@ -1076,6 +1080,10 @@ class SeparableFFTConv2D(sk.TreeClass): in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. + out_features: Number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimnsions. @@ -1186,6 +1194,10 @@ class SeparableFFTConv3D(sk.TreeClass): in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. + out_features: Number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimnsions. diff --git a/serket/nn/normalization.py b/serket/nn/normalization.py index d69dbc8..c63a670 100644 --- a/serket/nn/normalization.py +++ b/serket/nn/normalization.py @@ -278,6 +278,8 @@ def bn_train_step(x, state): state = jax.lax.stop_gradient(state) + x = jax.lax.cond(evalution, jax.lax.stop_gradient, lambda x: x, x) + if gamma is not None: output *= jnp.reshape(gamma, broadcast_shape) @@ -332,15 +334,20 @@ def _( class BatchNorm(sk.TreeClass): """Applies normalization over batched inputs` - Works under ``jax.vmap(BatchNorm(...), in_axes=(0, None))``, otherwise will be a no-op. + .. warning:: + Works under + - ``jax.vmap(BatchNorm(...), in_axes=(0, None))(x, state)`` + - ``jax.vmap(BatchNorm(...))(x)`` + + otherwise will be a no-op. Evaluation behavior: - ``output = (x - running_mean) / sqrt(running_var + eps)`` + - ``output = (x - running_mean) / sqrt(running_var + eps)`` Training behavior: - ``output = (x - batch_mean) / sqrt(batch_var + eps)`` - ``running_mean = momentum * running_mean + (1 - momentum) * batch_mean`` - ``running_var = momentum * running_var + (1 - momentum) * batch_var`` + - ``output = (x - batch_mean) / sqrt(batch_var + eps)`` + - ``running_mean = momentum * running_mean + (1 - momentum) * batch_mean`` + - ``running_var = momentum * running_var + (1 - momentum) * batch_var`` Args: in_features: the shape of the input to be normalized. diff --git a/serket/nn/recurrent.py b/serket/nn/recurrent.py index f9eacab..c60fa1e 100644 --- a/serket/nn/recurrent.py +++ b/serket/nn/recurrent.py @@ -437,7 +437,6 @@ def __init__( *, strides: StridesType = 1, padding: PaddingType = "SAME", - input_dilation: DilationType = 1, kernel_dilation: DilationType = 1, weight_init_func: InitType = "glorot_uniform", bias_init_func: InitType = "zeros", @@ -445,7 +444,6 @@ def __init__( act_func: ActivationType | None = "tanh", recurrent_act_func: ActivationType | None = "hard_sigmoid", key: jr.KeyArray = jr.PRNGKey(0), - conv_layer: Any = None, ): k1, k2 = jr.split(key, 2) @@ -454,26 +452,24 @@ def __init__( self.act_func = resolve_activation(act_func) self.recurrent_act_func = resolve_activation(recurrent_act_func) - self.in_to_hidden = conv_layer( + self.in_to_hidden = self.convolution( in_features, hidden_features * 4, kernel_size, strides=strides, padding=padding, - input_dilation=input_dilation, kernel_dilation=kernel_dilation, weight_init_func=weight_init_func, bias_init_func=bias_init_func, key=k1, ) - self.hidden_to_hidden = conv_layer( + self.hidden_to_hidden = self.convolution( hidden_features, hidden_features * 4, kernel_size, strides=strides, padding=padding, - input_dilation=input_dilation, kernel_dilation=kernel_dilation, weight_init_func=recurrent_weight_init_func, bias_init_func=None, @@ -498,31 +494,38 @@ def __call__(self, x: jax.Array, state: ConvLSTMNDState, **k) -> ConvLSTMNDState return ConvLSTMNDState(h, c) -@tree_state.def_state(ConvLSTMNDCell) -def conv_lstm_init_state(cell: ConvLSTMNDCell, x: jax.Array | None) -> ConvLSTMNDState: - if not (hasattr(x, "ndim") and hasattr(x, "shape")): - raise TypeError( - f"Expected {x=} to have ndim and shape attributes.", - "To initialize the `ConvLSTMNDCell` state.\n" - "pass a single sample array to `tree_state` second argument.", - ) +class ConvLSTM1DCell(ConvLSTMNDCell): + """1D Convolution LSTM cell that defines the update rule for the hidden state and cell state - if x.ndim != cell.spatial_ndim + 1: - raise ValueError( - f"{x.ndim=} != {(cell.spatial_ndim + 1)=}.", - "Expected input to have shape (channel, *spatial_dim)." - "Pass a single sample array to `tree_state", - ) + Args: + in_features: Number of input features + hidden_features: Number of output features + kernel_size: Size of the convolutional kernel + strides: Stride of the convolution + padding: Padding of the convolution + kernel_dilation: Dilation of the convolutional kernel + weight_init_func: Weight initialization function + bias_init_func: Bias initialization function + recurrent_weight_init_func: Recurrent weight initialization function + act_func: Activation function + recurrent_act_func: Recurrent activation function + key: PRNG key - spatial_dim = x.shape[1:] - if len(spatial_dim) != cell.spatial_ndim: - raise ValueError(f"{len(spatial_dim)=} != {cell.spatial_ndim=}.") - shape = (cell.hidden_features, *spatial_dim) - return ConvLSTMNDState(jnp.zeros(shape), jnp.zeros(shape)) + Note: + https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D + """ + + @property + def spatial_ndim(self) -> int: + return 1 + @property + def convolution(self): + return sk.nn.Conv1D -class ConvLSTM1DCell(ConvLSTMNDCell): - """1D Convolution LSTM cell that defines the update rule for the hidden state and cell state + +class FFTConvLSTM1DCell(ConvLSTMNDCell): + """1D FFT Convolution LSTM cell that defines the update rule for the hidden state and cell state Args: in_features: Number of input features @@ -530,7 +533,6 @@ class ConvLSTM1DCell(ConvLSTMNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - input_dilation: Dilation of the input kernel_dilation: Dilation of the convolutional kernel weight_init_func: Weight initialization function bias_init_func: Bias initialization function @@ -543,44 +545,14 @@ class ConvLSTM1DCell(ConvLSTMNDCell): https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D """ - def __init__( - self, - in_features: int, - hidden_features: int, - kernel_size: KernelSizeType, - *, - strides: StridesType = 1, - padding: PaddingType = "SAME", - input_dilation: DilationType = 1, - kernel_dilation: DilationType = 1, - weight_init_func: InitType = "glorot_uniform", - bias_init_func: InitType = "zeros", - recurrent_weight_init_func: InitType = "orthogonal", - act_func: ActivationType | None = "tanh", - recurrent_act_func: ActivationType | None = "hard_sigmoid", - key: jr.KeyArray = jr.PRNGKey(0), - ): - super().__init__( - in_features=in_features, - hidden_features=hidden_features, - kernel_size=kernel_size, - strides=strides, - padding=padding, - input_dilation=input_dilation, - kernel_dilation=kernel_dilation, - weight_init_func=weight_init_func, - bias_init_func=bias_init_func, - recurrent_weight_init_func=recurrent_weight_init_func, - act_func=act_func, - recurrent_act_func=recurrent_act_func, - key=key, - conv_layer=sk.nn.Conv1D, - ) - @property def spatial_ndim(self) -> int: return 1 + @property + def convolution(self): + return sk.nn.FFTConv1D + class ConvLSTM2DCell(ConvLSTMNDCell): """2D Convolution LSTM cell that defines the update rule for the hidden state and cell state @@ -591,7 +563,6 @@ class ConvLSTM2DCell(ConvLSTMNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - input_dilation: Dilation of the input kernel_dilation: Dilation of the convolutional kernel weight_init_func: Weight initialization function bias_init_func: Bias initialization function @@ -601,47 +572,47 @@ class ConvLSTM2DCell(ConvLSTMNDCell): key: PRNG key Note: - https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D + https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM2D """ - def __init__( - self, - in_features: int, - hidden_features: int, - kernel_size: KernelSizeType, - *, - strides: StridesType = 1, - padding: PaddingType = "SAME", - input_dilation: DilationType = 1, - kernel_dilation: DilationType = 1, - weight_init_func: InitType = "glorot_uniform", - bias_init_func: InitType = "zeros", - recurrent_weight_init_func: InitType = "orthogonal", - act_func: ActivationType | None = "tanh", - recurrent_act_func: ActivationType | None = "hard_sigmoid", - key: jr.KeyArray = jr.PRNGKey(0), - ): - super().__init__( - in_features=in_features, - hidden_features=hidden_features, - kernel_size=kernel_size, - strides=strides, - padding=padding, - input_dilation=input_dilation, - kernel_dilation=kernel_dilation, - weight_init_func=weight_init_func, - bias_init_func=bias_init_func, - recurrent_weight_init_func=recurrent_weight_init_func, - act_func=act_func, - recurrent_act_func=recurrent_act_func, - key=key, - conv_layer=sk.nn.Conv2D, - ) + @property + def spatial_ndim(self) -> int: + return 2 + + @property + def convolution(self): + return sk.nn.Conv2D + + +class FFTConvLSTM2DCell(ConvLSTMNDCell): + """2D FFT Convolution LSTM cell that defines the update rule for the hidden state and cell state + + Args: + in_features: Number of input features + hidden_features: Number of output features + kernel_size: Size of the convolutional kernel + strides: Stride of the convolution + padding: Padding of the convolution + kernel_dilation: Dilation of the convolutional kernel + weight_init_func: Weight initialization function + bias_init_func: Bias initialization function + recurrent_weight_init_func: Recurrent weight initialization function + act_func: Activation function + recurrent_act_func: Recurrent activation function + key: PRNG key + + Note: + https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM2D + """ @property def spatial_ndim(self) -> int: return 2 + @property + def convolution(self): + return sk.nn.FFTConv2D + class ConvLSTM3DCell(ConvLSTMNDCell): """3D Convolution LSTM cell that defines the update rule for the hidden state and cell state @@ -652,7 +623,6 @@ class ConvLSTM3DCell(ConvLSTMNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - input_dilation: Dilation of the input kernel_dilation: Dilation of the convolutional kernel weight_init_func: Weight initialization function bias_init_func: Bias initialization function @@ -662,47 +632,47 @@ class ConvLSTM3DCell(ConvLSTMNDCell): key: PRNG key Note: - https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D + https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM3D """ - def __init__( - self, - in_features: int, - hidden_features: int, - kernel_size: KernelSizeType, - *, - strides: StridesType = 1, - padding: PaddingType = "SAME", - input_dilation: DilationType = 1, - kernel_dilation: DilationType = 1, - weight_init_func: InitType = "glorot_uniform", - bias_init_func: InitType = "zeros", - recurrent_weight_init_func: InitType = "orthogonal", - act_func: ActivationType | None = "tanh", - recurrent_act_func: ActivationType | None = "hard_sigmoid", - key: jr.KeyArray = jr.PRNGKey(0), - ): - super().__init__( - in_features=in_features, - hidden_features=hidden_features, - kernel_size=kernel_size, - strides=strides, - padding=padding, - input_dilation=input_dilation, - kernel_dilation=kernel_dilation, - weight_init_func=weight_init_func, - bias_init_func=bias_init_func, - recurrent_weight_init_func=recurrent_weight_init_func, - act_func=act_func, - recurrent_act_func=recurrent_act_func, - key=key, - conv_layer=sk.nn.Conv3D, - ) + @property + def spatial_ndim(self) -> int: + return 3 + + @property + def convolution(self): + return sk.nn.Conv3D + + +class FFTConvLSTM3DCell(ConvLSTMNDCell): + """3D FFT Convolution LSTM cell that defines the update rule for the hidden state and cell state + + Args: + in_features: Number of input features + hidden_features: Number of output features + kernel_size: Size of the convolutional kernel + strides: Stride of the convolution + padding: Padding of the convolution + kernel_dilation: Dilation of the convolutional kernel + weight_init_func: Weight initialization function + bias_init_func: Bias initialization function + recurrent_weight_init_func: Recurrent weight initialization function + act_func: Activation function + recurrent_act_func: Recurrent activation function + key: PRNG key + + Note: + https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM3D + """ @property def spatial_ndim(self) -> int: return 3 + @property + def convolution(self): + return sk.nn.FFTConv3D + class ConvGRUNDState(RNNState): ... @@ -717,7 +687,6 @@ class ConvGRUNDCell(RNNCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - input_dilation: Dilation of the input kernel_dilation: Dilation of the convolutional kernel weight_init_func: Weight initialization function bias_init_func: Bias initialization function @@ -736,7 +705,6 @@ def __init__( *, strides: StridesType = 1, padding: PaddingType = "SAME", - input_dilation: DilationType = 1, kernel_dilation: DilationType = 1, weight_init_func: InitType = "glorot_uniform", bias_init_func: InitType = "zeros", @@ -744,7 +712,6 @@ def __init__( act_func: ActivationType | None = "tanh", recurrent_act_func: ActivationType | None = "sigmoid", key: jr.KeyArray = jr.PRNGKey(0), - conv_layer: Any = None, ): k1, k2 = jr.split(key, 2) @@ -753,26 +720,24 @@ def __init__( self.act_func = resolve_activation(act_func) self.recurrent_act_func = resolve_activation(recurrent_act_func) - self.in_to_hidden = conv_layer( + self.in_to_hidden = self.convolution( in_features, hidden_features * 3, kernel_size, strides=strides, padding=padding, - input_dilation=input_dilation, kernel_dilation=kernel_dilation, weight_init_func=weight_init_func, bias_init_func=bias_init_func, key=k1, ) - self.hidden_to_hidden = conv_layer( + self.hidden_to_hidden = self.convolution( hidden_features, hidden_features * 3, kernel_size, strides=strides, padding=padding, - input_dilation=input_dilation, kernel_dilation=kernel_dilation, weight_init_func=recurrent_weight_init_func, bias_init_func=None, @@ -794,33 +759,36 @@ def __call__(self, x: jax.Array, state: ConvGRUNDState, **k) -> ConvGRUNDState: return ConvGRUNDState(hidden_state=h) -@tree_state.def_state(ConvGRUNDCell) -def conv_gru_init_state(cell: ConvGRUNDCell, x: jax.Array | None) -> ConvGRUNDState: - if not (hasattr(x, "ndim") and hasattr(x, "shape")): - # maybe the input is not an array - raise TypeError( - f"Expected {x=} to have ndim and shape attributes.", - "To initialize the `ConvGRUNDCell` state.\n" - "pass a single sample array to `tree_state` second argument.", - ) +class ConvGRU1DCell(ConvGRUNDCell): + """1D Convolution GRU cell that defines the update rule for the hidden state and cell state - if x.ndim != cell.spatial_ndim + 1: - # channel, *spatial_dim - raise ValueError( - f"{x.ndim=} != {(cell.spatial_ndim + 1)=}.", - "Expected input to have shape (channel, *spatial_dim)." - "Pass a single sample array to `tree_state", - ) + Args: + in_features: Number of input features + hidden_features: Number of output features + kernel_size: Size of the convolutional kernel + strides: Stride of the convolution + padding: Padding of the convolution + kernel_dilation: Dilation of the convolutional kernel + weight_init_func: Weight initialization function + bias_init_func: Bias initialization function + recurrent_weight_init_func: Recurrent weight initialization function + act_func: Activation function + recurrent_act_func: Recurrent activation function + key: PRNG key + spatial_ndim: Number of spatial dimensions. + """ - spatial_dim = x.shape[1:] - if len(spatial_dim) != cell.spatial_ndim: - raise ValueError(f"{len(spatial_dim)=} != {cell.spatial_ndim=}.") - shape = (cell.hidden_features, *spatial_dim) - return ConvGRUNDState(jnp.zeros(shape), jnp.zeros(shape)) + @property + def spatial_ndim(self) -> int: + return 1 + @property + def convolution(self): + return sk.nn.Conv1D -class ConvGRU1DCell(ConvGRUNDCell): - """1D Convolution GRU cell that defines the update rule for the hidden state and cell state + +class FFTConvGRU1DCell(ConvGRUNDCell): + """1D FFT Convolution GRU cell that defines the update rule for the hidden state and cell state Args: in_features: Number of input features @@ -828,7 +796,6 @@ class ConvGRU1DCell(ConvGRUNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - input_dilation: Dilation of the input kernel_dilation: Dilation of the convolutional kernel weight_init_func: Weight initialization function bias_init_func: Bias initialization function @@ -839,54 +806,24 @@ class ConvGRU1DCell(ConvGRUNDCell): spatial_ndim: Number of spatial dimensions. """ - def __init__( - self, - in_features: int, - hidden_features: int, - kernel_size: int | tuple[int, ...], - *, - strides: StridesType = 1, - padding: PaddingType = "SAME", - input_dilation: DilationType = 1, - kernel_dilation: DilationType = 1, - weight_init_func: InitType = "glorot_uniform", - bias_init_func: InitType = "zeros", - recurrent_weight_init_func: InitType = "orthogonal", - act_func: ActivationType | None = "tanh", - recurrent_act_func: ActivationType | None = "sigmoid", - key: jr.KeyArray = jr.PRNGKey(0), - ): - super().__init__( - in_features=in_features, - hidden_features=hidden_features, - kernel_size=kernel_size, - strides=strides, - padding=padding, - input_dilation=input_dilation, - kernel_dilation=kernel_dilation, - weight_init_func=weight_init_func, - bias_init_func=bias_init_func, - recurrent_weight_init_func=recurrent_weight_init_func, - act_func=act_func, - recurrent_act_func=recurrent_act_func, - key=key, - conv_layer=sk.nn.Conv1D, - ) - @property def spatial_ndim(self) -> int: return 1 + @property + def convolution(self): + return sk.nn.FFTConv1D + class ConvGRU2DCell(ConvGRUNDCell): """2D Convolution GRU cell that defines the update rule for the hidden state and cell state + Args: in_features: Number of input features hidden_features: Number of output features kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - input_dilation: Dilation of the input kernel_dilation: Dilation of the convolutional kernel weight_init_func: Weight initialization function bias_init_func: Bias initialization function @@ -898,44 +835,42 @@ class ConvGRU2DCell(ConvGRUNDCell): """ - def __init__( - self, - in_features: int, - hidden_features: int, - kernel_size: int | tuple[int, ...], - *, - strides: StridesType = 1, - padding: PaddingType = "SAME", - input_dilation: DilationType = 1, - kernel_dilation: DilationType = 1, - weight_init_func: InitType = "glorot_uniform", - bias_init_func: InitType = "zeros", - recurrent_weight_init_func: InitType = "orthogonal", - act_func: ActivationType | None = "tanh", - recurrent_act_func: ActivationType | None = "sigmoid", - key: jr.KeyArray = jr.PRNGKey(0), - ): - super().__init__( - in_features=in_features, - hidden_features=hidden_features, - kernel_size=kernel_size, - strides=strides, - padding=padding, - input_dilation=input_dilation, - kernel_dilation=kernel_dilation, - weight_init_func=weight_init_func, - bias_init_func=bias_init_func, - recurrent_weight_init_func=recurrent_weight_init_func, - act_func=act_func, - recurrent_act_func=recurrent_act_func, - key=key, - conv_layer=sk.nn.Conv2D, - ) + @property + def spatial_ndim(self) -> int: + return 2 + + @property + def convolution(self): + return sk.nn.Conv2D + + +class FFTConvGRU2DCell(ConvGRUNDCell): + """2D FFT Convolution GRU cell that defines the update rule for the hidden state and cell state + + Args: + in_features: Number of input features + hidden_features: Number of output features + kernel_size: Size of the convolutional kernel + strides: Stride of the convolution + padding: Padding of the convolution + kernel_dilation: Dilation of the convolutional kernel + weight_init_func: Weight initialization function + bias_init_func: Bias initialization function + recurrent_weight_init_func: Recurrent weight initialization function + act_func: Activation function + recurrent_act_func: Recurrent activation function + key: PRNG key + spatial_ndim: Number of spatial dimensions. + """ @property def spatial_ndim(self) -> int: return 2 + @property + def convolution(self): + return sk.nn.FFTConv2D + class ConvGRU3DCell(ConvGRUNDCell): """3D Convolution GRU cell that defines the update rule for the hidden state and cell state @@ -946,7 +881,6 @@ class ConvGRU3DCell(ConvGRUNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - input_dilation: Dilation of the input kernel_dilation: Dilation of the convolutional kernel weight_init_func: Weight initialization function bias_init_func: Bias initialization function @@ -956,45 +890,41 @@ class ConvGRU3DCell(ConvGRUNDCell): key: PRNG key """ - def __init__( - self, - in_features: int, - hidden_features: int, - kernel_size: int | tuple[int, ...], - *, - strides: StridesType = 1, - padding: PaddingType = "SAME", - input_dilation: DilationType = 1, - kernel_dilation: DilationType = 1, - weight_init_func: InitType = "glorot_uniform", - bias_init_func: InitType = "zeros", - recurrent_weight_init_func: InitType = "orthogonal", - act_func: ActivationType | None = "tanh", - recurrent_act_func: ActivationType | None = "sigmoid", - key: jr.KeyArray = jr.PRNGKey(0), - ): - super().__init__( - in_features=in_features, - hidden_features=hidden_features, - kernel_size=kernel_size, - strides=strides, - padding=padding, - input_dilation=input_dilation, - kernel_dilation=kernel_dilation, - weight_init_func=weight_init_func, - bias_init_func=bias_init_func, - recurrent_weight_init_func=recurrent_weight_init_func, - act_func=act_func, - recurrent_act_func=recurrent_act_func, - key=key, - conv_layer=sk.nn.Conv3D, - spatial_ndim=3, - ) + @property + def spatial_ndim(self) -> int: + return 3 + + @property + def convolution(self): + return sk.nn.Conv3D + + +class FFTConvGRU3DCell(ConvGRUNDCell): + """3D Convolution GRU cell that defines the update rule for the hidden state and cell state + + Args: + in_features: Number of input features + hidden_features: Number of output features + kernel_size: Size of the convolutional kernel + strides: Stride of the convolution + padding: Padding of the convolution + kernel_dilation: Dilation of the convolutional kernel + weight_init_func: Weight initialization function + bias_init_func: Bias initialization function + recurrent_weight_init_func: Recurrent weight initialization function + act_func: Activation function + recurrent_act_func: Recurrent activation function + key: PRNG key + """ @property def spatial_ndim(self) -> int: return 3 + @property + def convolution(self): + return sk.nn.FFTConv3D + # Scanning API @@ -1095,7 +1025,7 @@ def __call__( if x.ndim != cell0.spatial_ndim + 2: raise ValueError( f"Expected x to have {(cell0.spatial_ndim + 2)=} dimensions corresponds to " - f"(timesteps, in_features, {'*'*cell0.spatial_ndim})," + f"(timesteps, in_features, {','.join('...'*cell0.spatial_ndim)})," f" got {x.ndim=}" ) @@ -1171,7 +1101,46 @@ def scan_func(carry, x): return result, carry +# register state handlers + + +def _check_rnn_cell_tree_state_input(cell: RNNCell, x): + if not (hasattr(x, "ndim") and hasattr(x, "shape")): + raise TypeError( + f"Expected {x=} to have `ndim` and `shape` attributes.", + f"To initialize the `{type(cell).__name__}` state.\n", + "Pass a single sample array to `tree_state(..., array=)`.", + ) + + if x.ndim != cell.spatial_ndim + 1: + raise ValueError( + f"{x.ndim=} != {(cell.spatial_ndim + 1)=}.", + f"Expected input to have `shape` (in_features, {'...'*cell.spatial_dim})." + "Pass a single sample array to `tree_state", + ) + + spatial_dim = x.shape[1:] + if len(spatial_dim) != cell.spatial_ndim: + raise ValueError(f"{len(spatial_dim)=} != {cell.spatial_ndim=}.") + + return x + + +@tree_state.def_state(ConvLSTMNDCell) +def conv_lstm_init_state(cell: ConvLSTMNDCell, x: Any) -> ConvLSTMNDState: + x = _check_rnn_cell_tree_state_input(cell, x) + shape = (cell.hidden_features, *x.shape[1:]) + return ConvLSTMNDState(jnp.zeros(shape), jnp.zeros(shape)) + + +@tree_state.def_state(ConvGRUNDCell) +def conv_gru_init_state(cell: ConvGRUNDCell, x: Any) -> ConvGRUNDState: + x = _check_rnn_cell_tree_state_input(cell, x) + shape = (cell.hidden_features, *x.shape[1:]) + return ConvGRUNDState(jnp.zeros(shape), jnp.zeros(shape)) + + @tree_state.def_state(ScanRNN) -def rnn_init_state(rnn: ScanRNN, x: jax.Array | None) -> RNNState: +def scan_rnn_init_state(rnn: ScanRNN, x: Any) -> RNNState: # should pass a single sample array to `tree_state` return _merge(tree_state(rnn.cells, array=x)) diff --git a/serket/nn/state.py b/serket/nn/state.py index d157a68..a923d42 100644 --- a/serket/nn/state.py +++ b/serket/nn/state.py @@ -37,14 +37,14 @@ def tree_state(tree: T, array: jax.Array | None = None) -> T: """Build state for a tree of layers. Some layers require state to be initialized before training. For example, - `BatchNorm` layers require `running_mean` and `running_var` to be initialized + :class:`nn.BatchNorm` layers requires ``running_mean`` and ``running_var`` to be initialized before training. This function initializes the state for a tree of layers, based on the layer defined ``state`` rule using ``tree_state.def_state``. Args: tree: A tree of layers. - array: An array to use for initializing state required by some layers - (e.g. ConvGRUNDCell). default: ``None``. + array: (Optional) array to use for initializing state required by some layers + (e.g. :class:`nn.ConvGRU1DCell`). default: ``None``. Returns: A tree of state leaves if it has state, otherwise ``None``. @@ -67,8 +67,7 @@ def tree_state(tree: T, array: jax.Array | None = None) -> T: ... pass >>> # state function accept the `layer` and optional input array as arguments >>> @sk.tree_state.def_state(LayerWithState) - ... def _(leaf, _): - ... del _ # array is not used + ... def _(leaf): ... return "some state" >>> sk.tree_state(LayerWithState()) 'some state' @@ -81,8 +80,13 @@ def is_leaf(x: Callable[[Any], bool]) -> bool: types.discard(object) return isinstance(x, tuple(types)) - def dispatch_func(node): - return tree_state.state_dispatcher(node, array) + def dispatch_func(leaf): + try: + # single argument + return tree_state.state_dispatcher(leaf) + except TypeError: + # with optional array argument + return tree_state.state_dispatcher(leaf, array) return jax.tree_map(dispatch_func, tree, is_leaf=is_leaf) diff --git a/tests/test_rnn.py b/tests/test_rnn.py index 0ca76cd..176fd36 100644 --- a/tests/test_rnn.py +++ b/tests/test_rnn.py @@ -71,6 +71,7 @@ from serket.nn.recurrent import ( # ConvGRU1DCell,; ConvGRU2DCell,; ConvGRU3DCell,; ConvLSTM2DCell,; ConvLSTM3DCell, ConvLSTM1DCell, DenseCell, + FFTConvLSTM1DCell, GRUCell, LSTMCell, ScanRNN, @@ -423,7 +424,8 @@ def test_gru(): npt.assert_allclose(y, ypred, atol=1e-4) -def test_conv_lstm1d(): +@pytest.mark.parametrize("layer", [ConvLSTM1DCell, FFTConvLSTM1DCell]) +def test_conv_lstm1d(layer): w_in_to_hidden = jnp.array( [ [ @@ -570,7 +572,7 @@ def test_conv_lstm1d(): # return_sequences=False,data_format='channels_first'))(inp) # rnn = tf.keras.Model(inputs=inp, outputs=rnn) - cell = ConvLSTM1DCell( + cell = layer( in_features=in_features, hidden_features=hidden_features, recurrent_act_func="sigmoid", @@ -599,7 +601,7 @@ def test_conv_lstm1d(): assert jnp.allclose(res_sk, y, atol=1e-5) - cell = ConvLSTM1DCell( + cell = layer( in_features=in_features, hidden_features=hidden_features, recurrent_act_func="sigmoid",