Skip to content

Commit

Permalink
fix rnn
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 23, 2023
1 parent e9bd5bb commit 4912117
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
5 changes: 4 additions & 1 deletion assets/logo.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions assets/logo_full.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 10 additions & 9 deletions serket/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,17 +1040,18 @@ def __init__(
reverse: tuple[bool, ...] | bool = False,
):
cell0, *_ = cells
indim0,outdim0 = cell0.in_features, cell0.hidden_features

for cell in cells:
if not isinstance(cell, RNNCell):
raise TypeError(f"Expected {cell=} to be an instance of `RNNCell`.")

if cell.in_features != indim0:
raise ValueError(f"{cell.in_features=} != {indim0=}.")

if cell.hidden_features != outdim0:
raise ValueError(f"{cell.hidden_features=} != {outdim0=}.")

if cell.in_features != cell0.in_features:
raise ValueError(f"{cell.in_features=} != {cell0.in_features=}.")

if cell.hidden_features != cell0.hidden_features:
raise ValueError(
f"{cell.hidden_features=} != {cell0.hidden_features=}."
)

if isinstance(reverse, bool):
reverse = (reverse,) * len(cells)
Expand Down Expand Up @@ -1102,7 +1103,7 @@ def __call__(
)

splits = len(self.cells)
state: RNNState = tree_state(self, array=x) if state is None else state
state: RNNState = tree_state(self, array=x[0]) if state is None else state
scan_func = _accumulate_scan if self.return_sequences else _no_accumulate_scan

result_states: list[tuple[jax.Array, RNNState]] = [
Expand All @@ -1119,7 +1120,7 @@ def __call__(
return result


def _split(state: RNNState, splits:int) -> list[RNNState]:
def _split(state: RNNState, splits: int) -> list[RNNState]:
flat_arrays: list[jax.Array] = jtu.tree_leaves(state)
return [type(state)(*x) for x in zip(*(jnp.split(x, splits) for x in flat_arrays))]

Expand Down

0 comments on commit 4912117

Please sign in to comment.