Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 24, 2023
1 parent ae23ef5 commit 813da80
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 95 deletions.
4 changes: 2 additions & 2 deletions serket/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,10 @@ def __init__(
self.bias_init = bias_init

if not (all(isinstance(i, int) for i in in_features)):
raise ValueError(f"Expected tuple of ints for {in_features=}")
raise TypeError(f"Expected tuple of ints for {in_features=}")

if not (all(isinstance(i, int) for i in in_axes)):
raise ValueError(f"Expected tuple of ints for {in_axes=}")
raise TypeError(f"Expected tuple of ints for {in_axes=}")

if len(in_axes) != len(in_features):
raise ValueError(f"{len(in_axes)=} != {len(in_features)=}")
Expand Down
8 changes: 8 additions & 0 deletions tests/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
SoftPlus,
SoftShrink,
SoftSign,
SquarePlus,
Swish,
Tanh,
TanhShrink,
Expand Down Expand Up @@ -291,6 +292,13 @@ def test_snake():
npt.assert_allclose(actual, expected, atol=1e-4)


def test_square_plus():
x = jnp.array([-1.0, 0, 1])
expected = 0.5 * (x + jnp.sqrt(x**2 + 4))
actual = SquarePlus()(x)
npt.assert_allclose(actual, expected, atol=1e-4)


def test_resolving():
with pytest.raises(ValueError):
resolve_activation("nonexistent")
Expand Down
5 changes: 5 additions & 0 deletions tests/test_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,15 @@ def test_random_apply():
layer = sk.nn.RandomApply(lambda x: x + 1, rate=0.0)
assert layer(1, key=jr.PRNGKey(0)) == 1

assert sk.tree_eval(layer)(1) == 2


def test_random_choice():
layer = sk.nn.RandomChoice(lambda x: x + 2, lambda x: x * 2)
key = jr.PRNGKey(0)
assert layer(1, key=key) == 3.0
key = jr.PRNGKey(10)
assert layer(1, key=key) == 2.0

# convert all choices to sequential
assert sk.tree_eval(layer)(1) == (1.0 + 2.0) * 2.0
92 changes: 0 additions & 92 deletions tests/test_contrast.py

This file was deleted.

74 changes: 74 additions & 0 deletions tests/test_image_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

from __future__ import annotations

import jax
import jax.numpy as jnp
import numpy.testing as npt
import pytest

import serket as sk
from serket.nn.image import (
AdjustContrast2D,
AvgBlur2D,
FFTFilter2D,
Filter2D,
Expand All @@ -28,6 +30,7 @@
HorizontalTranslate2D,
JigSaw2D,
Pixelate2D,
RandomContrast2D,
RandomHorizontalShear2D,
RandomRotate2D,
RandomVerticalShear2D,
Expand Down Expand Up @@ -320,3 +323,74 @@ def test_pixelate():
x = jnp.arange(1, 26).reshape(1, 5, 5)
layer = Pixelate2D(1)
npt.assert_allclose(layer(x), x)


def test_adjust_contrast_2d():
x = jnp.array(
[
[
[0.19165385, 0.4459561, 0.03873193],
[0.58923364, 0.0923605, 0.2597469],
[0.83097064, 0.4854728, 0.03308535],
],
[
[0.10485303, 0.10068893, 0.408355],
[0.40298176, 0.6227188, 0.8612417],
[0.52223504, 0.3363577, 0.1300546],
],
]
)
y = jnp.array(
[
[
[0.26067203, 0.38782316, 0.18421106],
[0.45946193, 0.21102534, 0.29471856],
[0.5803304, 0.4075815, 0.18138777],
],
[
[0.24628687, 0.24420482, 0.39803785],
[0.39535123, 0.50521976, 0.6244812],
[0.45497787, 0.3620392, 0.25888765],
],
]
)

npt.assert_allclose(AdjustContrast2D(contrast_factor=0.5)(x), y, atol=1e-5)


def test_random_contrast_2d():
x = jnp.array(
[
[
[0.19165385, 0.4459561, 0.03873193],
[0.58923364, 0.0923605, 0.2597469],
[0.83097064, 0.4854728, 0.03308535],
],
[
[0.10485303, 0.10068893, 0.408355],
[0.40298176, 0.6227188, 0.8612417],
[0.52223504, 0.3363577, 0.1300546],
],
]
)

y = jnp.array(
[
[
[0.23179087, 0.4121493, 0.1233343],
[0.5137658, 0.1613692, 0.28008443],
[0.68521255, 0.44017565, 0.11932957],
],
[
[0.18710288, 0.1841496, 0.40235513],
[0.39854428, 0.55438805, 0.7235553],
[0.4831221, 0.3512926, 0.20497654],
],
]
)

npt.assert_allclose(
RandomContrast2D(contrast_range=(0.5, 1))(x, key=jax.random.PRNGKey(0)),
y,
atol=1e-5,
)
32 changes: 31 additions & 1 deletion tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pytest

import serket as sk
from serket.nn import FNN, Embedding, GeneralLinear, Identity, Linear, Multilinear
from serket.nn import FNN, MLP, Embedding, GeneralLinear, Identity, Linear, Multilinear


def test_embed():
Expand Down Expand Up @@ -65,6 +65,10 @@ def update(NN, x, y):
y = jnp.array([[-0.31568417]])
npt.assert_allclose(layer(jnp.array([[1.0]])), y)

layer = Linear(None, 1, bias_init="zeros")
_, layer = layer.at["__call__"](jnp.ones([100, 2]))
assert layer.in_features == (2,)


def test_bilinear():
W = jnp.array(
Expand Down Expand Up @@ -135,3 +139,29 @@ def test_general_linear():

with pytest.raises(ValueError):
GeneralLinear(in_features=(1,), in_axes=(0, -3), out_features=5)

with pytest.raises(TypeError):
GeneralLinear(in_features=(1, "s"), in_axes=(0, -3), out_features=5)

with pytest.raises(TypeError):
GeneralLinear(in_features=(1, 2), in_axes=(0, "s"), out_features=3)


def test_mlp():
x = jnp.linspace(0, 1, 100)[:, None]

fnn = FNN([1, 2, 1])
mlp = MLP(in_features=1, out_features=1, hidden_size=2, num_hidden_layers=1)

npt.assert_allclose(fnn(x), mlp(x), atol=1e-4)

fnn = FNN([1, 2, 2, 1], act=("relu", "tanh"))
mlp = MLP(
in_features=1,
out_features=1,
hidden_size=2,
num_hidden_layers=2,
act=("relu", "tanh"),
)

npt.assert_allclose(fnn(x), mlp(x), atol=1e-4)

0 comments on commit 813da80

Please sign in to comment.