Skip to content

Commit

Permalink
organize geometric
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 10, 2024
1 parent a129c7b commit e76957d
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 126 deletions.
36 changes: 20 additions & 16 deletions docs/API/geometric.rst
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
Geometric API
---------------------------------
.. currentmodule:: serket.image

.. autoclass:: HorizontalFlip2D
.. autoclass:: RandomHorizontalFlip2D

.. autoclass:: HorizontalShear2D
.. autoclass:: RandomHorizontalShear2D

.. autoclass:: HorizontalTranslate2D
.. autoclass:: RandomHorizontalFlip2D
.. autoclass:: RandomHorizontalShear2D
.. autoclass:: RandomHorizontalTranslate2D
.. autoclass:: RandomPerspective2D

.. autoclass:: Rotate2D
.. autoclass:: RandomRotate2D

.. autoclass:: VerticalShear2D
.. autoclass:: RandomVerticalFlip2D
.. autoclass:: RandomVerticalShear2D

.. autoclass:: RandomVerticalTranslate2D
.. autoclass:: Rotate2D
.. autoclass:: VerticalFlip2D
.. autoclass:: RandomVerticalFlip2D

.. autoclass:: VerticalShear2D
.. autoclass:: VerticalTranslate2D
.. autoclass:: RandomVerticalTranslate2D

.. autoclass:: ElasticTransform2D
.. autoclass:: FFTElasticTransform2D
.. autofunction:: horizontal_flip_2d
.. autofunction:: horizontal_shear_2d
.. autofunction:: horizontal_translate_2d
.. autofunction:: random_horizontal_flip_2d
.. autofunction:: random_horizontal_shear_2d
.. autofunction:: random_horizontal_translate_2d
.. autofunction:: random_rotate_2d
.. autofunction:: random_vertical_flip_2d
.. autofunction:: random_vertical_shear_2d
.. autofunction:: random_vertical_translate_2d
.. autofunction:: rotate_2d
.. autofunction:: vertical_flip_2d
.. autofunction:: vertical_shear_2d
136 changes: 29 additions & 107 deletions serket/_src/image/geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,34 @@


