Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal changes. #90

Merged
merged 1 commit into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions dm_pix/_src/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import jax
import jax.numpy as jnp

# DO NOT REMOVE - Logging lib.


def adjust_brightness(image: chex.Array, delta: chex.Numeric) -> chex.Array:
"""Shifts the brightness of an RGB image by a given amount.
Expand All @@ -41,6 +43,8 @@ def adjust_brightness(image: chex.Array, delta: chex.Numeric) -> chex.Array:
Returns:
The brightness-adjusted image. May be outside of the [0, 1] range.
"""
# DO NOT REMOVE - Logging usage.

return image + jnp.asarray(delta, image.dtype)


Expand All @@ -62,6 +66,8 @@ def adjust_contrast(
Returns:
The contrast-adjusted image. May be outside of the [0, 1] range.
"""
# DO NOT REMOVE - Logging usage.

if _channels_last(image, channel_axis):
spatial_axes = (-3, -2)
else:
Expand Down Expand Up @@ -93,6 +99,8 @@ def adjust_gamma(
Returns:
The gamma-adjusted image.
"""
# DO NOT REMOVE - Logging usage.

if not assume_in_bounds:
image = jnp.clip(image, 0., 1.) # Clip image for safety.
return jnp.asarray(gain, image.dtype) * (
Expand All @@ -119,6 +127,8 @@ def adjust_hue(
Returns:
The saturation-adjusted image.
"""
# DO NOT REMOVE - Logging usage.

rgb = color_conversion.split_channels(image, channel_axis)
hue, saturation, value = color_conversion.rgb_planes_to_hsv_planes(*rgb)
rgb_adjusted = color_conversion.hsv_planes_to_rgb_planes((hue + delta) % 1.0,
Expand All @@ -144,6 +154,8 @@ def adjust_saturation(
Returns:
The saturation-adjusted image.
"""
# DO NOT REMOVE - Logging usage.

rgb = color_conversion.split_channels(image, channel_axis)
hue, saturation, value = color_conversion.rgb_planes_to_hsv_planes(*rgb)
factor = jnp.asarray(factor, image.dtype)
Expand Down Expand Up @@ -196,6 +208,8 @@ def elastic_deformation(
Returns:
The transformed image.
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(image, 3)
if channel_axis != -1:
image = jnp.moveaxis(image, source=channel_axis, destination=-1)
Expand Down Expand Up @@ -264,6 +278,8 @@ def center_crop(
Returns:
The cropped image(s).
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(image, {3, 4})
batch, current_height, current_width, channel = _get_dimension_values(
image=image, channel_axis=channel_axis
Expand Down Expand Up @@ -325,6 +341,8 @@ def pad_to_size(
Returns:
The padded image(s).
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(image, {3, 4})
batch, height, width, _ = _get_dimension_values(
image=image, channel_axis=channel_axis
Expand Down Expand Up @@ -374,6 +392,8 @@ def resize_with_crop_or_pad(
Returns:
The image(s) resized by crop or pad to the desired target size.
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(image, {3, 4})
image = center_crop(
image,
Expand Down Expand Up @@ -408,6 +428,8 @@ def flip_left_right(
Returns:
The flipped image.
"""
# DO NOT REMOVE - Logging usage.

if _channels_last(image, channel_axis):
flip_axis = -2 # Image is ...HWC
else:
Expand All @@ -432,6 +454,8 @@ def flip_up_down(
Returns:
The flipped image.
"""
# DO NOT REMOVE - Logging usage.

if _channels_last(image, channel_axis):
flip_axis = -3 # Image is ...HWC
else:
Expand Down Expand Up @@ -461,6 +485,8 @@ def gaussian_blur(
Returns:
The blurred image.
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(image, {3, 4})
data_format = "NHWC" if _channels_last(image, channel_axis) else "NCHW"
dimension_numbers = (data_format, "HWIO", data_format)
Expand Down Expand Up @@ -516,6 +542,8 @@ def rot90(
Returns:
The rotated image.
"""
# DO NOT REMOVE - Logging usage.

if _channels_last(image, channel_axis):
spatial_axes = (-3, -2) # Image is ...HWC
else:
Expand All @@ -535,6 +563,8 @@ def solarize(image: chex.Array, threshold: chex.Numeric) -> chex.Array:
Returns:
The solarized image.
"""
# DO NOT REMOVE - Logging usage.

return jnp.where(image < threshold, image, 1. - image)


Expand Down Expand Up @@ -654,6 +684,8 @@ def affine_transform(

>>> matrix = rotation_matrix.dot(translation_matrix)
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(image, 3)
chex.assert_rank(matrix, {1, 2})
chex.assert_rank(offset, {0, 1})
Expand Down Expand Up @@ -716,6 +748,8 @@ def rotate(
Returns:
The rotated image.
"""
# DO NOT REMOVE - Logging usage.

# Calculate inverse transform matrix assuming clockwise rotation.
c = jnp.cos(angle)
s = jnp.sin(angle)
Expand Down Expand Up @@ -747,6 +781,8 @@ def random_flip_left_right(
Returns:
A left-right flipped image if condition is met, otherwise original image.
"""
# DO NOT REMOVE - Logging usage.

should_transform = jax.random.bernoulli(key=key, p=probability)
return jax.lax.cond(should_transform, flip_left_right, lambda x: x, image)

Expand All @@ -769,6 +805,8 @@ def random_flip_up_down(
Returns:
An up-down flipped image if condition is met, otherwise original image.
"""
# DO NOT REMOVE - Logging usage.

should_transform = jax.random.bernoulli(key=key, p=probability)
return jax.lax.cond(should_transform, flip_up_down, lambda x: x, image)

Expand All @@ -779,6 +817,8 @@ def random_brightness(
max_delta: chex.Numeric,
) -> chex.Array:
"""`adjust_brightness(...)` with random delta in `[-max_delta, max_delta)`."""
# DO NOT REMOVE - Logging usage.

delta = jax.random.uniform(key, (), minval=-max_delta, maxval=max_delta)
return adjust_brightness(image, delta)

Expand All @@ -793,6 +833,8 @@ def random_gamma(
assume_in_bounds: bool = False,
) -> chex.Array:
"""`adjust_gamma(...)` with random gamma in [min_gamma, max_gamma)`."""
# DO NOT REMOVE - Logging usage.

gamma = jax.random.uniform(key, (), minval=min_gamma, maxval=max_gamma)
return adjust_gamma(
image, gamma, gain=gain, assume_in_bounds=assume_in_bounds)
Expand All @@ -806,6 +848,8 @@ def random_hue(
channel_axis: int = -1,
) -> chex.Array:
"""`adjust_hue(...)` with random delta in `[-max_delta, max_delta)`."""
# DO NOT REMOVE - Logging usage.

delta = jax.random.uniform(key, (), minval=-max_delta, maxval=max_delta)
return adjust_hue(image, delta, channel_axis=channel_axis)

Expand All @@ -819,6 +863,8 @@ def random_contrast(
channel_axis: int = -1,
) -> chex.Array:
"""`adjust_contrast(...)` with random factor in `[lower, upper)`."""
# DO NOT REMOVE - Logging usage.

factor = jax.random.uniform(key, (), minval=lower, maxval=upper)
return adjust_contrast(image, factor, channel_axis=channel_axis)

Expand All @@ -832,6 +878,8 @@ def random_saturation(
channel_axis: int = -1,
) -> chex.Array:
"""`adjust_saturation(...)` with random factor in `[lower, upper)`."""
# DO NOT REMOVE - Logging usage.

factor = jax.random.uniform(key, (), minval=lower, maxval=upper)
return adjust_saturation(image, factor, channel_axis=channel_axis)

Expand Down Expand Up @@ -859,6 +907,7 @@ def random_crop(
Returns:
A cropped image, a JAX array whose shape is same as `crop_sizes`.
"""
# DO NOT REMOVE - Logging usage.

image_shape = image.shape
assert len(image_shape) == len(crop_sizes), (
Expand Down
17 changes: 17 additions & 0 deletions dm_pix/_src/color_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import chex
import jax.numpy as jnp

# DO NOT REMOVE - Logging lib.


def split_channels(
image: chex.Array,
Expand All @@ -37,6 +39,8 @@ def split_channels(
A tuple of 3 images, with float values in range [0, 1], stacked
along channel_axis.
"""
# DO NOT REMOVE - Logging usage.

chex.assert_axis_dimension(image, axis=channel_axis, expected=3)
split_axes = jnp.split(image, 3, axis=channel_axis)
return tuple(map(lambda x: jnp.squeeze(x, axis=channel_axis), split_axes))
Expand All @@ -58,6 +62,8 @@ def rgb_to_hsv(
Returns:
An HSV image, with float values in range [0, 1], stacked along channel_axis.
"""
# DO NOT REMOVE - Logging usage.

eps = jnp.finfo(image_rgb.dtype).eps
image_rgb = jnp.where(jnp.abs(image_rgb) < eps, 0., image_rgb)
red, green, blue = split_channels(image_rgb, channel_axis)
Expand All @@ -81,6 +87,7 @@ def hsv_to_rgb(
Returns:
An RGB image, with float values in range [0, 1], stacked along channel_axis.
"""
# DO NOT REMOVE - Logging usage.
hue, saturation, value = split_channels(image_hsv, channel_axis)
return jnp.stack(
hsv_planes_to_rgb_planes(hue, saturation, value), axis=channel_axis)
Expand All @@ -106,6 +113,8 @@ def rgb_planes_to_hsv_planes(
Returns:
A tuple of (hue, saturation, value) planes, as float values in range [0, 1].
"""
# DO NOT REMOVE - Logging usage.

value = jnp.maximum(jnp.maximum(red, green), blue)
minimum = jnp.minimum(jnp.minimum(red, green), blue)
range_ = value - minimum
Expand Down Expand Up @@ -148,6 +157,8 @@ def hsv_planes_to_rgb_planes(
Returns:
A tuple of (red, green, blue) planes, as float values in range [0, 1].
"""
# DO NOT REMOVE - Logging usage.

dh = (hue % 1.0) * 6. # Wrap when hue >= 360°.
dr = jnp.clip(jnp.abs(dh - 3.) - 1., 0., 1.)
dg = jnp.clip(2. - jnp.abs(dh - 2.), 0., 1.)
Expand Down Expand Up @@ -177,6 +188,8 @@ def rgb_to_hsl(
Returns:
An HSL image, with float values in range [0, 1], stacked along channel_axis.
"""
# DO NOT REMOVE - Logging usage.

red, green, blue = split_channels(image_rgb, channel_axis)

c_max = jnp.maximum(red, jnp.maximum(green, blue))
Expand Down Expand Up @@ -218,6 +231,8 @@ def hsl_to_rgb(
Returns:
An RGB image, with float values in range [0, 1], stacked along channel_axis.
"""
# DO NOT REMOVE - Logging usage.

h, s, l = split_channels(image_hsl, channel_axis)

m2 = jnp.where(l <= 0.5, l * (1 + s), l + s - l * s)
Expand Down Expand Up @@ -257,6 +272,8 @@ def rgb_to_grayscale(
Returns:
The grayscale image.
"""
# DO NOT REMOVE - Logging usage.

assert luma_standard in ["rec601", "rec709", "bt2001"]
if luma_standard == "rec601":
# TensorFlow's default.
Expand Down
13 changes: 10 additions & 3 deletions dm_pix/_src/depth_and_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import jax
import jax.numpy as jnp

# DO NOT REMOVE - Logging lib.


def depth_to_space(inputs: chex.Array, block_size: int) -> chex.Array:
"""Rearranges data from depth into blocks of spatial data.
Expand All @@ -31,14 +33,17 @@ def depth_to_space(inputs: chex.Array, block_size: int) -> chex.Array:
[H * B, W * B, C / (B ** 2)], where B is `block_size`. If there's a leading
batch dimension, it stays unchanged.
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(inputs, {3, 4})
if inputs.ndim == 4: # Batched case.
return jax.vmap(depth_to_space, in_axes=(0, None))(inputs, block_size)

height, width, depth = inputs.shape
if depth % (block_size**2) != 0:
raise ValueError(
f'Number of channels {depth} must be divisible by block_size ** 2 {block_size**2}.'
f"Number of channels {depth} must be divisible by block_size ** 2"
f" {block_size**2}."
)
new_depth = depth // (block_size**2)
outputs = jnp.reshape(inputs,
Expand All @@ -64,17 +69,19 @@ def space_to_depth(inputs: chex.Array, block_size: int) -> chex.Array:
[H / B, W / B, C * (B ** 2)], where B is `block_size`. If there's a leading
batch dimension, it stays unchanged.
"""
# DO NOT REMOVE - Logging usage.

chex.assert_rank(inputs, {3, 4})
if inputs.ndim == 4: # Batched case.
return jax.vmap(space_to_depth, in_axes=(0, None))(inputs, block_size)

height, width, depth = inputs.shape
if height % block_size != 0:
raise ValueError(
f'Height {height} must be divisible by block size {block_size}.')
f"Height {height} must be divisible by block size {block_size}.")
if width % block_size != 0:
raise ValueError(
f'Width {width} must be divisible by block size {block_size}.')
f"Width {width} must be divisible by block size {block_size}.")
new_depth = depth * (block_size**2)
new_height = height // block_size
new_width = width // block_size
Expand Down
10 changes: 8 additions & 2 deletions dm_pix/_src/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from jax import lax
import jax.numpy as jnp

# DO NOT REMOVE - Logging lib.


def _round_half_away_from_zero(a: chex.Array) -> chex.Array:
return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a)
Expand Down Expand Up @@ -73,8 +75,8 @@ def _make_linear_interpolation_indices_flat_nd(

if shape.shape[0] != coordinates.shape[0]:
raise ValueError(
(f'{coordinates.shape[0]}-dimensional coordinates provided for '
f'{shape.shape[0]}-dimensional input'))
(f"{coordinates.shape[0]}-dimensional coordinates provided for "
f"{shape.shape[0]}-dimensional input"))

lower_nd, upper_nd, weights_nd = _make_linear_interpolation_indices_nd(
coordinates, shape)
Expand Down Expand Up @@ -153,6 +155,8 @@ def flat_nd_linear_interpolate(
The resulting mapped coordinates. The shape of the output is `M_coordinates`
(derived from `coordinates` by dropping the first axis).
"""
# DO NOT REMOVE - Logging usage.

if unflattened_vol_shape is None:
unflattened_vol_shape = volume.shape
volume = volume.flatten()
Expand Down Expand Up @@ -191,6 +195,8 @@ def flat_nd_linear_interpolate_constant(
The resulting mapped coordinates. The shape of the output is `M_coordinates`
(derived from `coordinates` by dropping the first axis).
"""
# DO NOT REMOVE - Logging usage.

volume_shape = volume.shape
if unflattened_vol_shape is not None:
volume_shape = unflattened_vol_shape
Expand Down
Loading
Loading