Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 6, 2023
1 parent 59ec944 commit 756acd0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
12 changes: 6 additions & 6 deletions serket/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def __init__(
dtype=dtype,
)

self.ih2h_weight = jnp.concatenate([i2h.weight, h2h.weight], axis=0)
self.ih2h_bias = i2h.bias
self.in_hidden_to_hidden_weight = jnp.concatenate([i2h.weight, h2h.weight])
self.in_hidden_to_hidden_bias = i2h.bias

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
Expand All @@ -181,7 +181,7 @@ def __call__(self, x: jax.Array, state: SimpleRNNState, **k) -> SimpleRNNState:
raise TypeError(f"Expected {state=} to be an instance of `SimpleRNNState`")

ih = jnp.concatenate([x, state.hidden_state], axis=-1)
h = ih @ self.ih2h_weight + self.ih2h_bias
h = ih @ self.in_hidden_to_hidden_weight + self.in_hidden_to_hidden_bias
return SimpleRNNState(self.act_func(h))

@property
Expand Down Expand Up @@ -368,8 +368,8 @@ def __init__(
dtype=dtype,
)

self.ih2h_weight = jnp.concatenate([i2h.weight, h2h.weight], axis=0)
self.ih2h_bias = i2h.bias
self.in_hidden_to_hidden_weight = jnp.concatenate([i2h.weight, h2h.weight])
self.in_hidden_to_hidden_bias = i2h.bias

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
Expand All @@ -380,7 +380,7 @@ def __call__(self, x: jax.Array, state: LSTMState, **k) -> LSTMState:

h, c = state.hidden_state, state.cell_state
ih = jnp.concatenate([x, h], axis=-1)
h = ih @ self.ih2h_weight + self.ih2h_bias
h = ih @ self.in_hidden_to_hidden_weight + self.in_hidden_to_hidden_bias
i, f, g, o = jnp.split(h, 4, axis=-1)
i = self.recurrent_act_func(i)
f = self.recurrent_act_func(f)
Expand Down
22 changes: 12 additions & 10 deletions tests/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_vanilla_rnn():
)

w_combined = jnp.concatenate([w_in_to_hidden, w_hidden_to_hidden], axis=0)
cell = cell.at["ih2h_weight"].set(w_combined)
cell = cell.at["in_hidden_to_hidden_weight"].set(w_combined)
sk_layer = ScanRNN(cell)
y = jnp.array([0.9637042, -0.8282256, 0.7314449])
npt.assert_allclose(sk_layer(x), y)
Expand Down Expand Up @@ -229,8 +229,8 @@ def test_lstm():
recurrent_weight_init="glorot_uniform",
)
w_combined = jnp.concatenate([w_in_to_hidden, w_hidden_to_hidden], axis=0)
cell = cell.at["ih2h_weight"].set(w_combined)
cell = cell.at["ih2h_bias"].set(b_hidden_to_hidden)
cell = cell.at["in_hidden_to_hidden_weight"].set(w_combined)
cell = cell.at["in_hidden_to_hidden_bias"].set(b_hidden_to_hidden)

sk_layer = ScanRNN(cell, return_sequences=False)

Expand Down Expand Up @@ -329,8 +329,8 @@ def test_lstm():

w_combined = jnp.concatenate([w_in_to_hidden, w_hidden_to_hidden], axis=0)

cell = cell.at["ih2h_weight"].set(w_combined)
cell = cell.at["ih2h_bias"].set(b_hidden_to_hidden)
cell = cell.at["in_hidden_to_hidden_weight"].set(w_combined)
cell = cell.at["in_hidden_to_hidden_bias"].set(b_hidden_to_hidden)

sk_layer = ScanRNN(cell, return_sequences=True)

Expand Down Expand Up @@ -751,14 +751,16 @@ def test_bilstm():
b_hidden_to_hidden_reverse = jnp.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0])

combined_w = jnp.concatenate([w_in_to_hidden, w_hidden_to_hidden], axis=0)
cell = cell.at["ih2h_weight"].set(combined_w)
cell = cell.at["ih2h_bias"].set(b_hidden_to_hidden)
cell = cell.at["in_hidden_to_hidden_weight"].set(combined_w)
cell = cell.at["in_hidden_to_hidden_bias"].set(b_hidden_to_hidden)

combined_w_reverse = jnp.concatenate(
[w_in_to_hidden_reverse, w_hidden_to_hidden_reverse], axis=0
[w_in_to_hidden_reverse, w_hidden_to_hidden_reverse]
)
reverse_cell = reverse_cell.at["in_hidden_to_hidden_weight"].set(combined_w_reverse)
reverse_cell = reverse_cell.at["in_hidden_to_hidden_bias"].set(
b_hidden_to_hidden_reverse
)
reverse_cell = reverse_cell.at["ih2h_weight"].set(combined_w_reverse)
reverse_cell = reverse_cell.at["ih2h_bias"].set(b_hidden_to_hidden_reverse)

res = ScanRNN(cell, reverse_cell, return_sequences=False)(x)

Expand Down

0 comments on commit 756acd0

Please sign in to comment.