diff --git a/serket/_src/nn/linear.py b/serket/_src/nn/linear.py index 2fd28ca..2a80bfd 100644 --- a/serket/_src/nn/linear.py +++ b/serket/_src/nn/linear.py @@ -258,7 +258,7 @@ def scan_func(x: jax.Array, weight: Batched[jax.Array]): def scan_func(x: jax.Array, weight_bias: Batched[jax.Array]): weight, bias = weight_bias[..., :-1], weight_bias[..., -1] - return act(x @ weight + bias), None + return act(x @ weight.T + bias), None weight_bias = jnp.concatenate([weight, bias[:, :, None]], axis=-1) output, _ = jax.lax.scan(scan_func, input, weight_bias) diff --git a/tests/test_linear.py b/tests/test_linear.py index 6111ec3..e9a0b61 100644 --- a/tests/test_linear.py +++ b/tests/test_linear.py @@ -142,3 +142,40 @@ def test_mlp(): layer = layer.at["out_linear"]["weight"].set(w3.T) npt.assert_allclose(layer(x), y) + + +def test_mlp_bias(): + x = jnp.linspace(0, 1, 100)[:, None] + + x = jax.random.normal(jax.random.PRNGKey(0), (10, 1)) + w1 = jax.random.normal(jax.random.PRNGKey(1), (1, 10)) + w2 = jax.random.normal(jax.random.PRNGKey(2), (10, 10)) + w3 = jax.random.normal(jax.random.PRNGKey(3), (10, 4)) + b1 = jax.random.normal(jax.random.PRNGKey(4), (10,)) + b2 = jax.random.normal(jax.random.PRNGKey(5), (10,)) + b3 = jax.random.normal(jax.random.PRNGKey(6), (4,)) + + y = x @ w1 + b1 + y = jax.nn.tanh(y) + y = y @ w2 + b2 + y = jax.nn.tanh(y) + y = y @ w3 + b3 + + layer = sk.nn.MLP( + 1, + 4, + hidden_features=10, + num_hidden_layers=2, + act="tanh", + bias_init="zeros", + key=jax.random.PRNGKey(0), + ) + + layer = layer.at["in_linear"]["weight"].set(w1.T) + layer = layer.at["in_linear"]["bias"].set(b1) + layer = layer.at["mid_linear"]["weight"].set(w2.T[None]) + layer = layer.at["mid_linear"]["bias"].set(b2[None]) + layer = layer.at["out_linear"]["weight"].set(w3.T) + layer = layer.at["out_linear"]["bias"].set(b3) + + npt.assert_allclose(layer(x), y)