Skip to content

Commit

Permalink
Update recurrent.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Dec 2, 2023
1 parent e760ac6 commit dc6779c
Showing 1 changed file with 5 additions and 18 deletions.
23 changes: 5 additions & 18 deletions serket/_src/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
)

P = ParamSpec("P")
I = TypeVar("I")
T = TypeVar("T")
S = TypeVar("S")

State = Any
Expand Down Expand Up @@ -1385,23 +1385,8 @@ class FFTConvGRU3DCell(ConvGRUNDCell):
convolution_layer = FFTConv3D


# Scanning API


def is_lazy_init(_, cell, backward_cell=None, **__) -> bool:
lhs = getattr(cell, "in_features", False) is None
rhs = getattr(backward_cell, "in_features", False) is None
return lhs or rhs


def is_lazy_call(instance, x, state=None, **_) -> bool:
lhs = getattr(instance.cell, "in_features", False) is None
rhs = getattr(instance.backward_cell, "in_features", False) is None
return lhs or rhs


def scan_cell(cell, in_axis=0, out_axis=0, reverse=False):
"""Scan a RNN cell over a sequence.
"""Scan am RNN cell over a sequence.
Args:
cell: the RNN cell to scan. The cell should have the following signature:
Expand Down Expand Up @@ -1449,9 +1434,11 @@ def scan_func(state, input):
output, state = cell(input, state)
return state, output

def wrapper(input: I, state: S) -> tuple[I, S]:
def wrapper(input: T, state: S) -> tuple[T, S]:
# push the scan axis to the front
input = jnp.moveaxis(input, in_axis, 0)
state, output = jax.lax.scan(scan_func, state, input, reverse=reverse)
# move the output axis to the desired location
output = jnp.moveaxis(output, 0, out_axis)
return output, state

Expand Down

0 comments on commit dc6779c

Please sign in to comment.