Skip to content

Commit

Permalink
add Posterize2D
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 10, 2023
1 parent c475c64 commit d89d01c
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANEGLOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
- `VerticalShear2D`
- `Pixelate2D`
- `Solarize2D`
- `Posterize2D`

### Deprecations

Expand Down
2 changes: 2 additions & 0 deletions serket/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
HorizontalShear2D,
Pixelate2D,
PixelShuffle2D,
Posterize2D,
RandomContrast2D,
RandomHorizontalShear2D,
RandomPerspective2D,
Expand Down Expand Up @@ -274,6 +275,7 @@
"HorizontalShear2D",
"Pixelate2D",
"PixelShuffle2D",
"Posterize2D",
"RandomContrast2D",
"RandomHorizontalShear2D",
"RandomPerspective2D",
Expand Down
71 changes: 71 additions & 0 deletions serket/nn/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from serket.nn.custom_transform import tree_eval
from serket.nn.linear import Identity
from serket.nn.utils import (
IsInstance,
Range,
maybe_lazy_call,
maybe_lazy_init,
positive_int_cb,
Expand Down Expand Up @@ -955,6 +957,75 @@ def spatial_ndim(self) -> int:
return 2


def posterize(image: jax.Array, bits: int) -> jax.Array:
"""Reduce the number of bits for each color channel.
Args:
image: The image to posterize.
bits: The number of bits to keep for each channel (1-8).
Reference:
- https://github.com/tensorflow/models/blob/v2.13.1/official/vision/ops/augment.py#L859-L862
- https://github.com/python-pillow/Pillow/blob/6651a3143621181d94cc92d36e1490721ef0b44f/src/PIL/ImageOps.py#L547
"""
shift = 8 - bits
return jnp.left_shift(jnp.right_shift(image, shift), shift)


@sk.autoinit
class Posterize2D(sk.TreeClass):
"""Reduce the number of bits for each color channel.
Args:
bits: The number of bits to keep for each channel (1-8).
Example:
>>> import jax.numpy as jnp
>>> import serket as sk
>>> layer = sk.nn.Posterize2D(4)
>>> x = jnp.arange(1, 51).reshape(2, 5, 5)
>>> print(x)
[[[ 1 2 3 4 5]
[ 6 7 8 9 10]
[11 12 13 14 15]
[16 17 18 19 20]
[21 22 23 24 25]]
[[26 27 28 29 30]
[31 32 33 34 35]
[36 37 38 39 40]
[41 42 43 44 45]
[46 47 48 49 50]]]
>>> print(layer(x))
[[[ 0 0 0 0 0]
[ 0 0 0 0 0]
[ 0 0 0 0 0]
[16 16 16 16 16]
[16 16 16 16 16]]
[[16 16 16 16 16]
[16 32 32 32 32]
[32 32 32 32 32]
[32 32 32 32 32]
[32 32 48 48 48]]]
Reference:
- https://www.tensorflow.org/api_docs/python/tfm/vision/augment/posterize
- https://github.com/python-pillow/Pillow/blob/main/src/PIL/ImageOps.py#L547
"""

bits: int = sk.field(callbacks=[IsInstance(int), Range(1, 8)])

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, **k) -> jax.Array:
bits = jax.lax.stop_gradient(self.bits)
return jax.vmap(posterize, in_axes=(0, None))(x, bits)

@property
def spatial_ndim(self) -> int:
return 2


@tree_eval.def_eval(RandomContrast2D)
@tree_eval.def_eval(RandomRotate2D)
@tree_eval.def_eval(RandomHorizontalShear2D)
Expand Down

0 comments on commit d89d01c

Please sign in to comment.