Skip to content

Commit

Permalink
Update test_convolution.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Dec 16, 2023
1 parent b032b24 commit 59c27c9
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,23 +1179,24 @@ def test_groups_error():
)
def test_lazy_conv(layer, array, expected_shape):
lazy = layer(None, 1, 3, key=jax.random.PRNGKey(0))
value, material = lazy.at["__call__"](array)
value, material = sk.value_and_tree(lambda layer: layer(array))(lazy)

assert value.shape == expected_shape
assert material.in_features == 10


def test_lazy_conv_local():
layer = sk.nn.Conv1DLocal(None, 1, 3, in_size=(3,), key=jax.random.PRNGKey(0))
_, layer = layer.at["__call__"](jnp.ones([10, 3]))
_, layer = sk.value_and_tree(lambda layer: layer(jnp.ones([10, 3])))(layer)
assert layer.in_features == 10
layer = sk.nn.Conv1DLocal(2, 1, 2, in_size=None, key=jax.random.PRNGKey(0))

with pytest.raises(ValueError):
# should raise error because in_features is specified = 2 and
# input in_features is 10
_, layer = layer.at["__call__"](jnp.ones([10, 3]))
_, layer = layer.at["__call__"](jnp.ones([2, 3]))
_, layer = sk.value_and_tree(lambda layer: layer(jnp.ones([10, 3])))(layer)

_, layer = sk.value_and_tree(lambda layer: layer(jnp.ones([2, 3])))(layer)
assert layer.in_features == 2


Expand Down

0 comments on commit 59c27c9

Please sign in to comment.