Skip to content

Commit

Permalink
.T mlp mid, add bias test
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Dec 2, 2023
1 parent 81b4aae commit d811469
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
2 changes: 1 addition & 1 deletion serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def scan_func(x: jax.Array, weight: Batched[jax.Array]):

def scan_func(x: jax.Array, weight_bias: Batched[jax.Array]):
weight, bias = weight_bias[..., :-1], weight_bias[..., -1]
return act(x @ weight + bias), None
return act(x @ weight.T + bias), None

weight_bias = jnp.concatenate([weight, bias[:, :, None]], axis=-1)
output, _ = jax.lax.scan(scan_func, input, weight_bias)
Expand Down
37 changes: 37 additions & 0 deletions tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,40 @@ def test_mlp():
layer = layer.at["out_linear"]["weight"].set(w3.T)

npt.assert_allclose(layer(x), y)


def test_mlp_bias():
x = jnp.linspace(0, 1, 100)[:, None]

x = jax.random.normal(jax.random.PRNGKey(0), (10, 1))
w1 = jax.random.normal(jax.random.PRNGKey(1), (1, 10))
w2 = jax.random.normal(jax.random.PRNGKey(2), (10, 10))
w3 = jax.random.normal(jax.random.PRNGKey(3), (10, 4))
b1 = jax.random.normal(jax.random.PRNGKey(4), (10,))
b2 = jax.random.normal(jax.random.PRNGKey(5), (10,))
b3 = jax.random.normal(jax.random.PRNGKey(6), (4,))

y = x @ w1 + b1
y = jax.nn.tanh(y)
y = y @ w2 + b2
y = jax.nn.tanh(y)
y = y @ w3 + b3

layer = sk.nn.MLP(
1,
4,
hidden_features=10,
num_hidden_layers=2,
act="tanh",
bias_init="zeros",
key=jax.random.PRNGKey(0),
)

layer = layer.at["in_linear"]["weight"].set(w1.T)
layer = layer.at["in_linear"]["bias"].set(b1)
layer = layer.at["mid_linear"]["weight"].set(w2.T[None])
layer = layer.at["mid_linear"]["bias"].set(b2[None])
layer = layer.at["out_linear"]["weight"].set(w3.T)
layer = layer.at["out_linear"]["bias"].set(b3)

npt.assert_allclose(layer(x), y)

0 comments on commit d811469

Please sign in to comment.