From dc6779c87ca6f64dd4865e2c486aae4d54a7f553 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Sat, 2 Dec 2023 15:19:04 +0900 Subject: [PATCH] Update recurrent.py --- serket/_src/nn/recurrent.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/serket/_src/nn/recurrent.py b/serket/_src/nn/recurrent.py index c25f789..c34688b 100644 --- a/serket/_src/nn/recurrent.py +++ b/serket/_src/nn/recurrent.py @@ -48,7 +48,7 @@ ) P = ParamSpec("P") -I = TypeVar("I") +T = TypeVar("T") S = TypeVar("S") State = Any @@ -1385,23 +1385,8 @@ class FFTConvGRU3DCell(ConvGRUNDCell): convolution_layer = FFTConv3D -# Scanning API - - -def is_lazy_init(_, cell, backward_cell=None, **__) -> bool: - lhs = getattr(cell, "in_features", False) is None - rhs = getattr(backward_cell, "in_features", False) is None - return lhs or rhs - - -def is_lazy_call(instance, x, state=None, **_) -> bool: - lhs = getattr(instance.cell, "in_features", False) is None - rhs = getattr(instance.backward_cell, "in_features", False) is None - return lhs or rhs - - def scan_cell(cell, in_axis=0, out_axis=0, reverse=False): - """Scan a RNN cell over a sequence. + """Scan am RNN cell over a sequence. Args: cell: the RNN cell to scan. The cell should have the following signature: @@ -1449,9 +1434,11 @@ def scan_func(state, input): output, state = cell(input, state) return state, output - def wrapper(input: I, state: S) -> tuple[I, S]: + def wrapper(input: T, state: S) -> tuple[T, S]: + # push the scan axis to the front input = jnp.moveaxis(input, in_axis, 0) state, output = jax.lax.scan(scan_func, state, input, reverse=reverse) + # move the output axis to the desired location output = jnp.moveaxis(output, 0, out_axis) return output, state