Skip to content

Commit

Permalink
edits
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 13, 2023
1 parent 39dde5b commit eba634a
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 55 deletions.
9 changes: 5 additions & 4 deletions serket/_src/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,11 @@ def fft_conv_general_dilated(
end = [z.shape[0], z.shape[1]]
end += [max((x_shape[i] - w_shape[i] + 1), 0) for i in range(2, spatial_ndim + 2)]

if all(s == 1 for s in strides):
return jax.lax.dynamic_slice(z, start, end)

return jax.lax.slice(z, start, end, (1, 1, *strides))
return (
jax.lax.dynamic_slice(z, start, end)
if all(s == 1 for s in strides)
else jax.lax.slice(z, start, end, (1, 1, *strides))
)


def is_lazy_call(instance, *_, **__) -> bool:
Expand Down
72 changes: 33 additions & 39 deletions serket/_src/nn/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,9 @@ def __init__(

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.image.resize(
x,
shape=(x.shape[0], *self.size),
method=self.method,
antialias=self.antialias,
)
in_axes = (0, None, None, None)
args = (x, self.size, self.method, self.antialias)
return jax.vmap(jax.image.resize, in_axes=in_axes)(*args)

@property
@abc.abstractmethod
Expand All @@ -77,11 +74,9 @@ def __init__(
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, **k) -> jax.Array:
resized_shape = tuple(s * x.shape[i + 1] for i, s in enumerate(self.scale))
return jax.image.resize(
x,
shape=(x.shape[0], *resized_shape),
method=self.method,
)
in_axes = (0, None, None)
args = (x, resized_shape, self.method)
return jax.vmap(jax.image.resize, in_axes=in_axes)(*args)

@property
@abc.abstractmethod
Expand All @@ -91,7 +86,7 @@ def spatial_ndim(self) -> int:


class CropND(sk.TreeClass):
"""Applies jax.lax.dynamic_slice_in_dim to the second dimension of the input.
"""Applies ``jax.lax.dynamic_slice_in_dim`` to the second dimension of the input.
Args:
size: size of the slice, accepted values are integers or tuples of integers.
Expand All @@ -104,8 +99,9 @@ def __init__(self, size: int | tuple[int, ...], start: int | tuple[int, ...]):

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, **k) -> jax.Array:
shape = ((0, *self.start), (x.shape[0], *self.size))
return jax.lax.stop_gradient(jax.lax.dynamic_slice(x, *shape))
in_axes = (0, None, None)
args = (x, self.start, self.size)
return jax.vmap(jax.lax.dynamic_slice, in_axes=in_axes)(*args)

@property
@abc.abstractmethod
Expand All @@ -125,10 +121,9 @@ def __init__(self, padding: int | tuple[int, int], value: float = 0.0):

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, **k) -> jax.Array:
# do not pad the channel axis
shape = ((0, 0), *self.padding)
value = jax.lax.stop_gradient(self.value)
return jnp.pad(x, shape, constant_values=value)
pad = ft.partial(jnp.pad, pad_width=self.padding, constant_values=value)
return jax.vmap(pad)(x)

@property
@abc.abstractmethod
Expand Down Expand Up @@ -372,7 +367,7 @@ def spatial_ndim(self) -> int:


class Crop1D(CropND):
"""Applies jax.lax.dynamic_slice_in_dim to the second dimension of the input.
"""Applies ``jax.lax.dynamic_slice_in_dim`` to the second dimension of the input.
Args:
size: size of the slice, either a single int or a tuple of int.
Expand All @@ -393,7 +388,7 @@ def spatial_ndim(self) -> int:


class Crop2D(CropND):
"""Applies jax.lax.dynamic_slice_in_dim to the second dimension of the input.
"""Applies ``jax.lax.dynamic_slice_in_dim`` to the second dimension of the input.
Args:
size: size of the slice, either a single int or a tuple of two ints
Expand Down Expand Up @@ -425,7 +420,7 @@ def spatial_ndim(self) -> int:


class Crop3D(CropND):
"""Applies jax.lax.dynamic_slice_in_dim to the second dimension of the input.
"""Applies ``jax.lax.dynamic_slice_in_dim`` to the second dimension of the input.
Args:
size: size of the slice, either a single int or a tuple of three ints
Expand Down Expand Up @@ -469,7 +464,7 @@ def spatial_ndim(self) -> int:


class RandomCrop1D(RandomCropND):
"""Applies jax.lax.dynamic_slice_in_dim with a random start along each axis
"""Applies ``jax.lax.dynamic_slice_in_dim`` with a random start along each axis
Args:
size: size of the slice, either a single int or a tuple of int. accepted
Expand All @@ -482,7 +477,7 @@ def spatial_ndim(self) -> int:


class RandomCrop2D(RandomCropND):
"""Applies jax.lax.dynamic_slice_in_dim with a random start along each axis
"""Applies ``jax.lax.dynamic_slice_in_dim`` with a random start along each axis
Args:
size: size of the slice in each axis. accepted values are either a single int
Expand All @@ -495,7 +490,7 @@ def spatial_ndim(self) -> int:


