diff --git a/serket/nn/recurrent.py b/serket/nn/recurrent.py index 366b43a..2356b01 100644 --- a/serket/nn/recurrent.py +++ b/serket/nn/recurrent.py @@ -170,8 +170,8 @@ def __init__( dtype=dtype, ) - self.ih2h_weight = jnp.concatenate([i2h.weight, h2h.weight], axis=0) - self.ih2h_bias = i2h.bias + self.in_hidden_to_hidden_weight = jnp.concatenate([i2h.weight, h2h.weight]) + self.in_hidden_to_hidden_bias = i2h.bias @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @@ -181,7 +181,7 @@ def __call__(self, x: jax.Array, state: SimpleRNNState, **k) -> SimpleRNNState: raise TypeError(f"Expected {state=} to be an instance of `SimpleRNNState`") ih = jnp.concatenate([x, state.hidden_state], axis=-1) - h = ih @ self.ih2h_weight + self.ih2h_bias + h = ih @ self.in_hidden_to_hidden_weight + self.in_hidden_to_hidden_bias return SimpleRNNState(self.act_func(h)) @property @@ -368,8 +368,8 @@ def __init__( dtype=dtype, ) - self.ih2h_weight = jnp.concatenate([i2h.weight, h2h.weight], axis=0) - self.ih2h_bias = i2h.bias + self.in_hidden_to_hidden_weight = jnp.concatenate([i2h.weight, h2h.weight]) + self.in_hidden_to_hidden_bias = i2h.bias @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @@ -380,7 +380,7 @@ def __call__(self, x: jax.Array, state: LSTMState, **k) -> LSTMState: h, c = state.hidden_state, state.cell_state ih = jnp.concatenate([x, h], axis=-1) - h = ih @ self.ih2h_weight + self.ih2h_bias + h = ih @ self.in_hidden_to_hidden_weight + self.in_hidden_to_hidden_bias i, f, g, o = jnp.split(h, 4, axis=-1) i = self.recurrent_act_func(i) f = self.recurrent_act_func(f) diff --git a/tests/test_rnn.py b/tests/test_rnn.py index e6af153..313341d 100644 --- a/tests/test_rnn.py +++ b/tests/test_rnn.py @@ -115,7 +115,7 @@ def test_vanilla_rnn(): ) w_combined = jnp.concatenate([w_in_to_hidden, w_hidden_to_hidden], axis=0) - cell = cell.at["ih2h_weight"].set(w_combined) + cell = cell.at["in_hidden_to_hidden_weight"].set(w_combined) sk_layer = ScanRNN(cell) y = jnp.array([0.9637042, -0.8282256, 0.7314449]) npt.assert_allclose(sk_layer(x), y) @@ -229,8 +229,8 @@ def test_lstm(): recurrent_weight_init="glorot_uniform", ) w_combined = jnp.concatenate([w_in_to_hidden, w_hidden_to_hidden], axis=0) - cell = cell.at["ih2h_weight"].set(w_combined) - cell = cell.at["ih2h_bias"].set(b_hidden_to_hidden) + cell = cell.at["in_hidden_to_hidden_weight"].set(w_combined) + cell = cell.at["in_hidden_to_hidden_bias"].set(b_hidden_to_hidden) sk_layer = ScanRNN(cell, return_sequences=False) @@ -329,8 +329,8 @@ def test_lstm(): w_combined = jnp.concatenate([w_in_to_hidden, w_hidden_to_hidden], axis=0) - cell = cell.at["ih2h_weight"].set(w_combined) - cell = cell.at["ih2h_bias"].set(b_hidden_to_hidden) + cell = cell.at["in_hidden_to_hidden_weight"].set(w_combined) + cell = cell.at["in_hidden_to_hidden_bias"].set(b_hidden_to_hidden) sk_layer = ScanRNN(cell, return_sequences=True) @@ -751,14 +751,16 @@ def test_bilstm(): b_hidden_to_hidden_reverse = jnp.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]) combined_w = jnp.concatenate([w_in_to_hidden, w_hidden_to_hidden], axis=0) - cell = cell.at["ih2h_weight"].set(combined_w) - cell = cell.at["ih2h_bias"].set(b_hidden_to_hidden) + cell = cell.at["in_hidden_to_hidden_weight"].set(combined_w) + cell = cell.at["in_hidden_to_hidden_bias"].set(b_hidden_to_hidden) combined_w_reverse = jnp.concatenate( - [w_in_to_hidden_reverse, w_hidden_to_hidden_reverse], axis=0 + [w_in_to_hidden_reverse, w_hidden_to_hidden_reverse] + ) + reverse_cell = reverse_cell.at["in_hidden_to_hidden_weight"].set(combined_w_reverse) + reverse_cell = reverse_cell.at["in_hidden_to_hidden_bias"].set( + b_hidden_to_hidden_reverse ) - reverse_cell = reverse_cell.at["ih2h_weight"].set(combined_w_reverse) - reverse_cell = reverse_cell.at["ih2h_bias"].set(b_hidden_to_hidden_reverse) res = ScanRNN(cell, reverse_cell, return_sequences=False)(x)