Skip to content

Commit

Permalink
passing explicit prg
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 17, 2023
1 parent 16d3bef commit c28c764
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 177 deletions.
83 changes: 40 additions & 43 deletions serket/_src/image/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,23 @@
import jax
import jax.numpy as jnp
import jax.random as jr
from typing_extensions import Annotated

import serket as sk
from serket._src.custom_transform import tree_eval
from serket._src.nn.linear import Identity
from serket._src.utils import IsInstance, Range, validate_spatial_nd
from serket._src.utils import CHWArray, HWArray, IsInstance, Range, validate_spatial_nd


def pixel_shuffle_3d(
array: Annotated[jax.Array, "CHW"],
upscale_factor: tuple[int, int],
) -> Annotated[jax.Array, "CHW"]:
"""Rearrange elements in a tensor."""
def pixel_shuffle_3d(array: CHWArray, upscale_factor: tuple[int, int]) -> CHWArray:
"""Rearrange elements in a tensor.
Args:
array: input array with shape (channels, height, width)
upscale_factor: factor to increase spatial resolution by.
Reference:
- https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
"""
channels, _, _ = array.shape

sr, sw = upscale_factor
Expand All @@ -48,16 +52,16 @@ def pixel_shuffle_3d(


def solarize_2d(
image: Annotated[jax.Array, "HW"],
image: HWArray,
threshold: float | int,
max_val: float | int,
) -> Annotated[jax.Array, "HW"]:
) -> HWArray:
"""Inverts all values above a given threshold."""
_, _ = image.shape
return jnp.where(image < threshold, image, max_val - image)


def adjust_contrast_2d(image: Annotated[jax.Array, "HW"], contrast_factor: float):
def adjust_contrast_2d(image: HWArray, contrast_factor: float):
"""Adjusts the contrast of an image by scaling the pixel values by a factor.
Args:
Expand All @@ -70,20 +74,18 @@ def adjust_contrast_2d(image: Annotated[jax.Array, "HW"], contrast_factor: float


def random_contrast_2d(
array: Annotated[jax.Array, "HW"],
array: HWArray,
contrast_range: tuple[float, float],
key: jr.KeyArray = jr.PRNGKey(0),
) -> Annotated[jax.Array, "HW"]:
key: jr.KeyArray,
) -> HWArray:
"""Randomly adjusts the contrast of an image by scaling the pixel values by a factor."""
_, _ = array.shape
minval, maxval = contrast_range
contrast_factor = jr.uniform(key=key, shape=(), minval=minval, maxval=maxval)
return adjust_contrast_2d(array, contrast_factor)


def pixelate_2d(
image: Annotated[jax.Array, "HW"], scale: int = 16
) -> Annotated[jax.Array, "HW"]:
def pixelate_2d(image: HWArray, scale: int = 16) -> HWArray:
"""Return a pixelated image by downsizing and upsizing"""
dtype = image.dtype
h, w = image.shape
Expand All @@ -95,12 +97,14 @@ def pixelate_2d(


@ft.partial(jax.jit, inline=True, static_argnums=1)
def jigsaw_2d(
image: Annotated[jax.Array, "HW"],
tiles: int = 1,
key: jr.KeyArray = jr.PRNGKey(0),
) -> Annotated[jax.Array, "HW"]:
"""Jigsaw channel-first image"""
def jigsaw_2d(image: HWArray, tiles: int, key: jr.KeyArray) -> HWArray:
"""Jigsaw an image by mixing up tiles.
Args:
image: The image to jigsaw in shape of (height, width).
tiles: The number of tiles per side.
key: The random key to use for shuffling.
"""
height, width = image.shape
tile_height = height // tiles
tile_width = width // tiles
Expand All @@ -117,10 +121,7 @@ def jigsaw_2d(
return image


def posterize_2d(
image: Annotated[jax.Array, "HW"],
bits: int,
) -> Annotated[jax.Array, "HW"]:
def posterize_2d(image: HWArray, bits: int) -> HWArray:
"""Reduce the number of bits for each color channel.
Args:
Expand All @@ -141,8 +142,7 @@ class PixelShuffle2D(sk.TreeClass):
.. image:: ../_static/pixelshuffle2d.png
Args:
upscale_factor: factor to increase spatial resolution by. accepts a
single integer or a tuple of length 2. defaults to 1.
upscale_factor: factor to increase spatial resolution by.
Reference:
- https://arxiv.org/abs/1609.05158
Expand All @@ -164,15 +164,14 @@ def __init__(self, upscale_factor: int | tuple[int, int] = 1):
self.upscale_factor = upscale_factor
return

