diff --git a/docs/API/geometric.rst b/docs/API/geometric.rst index 6a99c92..044d5f7 100644 --- a/docs/API/geometric.rst +++ b/docs/API/geometric.rst @@ -1,28 +1,32 @@ Geometric API --------------------------------- .. currentmodule:: serket.image - + .. autoclass:: HorizontalFlip2D -.. autoclass:: RandomHorizontalFlip2D - .. autoclass:: HorizontalShear2D -.. autoclass:: RandomHorizontalShear2D - .. autoclass:: HorizontalTranslate2D +.. autoclass:: RandomHorizontalFlip2D +.. autoclass:: RandomHorizontalShear2D .. autoclass:: RandomHorizontalTranslate2D -.. autoclass:: RandomPerspective2D - -.. autoclass:: Rotate2D .. autoclass:: RandomRotate2D - -.. autoclass:: VerticalShear2D +.. autoclass:: RandomVerticalFlip2D .. autoclass:: RandomVerticalShear2D - +.. autoclass:: RandomVerticalTranslate2D +.. autoclass:: Rotate2D .. autoclass:: VerticalFlip2D -.. autoclass:: RandomVerticalFlip2D - +.. autoclass:: VerticalShear2D .. autoclass:: VerticalTranslate2D -.. autoclass:: RandomVerticalTranslate2D -.. autoclass:: ElasticTransform2D -.. autoclass:: FFTElasticTransform2D \ No newline at end of file +.. autofunction:: horizontal_flip_2d +.. autofunction:: horizontal_shear_2d +.. autofunction:: horizontal_translate_2d +.. autofunction:: random_horizontal_flip_2d +.. autofunction:: random_horizontal_shear_2d +.. autofunction:: random_horizontal_translate_2d +.. autofunction:: random_rotate_2d +.. autofunction:: random_vertical_flip_2d +.. autofunction:: random_vertical_shear_2d +.. autofunction:: random_vertical_translate_2d +.. autofunction:: rotate_2d +.. autofunction:: vertical_flip_2d +.. autofunction:: vertical_shear_2d \ No newline at end of file diff --git a/serket/_src/image/geometric.py b/serket/_src/image/geometric.py index f50418e..7780dfc 100644 --- a/serket/_src/image/geometric.py +++ b/serket/_src/image/geometric.py @@ -29,6 +29,7 @@ def affine_2d(array: HWArray, matrix: HWArray) -> HWArray: + assert array.ndim == 2 h, w = array.shape center = jnp.array((h // 2, w // 2)) coords = jnp.indices((h, w)).reshape(2, -1) - center.reshape(2, 1) @@ -36,6 +37,26 @@ def affine_2d(array: HWArray, matrix: HWArray) -> HWArray: return map_coordinates(array, coords, order=1).reshape((h, w)) +def horizontal_flip_2d(image: HWArray) -> HWArray: + assert image.ndim == 2 + return jnp.flip(image, axis=1) + + +def random_horizontal_flip_2d(key: jax.Array, image: HWArray, rate: float) -> HWArray: + prop = jr.bernoulli(key, rate) + return jnp.where(prop, horizontal_flip_2d(image), image) + + +def vertical_flip_2d(image: HWArray) -> HWArray: + assert image.ndim == 2 + return jnp.flip(image, axis=0) + + +def random_vertical_flip_2d(key: jax.Array, image: HWArray, rate: float) -> HWArray: + prop = jr.bernoulli(key, rate) + return jnp.where(prop, vertical_flip_2d(image), image) + + def horizontal_shear_2d(image: HWArray, angle: float) -> HWArray: """shear rows by an angle in degrees""" shear = jnp.tan(jnp.deg2rad(angle)) @@ -92,37 +113,9 @@ def random_rotate_2d( return rotate_2d(image, angle) -def perspective_transform_2d(image: HWArray, coeffs: jax.Array) -> HWArray: - """Apply a perspective transform to an image.""" - - rows, cols = image.shape - y, x = jnp.meshgrid(jnp.arange(rows), jnp.arange(cols), indexing="ij") - a, b, c, d, e, f, g, h = coeffs - w = g * x + h * y + 1.0 - x_prime = (a * x + b * y + c) / w - y_prime = (d * x + e * y + f) / w - coords = [y_prime.ravel(), x_prime.ravel()] - return map_coordinates(image, coords, order=1).reshape(rows, cols) - - -def random_perspective_2d( - key: jax.Array, - image: HWArray, - scale: float, -) -> HWArray: - """Applies a random perspective transform to a channel-first image""" - _, _ = image.shape - a = e = 1.0 - b = d = 0.0 - c = f = 0.0 # no translation - g, h = jr.uniform(key, shape=(2,), minval=-1e-4, maxval=1e-4) * scale - coeffs = jnp.array([a, b, c, d, e, f, g, h]) - return perspective_transform_2d(image, coeffs) - - def horizontal_translate_2d(image: HWArray, shift: int) -> HWArray: """Translate an image horizontally by a pixel value.""" - _, _ = image.shape + assert image.ndim == 2 if shift > 0: return jnp.zeros_like(image).at[:, shift:].set(image[:, :-shift]) if shift < 0: @@ -132,7 +125,7 @@ def horizontal_translate_2d(image: HWArray, shift: int) -> HWArray: def vertical_translate_2d(image: HWArray, shift: int) -> HWArray: """Translate an image vertically by a pixel value.""" - _, _ = image.shape + assert image.ndim == 2 if shift > 0: return jnp.zeros_like(image).at[shift:, :].set(image[:-shift, :]) if shift < 0: @@ -152,38 +145,6 @@ def random_vertical_translate_2d(key: jax.Array, image: HWArray) -> HWArray: return vertical_translate_2d(image, shift) -def wave_transform_2d(image: HWArray, length: float, amplitude: float) -> HWArray: - """Transform an image with a sinusoidal wave.""" - _, _ = image.shape - eps = jnp.finfo(image.dtype).eps - ny, nx = jnp.indices(image.shape) - sinx = nx + amplitude * jnp.sin(ny / (length + eps)) - cosy = ny + amplitude * jnp.cos(nx / (length + eps)) - return map_coordinates(image, [cosy, sinx], order=1) - - -def random_wave_transform_2d( - key: jax.Array, - image: HWArray, - length_range: tuple[float, float], - amplitude_range: tuple[float, float], -) -> HWArray: - """Transform an image with a sinusoidal wave. - - Args: - key: a random key. - image: a 2D image to transform. - length_range: a tuple of min length and max length to randdomly choose from. - amplitude_range: a tuple of min amplitude and max amplitude to randdomly choose from. - """ - k1, k2 = jr.split(key) - l0, l1 = length_range - length = jr.uniform(k1, shape=(), minval=l0, maxval=l1) - a0, a1 = amplitude_range - amplitude = jr.uniform(k2, shape=(), minval=a0, maxval=a1) - return wave_transform_2d(image, length, amplitude) - - class Rotate2D(TreeClass): """Rotate_2d a 2D image by an angle in dgrees in CCW direction @@ -438,44 +399,6 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: spatial_ndim: int = 2 -class RandomPerspective2D(TreeClass): - """Applies a random perspective transform to a channel-first image. - - .. image:: ../_static/randomperspective2d.png - - Args: - scale: the scale of the random perspective transform. Higher scale will - lead to higher degree of perspective transform. default to 1.0. 0.0 - means no perspective transform. - - Note: - - Use :func:`tree_eval` to replace this layer with :class:`Identity` during - evaluation. - - >>> import serket as sk - >>> import jax.numpy as jnp - >>> x = jnp.arange(1, 17).reshape(1, 4, 4) - >>> layer = sk.image.RandomPerspective2D(100) - >>> eval_layer = sk.tree_eval(layer) - >>> print(eval_layer(x)) - [[[ 1 2 3 4] - [ 5 6 7 8] - [ 9 10 11 12] - [13 14 15 16]]] - """ - - def __init__(self, scale: float = 1.0): - self.scale = scale - - @ft.partial(validate_spatial_ndim, argnum=0) - def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: - scale = jax.lax.stop_gradient(self.scale) - args = (key, image, scale) - return jax.vmap(random_perspective_2d, in_axes=(None, 0, None))(*args) - - spatial_ndim: int = 2 - - @autoinit class HorizontalTranslate2D(TreeClass): """Translate an image horizontally by a pixel value. @@ -638,7 +561,7 @@ class HorizontalFlip2D(TreeClass): @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: - return jax.vmap(lambda x: jnp.flip(x, axis=1))(image) + return jax.vmap(horizontal_flip_2d)(image) spatial_ndim: int = 2 @@ -673,9 +596,9 @@ class RandomHorizontalFlip2D(TreeClass): @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: + in_axes = (None, 0, None) rate = jax.lax.stop_gradient(self.rate) - prop = jax.random.bernoulli(key, rate) - return jnp.where(prop, jax.vmap(lambda x: jnp.flip(x, axis=1))(image), image) + return jax.vmap(random_horizontal_flip_2d, in_axes=in_axes)(key, image, rate) spatial_ndim: int = 2 @@ -705,7 +628,7 @@ class VerticalFlip2D(TreeClass): @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: - return jax.vmap(lambda x: jnp.flip(x, axis=0))(image) + return jax.vmap(vertical_flip_2d)(image) spatial_ndim: int = 2 @@ -740,9 +663,9 @@ class RandomVerticalFlip2D(TreeClass): @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: + in_axes = (None, 0, None) rate = jax.lax.stop_gradient(self.rate) - prop = jax.random.bernoulli(key, rate) - return jnp.where(prop, jax.vmap(lambda x: jnp.flip(x, axis=0))(image), image) + return jax.vmap(random_vertical_flip_2d, in_axes=in_axes)(key, image, rate) spatial_ndim: int = 2 @@ -752,7 +675,6 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: @tree_eval.def_eval(RandomVerticalFlip2D) @tree_eval.def_eval(RandomHorizontalShear2D) @tree_eval.def_eval(RandomVerticalShear2D) -@tree_eval.def_eval(RandomPerspective2D) @tree_eval.def_eval(RandomHorizontalTranslate2D) @tree_eval.def_eval(RandomVerticalTranslate2D) def _(_): diff --git a/serket/image/__init__.py b/serket/image/__init__.py index 8639ba9..15ecee5 100644 --- a/serket/image/__init__.py +++ b/serket/image/__init__.py @@ -110,7 +110,6 @@ RandomHorizontalFlip2D, RandomHorizontalShear2D, RandomHorizontalTranslate2D, - RandomPerspective2D, RandomRotate2D, RandomVerticalFlip2D, RandomVerticalShear2D, @@ -119,6 +118,19 @@ VerticalFlip2D, VerticalShear2D, VerticalTranslate2D, + horizontal_flip_2d, + horizontal_shear_2d, + horizontal_translate_2d, + random_horizontal_flip_2d, + random_horizontal_shear_2d, + random_horizontal_translate_2d, + random_rotate_2d, + random_vertical_flip_2d, + random_vertical_shear_2d, + random_vertical_translate_2d, + rotate_2d, + vertical_flip_2d, + vertical_shear_2d, ) __all__ = [ @@ -200,14 +212,12 @@ "filter_2d", "fft_filter_2d", # geometric - "CenterCrop2D", "HorizontalFlip2D", "HorizontalShear2D", "HorizontalTranslate2D", "RandomHorizontalFlip2D", "RandomHorizontalShear2D", "RandomHorizontalTranslate2D", - "RandomPerspective2D", "RandomRotate2D", "RandomVerticalFlip2D", "RandomVerticalShear2D", @@ -216,6 +226,19 @@ "VerticalFlip2D", "VerticalShear2D", "VerticalTranslate2D", + "horizontal_flip_2d", + "horizontal_shear_2d", + "horizontal_translate_2d", + "random_horizontal_flip_2d", + "random_horizontal_shear_2d", + "random_horizontal_translate_2d", + "random_rotate_2d", + "random_vertical_flip_2d", + "random_vertical_shear_2d", + "random_vertical_translate_2d", + "rotate_2d", + "vertical_flip_2d", + "vertical_shear_2d", # color "GrayscaleToRGB2D", "HSVToRGB2D",