From 4912117168f8e7ffc6e0c96eaf247621eaac7e8e Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Mon, 24 Jul 2023 08:27:29 +0900 Subject: [PATCH] fix rnn --- assets/logo.svg | 5 ++++- assets/logo_full.svg | 1 + serket/nn/recurrent.py | 19 ++++++++++--------- 3 files changed, 15 insertions(+), 10 deletions(-) create mode 100644 assets/logo_full.svg diff --git a/assets/logo.svg b/assets/logo.svg index a9f0903..3358e2a 100644 --- a/assets/logo.svg +++ b/assets/logo.svg @@ -1 +1,4 @@ - \ No newline at end of file + + + + diff --git a/assets/logo_full.svg b/assets/logo_full.svg new file mode 100644 index 0000000..a9f0903 --- /dev/null +++ b/assets/logo_full.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/serket/nn/recurrent.py b/serket/nn/recurrent.py index 3e7fd39..469ba1d 100644 --- a/serket/nn/recurrent.py +++ b/serket/nn/recurrent.py @@ -1040,17 +1040,18 @@ def __init__( reverse: tuple[bool, ...] | bool = False, ): cell0, *_ = cells - indim0,outdim0 = cell0.in_features, cell0.hidden_features for cell in cells: if not isinstance(cell, RNNCell): raise TypeError(f"Expected {cell=} to be an instance of `RNNCell`.") - - if cell.in_features != indim0: - raise ValueError(f"{cell.in_features=} != {indim0=}.") - - if cell.hidden_features != outdim0: - raise ValueError(f"{cell.hidden_features=} != {outdim0=}.") + + if cell.in_features != cell0.in_features: + raise ValueError(f"{cell.in_features=} != {cell0.in_features=}.") + + if cell.hidden_features != cell0.hidden_features: + raise ValueError( + f"{cell.hidden_features=} != {cell0.hidden_features=}." + ) if isinstance(reverse, bool): reverse = (reverse,) * len(cells) @@ -1102,7 +1103,7 @@ def __call__( ) splits = len(self.cells) - state: RNNState = tree_state(self, array=x) if state is None else state + state: RNNState = tree_state(self, array=x[0]) if state is None else state scan_func = _accumulate_scan if self.return_sequences else _no_accumulate_scan result_states: list[tuple[jax.Array, RNNState]] = [ @@ -1119,7 +1120,7 @@ def __call__( return result -def _split(state: RNNState, splits:int) -> list[RNNState]: +def _split(state: RNNState, splits: int) -> list[RNNState]: flat_arrays: list[jax.Array] = jtu.tree_leaves(state) return [type(state)(*x) for x in zip(*(jnp.split(x, splits) for x in flat_arrays))]