Skip to content

Commit

Permalink
add jigsaw
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 18, 2023
1 parent 12ff464 commit e87f4c7
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANEGLOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
- `Pixelate2D`
- `Solarize2D`
- `Posterize2D`
- `JigSaw2D`

### Deprecations

Expand Down
1 change: 1 addition & 0 deletions docs/API/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Image filtering
.. autoclass:: HistogramEqualization2D
.. autoclass:: HorizontalShear2D
.. autoclass:: HorizontalTranslate2D
.. autoclass:: JigSaw2D
.. autoclass:: Pixelate2D
.. autoclass:: PixelShuffle2D
.. autoclass:: RandomContrast2D
Expand Down
2 changes: 2 additions & 0 deletions serket/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
HistogramEqualization2D,
HorizontalShear2D,
HorizontalTranslate2D,
JigSaw2D,
Pixelate2D,
PixelShuffle2D,
Posterize2D,
Expand Down Expand Up @@ -280,6 +281,7 @@
"HistogramEqualization2D",
"HorizontalShear2D",
"HorizontalTranslate2D",
"JigSaw2D",
"Pixelate2D",
"PixelShuffle2D",
"Posterize2D",
Expand Down
89 changes: 89 additions & 0 deletions serket/nn/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,12 +1174,101 @@ def spatial_ndim(self) -> int:
return 2


@ft.partial(jax.jit, inline=True, static_argnums=1)
def jigsaw(
image: jax.Array,
tiles: int = 1,
key: jr.KeyArray = jr.PRNGKey(0),
) -> jax.Array:
"""Jigsaw channel-first image
Args:
image: channel-first image (CHW)
tiles: number of tiles per side
key: random key
"""
channels, height, width = image.shape
tile_height = height // tiles
tile_width = width // tiles

image_ = image[:, : height - height % tiles, : width - width % tiles]

image_ = image_.reshape(channels, tiles, tile_height, tiles, tile_width)
image_ = image_.transpose(1, 3, 0, 2, 4)
image_ = image_.reshape(-1, channels, tile_height, tile_width)

indices = jr.permutation(key, len(image_))
image_ = jax.vmap(lambda x: image_[x])(indices)

image_ = image_.reshape(tiles, tiles, channels, tile_height, tile_width)
image_ = image_.transpose(2, 0, 3, 1, 4)
image_ = image_.reshape(channels, tiles * tile_height, tiles * tile_width)

image = image.at[:, : height - height % tiles, : width - width % tiles].set(image_)

return image


@sk.autoinit
class JigSaw2D(sk.TreeClass):
"""Mixes up tiles of an image.
Args:
tiles: number of tiles per side
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 17).reshape(1, 4, 4)
>>> print(x)
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
>>> print(sk.nn.JigSaw(2)(x))
[[[ 9 10 3 4]
[13 14 7 8]
[11 12 1 2]
[15 16 5 6]]]
Note:
- Use :func:`tree_eval` to replace this layer with :class:`Identity` during
evaluation.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 17).reshape(1, 4, 4)
>>> layer = sk.nn.JigSaw2D(2)
>>> eval_layer = sk.tree_eval(layer)
>>> print(eval_layer(x))
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
Reference:
- https://imgaug.readthedocs.io/en/latest/source/overview/geometric.html#jigsaw
"""

tiles: int = sk.field(callbacks=[IsInstance(int), Range(1)])

def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array:
"""Mixes up tiles of an image.
Args:
x: channel-first image (CHW)
key: random key
"""
return jigsaw(x, self.tiles, key)


@tree_eval.def_eval(RandomContrast2D)
@tree_eval.def_eval(RandomRotate2D)
@tree_eval.def_eval(RandomHorizontalShear2D)
@tree_eval.def_eval(RandomVerticalShear2D)
@tree_eval.def_eval(RandomPerspective2D)
@tree_eval.def_eval(RandomHorizontalTranslate2D)
@tree_eval.def_eval(RandomVerticalTranslate2D)
@tree_eval.def_eval(JigSaw2D)
def random_image_transform(_):
return Identity()
10 changes: 10 additions & 0 deletions tests/test_image_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Filter2D,
GaussianBlur2D,
HorizontalTranslate2D,
JigSaw2D,
Solarize2D,
VerticalTranslate2D,
)
Expand Down Expand Up @@ -204,3 +205,12 @@ def test_vertical_translate():
layer = VerticalTranslate2D(0)

npt.assert_allclose(layer(x), x)


def test_jigsaw():
x = jnp.arange(1, 17).reshape(1, 4, 4)
layer = JigSaw2D(2)
npt.assert_allclose(
layer(x),
jnp.array([[[9, 10, 3, 4], [13, 14, 7, 8], [11, 12, 1, 2], [15, 16, 5, 6]]]),
)

0 comments on commit e87f4c7

Please sign in to comment.