def affine_2d(array: HWArray, matrix: HWArray) -> HWArray:
assert array.ndim == 2
h, w = array.shape
center = jnp.array((h // 2, w // 2))
coords = jnp.indices((h, w)).reshape(2, -1) - center.reshape(2, 1)
coords = matrix @ coords + center.reshape(2, 1)
return map_coordinates(array, coords, order=1).reshape((h, w))


def horizontal_flip_2d(image: HWArray) -> HWArray:
assert image.ndim == 2
return jnp.flip(image, axis=1)


def random_horizontal_flip_2d(key: jax.Array, image: HWArray, rate: float) -> HWArray:
prop = jr.bernoulli(key, rate)
return jnp.where(prop, horizontal_flip_2d(image), image)


def vertical_flip_2d(image: HWArray) -> HWArray:
assert image.ndim == 2
return jnp.flip(image, axis=0)


def random_vertical_flip_2d(key: jax.Array, image: HWArray, rate: float) -> HWArray:
prop = jr.bernoulli(key, rate)
return jnp.where(prop, vertical_flip_2d(image), image)


def horizontal_shear_2d(image: HWArray, angle: float) -> HWArray:
"""shear rows by an angle in degrees"""
shear = jnp.tan(jnp.deg2rad(angle))
Expand Down Expand Up @@ -92,37 +113,9 @@ def random_rotate_2d(
return rotate_2d(image, angle)


def perspective_transform_2d(image: HWArray, coeffs: jax.Array) -> HWArray:
"""Apply a perspective transform to an image."""

rows, cols = image.shape
y, x = jnp.meshgrid(jnp.arange(rows), jnp.arange(cols), indexing="ij")
a, b, c, d, e, f, g, h = coeffs
w = g * x + h * y + 1.0
x_prime = (a * x + b * y + c) / w
y_prime = (d * x + e * y + f) / w
coords = [y_prime.ravel(), x_prime.ravel()]
return map_coordinates(image, coords, order=1).reshape(rows, cols)


def random_perspective_2d(
key: jax.Array,
image: HWArray,
scale: float,
) -> HWArray:
"""Applies a random perspective transform to a channel-first image"""
_, _ = image.shape
a = e = 1.0
b = d = 0.0
c = f = 0.0 # no translation
g, h = jr.uniform(key, shape=(2,), minval=-1e-4, maxval=1e-4) * scale
coeffs = jnp.array([a, b, c, d, e, f, g, h])
return perspective_transform_2d(image, coeffs)


def horizontal_translate_2d(image: HWArray, shift: int) -> HWArray:
"""Translate an image horizontally by a pixel value."""
_, _ = image.shape
assert image.ndim == 2
if shift > 0:
return jnp.zeros_like(image).at[:, shift:].set(image[:, :-shift])
if shift < 0:
Expand All @@ -132,7 +125,7 @@ def horizontal_translate_2d(image: HWArray, shift: int) -> HWArray:

def vertical_translate_2d(image: HWArray, shift: int) -> HWArray:
"""Translate an image vertically by a pixel value."""
_, _ = image.shape
assert image.ndim == 2
if shift > 0:
return jnp.zeros_like(image).at[shift:, :].set(image[:-shift, :])
if shift < 0:
Expand All @@ -152,38 +145,6 @@ def random_vertical_translate_2d(key: jax.Array, image: HWArray) -> HWArray:
return vertical_translate_2d(image, shift)


def wave_transform_2d(image: HWArray, length: float, amplitude: float) -> HWArray:
"""Transform an image with a sinusoidal wave."""
_, _ = image.shape
eps = jnp.finfo(image.dtype).eps
ny, nx = jnp.indices(image.shape)
sinx = nx + amplitude * jnp.sin(ny / (length + eps))
cosy = ny + amplitude * jnp.cos(nx / (length + eps))
return map_coordinates(image, [cosy, sinx], order=1)


def random_wave_transform_2d(
key: jax.Array,
image: HWArray,
length_range: tuple[float, float],
amplitude_range: tuple[float, float],
) -> HWArray:
"""Transform an image with a sinusoidal wave.
Args:
key: a random key.
image: a 2D image to transform.
length_range: a tuple of min length and max length to randdomly choose from.
amplitude_range: a tuple of min amplitude and max amplitude to randdomly choose from.
"""
k1, k2 = jr.split(key)
l0, l1 = length_range
length = jr.uniform(k1, shape=(), minval=l0, maxval=l1)
a0, a1 = amplitude_range
amplitude = jr.uniform(k2, shape=(), minval=a0, maxval=a1)
return wave_transform_2d(image, length, amplitude)


class Rotate2D(TreeClass):
"""Rotate_2d a 2D image by an angle in dgrees in CCW direction
Expand Down Expand Up @@ -438,44 +399,6 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
spatial_ndim: int = 2


class RandomPerspective2D(TreeClass):
"""Applies a random perspective transform to a channel-first image.
.. image:: ../_static/randomperspective2d.png
Args:
scale: the scale of the random perspective transform. Higher scale will
lead to higher degree of perspective transform. default to 1.0. 0.0
means no perspective transform.
Note:
- Use :func:`tree_eval` to replace this layer with :class:`Identity` during
evaluation.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 17).reshape(1, 4, 4)
>>> layer = sk.image.RandomPerspective2D(100)
>>> eval_layer = sk.tree_eval(layer)
>>> print(eval_layer(x))
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
"""

def __init__(self, scale: float = 1.0):
self.scale = scale

@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
scale = jax.lax.stop_gradient(self.scale)
args = (key, image, scale)
return jax.vmap(random_perspective_2d, in_axes=(None, 0, None))(*args)

spatial_ndim: int = 2


@autoinit
class HorizontalTranslate2D(TreeClass):
"""Translate an image horizontally by a pixel value.
Expand Down Expand Up @@ -638,7 +561,7 @@ class HorizontalFlip2D(TreeClass):

@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return jax.vmap(lambda x: jnp.flip(x, axis=1))(image)
return jax.vmap(horizontal_flip_2d)(image)

spatial_ndim: int = 2

Expand Down Expand Up @@ -673,9 +596,9 @@ class RandomHorizontalFlip2D(TreeClass):

@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
in_axes = (None, 0, None)
rate = jax.lax.stop_gradient(self.rate)
prop = jax.random.bernoulli(key, rate)
return jnp.where(prop, jax.vmap(lambda x: jnp.flip(x, axis=1))(image), image)
return jax.vmap(random_horizontal_flip_2d, in_axes=in_axes)(key, image, rate)

spatial_ndim: int = 2

Expand Down Expand Up @@ -705,7 +628,7 @@ class VerticalFlip2D(TreeClass):

@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return jax.vmap(lambda x: jnp.flip(x, axis=0))(image)
return jax.vmap(vertical_flip_2d)(image)

spatial_ndim: int = 2

Expand Down Expand Up @@ -740,9 +663,9 @@ class RandomVerticalFlip2D(TreeClass):

@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
in_axes = (None, 0, None)
rate = jax.lax.stop_gradient(self.rate)
prop = jax.random.bernoulli(key, rate)
return jnp.where(prop, jax.vmap(lambda x: jnp.flip(x, axis=0))(image), image)
return jax.vmap(random_vertical_flip_2d, in_axes=in_axes)(key, image, rate)

spatial_ndim: int = 2

Expand All @@ -752,7 +675,6 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
@tree_eval.def_eval(RandomVerticalFlip2D)
@tree_eval.def_eval(RandomHorizontalShear2D)
@tree_eval.def_eval(RandomVerticalShear2D)
@tree_eval.def_eval(RandomPerspective2D)
@tree_eval.def_eval(RandomHorizontalTranslate2D)
@tree_eval.def_eval(RandomVerticalTranslate2D)
def _(_):
Expand Down
29 changes: 26 additions & 3 deletions serket/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@
RandomHorizontalFlip2D,
RandomHorizontalShear2D,
RandomHorizontalTranslate2D,
RandomPerspective2D,
RandomRotate2D,
RandomVerticalFlip2D,
RandomVerticalShear2D,
Expand All @@ -119,6 +118,19 @@
VerticalFlip2D,
VerticalShear2D,
VerticalTranslate2D,
horizontal_flip_2d,
horizontal_shear_2d,
horizontal_translate_2d,
random_horizontal_flip_2d,
random_horizontal_shear_2d,
random_horizontal_translate_2d,
random_rotate_2d,
random_vertical_flip_2d,
random_vertical_shear_2d,
random_vertical_translate_2d,
rotate_2d,
vertical_flip_2d,
vertical_shear_2d,
)

__all__ = [
Expand Down Expand Up @@ -200,14 +212,12 @@
"filter_2d",
"fft_filter_2d",
# geometric
"CenterCrop2D",
"HorizontalFlip2D",
"HorizontalShear2D",
"HorizontalTranslate2D",
"RandomHorizontalFlip2D",
"RandomHorizontalShear2D",
"RandomHorizontalTranslate2D",
"RandomPerspective2D",
"RandomRotate2D",
"RandomVerticalFlip2D",
"RandomVerticalShear2D",
Expand All @@ -216,6 +226,19 @@
"VerticalFlip2D",
"VerticalShear2D",
"VerticalTranslate2D",
"horizontal_flip_2d",
"horizontal_shear_2d",
"horizontal_translate_2d",
"random_horizontal_flip_2d",
"random_horizontal_shear_2d",
"random_horizontal_translate_2d",
"random_rotate_2d",
"random_vertical_flip_2d",
"random_vertical_shear_2d",
"random_vertical_translate_2d",
"rotate_2d",
"vertical_flip_2d",
"vertical_shear_2d",
# color
"GrayscaleToRGB2D",
"HSVToRGB2D",
Expand Down

0 comments on commit e76957d

Please sign in to comment.