diff --git a/dm_pix/_src/augment.py b/dm_pix/_src/augment.py index 923130a..d749f61 100644 --- a/dm_pix/_src/augment.py +++ b/dm_pix/_src/augment.py @@ -28,6 +28,8 @@ import jax import jax.numpy as jnp +# DO NOT REMOVE - Logging lib. + def adjust_brightness(image: chex.Array, delta: chex.Numeric) -> chex.Array: """Shifts the brightness of an RGB image by a given amount. @@ -41,6 +43,8 @@ def adjust_brightness(image: chex.Array, delta: chex.Numeric) -> chex.Array: Returns: The brightness-adjusted image. May be outside of the [0, 1] range. """ + # DO NOT REMOVE - Logging usage. + return image + jnp.asarray(delta, image.dtype) @@ -62,6 +66,8 @@ def adjust_contrast( Returns: The contrast-adjusted image. May be outside of the [0, 1] range. """ + # DO NOT REMOVE - Logging usage. + if _channels_last(image, channel_axis): spatial_axes = (-3, -2) else: @@ -93,6 +99,8 @@ def adjust_gamma( Returns: The gamma-adjusted image. """ + # DO NOT REMOVE - Logging usage. + if not assume_in_bounds: image = jnp.clip(image, 0., 1.) # Clip image for safety. return jnp.asarray(gain, image.dtype) * ( @@ -119,6 +127,8 @@ def adjust_hue( Returns: The saturation-adjusted image. """ + # DO NOT REMOVE - Logging usage. + rgb = color_conversion.split_channels(image, channel_axis) hue, saturation, value = color_conversion.rgb_planes_to_hsv_planes(*rgb) rgb_adjusted = color_conversion.hsv_planes_to_rgb_planes((hue + delta) % 1.0, @@ -144,6 +154,8 @@ def adjust_saturation( Returns: The saturation-adjusted image. """ + # DO NOT REMOVE - Logging usage. + rgb = color_conversion.split_channels(image, channel_axis) hue, saturation, value = color_conversion.rgb_planes_to_hsv_planes(*rgb) factor = jnp.asarray(factor, image.dtype) @@ -196,6 +208,8 @@ def elastic_deformation( Returns: The transformed image. """ + # DO NOT REMOVE - Logging usage. + chex.assert_rank(image, 3) if channel_axis != -1: image = jnp.moveaxis(image, source=channel_axis, destination=-1) @@ -264,6 +278,8 @@ def center_crop( Returns: The cropped image(s). """ + # DO NOT REMOVE - Logging usage. + chex.assert_rank(image, {3, 4}) batch, current_height, current_width, channel = _get_dimension_values( image=image, channel_axis=channel_axis @@ -325,6 +341,8 @@ def pad_to_size( Returns: The padded image(s). """ + # DO NOT REMOVE - Logging usage. + chex.assert_rank(image, {3, 4}) batch, height, width, _ = _get_dimension_values( image=image, channel_axis=channel_axis @@ -374,6 +392,8 @@ def resize_with_crop_or_pad( Returns: The image(s) resized by crop or pad to the desired target size. """ + # DO NOT REMOVE - Logging usage. + chex.assert_rank(image, {3, 4}) image = center_crop( image, @@ -408,6 +428,8 @@ def flip_left_right( Returns: The flipped image. """ + # DO NOT REMOVE - Logging usage. + if _channels_last(image, channel_axis): flip_axis = -2 # Image is ...HWC else: @@ -432,6 +454,8 @@ def flip_up_down( Returns: The flipped image. """ + # DO NOT REMOVE - Logging usage. + if _channels_last(image, channel_axis): flip_axis = -3 # Image is ...HWC else: @@ -461,6 +485,8 @@ def gaussian_blur( Returns: The blurred image. """ + # DO NOT REMOVE - Logging usage. + chex.assert_rank(image, {3, 4}) data_format = "NHWC" if _channels_last(image, channel_axis) else "NCHW" dimension_numbers = (data_format, "HWIO", data_format) @@ -516,6 +542,8 @@ def rot90( Returns: The rotated image. """ + # DO NOT REMOVE - Logging usage. + if _channels_last(image, channel_axis): spatial_axes = (-3, -2) # Image is ...HWC else: @@ -535,6 +563,8 @@ def solarize(image: chex.Array, threshold: chex.Numeric) -> chex.Array: Returns: The solarized image. """ + # DO NOT REMOVE - Logging usage. + return jnp.where(image < threshold, image, 1. - image) @@ -654,6 +684,8 @@ def affine_transform( >>> matrix = rotation_matrix.dot(translation_matrix) """ + # DO NOT REMOVE - Logging usage. + chex.assert_rank(image, 3) chex.assert_rank(matrix, {1, 2}) chex.assert_rank(offset, {0, 1}) @@ -716,6 +748,8 @@ def rotate( Returns: The rotated image. """ + # DO NOT REMOVE - Logging usage. + # Calculate inverse transform matrix assuming clockwise rotation. c = jnp.cos(angle) s = jnp.sin(angle) @@ -747,6 +781,8 @@ def random_flip_left_right( Returns: A left-right flipped image if condition is met, otherwise original image. """ + # DO NOT REMOVE - Logging usage. + should_transform = jax.random.bernoulli(key=key, p=probability) return jax.lax.cond(should_transform, flip_left_right, lambda x: x, image) @@ -769,6 +805,8 @@ def random_flip_up_down( Returns: An up-down flipped image if condition is met, otherwise original image. """ + # DO NOT REMOVE - Logging usage. + should_transform = jax.random.bernoulli(key=key, p=probability) return jax.lax.cond(should_transform, flip_up_down, lambda x: x, image) @@ -779,6 +817,8 @@ def random_brightness( max_delta: chex.Numeric, ) -> chex.Array: """`adjust_brightness(...)` with random delta in `[-max_delta, max_delta)`.""" + # DO NOT REMOVE - Logging usage. + delta = jax.random.uniform(key, (), minval=-max_delta, maxval=max_delta) return adjust_brightness(image, delta) @@ -793,6 +833,8 @@ def random_gamma( assume_in_bounds: bool = False, ) -> chex.Array: """`adjust_gamma(...)` with random gamma in [min_gamma, max_gamma)`.""" + # DO NOT REMOVE - Logging usage. + gamma = jax.random.uniform(key, (), minval=min_gamma, maxval=max_gamma) return adjust_gamma( image, gamma, gain=gain, assume_in_bounds=assume_in_bounds) @@ -806,6 +848,8 @@ def random_hue( channel_axis: int = -1, ) -> chex.Array: """`adjust_hue(...)` with random delta in `[-max_delta, max_delta)`.""" + # DO NOT REMOVE - Logging usage. + delta = jax.random.uniform(key, (), minval=-max_delta, maxval=max_delta) return adjust_hue(image, delta, channel_axis=channel_axis) @@ -819,6 +863,8 @@ def random_contrast( channel_axis: int = -1, ) -> chex.Array: """`adjust_contrast(...)` with random factor in `[lower, upper)`.""" + # DO NOT REMOVE - Logging usage. + factor = jax.random.uniform(key, (), minval=lower, maxval=upper) return adjust_contrast(image, factor, channel_axis=channel_axis) @@ -832,6 +878,8 @@ def random_saturation( channel_axis: int = -1, ) -> chex.Array: """`adjust_saturation(...)` with random factor in `[lower, upper)`.""" + # DO NOT REMOVE - Logging usage. + factor = jax.random.uniform(key, (), minval=lower, maxval=upper) return adjust_saturation(image, factor, channel_axis=channel_axis) @@ -859,6 +907,7 @@ def random_crop( Returns: A cropped image, a JAX array whose shape is same as `crop_sizes`. """ + # DO NOT REMOVE - Logging usage. image_shape = image.shape assert len(image_shape) == len(crop_sizes), ( diff --git a/dm_pix/_src/color_conversion.py b/dm_pix/_src/color_conversion.py index 0f37487..39bc4ac 100644 --- a/dm_pix/_src/color_conversion.py +++ b/dm_pix/_src/color_conversion.py @@ -21,6 +21,8 @@ import chex import jax.numpy as jnp +# DO NOT REMOVE - Logging lib. + def split_channels( image: chex.Array, @@ -37,6 +39,8 @@ def split_channels( A tuple of 3 images, with float values in range [0, 1], stacked along channel_axis. """ + # DO NOT REMOVE - Logging usage. + chex.assert_axis_dimension(image, axis=channel_axis, expected=3) split_axes = jnp.split(image, 3, axis=channel_axis) return tuple(map(lambda x: jnp.squeeze(x, axis=channel_axis), split_axes)) @@ -58,6 +62,8 @@ def rgb_to_hsv( Returns: An HSV image, with float values in range [0, 1], stacked along channel_axis. """ + # DO NOT REMOVE - Logging usage. + eps = jnp.finfo(image_rgb.dtype).eps image_rgb = jnp.where(jnp.abs(image_rgb) < eps, 0., image_rgb) red, green, blue = split_channels(image_rgb, channel_axis) @@ -81,6 +87,7 @@ def hsv_to_rgb( Returns: An RGB image, with float values in range [0, 1], stacked along channel_axis. """ + # DO NOT REMOVE - Logging usage. hue, saturation, value = split_channels(image_hsv, channel_axis) return jnp.stack( hsv_planes_to_rgb_planes(hue, saturation, value), axis=channel_axis) @@ -106,6 +113,8 @@ def rgb_planes_to_hsv_planes( Returns: A tuple of (hue, saturation, value) planes, as float values in range [0, 1]. """ + # DO NOT REMOVE - Logging usage. + value = jnp.maximum(jnp.maximum(red, green), blue) minimum = jnp.minimum(jnp.minimum(red, green), blue) range_ = value - minimum @@ -148,6 +157,8 @@ def hsv_planes_to_rgb_planes( Returns: A tuple of (red, green, blue) planes, as float values in range [0, 1]. """ + # DO NOT REMOVE - Logging usage. + dh = (hue % 1.0) * 6. # Wrap when hue >= 360°. dr = jnp.clip(jnp.abs(dh - 3.) - 1., 0., 1.) dg = jnp.clip(2. - jnp.abs(dh - 2.), 0., 1.) @@ -177,6 +188,8 @@ def rgb_to_hsl( Returns: An HSL image, with float values in range [0, 1], stacked along channel_axis. """ + # DO NOT REMOVE - Logging usage. + red, green, blue = split_channels(image_rgb, channel_axis) c_max = jnp.maximum(red, jnp.maximum(green, blue)) @@ -218,6 +231,8 @@ def hsl_to_rgb( Returns: An RGB image, with float values in range [0, 1], stacked along channel_axis. """ + # DO NOT REMOVE - Logging usage. + h, s, l = split_channels(image_hsl, channel_axis) m2 = jnp.where(l <= 0.5, l * (1 + s), l + s - l * s) @@ -257,6 +272,8 @@ def rgb_to_grayscale( Returns: The grayscale image. """ + # DO NOT REMOVE - Logging usage. + assert luma_standard in ["rec601", "rec709", "bt2001"] if luma_standard == "rec601": # TensorFlow's default. diff --git a/dm_pix/_src/depth_and_space.py b/dm_pix/_src/depth_and_space.py index 3f9ae61..79ee416 100644 --- a/dm_pix/_src/depth_and_space.py +++ b/dm_pix/_src/depth_and_space.py @@ -17,6 +17,8 @@ import jax import jax.numpy as jnp +# DO NOT REMOVE - Logging lib. + def depth_to_space(inputs: chex.Array, block_size: int) -> chex.Array: """Rearranges data from depth into blocks of spatial data. @@ -31,6 +33,8 @@ def depth_to_space(inputs: chex.Array, block_size: int) -> chex.Array: [H * B, W * B, C / (B ** 2)], where B is `block_size`. If there's a leading batch dimension, it stays unchanged. """ + # DO NOT REMOVE - Logging usage. + chex.assert_rank(inputs, {3, 4}) if inputs.ndim == 4: # Batched case. return jax.vmap(depth_to_space, in_axes=(0, None))(inputs, block_size) @@ -38,7 +42,8 @@ def depth_to_space(inputs: chex.Array, block_size: int) -> chex.Array: height, width, depth = inputs.shape if depth % (block_size**2) != 0: raise ValueError( - f'Number of channels {depth} must be divisible by block_size ** 2 {block_size**2}.' + f"Number of channels {depth} must be divisible by block_size ** 2" + f" {block_size**2}." ) new_depth = depth // (block_size**2) outputs = jnp.reshape(inputs, @@ -64,6 +69,8 @@ def space_to_depth(inputs: chex.Array, block_size: int) -> chex.Array: [H / B, W / B, C * (B ** 2)], where B is `block_size`. If there's a leading batch dimension, it stays unchanged. """ + # DO NOT REMOVE - Logging usage. + chex.assert_rank(inputs, {3, 4}) if inputs.ndim == 4: # Batched case. return jax.vmap(space_to_depth, in_axes=(0, None))(inputs, block_size) @@ -71,10 +78,10 @@ def space_to_depth(inputs: chex.Array, block_size: int) -> chex.Array: height, width, depth = inputs.shape if height % block_size != 0: raise ValueError( - f'Height {height} must be divisible by block size {block_size}.') + f"Height {height} must be divisible by block size {block_size}.") if width % block_size != 0: raise ValueError( - f'Width {width} must be divisible by block size {block_size}.') + f"Width {width} must be divisible by block size {block_size}.") new_depth = depth * (block_size**2) new_height = height // block_size new_width = width // block_size diff --git a/dm_pix/_src/interpolation.py b/dm_pix/_src/interpolation.py index 61192fb..4091a4a 100644 --- a/dm_pix/_src/interpolation.py +++ b/dm_pix/_src/interpolation.py @@ -23,6 +23,8 @@ from jax import lax import jax.numpy as jnp +# DO NOT REMOVE - Logging lib. + def _round_half_away_from_zero(a: chex.Array) -> chex.Array: return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a) @@ -73,8 +75,8 @@ def _make_linear_interpolation_indices_flat_nd( if shape.shape[0] != coordinates.shape[0]: raise ValueError( - (f'{coordinates.shape[0]}-dimensional coordinates provided for ' - f'{shape.shape[0]}-dimensional input')) + (f"{coordinates.shape[0]}-dimensional coordinates provided for " + f"{shape.shape[0]}-dimensional input")) lower_nd, upper_nd, weights_nd = _make_linear_interpolation_indices_nd( coordinates, shape) @@ -153,6 +155,8 @@ def flat_nd_linear_interpolate( The resulting mapped coordinates. The shape of the output is `M_coordinates` (derived from `coordinates` by dropping the first axis). """ + # DO NOT REMOVE - Logging usage. + if unflattened_vol_shape is None: unflattened_vol_shape = volume.shape volume = volume.flatten() @@ -191,6 +195,8 @@ def flat_nd_linear_interpolate_constant( The resulting mapped coordinates. The shape of the output is `M_coordinates` (derived from `coordinates` by dropping the first axis). """ + # DO NOT REMOVE - Logging usage. + volume_shape = volume.shape if unflattened_vol_shape is not None: volume_shape = unflattened_vol_shape diff --git a/dm_pix/_src/metrics.py b/dm_pix/_src/metrics.py index 0b9607c..e3a3802 100644 --- a/dm_pix/_src/metrics.py +++ b/dm_pix/_src/metrics.py @@ -23,6 +23,8 @@ import jax import jax.numpy as jnp +# DO NOT REMOVE - Logging lib. + def mae(a: chex.Array, b: chex.Array) -> chex.Numeric: """Returns the Mean Absolute Error between `a` and `b`. @@ -34,6 +36,8 @@ def mae(a: chex.Array, b: chex.Array) -> chex.Numeric: Returns: MAE between `a` and `b`. """ + # DO NOT REMOVE - Logging usage. + chex.assert_rank([a, b], {3, 4}) chex.assert_type([a, b], float) chex.assert_equal_shape([a, b]) @@ -50,6 +54,8 @@ def mse(a: chex.Array, b: chex.Array) -> chex.Numeric: Returns: MSE between `a` and `b`. """ + # DO NOT REMOVE - Logging usage. + chex.assert_rank([a, b], {3, 4}) chex.assert_type([a, b], float) chex.assert_equal_shape([a, b]) @@ -69,6 +75,8 @@ def psnr(a: chex.Array, b: chex.Array) -> chex.Numeric: Returns: PSNR in decibels between `a` and `b`. """ + # DO NOT REMOVE - Logging usage. + chex.assert_rank([a, b], {3, 4}) chex.assert_type([a, b], float) chex.assert_equal_shape([a, b]) @@ -85,6 +93,8 @@ def rmse(a: chex.Array, b: chex.Array) -> chex.Numeric: Returns: RMSE between `a` and `b`. """ + # DO NOT REMOVE - Logging usage. + chex.assert_rank([a, b], {3, 4}) chex.assert_type([a, b], float) chex.assert_equal_shape([a, b]) @@ -114,6 +124,8 @@ def simse(a: chex.Array, b: chex.Array) -> chex.Numeric: Returns: SIMSE between `a` and `b`. """ + # DO NOT REMOVE - Logging usage. + chex.assert_rank([a, b], {3, 4}) chex.assert_type([a, b], float) chex.assert_equal_shape([a, b]) @@ -168,6 +180,8 @@ def ssim( Returns: Each image's mean SSIM, or a tensor of individual values if `return_map`. """ + # DO NOT REMOVE - Logging usage. + chex.assert_rank([a, b], {3, 4}) chex.assert_type([a, b], float) chex.assert_equal_shape([a, b]) diff --git a/dm_pix/_src/patch.py b/dm_pix/_src/patch.py index 482a54c..2d9f593 100644 --- a/dm_pix/_src/patch.py +++ b/dm_pix/_src/patch.py @@ -19,6 +19,8 @@ import jax import jax.numpy as jnp +# DO NOT REMOVE - Logging lib. + def extract_patches( images: chex.Array, @@ -26,7 +28,7 @@ def extract_patches( strides: Sequence[int], rates: Sequence[int], *, - padding: str = 'VALID', + padding: str = "VALID", ) -> jnp.ndarray: """Extract patches from images. @@ -58,17 +60,19 @@ def extract_patches( Tensor of shape [B, patch_rows, patch_cols, ..., size_rows * size_cols * ... * C]. """ + # DO NOT REMOVE - Logging usage. + ndim = images.ndim if len(sizes) != ndim or sizes[0] != 1 or sizes[-1] != 1: - raise ValueError('Input `sizes` must be [1, size_rows, size_cols, ..., 1] ' - f'and same length as `images.ndim` {ndim}. Got {sizes}.') + raise ValueError("Input `sizes` must be [1, size_rows, size_cols, ..., 1] " + f"and same length as `images.ndim` {ndim}. Got {sizes}.") if len(strides) != ndim or strides[0] != 1 or strides[-1] != 1: - raise ValueError('Input `strides` must be [1, size_rows, size_cols, ..., 1]' - f'and same length as `images.ndim` {ndim}. Got {strides}.') + raise ValueError("Input `strides` must be [1, size_rows, size_cols, ..., 1]" + f"and same length as `images.ndim` {ndim}. Got {strides}.") if len(rates) != ndim or rates[0] != 1 or rates[-1] != 1: - raise ValueError('Input `rates` must be [1, size_rows, size_cols, ..., 1] ' - f'and same length as `images.ndim` {ndim}. Got {rates}.') + raise ValueError("Input `rates` must be [1, size_rows, size_cols, ..., 1] " + f"and same length as `images.ndim` {ndim}. Got {rates}.") channels = images.shape[-1] lhs_spec = out_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))