Skip to content

Commit

Permalink
Internal changes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 530866995
  • Loading branch information
claudiofantacci authored and PIXDev committed Dec 4, 2023
1 parent c1a933a commit 442c202
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 12 deletions.
49 changes: 49 additions & 0 deletions dm_pix/_src/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import jax
import jax.numpy as jnp

# DO NOT REMOVE - Logging lib.


def adjust_brightness(image: chex.Array, delta: chex.Numeric) -> chex.Array:
"""Shifts the brightness of an RGB image by a given amount.
Expand All @@ -41,6 +43,8 @@ def adjust_brightness(image: chex.Array, delta: chex.Numeric) -> chex.Array:
Returns:
The brightness-adjusted image. May be outside of the [0, 1] range.
"""
# DO NOT REMOVE - Logging usage.

return image + jnp.asarray(delta, image.dtype)


Expand All @@ -62,6 +66,8 @@ def adjust_contrast(
Returns:
The contrast-adjusted image. May be outside of the [0, 1] range.
"""
# DO NOT REMOVE - Logging usage.

if _channels_last(image, channel_axis):
spatial_axes = (-3, -2)
else:
Expand Down Expand Up @@ -93,6 +99,8 @@ def adjust_gamma(
Returns:
The gamma-adjusted image.
"""
# DO NOT REMOVE - Logging usage.

if not assume_in_bounds:
image = jnp.clip(image, 0., 1.) # Clip image for safety.
return jnp.asarray(gain, image.dtype) * (
Expand All @@ -119,6 +127,8 @@ def adjust_hue(
Returns:
The saturation-adjusted image.
"""
# DO NOT REMOVE - Logging usage.

rgb = color_conversion.split_channels(image, channel_axis)
hue, saturation, value = color_conversion.rgb_planes_to_hsv_planes(*rgb)
rgb_adjusted = color_conversion.hsv_planes_to_rgb_planes((hue + delta) % 1.0,
Expand All @@ -144,6 +154,8 @@ def adjust_saturation(
Returns:
The saturation-adjusted image.
"""
# DO NOT REMOVE - Logging usage.

rgb = color_conversion.split_channels(image, channel_axis)
hue, saturation, value = color_conversion.rgb_planes_to_hsv_planes(*rgb)
factor = jnp.asarray(factor, image.dtype)
Expand Down Expand Up @@ -196,6 +208,8 @@ def elastic_deformation(
Returns:
The transformed image.
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(image, 3)
if channel_axis != -1:
image = jnp.moveaxis(image, source=channel_axis, destination=-1)
Expand Down Expand Up @@ -264,6 +278,8 @@ def center_crop(
Returns:
The cropped image(s).
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(image, {3, 4})
batch, current_height, current_width, channel = _get_dimension_values(
image=image, channel_axis=channel_axis
Expand Down Expand Up @@ -325,6 +341,8 @@ def pad_to_size(
Returns:
The padded image(s).
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(image, {3, 4})
batch, height, width, _ = _get_dimension_values(
image=image, channel_axis=channel_axis
Expand Down Expand Up @@ -374,6 +392,8 @@ def resize_with_crop_or_pad(
Returns:
The image(s) resized by crop or pad to the desired target size.
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(image, {3, 4})
image = center_crop(
image,
Expand Down Expand Up @@ -408,6 +428,8 @@ def flip_left_right(
Returns:
The flipped image.
"""
# DO NOT REMOVE - Logging usage.

if _channels_last(image, channel_axis):
flip_axis = -2 # Image is ...HWC
else:
Expand All @@ -432,6 +454,8 @@ def flip_up_down(
Returns:
The flipped image.
"""
# DO NOT REMOVE - Logging usage.

if _channels_last(image, channel_axis):
flip_axis = -3 # Image is ...HWC
else:
Expand Down Expand Up @@ -461,6 +485,8 @@ def gaussian_blur(
Returns:
The blurred image.
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(image, {3, 4})
data_format = "NHWC" if _channels_last(image, channel_axis) else "NCHW"
dimension_numbers = (data_format, "HWIO", data_format)
Expand Down Expand Up @@ -516,6 +542,8 @@ def rot90(
Returns:
The rotated image.
"""
# DO NOT REMOVE - Logging usage.

if _channels_last(image, channel_axis):
spatial_axes = (-3, -2) # Image is ...HWC
else:
Expand All @@ -535,6 +563,8 @@ def solarize(image: chex.Array, threshold: chex.Numeric) -> chex.Array:
Returns:
The solarized image.
"""
# DO NOT REMOVE - Logging usage.

return jnp.where(image < threshold, image, 1. - image)


Expand Down Expand Up @@ -654,6 +684,8 @@ def affine_transform(
>>> matrix = rotation_matrix.dot(translation_matrix)
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(image, 3)
chex.assert_rank(matrix, {1, 2})
chex.assert_rank(offset, {0, 1})
Expand Down Expand Up @@ -716,6 +748,8 @@ def rotate(
Returns:
The rotated image.
"""
# DO NOT REMOVE - Logging usage.

# Calculate inverse transform matrix assuming clockwise rotation.
c = jnp.cos(angle)
s = jnp.sin(angle)
Expand Down Expand Up @@ -747,6 +781,8 @@ def random_flip_left_right(
Returns:
A left-right flipped image if condition is met, otherwise original image.
"""
# DO NOT REMOVE - Logging usage.

should_transform = jax.random.bernoulli(key=key, p=probability)
return jax.lax.cond(should_transform, flip_left_right, lambda x: x, image)

Expand All @@ -769,6 +805,8 @@ def random_flip_up_down(
Returns:
An up-down flipped image if condition is met, otherwise original image.
"""
# DO NOT REMOVE - Logging usage.

should_transform = jax.random.bernoulli(key=key, p=probability)
return jax.lax.cond(should_transform, flip_up_down, lambda x: x, image)

Expand All @@ -779,6 +817,8 @@ def random_brightness(
max_delta: chex.Numeric,
) -> chex.Array:
"""`adjust_brightness(...)` with random delta in `[-max_delta, max_delta)`."""
# DO NOT REMOVE - Logging usage.

delta = jax.random.uniform(key, (), minval=-max_delta, maxval=max_delta)
return adjust_brightness(image, delta)

Expand All @@ -793,6 +833,8 @@ def random_gamma(
assume_in_bounds: bool = False,
) -> chex.Array:
"""`adjust_gamma(...)` with random gamma in [min_gamma, max_gamma)`."""
# DO NOT REMOVE - Logging usage.

gamma = jax.random.uniform(key, (), minval=min_gamma, maxval=max_gamma)
return adjust_gamma(
image, gamma, gain=gain, assume_in_bounds=assume_in_bounds)
Expand All @@ -806,6 +848,8 @@ def random_hue(
channel_axis: int = -1,
) -> chex.Array:
"""`adjust_hue(...)` with random delta in `[-max_delta, max_delta)`."""
# DO NOT REMOVE - Logging usage.

delta = jax.random.uniform(key, (), minval=-max_delta, maxval=max_delta)
return adjust_hue(image, delta, channel_axis=channel_axis)

Expand All @@ -819,6 +863,8 @@ def random_contrast(
channel_axis: int = -1,
) -> chex.Array:
"""`adjust_contrast(...)` with random factor in `[lower, upper)`."""
# DO NOT REMOVE - Logging usage.

factor = jax.random.uniform(key, (), minval=lower, maxval=upper)
return adjust_contrast(image, factor, channel_axis=channel_axis)

Expand All @@ -832,6 +878,8 @@ def random_saturation(
channel_axis: int = -1,
) -> chex.Array:
"""`adjust_saturation(...)` with random factor in `[lower, upper)`."""
# DO NOT REMOVE - Logging usage.

factor = jax.random.uniform(key, (), minval=lower, maxval=upper)
return adjust_saturation(image, factor, channel_axis=channel_axis)

Expand Down Expand Up @@ -859,6 +907,7 @@ def random_crop(
Returns:
A cropped image, a JAX array whose shape is same as `crop_sizes`.
"""
# DO NOT REMOVE - Logging usage.

image_shape = image.shape
assert len(image_shape) == len(crop_sizes), (
Expand Down
17 changes: 17 additions & 0 deletions dm_pix/_src/color_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import chex
import jax.numpy as jnp

# DO NOT REMOVE - Logging lib.


def split_channels(
image: chex.Array,
Expand All @@ -37,6 +39,8 @@ def split_channels(
A tuple of 3 images, with float values in range [0, 1], stacked
along channel_axis.
"""
# DO NOT REMOVE - Logging usage.

chex.assert_axis_dimension(image, axis=channel_axis, expected=3)
split_axes = jnp.split(image, 3, axis=channel_axis)
return tuple(map(lambda x: jnp.squeeze(x, axis=channel_axis), split_axes))
Expand All @@ -58,6 +62,8 @@ def rgb_to_hsv(
Returns:
An HSV image, with float values in range [0, 1], stacked along channel_axis.
"""
# DO NOT REMOVE - Logging usage.

eps = jnp.finfo(image_rgb.dtype).eps
image_rgb = jnp.where(jnp.abs(image_rgb) < eps, 0., image_rgb)
red, green, blue = split_channels(image_rgb, channel_axis)
Expand All @@ -81,6 +87,7 @@ def hsv_to_rgb(
Returns:
An RGB image, with float values in range [0, 1], stacked along channel_axis.
"""
# DO NOT REMOVE - Logging usage.
hue, saturation, value = split_channels(image_hsv, channel_axis)
return jnp.stack(
hsv_planes_to_rgb_planes(hue, saturation, value), axis=channel_axis)
Expand All @@ -106,6 +113,8 @@ def rgb_planes_to_hsv_planes(
Returns:
A tuple of (hue, saturation, value) planes, as float values in range [0, 1].
"""
# DO NOT REMOVE - Logging usage.

value = jnp.maximum(jnp.maximum(red, green), blue)
minimum = jnp.minimum(jnp.minimum(red, green), blue)
range_ = value - minimum
Expand Down Expand Up @@ -148,6 +157,8 @@ def hsv_planes_to_rgb_planes(
Returns:
A tuple of (red, green, blue) planes, as float values in range [0, 1].
"""
# DO NOT REMOVE - Logging usage.

dh = (hue % 1.0) * 6. # Wrap when hue >= 360°.
dr = jnp.clip(jnp.abs(dh - 3.) - 1., 0., 1.)
dg = jnp.clip(2. - jnp.abs(dh - 2.), 0., 1.)
Expand Down Expand Up @@ -177,6 +188,8 @@ def rgb_to_hsl(
Returns:
An HSL image, with float values in range [0, 1], stacked along channel_axis.
"""
# DO NOT REMOVE - Logging usage.

red, green, blue = split_channels(image_rgb, channel_axis)

c_max = jnp.maximum(red, jnp.maximum(green, blue))
Expand Down Expand Up @@ -218,6 +231,8 @@ def hsl_to_rgb(
Returns:
An RGB image, with float values in range [0, 1], stacked along channel_axis.
"""
# DO NOT REMOVE - Logging usage.

h, s, l = split_channels(image_hsl, channel_axis)

m2 = jnp.where(l <= 0.5, l * (1 + s), l + s - l * s)
Expand Down Expand Up @@ -257,6 +272,8 @@ def rgb_to_grayscale(
Returns:
The grayscale image.
"""
# DO NOT REMOVE - Logging usage.

assert luma_standard in ["rec601", "rec709", "bt2001"]
if luma_standard == "rec601":
# TensorFlow's default.
Expand Down
13 changes: 10 additions & 3 deletions dm_pix/_src/depth_and_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import jax
import jax.numpy as jnp

# DO NOT REMOVE - Logging lib.


def depth_to_space(inputs: chex.Array, block_size: int) -> chex.Array:
"""Rearranges data from depth into blocks of spatial data.
Expand All @@ -31,14 +33,17 @@ def depth_to_space(inputs: chex.Array, block_size: int) -> chex.Array:
[H * B, W * B, C / (B ** 2)], where B is `block_size`. If there's a leading
batch dimension, it stays unchanged.
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(inputs, {3, 4})
if inputs.ndim == 4: # Batched case.
return jax.vmap(depth_to_space, in_axes=(0, None))(inputs, block_size)

height, width, depth = inputs.shape
if depth % (block_size**2) != 0:
raise ValueError(
f'Number of channels {depth} must be divisible by block_size ** 2 {block_size**2}.'
f"Number of channels {depth} must be divisible by block_size ** 2"
f" {block_size**2}."
)
new_depth = depth // (block_size**2)
outputs = jnp.reshape(inputs,
Expand All @@ -64,17 +69,19 @@ def space_to_depth(inputs: chex.Array, block_size: int) -> chex.Array:
[H / B, W / B, C * (B ** 2)], where B is `block_size`. If there's a leading
batch dimension, it stays unchanged.
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(inputs, {3, 4})
if inputs.ndim == 4: # Batched case.
return jax.vmap(space_to_depth, in_axes=(0, None))(inputs, block_size)

height, width, depth = inputs.shape
if height % block_size != 0:
raise ValueError(
f'Height {height} must be divisible by block size {block_size}.')
f"Height {height} must be divisible by block size {block_size}.")
if width % block_size != 0:
raise ValueError(
f'Width {width} must be divisible by block size {block_size}.')
f"Width {width} must be divisible by block size {block_size}.")
new_depth = depth * (block_size**2)
new_height = height // block_size
new_width = width // block_size
Expand Down
10 changes: 8 additions & 2 deletions dm_pix/_src/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from jax import lax
import jax.numpy as jnp

# DO NOT REMOVE - Logging lib.


def _round_half_away_from_zero(a: chex.Array) -> chex.Array:
return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a)
Expand Down Expand Up @@ -73,8 +75,8 @@ def _make_linear_interpolation_indices_flat_nd(

if shape.shape[0] != coordinates.shape[0]:
raise ValueError(
(f'{coordinates.shape[0]}-dimensional coordinates provided for '
f'{shape.shape[0]}-dimensional input'))
(f"{coordinates.shape[0]}-dimensional coordinates provided for "
f"{shape.shape[0]}-dimensional input"))

lower_nd, upper_nd, weights_nd = _make_linear_interpolation_indices_nd(
coordinates, shape)
Expand Down Expand Up @@ -153,6 +155,8 @@ def flat_nd_linear_interpolate(
The resulting mapped coordinates. The shape of the output is `M_coordinates`
(derived from `coordinates` by dropping the first axis).
"""
# DO NOT REMOVE - Logging usage.

if unflattened_vol_shape is None:
unflattened_vol_shape = volume.shape
volume = volume.flatten()
Expand Down Expand Up @@ -191,6 +195,8 @@ def flat_nd_linear_interpolate_constant(
The resulting mapped coordinates. The shape of the output is `M_coordinates`
(derived from `coordinates` by dropping the first axis).
"""
# DO NOT REMOVE - Logging usage.

volume_shape = volume.shape
if unflattened_vol_shape is not None:
volume_shape = unflattened_vol_shape
Expand Down
Loading

0 comments on commit 442c202

Please sign in to comment.