Skip to content

Commit

Permalink
DenseCell => LinearCell
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 1, 2024
1 parent ff93087 commit 892bcf0
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 22 deletions.
2 changes: 1 addition & 1 deletion docs/API/recurrent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Recurrent
.. autoclass:: LSTMCell
.. autoclass:: GRUCell
.. autoclass:: SimpleRNNCell
.. autoclass:: DenseCell
.. autoclass:: LinearCell

.. autoclass:: ConvLSTM1DCell
.. autoclass:: ConvLSTM2DCell
Expand Down
30 changes: 12 additions & 18 deletions serket/_src/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ class RNNState(sk.TreeClass):
hidden_state: jax.Array


class SimpleRNNState(RNNState):
...
class SimpleRNNState(RNNState): ...


class SimpleRNNCell(sk.TreeClass):
Expand Down Expand Up @@ -196,11 +195,10 @@ def __call__(
spatial_ndim: int = 0


class DenseState(RNNState):
...
class DenseState(RNNState): ...


class DenseCell(sk.TreeClass):
class LinearCell(sk.TreeClass):
"""No hidden state cell that applies a dense(Linear+activation) layer to the input
Args:
Expand All @@ -218,7 +216,7 @@ class DenseCell(sk.TreeClass):
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> # 10-dimensional input, 20-dimensional hidden state
>>> cell = sk.nn.DenseCell(10, 20, key=jr.PRNGKey(0))
>>> cell = sk.nn.LinearCell(10, 20, key=jr.PRNGKey(0))
>>> # 20-dimensional hidden state
>>> input = jnp.ones(10) # 10 features
>>> state = sk.tree_state(cell)
Expand All @@ -227,7 +225,7 @@ class DenseCell(sk.TreeClass):
(20,)
Note:
:class:`.DenseCell` supports lazy initialization, meaning that the
:class:`.LinearCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
Expand All @@ -238,7 +236,7 @@ class DenseCell(sk.TreeClass):
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.DenseCell(None, 20, key=jr.PRNGKey(0))
>>> lazy = sk.nn.LinearCell(None, 20, key=jr.PRNGKey(0))
>>> input = jnp.ones(10) # 10 features
>>> state = sk.tree_state(lazy)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
Expand Down Expand Up @@ -423,8 +421,7 @@ def __call__(
spatial_ndim: int = 0


class GRUState(RNNState):
...
class GRUState(RNNState): ...


class GRUCell(sk.TreeClass):
Expand Down Expand Up @@ -623,8 +620,7 @@ def __call__(

@property
@abc.abstractmethod
def conv_layer(self):
...
def conv_layer(self): ...

spatial_ndim = property(abc.abstractmethod(lambda _: ...))

Expand Down Expand Up @@ -971,8 +967,7 @@ class FFTConvLSTM3DCell(ConvLSTMNDCell):
conv_layer = FFTConv3D


class ConvGRUNDState(RNNState):
...
class ConvGRUNDState(RNNState): ...


class ConvGRUNDCell(sk.TreeClass):
Expand Down Expand Up @@ -1049,8 +1044,7 @@ def __call__(

@property
@abc.abstractmethod
def conv_layer(self):
...
def conv_layer(self): ...

spatial_ndim = property(abc.abstractmethod(lambda _: ...))

Expand Down Expand Up @@ -1484,8 +1478,8 @@ def _(cell: SimpleRNNCell) -> SimpleRNNState:
return SimpleRNNState(jnp.zeros([cell.hidden_features]))


@tree_state.def_state(DenseCell)
def _(cell: DenseCell) -> DenseState:
@tree_state.def_state(LinearCell)
def _(cell: LinearCell) -> DenseState:
return DenseState(jnp.empty([cell.hidden_features]))


Expand Down
4 changes: 2 additions & 2 deletions serket/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,14 @@
ConvLSTM1DCell,
ConvLSTM2DCell,
ConvLSTM3DCell,
DenseCell,
FFTConvGRU1DCell,
FFTConvGRU2DCell,
FFTConvGRU3DCell,
FFTConvLSTM1DCell,
FFTConvLSTM2DCell,
FFTConvLSTM3DCell,
GRUCell,
LinearCell,
LSTMCell,
SimpleRNNCell,
scan_cell,
Expand Down Expand Up @@ -303,7 +303,7 @@
"ConvLSTM1DCell",
"ConvLSTM2DCell",
"ConvLSTM3DCell",
"DenseCell",
"LinearCell",
"FFTConvGRU1DCell",
"FFTConvGRU2DCell",
"FFTConvGRU3DCell",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_conv_lstm(sk_layer, keras_layer, ndim):


def test_dense_cell():
cell = sk.nn.DenseCell(
cell = sk.nn.LinearCell(
in_features=10,
hidden_features=10,
act=lambda x: x,
Expand Down

0 comments on commit 892bcf0

Please sign in to comment.