diff --git a/CHANEGLOG.md b/CHANEGLOG.md index e07c7d7..e1e8034 100644 --- a/CHANEGLOG.md +++ b/CHANEGLOG.md @@ -47,6 +47,7 @@ - `VerticalShear2D` - `Pixelate2D` - `Solarize2D` +- `Posterize2D` ### Deprecations diff --git a/serket/nn/__init__.py b/serket/nn/__init__.py index 695d578..7f706bc 100644 --- a/serket/nn/__init__.py +++ b/serket/nn/__init__.py @@ -96,6 +96,7 @@ HorizontalShear2D, Pixelate2D, PixelShuffle2D, + Posterize2D, RandomContrast2D, RandomHorizontalShear2D, RandomPerspective2D, @@ -274,6 +275,7 @@ "HorizontalShear2D", "Pixelate2D", "PixelShuffle2D", + "Posterize2D", "RandomContrast2D", "RandomHorizontalShear2D", "RandomPerspective2D", diff --git a/serket/nn/image.py b/serket/nn/image.py index 30165fb..44e6ee3 100644 --- a/serket/nn/image.py +++ b/serket/nn/image.py @@ -28,6 +28,8 @@ from serket.nn.custom_transform import tree_eval from serket.nn.linear import Identity from serket.nn.utils import ( + IsInstance, + Range, maybe_lazy_call, maybe_lazy_init, positive_int_cb, @@ -955,6 +957,75 @@ def spatial_ndim(self) -> int: return 2 +def posterize(image: jax.Array, bits: int) -> jax.Array: + """Reduce the number of bits for each color channel. + + Args: + image: The image to posterize. + bits: The number of bits to keep for each channel (1-8). + + Reference: + - https://github.com/tensorflow/models/blob/v2.13.1/official/vision/ops/augment.py#L859-L862 + - https://github.com/python-pillow/Pillow/blob/6651a3143621181d94cc92d36e1490721ef0b44f/src/PIL/ImageOps.py#L547 + """ + shift = 8 - bits + return jnp.left_shift(jnp.right_shift(image, shift), shift) + + +@sk.autoinit +class Posterize2D(sk.TreeClass): + """Reduce the number of bits for each color channel. + + Args: + bits: The number of bits to keep for each channel (1-8). + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> layer = sk.nn.Posterize2D(4) + >>> x = jnp.arange(1, 51).reshape(2, 5, 5) + >>> print(x) + [[[ 1 2 3 4 5] + [ 6 7 8 9 10] + [11 12 13 14 15] + [16 17 18 19 20] + [21 22 23 24 25]] + + [[26 27 28 29 30] + [31 32 33 34 35] + [36 37 38 39 40] + [41 42 43 44 45] + [46 47 48 49 50]]] + >>> print(layer(x)) + [[[ 0 0 0 0 0] + [ 0 0 0 0 0] + [ 0 0 0 0 0] + [16 16 16 16 16] + [16 16 16 16 16]] + + [[16 16 16 16 16] + [16 32 32 32 32] + [32 32 32 32 32] + [32 32 32 32 32] + [32 32 48 48 48]]] + + Reference: + - https://www.tensorflow.org/api_docs/python/tfm/vision/augment/posterize + - https://github.com/python-pillow/Pillow/blob/main/src/PIL/ImageOps.py#L547 + """ + + bits: int = sk.field(callbacks=[IsInstance(int), Range(1, 8)]) + + @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") + def __call__(self, x: jax.Array, **k) -> jax.Array: + bits = jax.lax.stop_gradient(self.bits) + return jax.vmap(posterize, in_axes=(0, None))(x, bits) + + @property + def spatial_ndim(self) -> int: + return 2 + + @tree_eval.def_eval(RandomContrast2D) @tree_eval.def_eval(RandomRotate2D) @tree_eval.def_eval(RandomHorizontalShear2D)