diff --git a/assets/logo.svg b/assets/logo.svg
index a9f0903..3358e2a 100644
--- a/assets/logo.svg
+++ b/assets/logo.svg
@@ -1 +1,4 @@
-
\ No newline at end of file
+
diff --git a/assets/logo_full.svg b/assets/logo_full.svg
new file mode 100644
index 0000000..a9f0903
--- /dev/null
+++ b/assets/logo_full.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/serket/nn/recurrent.py b/serket/nn/recurrent.py
index 3e7fd39..469ba1d 100644
--- a/serket/nn/recurrent.py
+++ b/serket/nn/recurrent.py
@@ -1040,17 +1040,18 @@ def __init__(
reverse: tuple[bool, ...] | bool = False,
):
cell0, *_ = cells
- indim0,outdim0 = cell0.in_features, cell0.hidden_features
for cell in cells:
if not isinstance(cell, RNNCell):
raise TypeError(f"Expected {cell=} to be an instance of `RNNCell`.")
-
- if cell.in_features != indim0:
- raise ValueError(f"{cell.in_features=} != {indim0=}.")
-
- if cell.hidden_features != outdim0:
- raise ValueError(f"{cell.hidden_features=} != {outdim0=}.")
+
+ if cell.in_features != cell0.in_features:
+ raise ValueError(f"{cell.in_features=} != {cell0.in_features=}.")
+
+ if cell.hidden_features != cell0.hidden_features:
+ raise ValueError(
+ f"{cell.hidden_features=} != {cell0.hidden_features=}."
+ )
if isinstance(reverse, bool):
reverse = (reverse,) * len(cells)
@@ -1102,7 +1103,7 @@ def __call__(
)
splits = len(self.cells)
- state: RNNState = tree_state(self, array=x) if state is None else state
+ state: RNNState = tree_state(self, array=x[0]) if state is None else state
scan_func = _accumulate_scan if self.return_sequences else _no_accumulate_scan
result_states: list[tuple[jax.Array, RNNState]] = [
@@ -1119,7 +1120,7 @@ def __call__(
return result
-def _split(state: RNNState, splits:int) -> list[RNNState]:
+def _split(state: RNNState, splits: int) -> list[RNNState]:
flat_arrays: list[jax.Array] = jtu.tree_leaves(state)
return [type(state)(*x) for x in zip(*(jnp.split(x, splits) for x in flat_arrays))]