Skip to content

Commit

Permalink
docstring rnn/MLP
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Dec 3, 2023
1 parent d811469 commit 348a412
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 14 deletions.
30 changes: 21 additions & 9 deletions serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,15 @@ def scan_linear(
# reduce the ``jaxpr`` size by using ``scan``
if bias is None:

def scan_func(x: jax.Array, weight: Batched[jax.Array]):
return act(x @ weight.T), None
def scan_func(input: jax.Array, weight: Batched[jax.Array]):
return act(input @ weight.T), None

input, _ = jax.lax.scan(scan_func, input, weight)
return input
output, _ = jax.lax.scan(scan_func, input, weight)
return output

def scan_func(x: jax.Array, weight_bias: Batched[jax.Array]):
def scan_func(input: jax.Array, weight_bias: Batched[jax.Array]):
weight, bias = weight_bias[..., :-1], weight_bias[..., -1]
return act(x @ weight.T + bias), None
return act(input @ weight.T + bias), None

weight_bias = jnp.concatenate([weight, bias[:, :, None]], axis=-1)
output, _ = jax.lax.scan(scan_func, input, weight_bias)
Expand All @@ -271,9 +271,9 @@ class MLP(sk.TreeClass):
Args:
in_features: Number of input features.
out_features: Number of output features.
key: Random number generator key.
hidden_features: Number of hidden units in each hidden layer.
num_hidden_layers: Number of hidden layers including the output layer.
key: Random number generator key.
act: Activation function. Defaults to ``tanh``.
weight_init: Weight initialization function. Defaults to ``glorot_uniform``.
bias_init: Bias initialization function. Defaults to ``zeros``.
Expand Down Expand Up @@ -317,16 +317,28 @@ class MLP(sk.TreeClass):
Note:
:class:`.MLP` uses ``jax.lax.scan`` to reduce the ``jaxpr`` size.
Leading to faster compilation times and smaller ``jaxpr`` size.
>>> import serket as sk
>>> import jax
>>> import jax.numpy as jnp
>>> # 10 hidden layers
>>> mlp1 = sk.nn.MLP(1, 2, 5, 10, key=jax.random.PRNGKey(0))
>>> # 50 hidden layers
>>> mlp2 = sk.nn.MLP(1, 2, 5, 50, key=jax.random.PRNGKey(0))
>>> jaxpr1 = jax.make_jaxpr(mlp1)(jnp.ones([10, 1]))
>>> jaxpr2 = jax.make_jaxpr(mlp2)(jnp.ones([10, 1]))
>>> # same number of equations irrespective of the number of hidden layers
>>> assert len(jaxpr1.jaxpr.eqns) == len(jaxpr2.jaxpr.eqns)
"""

def __init__(
self,
in_features: int,
out_features: int,
*,
key: jax.Array,
hidden_features: int,
num_hidden_layers: int,
*,
key: jax.Array,
act: ActivationType = "tanh",
weight_init: InitType = "glorot_uniform",
bias_init: InitType = "zeros",
Expand Down
47 changes: 42 additions & 5 deletions serket/_src/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,12 @@ class FFTConvGRU3DCell(ConvGRUNDCell):
convolution_layer = FFTConv3D


def scan_cell(cell, in_axis=0, out_axis=0, reverse=False):
def scan_cell(
cell,
in_axis: int = 0,
out_axis: int = 0,
reverse: bool = False,
) -> Callable[[jax.Array, S], tuple[jax.Array, S]]:
"""Scan am RNN cell over a sequence.
Args:
Expand Down Expand Up @@ -1420,21 +1425,53 @@ def scan_cell(cell, in_axis=0, out_axis=0, reverse=False):
>>> k1, k2 = jr.split(jr.PRNGKey(0))
>>> cell1 = sk.nn.SimpleRNNCell(1, 2, key=k1)
>>> cell2 = sk.nn.SimpleRNNCell(1, 2, key=k2)
>>> state1 = sk.tree_state(cell1)
>>> state2 = sk.tree_state(cell2)
>>> state1, state2 = sk.tree_state((cell1, cell2))
>>> input = jnp.ones([10, 1])
>>> output1, state1 = sk.nn.scan_cell(cell1)(input, state1)
>>> output2, state2 = sk.nn.scan_cell(cell2, reverse=True)(input, state2)
>>> output = jnp.concatenate((output1, output2), axis=1)
>>> print(output.shape)
(10, 4)
Example:
Combining multiple RNN cells:
>>> import serket as sk
>>> import jax
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import numpy.testing as npt
>>> k1, k2 = jr.split(jr.PRNGKey(0))
>>> cell1 = sk.nn.LSTMCell(1, 2, bias_init=None, key=k1)
>>> cell2 = sk.nn.LSTMCell(2, 1, bias_init=None, key=k2)
>>> def cell(input, state):
... state1, state2 = state
... output, state1 = cell1(input, state1)
... output, state2 = cell2(output, state2)
... return output, (state1, state2)
>>> state = sk.tree_state((cell1, cell2))
>>> input = jnp.ones([2, 1])
>>> output1, state = sk.nn.scan_cell(cell)(input, state)
<BLANKLINE>
>>> # This is equivalent to:
>>> state1, state2 = sk.tree_state((cell1, cell2))
>>> output2 = jnp.zeros([2, 1])
>>> # first step
>>> output, state1 = cell1(input[0], state1)
>>> output, state2 = cell2(output, state2)
>>> output2 = output2.at[0].set(output)
>>> # second step
>>> output, state1 = cell1(input[1], state1)
>>> output, state2 = cell2(output, state2)
>>> output2 = output2.at[1].set(output)
>>> npt.assert_allclose(output1, output2, atol=1e-6)
"""

def scan_func(state, input):
def scan_func(state: S, input: jax.Array) -> tuple[S, jax.Array]:
output, state = cell(input, state)
return state, output

def wrapper(input: T, state: S) -> tuple[T, S]:
def wrapper(input: jax.Array, state: S) -> tuple[jax.Array, S]:
# push the scan axis to the front
input = jnp.moveaxis(input, in_axis, 0)
state, output = jax.lax.scan(scan_func, state, input, reverse=reverse)
Expand Down

0 comments on commit 348a412

Please sign in to comment.