From e87f4c739a2b09f5915d3651edcb4ebf26eb8690 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Fri, 18 Aug 2023 18:56:02 +0300 Subject: [PATCH] add jigsaw --- CHANEGLOG.md | 1 + docs/API/image.rst | 1 + serket/nn/__init__.py | 2 + serket/nn/image.py | 89 ++++++++++++++++++++++++++++++++++++++ tests/test_image_filter.py | 10 +++++ 5 files changed, 103 insertions(+) diff --git a/CHANEGLOG.md b/CHANEGLOG.md index f808983..5358305 100644 --- a/CHANEGLOG.md +++ b/CHANEGLOG.md @@ -49,6 +49,7 @@ - `Pixelate2D` - `Solarize2D` - `Posterize2D` +- `JigSaw2D` ### Deprecations diff --git a/docs/API/image.rst b/docs/API/image.rst index ee86670..c2bd50e 100644 --- a/docs/API/image.rst +++ b/docs/API/image.rst @@ -10,6 +10,7 @@ Image filtering .. autoclass:: HistogramEqualization2D .. autoclass:: HorizontalShear2D .. autoclass:: HorizontalTranslate2D +.. autoclass:: JigSaw2D .. autoclass:: Pixelate2D .. autoclass:: PixelShuffle2D .. autoclass:: RandomContrast2D diff --git a/serket/nn/__init__.py b/serket/nn/__init__.py index 97ad19b..c430e72 100644 --- a/serket/nn/__init__.py +++ b/serket/nn/__init__.py @@ -96,6 +96,7 @@ HistogramEqualization2D, HorizontalShear2D, HorizontalTranslate2D, + JigSaw2D, Pixelate2D, PixelShuffle2D, Posterize2D, @@ -280,6 +281,7 @@ "HistogramEqualization2D", "HorizontalShear2D", "HorizontalTranslate2D", + "JigSaw2D", "Pixelate2D", "PixelShuffle2D", "Posterize2D", diff --git a/serket/nn/image.py b/serket/nn/image.py index 7a57f07..eb7eef6 100644 --- a/serket/nn/image.py +++ b/serket/nn/image.py @@ -1174,6 +1174,94 @@ def spatial_ndim(self) -> int: return 2 +@ft.partial(jax.jit, inline=True, static_argnums=1) +def jigsaw( + image: jax.Array, + tiles: int = 1, + key: jr.KeyArray = jr.PRNGKey(0), +) -> jax.Array: + """Jigsaw channel-first image + + Args: + image: channel-first image (CHW) + tiles: number of tiles per side + key: random key + """ + channels, height, width = image.shape + tile_height = height // tiles + tile_width = width // tiles + + image_ = image[:, : height - height % tiles, : width - width % tiles] + + image_ = image_.reshape(channels, tiles, tile_height, tiles, tile_width) + image_ = image_.transpose(1, 3, 0, 2, 4) + image_ = image_.reshape(-1, channels, tile_height, tile_width) + + indices = jr.permutation(key, len(image_)) + image_ = jax.vmap(lambda x: image_[x])(indices) + + image_ = image_.reshape(tiles, tiles, channels, tile_height, tile_width) + image_ = image_.transpose(2, 0, 3, 1, 4) + image_ = image_.reshape(channels, tiles * tile_height, tiles * tile_width) + + image = image.at[:, : height - height % tiles, : width - width % tiles].set(image_) + + return image + + +@sk.autoinit +class JigSaw2D(sk.TreeClass): + """Mixes up tiles of an image. + + Args: + tiles: number of tiles per side + + Example: + >>> import serket as sk + >>> import jax.numpy as jnp + >>> x = jnp.arange(1, 17).reshape(1, 4, 4) + >>> print(x) + [[[ 1 2 3 4] + [ 5 6 7 8] + [ 9 10 11 12] + [13 14 15 16]]] + >>> print(sk.nn.JigSaw(2)(x)) + [[[ 9 10 3 4] + [13 14 7 8] + [11 12 1 2] + [15 16 5 6]]] + + Note: + - Use :func:`tree_eval` to replace this layer with :class:`Identity` during + evaluation. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> x = jnp.arange(1, 17).reshape(1, 4, 4) + >>> layer = sk.nn.JigSaw2D(2) + >>> eval_layer = sk.tree_eval(layer) + >>> print(eval_layer(x)) + [[[ 1 2 3 4] + [ 5 6 7 8] + [ 9 10 11 12] + [13 14 15 16]]] + + Reference: + - https://imgaug.readthedocs.io/en/latest/source/overview/geometric.html#jigsaw + """ + + tiles: int = sk.field(callbacks=[IsInstance(int), Range(1)]) + + def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array: + """Mixes up tiles of an image. + + Args: + x: channel-first image (CHW) + key: random key + """ + return jigsaw(x, self.tiles, key) + + @tree_eval.def_eval(RandomContrast2D) @tree_eval.def_eval(RandomRotate2D) @tree_eval.def_eval(RandomHorizontalShear2D) @@ -1181,5 +1269,6 @@ def spatial_ndim(self) -> int: @tree_eval.def_eval(RandomPerspective2D) @tree_eval.def_eval(RandomHorizontalTranslate2D) @tree_eval.def_eval(RandomVerticalTranslate2D) +@tree_eval.def_eval(JigSaw2D) def random_image_transform(_): return Identity() diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index bc2e7e5..dba1bdf 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -24,6 +24,7 @@ Filter2D, GaussianBlur2D, HorizontalTranslate2D, + JigSaw2D, Solarize2D, VerticalTranslate2D, ) @@ -204,3 +205,12 @@ def test_vertical_translate(): layer = VerticalTranslate2D(0) npt.assert_allclose(layer(x), x) + + +def test_jigsaw(): + x = jnp.arange(1, 17).reshape(1, 4, 4) + layer = JigSaw2D(2) + npt.assert_allclose( + layer(x), + jnp.array([[[9, 10, 3, 4], [13, 14, 7, 8], [11, 12, 1, 2], [15, 16, 5, 6]]]), + )