raise ValueError("upscale_factor must be an integer or tuple of length 2")
raise ValueError("`upscale_factor` must be an integer or tuple of length 2")

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
def __call__(self, x: CHWArray) -> CHWArray:
return pixel_shuffle_3d(x, self.upscale_factor)

@property
def spatial_ndim(self) -> int:
"""Number of spatial dimensions of the image."""
return 2


Expand All @@ -193,13 +192,12 @@ class AdjustContrast2D(sk.TreeClass):
contrast_factor: float = 1.0

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
def __call__(self, x: CHWArray) -> CHWArray:
contrast_factor = jax.lax.stop_gradient(self.contrast_factor)
return jax.vmap(adjust_contrast_2d, in_axes=(0, None))(x, contrast_factor)

@property
def spatial_ndim(self) -> int:
"""Number of spatial dimensions of the image."""
return 2


Expand Down Expand Up @@ -232,7 +230,7 @@ def __init__(self, contrast_range: tuple[float, float] = (0.5, 1)):
self.contrast_range = contrast_range

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array:
def __call__(self, x: CHWArray, *, key: jr.KeyArray) -> CHWArray:
contrast_range = jax.lax.stop_gradient(self.contrast_range)
in_axes = (0, None, None)
return jax.vmap(random_contrast_2d, in_axes=in_axes)(x, contrast_range, key)
Expand Down Expand Up @@ -271,9 +269,8 @@ def __init__(self, scale: int):
self.scale = scale

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
scale = jax.lax.stop_gradient(self.scale)
return jax.vmap(pixelate_2d, in_axes=(0, None))(x, scale)
def __call__(self, x: CHWArray) -> CHWArray:
return jax.vmap(pixelate_2d, in_axes=(0, None))(x, self.scale)

@property
def spatial_ndim(self) -> int:
Expand Down Expand Up @@ -311,7 +308,7 @@ class Solarize2D(sk.TreeClass):
max_val: float = 1.0

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
def __call__(self, x: CHWArray) -> CHWArray:
threshold, max_val = jax.lax.stop_gradient((self.threshold, self.max_val))
return jax.vmap(solarize_2d, in_axes=(0, None, None))(x, threshold, max_val)

Expand Down Expand Up @@ -365,9 +362,8 @@ class Posterize2D(sk.TreeClass):
bits: int = sk.field(on_setattr=[IsInstance(int), Range(1, 8)])

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
bits = jax.lax.stop_gradient(self.bits)
return jax.vmap(posterize_2d, in_axes=(0, None))(x, bits)
def __call__(self, x: CHWArray) -> CHWArray:
return jax.vmap(posterize_2d, in_axes=(0, None))(x, self.bits)

@property
def spatial_ndim(self) -> int:
Expand All @@ -386,13 +382,14 @@ class JigSaw2D(sk.TreeClass):
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> x = jnp.arange(1, 17).reshape(1, 4, 4)
>>> print(x)
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
>>> print(sk.image.JigSaw2D(2)(x))
>>> print(sk.image.JigSaw2D(2)(x, key=jr.PRNGKey(0)))
[[[ 9 10 3 4]
[13 14 7 8]
[11 12 1 2]
Expand Down Expand Up @@ -420,7 +417,7 @@ class JigSaw2D(sk.TreeClass):
tiles: int = sk.field(on_setattr=[IsInstance(int), Range(1)])

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array:
def __call__(self, x: CHWArray, *, key: jr.KeyArray) -> CHWArray:
"""Mixes up tiles of an image.
Args:
Expand Down
Loading

0 comments on commit c28c764

Please sign in to comment.