diff --git a/dm_pix/__init__.py b/dm_pix/__init__.py index 4cd3e8b..d9b6330 100644 --- a/dm_pix/__init__.py +++ b/dm_pix/__init__.py @@ -28,6 +28,7 @@ adjust_gamma = augment.adjust_gamma adjust_hue = augment.adjust_hue adjust_saturation = augment.adjust_saturation +affine_transform = augment.affine_transform flip_left_right = augment.flip_left_right flip_up_down = augment.flip_up_down gaussian_blur = augment.gaussian_blur @@ -74,6 +75,7 @@ "adjust_gamma", "adjust_hue", "adjust_saturation", + "affine_transform", "depth_to_space", "extract_patches", "flat_nd_linear_interpolate", diff --git a/dm_pix/_src/augment.py b/dm_pix/_src/augment.py index 66c0135..2994b1b 100644 --- a/dm_pix/_src/augment.py +++ b/dm_pix/_src/augment.py @@ -19,10 +19,12 @@ that of TensorFlow. """ -from typing import Sequence, Tuple +import functools +from typing import Sequence, Tuple, Union import chex from dm_pix._src import color_conversion +from dm_pix._src import interpolation import jax import jax.numpy as jnp @@ -297,6 +299,97 @@ def solarize(image: chex.Array, threshold: chex.Numeric) -> chex.Array: return jnp.where(image < threshold, image, 1. - image) +def affine_transform( + image: chex.Array, + matrix: chex.Array, + *, + offset: Union[chex.Array, chex.Numeric] = 0., + order: int = 1, + mode: str = "nearest", + cval: float = 0.0, +) -> chex.Array: + """Applies an affine transformation given by matrix. + + Given an output image pixel index vector o, the pixel value is determined from + the input image at position jnp.dot(matrix, o) + offset. + + This does 'pull' (or 'backward') resampling, transforming the output space to + the input to locate data. Affine transformations are often described in the + 'push' (or 'forward') direction, transforming input to output. If you have a + matrix for the 'push' transformation, use its inverse (jax.numpy.linalg.inv) + in this function. + + Args: + image: a JAX array representing an image. Assumes that the image is + either HWC or CHW. + matrix: the inverse coordinate transformation matrix, mapping output + coordinates to input coordinates. If ndim is the number of dimensions of + input, the given matrix must have one of the following shapes: + - (ndim, ndim): the linear transformation matrix for each output + coordinate. + - (ndim,): assume that the 2-D transformation matrix is diagonal, with the + diagonal specified by the given value. + - (ndim + 1, ndim + 1): assume that the transformation is specified using + homogeneous coordinates [1]. In this case, any value passed to offset is + ignored. + - (ndim, ndim + 1): as above, but the bottom row of a homogeneous + transformation matrix is always [0, 0, 0, 1], and may be omitted. + offset: the offset into the array where the transform is applied. If a + float, offset is the same for each axis. If an array, offset should + contain one value for each axis. + order: the order of the spline interpolation, default is 1. The order has + to be in the range 0-1. Note that PIX interpolation will only be used for + order=1, for other values we use `jax.scipy.ndimage.map_coordinates`. + mode: the mode parameter determines how the input array is extended beyond + its boundaries. Default is "nearest", using PIX + `flat_nd_linear_interpolate` function, which is very fast on accelerators + (especially on TPUs). For all other modes, 'constant', 'wrap', 'mirror' + and 'reflect', we rely on `jax.scipy.ndimage.map_coordinates`, which + however is slow on accelerators, so use it with care. + cval: value to fill past edges of input if mode is 'constant'. Default is + 0.0. + + Returns: + The input image transformed by the given matrix. + """ + chex.assert_rank(image, 3) + chex.assert_rank(matrix, {1, 2}) + chex.assert_rank(offset, {0, 1}) + + if matrix.ndim == 1: + matrix = jnp.diag(matrix) + + if matrix.shape not in [(3, 3), (4, 4), (3, 4)]: + error_msg = ( + "Expected matrix shape must be one of (ndim, ndim), (ndim,)" + "(ndim + 1, ndim + 1) or (ndim, ndim + 1) being ndim the image.ndim. " + f"The affine matrix provided has shape {matrix.shape}.") + raise ValueError(error_msg) + + meshgrid = jnp.meshgrid(*[jnp.arange(size) for size in image.shape], + indexing="ij") + indices = jnp.concatenate( + [jnp.expand_dims(x, axis=-1) for x in meshgrid], axis=-1) + + if matrix.shape == (4, 4) or matrix.shape == (3, 4): + offset = matrix[:image.ndim, image.ndim] + matrix = matrix[:image.ndim, :image.ndim] + + coordinates = indices @ matrix.T + coordinates = jnp.moveaxis(coordinates, source=-1, destination=0) + + # Alter coordinates to account for offset. + offset = jnp.full((3,), fill_value=offset) + coordinates += jnp.reshape(a=offset, newshape=(*offset.shape, 1, 1, 1)) + + if mode == "nearest" and order == 1: + interpolate_function = interpolation.flat_nd_linear_interpolate + else: + interpolate_function = functools.partial( + jax.scipy.ndimage.map_coordinates, mode=mode, order=order, cval=cval) + return interpolate_function(image, coordinates) + + def random_flip_left_right( key: chex.PRNGKey, image: chex.Array, diff --git a/dm_pix/_src/augment_test.py b/dm_pix/_src/augment_test.py index 2291e06..fe3800b 100644 --- a/dm_pix/_src/augment_test.py +++ b/dm_pix/_src/augment_test.py @@ -19,7 +19,9 @@ from absl.testing import parameterized from dm_pix._src import augment import jax +import jax.numpy as jnp import numpy as np +import scipy import tensorflow as tf _IMG_SHAPE = (131, 111, 3) @@ -33,10 +35,11 @@ class _ImageAugmentationTest(parameterized.TestCase): """Runs tests for the various augments with the correct arguments.""" - def _test_fn_with_random_arg(self, images_list, jax_fn, tf_fn, **kw_range): + def _test_fn_with_random_arg( + self, images_list, jax_fn, reference_fn, **kw_range): pass - def _test_fn(self, images_list, jax_fn, tf_fn): + def _test_fn(self, images_list, jax_fn, reference_fn): pass def assertAllCloseTolerant(self, x, y): @@ -51,14 +54,14 @@ def test_adjust_brightness(self, images_list): self._test_fn_with_random_arg( images_list, jax_fn=augment.adjust_brightness, - tf_fn=tf.image.adjust_brightness, + reference_fn=tf.image.adjust_brightness, delta=(-0.5, 0.5)) key = jax.random.PRNGKey(0) self._test_fn_with_random_arg( images_list, jax_fn=functools.partial(augment.random_brightness, key), - tf_fn=None, + reference_fn=None, max_delta=(0, 0.5)) @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE), @@ -67,13 +70,13 @@ def test_adjust_contrast(self, images_list): self._test_fn_with_random_arg( images_list, jax_fn=augment.adjust_contrast, - tf_fn=tf.image.adjust_contrast, + reference_fn=tf.image.adjust_contrast, factor=(0.5, 1.5)) key = jax.random.PRNGKey(0) self._test_fn_with_random_arg( images_list, jax_fn=functools.partial(augment.random_contrast, key, upper=1), - tf_fn=None, + reference_fn=None, lower=(0, 0.9)) # Doesn't make sense outside of [0, 1]. @@ -82,7 +85,7 @@ def test_adjust_gamma(self, images_list): self._test_fn_with_random_arg( images_list, jax_fn=augment.adjust_gamma, - tf_fn=tf.image.adjust_gamma, + reference_fn=tf.image.adjust_gamma, gamma=(0.5, 1.5)) @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE), @@ -104,13 +107,13 @@ def perturb(rgb): self._test_fn_with_random_arg( images_list, jax_fn=augment.adjust_saturation, - tf_fn=tf.image.adjust_saturation, + reference_fn=tf.image.adjust_saturation, factor=(0.5, 1.5)) key = jax.random.PRNGKey(0) self._test_fn_with_random_arg( images_list, jax_fn=functools.partial(augment.random_saturation, key, upper=1), - tf_fn=None, + reference_fn=None, lower=(0, 0.9)) # CPU TF uses a different hue adjustment method outside of the [0, 1] range. @@ -121,13 +124,13 @@ def test_adjust_hue(self, images_list): self._test_fn_with_random_arg( images_list, jax_fn=augment.adjust_hue, - tf_fn=tf.image.adjust_hue, + reference_fn=tf.image.adjust_hue, delta=(-0.5, 0.5)) key = jax.random.PRNGKey(0) self._test_fn_with_random_arg( images_list, jax_fn=functools.partial(augment.random_hue, key), - tf_fn=None, + reference_fn=None, max_delta=(0, 0.5)) @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE), @@ -136,15 +139,15 @@ def test_rot90(self, images_list): self._test_fn( images_list, jax_fn=lambda img: augment.rot90(img, k=1), - tf_fn=lambda img: tf.image.rot90(img, k=1)) + reference_fn=lambda img: tf.image.rot90(img, k=1)) self._test_fn( images_list, jax_fn=lambda img: augment.rot90(img, k=2), - tf_fn=lambda img: tf.image.rot90(img, k=2)) + reference_fn=lambda img: tf.image.rot90(img, k=2)) self._test_fn( images_list, jax_fn=lambda img: augment.rot90(img, k=3), - tf_fn=lambda img: tf.image.rot90(img, k=3)) + reference_fn=lambda img: tf.image.rot90(img, k=3)) # The functions below don't have a TF equivalent to compare to, we just check # that they run. @@ -154,76 +157,167 @@ def test_flip(self, images_list): self._test_fn( images_list, jax_fn=augment.flip_left_right, - tf_fn=tf.image.flip_left_right) + reference_fn=tf.image.flip_left_right) self._test_fn( - images_list, jax_fn=augment.flip_up_down, tf_fn=tf.image.flip_up_down) + images_list, + jax_fn=augment.flip_up_down, + reference_fn=tf.image.flip_up_down) key = jax.random.PRNGKey(0) self._test_fn( images_list, jax_fn=functools.partial(augment.random_flip_left_right, key), - tf_fn=None) + reference_fn=None) self._test_fn( images_list, jax_fn=functools.partial(augment.random_flip_up_down, key), - tf_fn=None) + reference_fn=None) self._test_fn_with_random_arg( images_list, jax_fn=functools.partial(augment.random_flip_left_right, key), - tf_fn=None, + reference_fn=None, probability=(0., 1.)) self._test_fn_with_random_arg( images_list, jax_fn=functools.partial(augment.random_flip_up_down, key), - tf_fn=None, + reference_fn=None, probability=(0., 1.)) + # Due to a bug in scipy we cannot test all available modes, refer to this + # issue for more information: https://github.com/google/jax/issues/11097 + @parameterized.named_parameters( + ("in_range_nearest_0", _RAND_FLOATS_IN_RANGE, "nearest", 0), + ("in_range_nearest_1", _RAND_FLOATS_IN_RANGE, "nearest", 1), + ("in_range_mirror_1", _RAND_FLOATS_IN_RANGE, "mirror", 1), + ("out_of_range_nearest_0", _RAND_FLOATS_OUT_OF_RANGE, "nearest", 0), + ("out_of_range_nearest_1", _RAND_FLOATS_OUT_OF_RANGE, "nearest", 1), + ("out_of_range_mirror_1", _RAND_FLOATS_OUT_OF_RANGE, "mirror", 1), + ) + def test_affine_transform(self, images_list, mode, order): + # (ndim, ndim) no offset + self._test_fn( + images_list, + jax_fn=functools.partial( + augment.affine_transform, matrix=np.eye(3), mode=mode, order=order), + reference_fn=functools.partial( + scipy.ndimage.affine_transform, + matrix=np.eye(3), + order=order, + mode=mode)) + + # (ndim, ndim) with offset + matrix = jnp.array([[-0.5, 0.2, 0], [0.8, 0.5, 0], [0, 0, 1]]) + offset = jnp.array([40., 32, 0]) + self._test_fn( + images_list, + jax_fn=functools.partial( + augment.affine_transform, + matrix=matrix, + mode=mode, + offset=offset, + order=order), + reference_fn=functools.partial( + scipy.ndimage.affine_transform, + matrix=matrix, + offset=offset, + order=order, + mode=mode)) + + # (ndim + 1, ndim + 1) + matrix = jnp.array( + [[0.4, 0.2, 0, -10], [0.2, -0.5, 0, 5], [0, 0, 1, 0], [0, 0, 0, 1]]) + self._test_fn( + images_list, + jax_fn=functools.partial( + augment.affine_transform, matrix=matrix, mode=mode, order=order), + reference_fn=functools.partial( + scipy.ndimage.affine_transform, + matrix=matrix, + order=order, + mode=mode)) + + # (ndim, ndim + 1) + matrix = jnp.array([[0.4, 0.2, 0, -10], [0.2, -0.5, 0, 5], [0, 0, 1, 0]]) + self._test_fn( + images_list, + jax_fn=functools.partial( + augment.affine_transform, matrix=matrix, mode=mode, order=order), + reference_fn=functools.partial( + scipy.ndimage.affine_transform, + matrix=matrix, + order=order, + mode=mode)) + + # (ndim,) + matrix = jnp.array([0.4, 0.2, 1]) + self._test_fn( + images_list, + jax_fn=functools.partial( + augment.affine_transform, matrix=matrix, mode=mode, order=order), + reference_fn=functools.partial( + scipy.ndimage.affine_transform, + matrix=matrix, + order=order, + mode=mode)) + @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE), ("out_of_range", _RAND_FLOATS_OUT_OF_RANGE)) def test_solarize(self, images_list): self._test_fn_with_random_arg( - images_list, jax_fn=augment.solarize, tf_fn=None, threshold=(0., 1.)) + images_list, + jax_fn=augment.solarize, + reference_fn=None, + threshold=(0., 1.)) @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE), ("out_of_range", _RAND_FLOATS_OUT_OF_RANGE)) def test_gaussian_blur(self, images_list): blur_fn = functools.partial(augment.gaussian_blur, kernel_size=_KERNEL_SIZE) self._test_fn_with_random_arg( - images_list, jax_fn=blur_fn, tf_fn=None, sigma=(0.1, 2.0)) + images_list, + jax_fn=blur_fn, + reference_fn=None, + sigma=(0.1, 2.0)) @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE), ("out_of_range", _RAND_FLOATS_OUT_OF_RANGE)) def test_random_crop(self, images_list): key = jax.random.PRNGKey(43) crop_fn = lambda img: augment.random_crop(key, img, (100, 100, 3)) - self._test_fn(images_list, jax_fn=crop_fn, tf_fn=None) + self._test_fn(images_list, jax_fn=crop_fn, reference_fn=None) -class TestMatchTensorflow(_ImageAugmentationTest): +class TestMatchReference(_ImageAugmentationTest): - def _test_fn_with_random_arg(self, images_list, jax_fn, tf_fn, **kw_range): - if tf_fn is None: + def _test_fn_with_random_arg( + self, images_list, jax_fn, reference_fn, **kw_range): + if reference_fn is None: return assert len(kw_range) == 1 kw_name, (random_min, random_max) = list(kw_range.items())[0] for image_rgb in images_list: argument = np.random.uniform(random_min, random_max, size=()) adjusted_jax = jax_fn(image_rgb, **{kw_name: argument}) - adjusted_tf = tf_fn(image_rgb, argument).numpy() - self.assertAllCloseTolerant(adjusted_jax, adjusted_tf) + adjusted_reference = reference_fn(image_rgb, argument) + if hasattr(adjusted_reference, "numpy"): + adjusted_reference = adjusted_reference.numpy() + self.assertAllCloseTolerant(adjusted_jax, adjusted_reference) - def _test_fn(self, images_list, jax_fn, tf_fn): - if tf_fn is None: + def _test_fn(self, images_list, jax_fn, reference_fn): + if reference_fn is None: return for image_rgb in images_list: adjusted_jax = jax_fn(image_rgb) - adjusted_tf = tf_fn(image_rgb).numpy() - self.assertAllCloseTolerant(adjusted_jax, adjusted_tf) + adjusted_reference = reference_fn(image_rgb) + if hasattr(adjusted_reference, "numpy"): + adjusted_reference = adjusted_reference.numpy() + self.assertAllCloseTolerant(adjusted_jax, adjusted_reference) class TestVmap(_ImageAugmentationTest): - def _test_fn_with_random_arg(self, images_list, jax_fn, tf_fn, **kw_range): - del tf_fn # unused. + def _test_fn_with_random_arg( + self, images_list, jax_fn, reference_fn, **kw_range): + del reference_fn # unused. assert len(kw_range) == 1 kw_name, (random_min, random_max) = list(kw_range.items())[0] arguments = [ @@ -241,8 +335,8 @@ def _test_fn_with_random_arg(self, images_list, jax_fn, tf_fn, **kw_range): adjusted_jax = jax_fn(image_rgb, **{kw_name: argument}) self.assertAllCloseTolerant(adjusted_jax, adjusted_vmap) - def _test_fn(self, images_list, jax_fn, tf_fn): - del tf_fn # unused. + def _test_fn(self, images_list, jax_fn, reference_fn): + del reference_fn # unused. fn_vmap = jax.vmap(jax_fn) outputs_vmaped = list(fn_vmap(np.stack(images_list, axis=0))) assert len(images_list) == len(outputs_vmaped) @@ -253,8 +347,9 @@ def _test_fn(self, images_list, jax_fn, tf_fn): class TestJit(_ImageAugmentationTest): - def _test_fn_with_random_arg(self, images_list, jax_fn, tf_fn, **kw_range): - del tf_fn # unused. + def _test_fn_with_random_arg( + self, images_list, jax_fn, reference_fn, **kw_range): + del reference_fn # unused. assert len(kw_range) == 1 kw_name, (random_min, random_max) = list(kw_range.items())[0] jax_fn_jitted = jax.jit(jax_fn) @@ -264,8 +359,8 @@ def _test_fn_with_random_arg(self, images_list, jax_fn, tf_fn, **kw_range): adjusted_jit = jax_fn_jitted(image_rgb, **{kw_name: argument}) self.assertAllCloseTolerant(adjusted_jax, adjusted_jit) - def _test_fn(self, images_list, jax_fn, tf_fn): - del tf_fn # unused. + def _test_fn(self, images_list, jax_fn, reference_fn): + del reference_fn # unused. jax_fn_jitted = jax.jit(jax_fn) for image_rgb in images_list: adjusted_jax = jax_fn(image_rgb) @@ -274,4 +369,5 @@ def _test_fn(self, images_list, jax_fn, tf_fn): if __name__ == "__main__": + jax.config.update("jax_default_matmul_precision", "float32") absltest.main() diff --git a/docs/api.rst b/docs/api.rst index 3ee4b68..0e81187 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -10,6 +10,7 @@ Augmentations adjust_gamma adjust_hue adjust_saturation + affine_transform flip_left_right flip_up_down gaussian_blur @@ -48,6 +49,11 @@ adjust_saturation .. autofunction:: adjust_saturation +affine_transform +~~~~~~~~~~~~~~~~~ + +.. autofunction:: affine_transform + flip_left_right ~~~~~~~~~~~~~~~ diff --git a/requirements_tests.txt b/requirements_tests.txt index ea78edf..aab4d65 100644 --- a/requirements_tests.txt +++ b/requirements_tests.txt @@ -1,2 +1,3 @@ pytest-xdist tensorflow +scipy