Skip to content

Commit

Permalink
edits
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 12, 2023
1 parent a6e0faa commit 39dde5b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 7 deletions.
9 changes: 6 additions & 3 deletions serket/_src/image/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def filter_2d(
dimension_numbers=generate_conv_dim_numbers(2),
feature_group_count=array.shape[0], # in_features
)
return jax.lax.stop_gradient_p.bind(jnp.squeeze(x, (0, 1)))
return jnp.squeeze(x, (0, 1))


def fft_filter_2d(
Expand Down Expand Up @@ -97,7 +97,7 @@ def fft_filter_2d(
dilation=(1, 1),
groups=array.shape[0], # in_features
)
return jax.lax.stop_gradient_p.bind(jnp.squeeze(x, (0, 1)))
return jnp.squeeze(x, (0, 1))


def calculate_average_kernel(
Expand Down Expand Up @@ -180,7 +180,10 @@ def calculate_laplacian_kernel(


def calculate_motion_kernel(
kernel_size: int, angle: float, direction=0.0, dtype: DType = jnp.float32
kernel_size: int,
angle: float,
direction=0.0,
dtype: DType = jnp.float32,
) -> Annotated[jax.Array, "HW"]:
"""Returns 2D motion blur filter.
Expand Down
8 changes: 4 additions & 4 deletions serket/_src/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,16 @@ class MultiHeadAttention(sk.TreeClass):
k_features: Number of features for the key.
v_features: Number of features for the value.
out_features: Number of features for the output.
q_weight_init: Initializer for the query weight. Defaults to glorot_uniform.
q_weight_init: Initializer for the query weight. Defaults to ``glorot_uniform``.
q_bias_init: Initializer for the query bias. Defaults to zeros. use
``None`` to disable bias.
k_weight_init: Initializer for the key weight. Defaults to glorot_uniform.
k_weight_init: Initializer for the key weight. Defaults to ``glorot_uniform``.
k_bias_init: Initializer for the key bias. Defaults to zeros. use
``None`` to disable bias.
v_weight_init: Initializer for the value weight. Defaults to glorot_uniform.
v_weight_init: Initializer for the value weight. Defaults to ``glorot_uniform``.
v_bias_init: Initializer for the value bias. Defaults to zeros. use
``None`` to disable bias.
out_weight_init: Initializer for the output weight. Defaults to glorot_uniform.
out_weight_init: Initializer for the output weight. Defaults to ``glorot_uniform``.
out_bias_init: Initializer for the output bias. Defaults to zeros. use
``None`` to disable bias.
drop_rate: Dropout rate. defaults to 0.0.
Expand Down
28 changes: 28 additions & 0 deletions tests/test_image_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def test_AvgBlur2D():
)
npt.assert_allclose(y, z, atol=1e-5)

layer = sk.tree_mask(sk.image.FFTAvgBlur2D((3, 5)))
grads = jax.grad(lambda node: jnp.sum(node(x)))(layer)
npt.assert_allclose(grads.kernel_x, jnp.zeros_like(grads.kernel_x))
npt.assert_allclose(grads.kernel_y, jnp.zeros_like(grads.kernel_y))


def test_GaussBlur2D():
layer = sk.image.GaussianBlur2D(kernel_size=3, sigma=1.0)
Expand Down Expand Up @@ -81,6 +86,11 @@ def test_filter2d():

npt.assert_allclose(layer(x), layer2(x), atol=1e-4)

layer = sk.tree_mask(sk.image.AvgBlur2D((3, 5)))
grads = jax.grad(lambda node: jnp.sum(node(x)))(layer)
npt.assert_allclose(grads.kernel_x, jnp.zeros_like(grads.kernel_x))
npt.assert_allclose(grads.kernel_y, jnp.zeros_like(grads.kernel_y))


def test_solarize2d():
x = jnp.arange(1, 26).reshape(1, 5, 5)
Expand Down Expand Up @@ -412,6 +422,11 @@ def test_unsharp_mask():
atol=1e-5,
)

layer = sk.tree_mask(sk.image.UnsharpMask2D((3, 5)))
grads = jax.grad(lambda node: jnp.sum(node(x)))(layer)
npt.assert_allclose(grads.kernel_x, jnp.zeros_like(grads.kernel_x))
npt.assert_allclose(grads.kernel_y, jnp.zeros_like(grads.kernel_y))


def test_box_blur():
x = jnp.arange(1, 17).reshape(1, 4, 4).astype(jnp.float32)
Expand All @@ -429,6 +444,11 @@ def test_box_blur():
npt.assert_allclose(sk.image.BoxBlur2D((3, 5))(x), y, atol=1e-6)
npt.assert_allclose(sk.image.FFTBoxBlur2D((3, 5))(x), y, atol=1e-6)

layer = sk.tree_mask(sk.image.BoxBlur2D((3, 5)))
grads = jax.grad(lambda node: jnp.sum(node(x)))(layer)
npt.assert_allclose(grads.kernel_x, jnp.zeros_like(grads.kernel_x))
npt.assert_allclose(grads.kernel_y, jnp.zeros_like(grads.kernel_y))


def test_laplacian():
x = jax.random.uniform(jax.random.PRNGKey(0), (2, 10, 10))
Expand All @@ -447,6 +467,10 @@ def test_laplacian():
atol=1e-5,
)

layer = sk.tree_mask(sk.image.Laplacian2D((3, 5)))
grads = jax.grad(lambda node: jnp.sum(node(x)))(layer)
npt.assert_allclose(grads.kernel, jnp.zeros_like(grads.kernel))


def test_center_crop():
x = jnp.arange(1, 26).reshape(1, 5, 5)
Expand All @@ -469,3 +493,7 @@ def test_motion():
]
)
npt.assert_allclose(y, ytrue)

layer = sk.tree_mask(sk.image.MotionBlur2D(3))
grads = jax.grad(lambda node: jnp.sum(node(x)))(layer)
npt.assert_allclose(grads.kernel, jnp.zeros_like(grads.kernel))

0 comments on commit 39dde5b

Please sign in to comment.