class RandomCrop3D(RandomCropND):
"""Applies jax.lax.dynamic_slice_in_dim with a random start along each axis
"""Applies ``jax.lax.dynamic_slice_in_dim`` with a random start along each axis
Args:
size: size of the slice in each axis. accepted values are either a single int
Expand Down Expand Up @@ -577,9 +572,9 @@ def __init__(self, length_factor: tuple[int, int] = (0.0, 1.0)):
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array:
k1, k2 = jr.split(key, 2)
low, high = self.length_factor
low, high = jax.lax.stop_gradient(self.length_factor)
x = zoom_axis(x, jr.uniform(k1, minval=low, maxval=high), k2, axis=1)
return jax.lax.stop_gradient(x)
return x

@property
def spatial_ndim(self) -> int:
Expand Down Expand Up @@ -637,10 +632,10 @@ def __init__(
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array:
k1, k2, k3, k4 = jr.split(key, 4)
low, high = self.height_factor
x = zoom_axis(x, jr.uniform(k1, minval=low, maxval=high), k3, axis=1)
low, high == self.width_factor
x = zoom_axis(x, jr.uniform(k2, minval=low, maxval=high), k4, axis=2)
factors = (self.height_factor, self.width_factor)
((hfl, hfh), (wfl, wfh)) = jax.lax.stop_gradient(factors)
x = zoom_axis(x, jr.uniform(k1, minval=hfl, maxval=hfh), k3, axis=1)
x = zoom_axis(x, jr.uniform(k2, minval=wfl, maxval=wfh), k4, axis=2)
return jax.lax.stop_gradient(x)

@property
Expand All @@ -660,9 +655,9 @@ def __init__(
Positive values are zoom in, negative values are zoom out, and 0 is no zoom.
Args:
height_factor: (min, max)
width_factor: (min, max)
depth_factor: (min, max)
height_factor: (min, max) for height
width_factor: (min, max) for width
depth_factor: (min, max) for depth
Reference:
- https://www.tensorflow.org/api_docs/python/tf/keras/layers/RandomZoom
Expand All @@ -683,13 +678,12 @@ def __init__(
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array:
k1, k2, k3, k4, k5, k6 = jr.split(key, 6)
low, high = self.height_factor
x = zoom_axis(x, jr.uniform(k1, minval=low, maxval=high), k3, axis=1)
low, high == self.width_factor
x = zoom_axis(x, jr.uniform(k2, minval=low, maxval=high), k4, axis=2)
low, high == self.depth_factor
x = zoom_axis(x, jr.uniform(k5, minval=low, maxval=high), k6, axis=3)
return jax.lax.stop_gradient(x)
factors = (self.height_factor, self.width_factor, self.depth_factor)
((hfl, hfh), (wfl, wfh), (dfl, dfh)) = jax.lax.stop_gradient(factors)
x = zoom_axis(x, jr.uniform(k1, minval=hfl, maxval=hfh), k3, axis=1)
x = zoom_axis(x, jr.uniform(k2, minval=wfl, maxval=wfh), k4, axis=2)
x = zoom_axis(x, jr.uniform(k5, minval=dfl, maxval=dfh), k6, axis=3)
return x

@property
def spatial_ndim(self) -> int:
Expand Down
39 changes: 39 additions & 0 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,3 +1621,42 @@ def test_lazy_conv(layer, array, expected_shape):

assert value.shape == expected_shape
assert materialized_layer.in_features == 10


@pytest.mark.parametrize(
"direct_layer,fft_layer,kernel_size,strides,padding,dilation,ndim",
[
[sk.nn.Conv1D, sk.nn.FFTConv1D, 3, 2, 1, 1, 1],
[sk.nn.Conv2D, sk.nn.FFTConv2D, (3, 3), (2, 2), (1, 1), (2, 1), 2],
[sk.nn.Conv3D, sk.nn.FFTConv3D, (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 2, 1), 3],
],
)
def test_direct_fft_conv(
direct_layer,
fft_layer,
kernel_size,
strides,
padding,
dilation,
ndim,
):
array = jnp.ones([10] + [10] * ndim)
npt.assert_allclose(
direct_layer(
10,
1,
kernel_size=kernel_size,
strides=strides,
padding=padding,
dilation=dilation,
)(array),
fft_layer(
10,
1,
kernel_size=kernel_size,
strides=strides,
padding=padding,
dilation=dilation,
)(array),
atol=5e-6,
)
24 changes: 12 additions & 12 deletions tests/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

from __future__ import annotations

import jax
import jax.numpy as jnp
import numpy.testing as npt
import pytest

import serket as sk

Expand Down Expand Up @@ -138,15 +140,13 @@ def test_padding3d():
assert layer(jnp.ones((1, 1, 1, 1))).shape == (1, 4, 8, 12)


def test_random_zoom():
npt.assert_allclose(
sk.nn.RandomZoom1D((0, 0))(jnp.ones((10, 5))), jnp.ones((10, 5))
)

npt.assert_allclose(
sk.nn.RandomZoom2D((0.5, 0.5))(jnp.ones((10, 5, 5))).shape, (10, 5, 5)
)

npt.assert_allclose(
sk.nn.RandomZoom3D((0.5, 0.5))(jnp.ones((10, 5, 5, 5))).shape, (10, 5, 5, 5)
)
@pytest.mark.parametrize(
"layer,shape,ndim",
[
[sk.nn.RandomZoom1D, (10, 5), 1],
[sk.nn.RandomZoom2D, (10, 5, 5), 2],
[sk.nn.RandomZoom3D, (10, 5, 5, 5), 3],
],
)
def test_random_zoom(layer, shape, ndim):
npt.assert_allclose(layer((0, 0))(jnp.ones(shape)).shape, shape)

0 comments on commit eba634a

Please sign in to comment.