diff --git a/serket/_src/image/augment.py b/serket/_src/image/augment.py index 1bbd57a..dd7c12c 100644 --- a/serket/_src/image/augment.py +++ b/serket/_src/image/augment.py @@ -19,19 +19,23 @@ import jax import jax.numpy as jnp import jax.random as jr -from typing_extensions import Annotated import serket as sk from serket._src.custom_transform import tree_eval from serket._src.nn.linear import Identity -from serket._src.utils import IsInstance, Range, validate_spatial_nd +from serket._src.utils import CHWArray, HWArray, IsInstance, Range, validate_spatial_nd -def pixel_shuffle_3d( - array: Annotated[jax.Array, "CHW"], - upscale_factor: tuple[int, int], -) -> Annotated[jax.Array, "CHW"]: - """Rearrange elements in a tensor.""" +def pixel_shuffle_3d(array: CHWArray, upscale_factor: tuple[int, int]) -> CHWArray: + """Rearrange elements in a tensor. + + Args: + array: input array with shape (channels, height, width) + upscale_factor: factor to increase spatial resolution by. + + Reference: + - https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html + """ channels, _, _ = array.shape sr, sw = upscale_factor @@ -48,16 +52,16 @@ def pixel_shuffle_3d( def solarize_2d( - image: Annotated[jax.Array, "HW"], + image: HWArray, threshold: float | int, max_val: float | int, -) -> Annotated[jax.Array, "HW"]: +) -> HWArray: """Inverts all values above a given threshold.""" _, _ = image.shape return jnp.where(image < threshold, image, max_val - image) -def adjust_contrast_2d(image: Annotated[jax.Array, "HW"], contrast_factor: float): +def adjust_contrast_2d(image: HWArray, contrast_factor: float): """Adjusts the contrast of an image by scaling the pixel values by a factor. Args: @@ -70,10 +74,10 @@ def adjust_contrast_2d(image: Annotated[jax.Array, "HW"], contrast_factor: float def random_contrast_2d( - array: Annotated[jax.Array, "HW"], + array: HWArray, contrast_range: tuple[float, float], - key: jr.KeyArray = jr.PRNGKey(0), -) -> Annotated[jax.Array, "HW"]: + key: jr.KeyArray, +) -> HWArray: """Randomly adjusts the contrast of an image by scaling the pixel values by a factor.""" _, _ = array.shape minval, maxval = contrast_range @@ -81,9 +85,7 @@ def random_contrast_2d( return adjust_contrast_2d(array, contrast_factor) -def pixelate_2d( - image: Annotated[jax.Array, "HW"], scale: int = 16 -) -> Annotated[jax.Array, "HW"]: +def pixelate_2d(image: HWArray, scale: int = 16) -> HWArray: """Return a pixelated image by downsizing and upsizing""" dtype = image.dtype h, w = image.shape @@ -95,12 +97,14 @@ def pixelate_2d( @ft.partial(jax.jit, inline=True, static_argnums=1) -def jigsaw_2d( - image: Annotated[jax.Array, "HW"], - tiles: int = 1, - key: jr.KeyArray = jr.PRNGKey(0), -) -> Annotated[jax.Array, "HW"]: - """Jigsaw channel-first image""" +def jigsaw_2d(image: HWArray, tiles: int, key: jr.KeyArray) -> HWArray: + """Jigsaw an image by mixing up tiles. + + Args: + image: The image to jigsaw in shape of (height, width). + tiles: The number of tiles per side. + key: The random key to use for shuffling. + """ height, width = image.shape tile_height = height // tiles tile_width = width // tiles @@ -117,10 +121,7 @@ def jigsaw_2d( return image -def posterize_2d( - image: Annotated[jax.Array, "HW"], - bits: int, -) -> Annotated[jax.Array, "HW"]: +def posterize_2d(image: HWArray, bits: int) -> HWArray: """Reduce the number of bits for each color channel. Args: @@ -141,8 +142,7 @@ class PixelShuffle2D(sk.TreeClass): .. image:: ../_static/pixelshuffle2d.png Args: - upscale_factor: factor to increase spatial resolution by. accepts a - single integer or a tuple of length 2. defaults to 1. + upscale_factor: factor to increase spatial resolution by. Reference: - https://arxiv.org/abs/1609.05158 @@ -164,15 +164,14 @@ def __init__(self, upscale_factor: int | tuple[int, int] = 1): self.upscale_factor = upscale_factor return - raise ValueError("upscale_factor must be an integer or tuple of length 2") + raise ValueError("`upscale_factor` must be an integer or tuple of length 2") @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: return pixel_shuffle_3d(x, self.upscale_factor) @property def spatial_ndim(self) -> int: - """Number of spatial dimensions of the image.""" return 2 @@ -193,13 +192,12 @@ class AdjustContrast2D(sk.TreeClass): contrast_factor: float = 1.0 @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: contrast_factor = jax.lax.stop_gradient(self.contrast_factor) return jax.vmap(adjust_contrast_2d, in_axes=(0, None))(x, contrast_factor) @property def spatial_ndim(self) -> int: - """Number of spatial dimensions of the image.""" return 2 @@ -232,7 +230,7 @@ def __init__(self, contrast_range: tuple[float, float] = (0.5, 1)): self.contrast_range = contrast_range @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array: + def __call__(self, x: CHWArray, *, key: jr.KeyArray) -> CHWArray: contrast_range = jax.lax.stop_gradient(self.contrast_range) in_axes = (0, None, None) return jax.vmap(random_contrast_2d, in_axes=in_axes)(x, contrast_range, key) @@ -271,9 +269,8 @@ def __init__(self, scale: int): self.scale = scale @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: - scale = jax.lax.stop_gradient(self.scale) - return jax.vmap(pixelate_2d, in_axes=(0, None))(x, scale) + def __call__(self, x: CHWArray) -> CHWArray: + return jax.vmap(pixelate_2d, in_axes=(0, None))(x, self.scale) @property def spatial_ndim(self) -> int: @@ -311,7 +308,7 @@ class Solarize2D(sk.TreeClass): max_val: float = 1.0 @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: threshold, max_val = jax.lax.stop_gradient((self.threshold, self.max_val)) return jax.vmap(solarize_2d, in_axes=(0, None, None))(x, threshold, max_val) @@ -365,9 +362,8 @@ class Posterize2D(sk.TreeClass): bits: int = sk.field(on_setattr=[IsInstance(int), Range(1, 8)]) @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: - bits = jax.lax.stop_gradient(self.bits) - return jax.vmap(posterize_2d, in_axes=(0, None))(x, bits) + def __call__(self, x: CHWArray) -> CHWArray: + return jax.vmap(posterize_2d, in_axes=(0, None))(x, self.bits) @property def spatial_ndim(self) -> int: @@ -386,13 +382,14 @@ class JigSaw2D(sk.TreeClass): Example: >>> import serket as sk >>> import jax.numpy as jnp + >>> import jax.random as jr >>> x = jnp.arange(1, 17).reshape(1, 4, 4) >>> print(x) [[[ 1 2 3 4] [ 5 6 7 8] [ 9 10 11 12] [13 14 15 16]]] - >>> print(sk.image.JigSaw2D(2)(x)) + >>> print(sk.image.JigSaw2D(2)(x, key=jr.PRNGKey(0))) [[[ 9 10 3 4] [13 14 7 8] [11 12 1 2] @@ -420,7 +417,7 @@ class JigSaw2D(sk.TreeClass): tiles: int = sk.field(on_setattr=[IsInstance(int), Range(1)]) @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array: + def __call__(self, x: CHWArray, *, key: jr.KeyArray) -> CHWArray: """Mixes up tiles of an image. Args: diff --git a/serket/_src/image/filter.py b/serket/_src/image/filter.py index 6bcdcee..b8ee16e 100644 --- a/serket/_src/image/filter.py +++ b/serket/_src/image/filter.py @@ -19,13 +19,14 @@ import jax import jax.numpy as jnp import kernex as kex -from typing_extensions import Annotated import serket as sk from serket._src.image.geometric import rotate_2d from serket._src.nn.convolution import fft_conv_general_dilated from serket._src.nn.initialization import DType from serket._src.utils import ( + CHWArray, + HWArray, canonicalize, generate_conv_dim_numbers, resolve_string_padding, @@ -33,15 +34,12 @@ ) -def filter_2d( - array: Annotated[jax.Array, "HW"], - weight: Annotated[jax.Array, "HW"], -) -> Annotated[jax.Array, "HW"]: +def filter_2d(array: HWArray, weight: HWArray) -> HWArray: """Filtering wrapping ``jax.lax.conv_general_dilated``. Args: - array: 2D input array. shape is (row, col). - weight: convolutional kernel. shape is (row, col). + array: 2D input array. shape is (height, width). + weight: convolutional kernel. shape is (height, width). Note: - To filter 3D array, channel-wise use ``jax.vmap(filter_2d, in_axes=(0, None))``. @@ -56,7 +54,7 @@ def filter_2d( lhs=jnp.expand_dims(array, 0), rhs=weight, window_strides=(1, 1), - padding="SAME", + padding="same", rhs_dilation=(1, 1), dimension_numbers=generate_conv_dim_numbers(2), feature_group_count=array.shape[0], # in_features @@ -64,15 +62,12 @@ def filter_2d( return jnp.squeeze(x, (0, 1)) -def fft_filter_2d( - array: Annotated[jax.Array, "HW"], - weight: Annotated[jax.Array, "HW"], -) -> Annotated[jax.Array, "HW"]: +def fft_filter_2d(array: HWArray, weight: HWArray) -> HWArray: """Filtering wrapping ``serket`` ``fft_conv_general_dilated`` Args: - array: 2D input array. shape is (row, col). - weight: convolutional kernel. shape is (row, col). + array: 2D input array. shape is (height, width). + weight: convolutional kernel. shape is (height, width). Note: - To filter 3D array, channel-wise use ``jax.vmap(filter_2d, in_axes=(0, None))``. @@ -85,7 +80,7 @@ def fft_filter_2d( padding = resolve_string_padding( in_dim=array.shape[1:], - padding="SAME", + padding="same", kernel_size=weight.shape[2:], strides=(1, 1), ) @@ -101,10 +96,7 @@ def fft_filter_2d( return jnp.squeeze(x, (0, 1)) -def calculate_average_kernel( - kernel_size: int, - dtype: DType, -) -> Annotated[jax.Array, "HW"]: +def calculate_average_kernel(kernel_size: int, dtype: DType) -> HWArray: """Calculate average kernel. Args: @@ -121,11 +113,7 @@ def calculate_average_kernel( return kernel -def calculate_gaussian_kernel( - kernel_size: int, - sigma: float, - dtype: DType, -) -> Annotated[jax.Array, "HW"]: +def calculate_gaussian_kernel(kernel_size: int, sigma: float, dtype: DType) -> HWArray: """Calculate gaussian kernel. Args: @@ -145,7 +133,7 @@ def calculate_gaussian_kernel( return kernel -def calculate_box_kernel(kernel_size: int, dtype: DType) -> Annotated[jax.Array, "HW"]: +def calculate_box_kernel(kernel_size: int, dtype: DType) -> HWArray: """Calculate box kernel. Args: @@ -161,10 +149,7 @@ def calculate_box_kernel(kernel_size: int, dtype: DType) -> Annotated[jax.Array, return kernel / kernel_size -def calculate_laplacian_kernel( - kernel_size: tuple[int, int], - dtype: DType, -) -> Annotated[jax.Array, "HW"]: +def calculate_laplacian_kernel(kernel_size: tuple[int, int], dtype: DType) -> HWArray: """Calculate laplacian kernel. Args: @@ -183,9 +168,9 @@ def calculate_laplacian_kernel( def calculate_motion_kernel( kernel_size: int, angle: float, - direction=0.0, - dtype: DType = jnp.float32, -) -> Annotated[jax.Array, "HW"]: + direction, + dtype: DType, +) -> HWArray: """Returns 2D motion blur filter. Args: @@ -207,11 +192,9 @@ def calculate_motion_kernel( @ft.partial(jax.jit, inline=True, static_argnums=1) -def median_blur_2d( - array: Annotated[jax.Array, "HW"], - kernel_size: tuple[int, int], -) -> Annotated[jax.Array, "HW"]: - _, _ = array.shape +def median_blur_2d(array: HWArray, kernel_size: tuple[int, int]) -> HWArray: + """Median blur""" + assert array.ndim == 2 @kex.kmap(kernel_size=kernel_size, padding="same") def median_kernel(array: jax.Array) -> jax.Array: @@ -259,7 +242,7 @@ class AvgBlur2D(AvgBlur2DBase): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: kernel_x, kernel_y = jax.lax.stop_gradient((self.kernel_x, self.kernel_y)) x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel_x) x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel_y.T) @@ -288,7 +271,7 @@ class FFTAvgBlur2D(AvgBlur2DBase): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: kernel_x, kernel_y = jax.lax.stop_gradient((self.kernel_x, self.kernel_y)) x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel_x) x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel_y.T) @@ -338,7 +321,7 @@ class GaussianBlur2D(GaussianBlur2DBase): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: kernel_x, kernel_y = jax.lax.stop_gradient((self.kernel_x, self.kernel_y)) x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel_x) x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel_y.T) @@ -368,7 +351,7 @@ class FFTGaussianBlur2D(GaussianBlur2DBase): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: kernel_x, kernel_y = jax.lax.stop_gradient((self.kernel_x, self.kernel_y)) x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel_x) x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel_y.T) @@ -398,7 +381,7 @@ class UnsharpMask2D(GaussianBlur2DBase): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: kernel_x, kernel_y = jax.lax.stop_gradient((self.kernel_x, self.kernel_y)) blur = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel_x) blur = jax.vmap(filter_2d, in_axes=(0, None))(blur, kernel_y.T) @@ -428,7 +411,7 @@ class FFTUnsharpMask2D(GaussianBlur2DBase): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: kernel_x, kernel_y = jax.lax.stop_gradient((self.kernel_x, self.kernel_y)) blur = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel_x) blur = jax.vmap(fft_filter_2d, in_axes=(0, None))(blur, kernel_y.T) @@ -474,7 +457,7 @@ class BoxBlur2D(BoxBlur2DBase): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: kernel_x, kernel_y = jax.lax.stop_gradient((self.kernel_x, self.kernel_y)) x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel_x) x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel_y.T) @@ -503,7 +486,7 @@ class FFTBoxBlur2D(BoxBlur2DBase): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: kernel_x, kernel_y = jax.lax.stop_gradient((self.kernel_x, self.kernel_y)) x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel_x) x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel_y.T) @@ -550,7 +533,7 @@ class Laplacian2D(Laplacian2DBase): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: kernel = jax.lax.stop_gradient_p.bind(self.kernel) x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel) return x @@ -581,7 +564,7 @@ class FFTLaplacian2D(Laplacian2DBase): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: kernel = jax.lax.stop_gradient_p.bind(self.kernel) x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel) return x @@ -630,7 +613,7 @@ class MotionBlur2D(MotionBlur2DBase): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: kernel = jax.lax.stop_gradient_p.bind(self.kernel) x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel) return x @@ -659,7 +642,7 @@ class FFTMotionBlur2D(MotionBlur2DBase): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: kernel = jax.lax.stop_gradient_p.bind(self.kernel) x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel) return x @@ -695,7 +678,7 @@ def __init__(self, kernel_size: int | tuple[int, int]): self.kernel_size = canonicalize(kernel_size, ndim=2, name="kernel_size") @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: x = jax.vmap(median_blur_2d, in_axes=(0, None))(x, self.kernel_size) return x @@ -725,11 +708,11 @@ class Filter2D(sk.TreeClass): [4. 6. 6. 6. 4.]]] """ - def __init__(self, kernel: jax.Array, *, dtype: DType = jnp.float32): + def __init__(self, kernel: HWArray, *, dtype: DType = jnp.float32): self.kernel = kernel.astype(dtype) @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: kernel = jax.lax.stop_gradient_p.bind(self.kernel) x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel) return x @@ -760,11 +743,11 @@ class FFTFilter2D(sk.TreeClass): [4. 6.0000005 6.0000005 6.0000005 4. ]]] """ - def __init__(self, kernel: jax.Array, *, dtype: DType = jnp.float32): + def __init__(self, kernel: HWArray, *, dtype: DType = jnp.float32): self.kernel = kernel.astype(dtype) @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: kernel = jax.lax.stop_gradient_p.bind(self.kernel) x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel) return x diff --git a/serket/_src/image/geometric.py b/serket/_src/image/geometric.py index f21c01f..86bc8ba 100644 --- a/serket/_src/image/geometric.py +++ b/serket/_src/image/geometric.py @@ -20,17 +20,14 @@ import jax.numpy as jnp import jax.random as jr from jax.scipy.ndimage import map_coordinates -from typing_extensions import Annotated import serket as sk from serket._src.custom_transform import tree_eval from serket._src.nn.linear import Identity -from serket._src.utils import IsInstance, validate_spatial_nd +from serket._src.utils import CHWArray, HWArray, IsInstance, validate_spatial_nd -def affine_2d( - array: Annotated[jax.Array, "HW"], matrix: Annotated[jax.Array, "HW"] -) -> Annotated[jax.Array, "HW"]: +def affine_2d(array: HWArray, matrix: HWArray) -> HWArray: h, w = array.shape center = jnp.array((h // 2, w // 2)) coords = jnp.indices((h, w)).reshape(2, -1) - center.reshape(2, 1) @@ -38,9 +35,7 @@ def affine_2d( return map_coordinates(array, coords, order=1).reshape((h, w)) -def horizontal_shear_2d( - image: Annotated[jax.Array, "HW"], angle: float -) -> Annotated[jax.Array, "HW"]: +def horizontal_shear_2d(image: HWArray, angle: float) -> HWArray: """shear rows by an angle in degrees""" shear = jnp.tan(jnp.deg2rad(angle)) matrix = jnp.array([[1, 0], [shear, 1]]) @@ -50,7 +45,7 @@ def horizontal_shear_2d( def random_horizontal_shear_2d( image: jax.Array, angle_range: tuple[float, float], - key: jax.random.KeyArray, + key: jr.KeyArray, ) -> jax.Array: """shear rows by an angle in degrees""" minval, maxval = angle_range @@ -59,9 +54,9 @@ def random_horizontal_shear_2d( def vertical_shear_2d( - image: Annotated[jax.Array, "HW"], + image: HWArray, angle: float, -) -> Annotated[jax.Array, "HW"]: +) -> HWArray: """shear cols by an angle in degrees""" shear = jnp.tan(jnp.deg2rad(angle)) matrix = jnp.array([[1, shear], [0, 1]]) @@ -69,19 +64,17 @@ def vertical_shear_2d( def random_vertical_shear_2d( - image: Annotated[jax.Array, "HW"], + image: HWArray, angle_range: tuple[float, float], - key: jax.random.KeyArray, -) -> Annotated[jax.Array, "HW"]: + key: jr.KeyArray, +) -> HWArray: """shear cols by an angle in degrees""" minval, maxval = angle_range angle = jr.uniform(key=key, shape=(), minval=minval, maxval=maxval) return vertical_shear_2d(image, angle) -def rotate_2d( - image: Annotated[jax.Array, "HW"], angle: float -) -> Annotated[jax.Array, "HW"]: +def rotate_2d(image: HWArray, angle: float) -> HWArray: """rotate an image by an angle in degrees in CCW direction.""" θ = jnp.deg2rad(-angle) matrix = jnp.array([[jnp.cos(θ), -jnp.sin(θ)], [jnp.sin(θ), jnp.cos(θ)]]) @@ -89,19 +82,16 @@ def rotate_2d( def random_rotate_2d( - image: Annotated[jax.Array, "HW"], + image: HWArray, angle_range: tuple[float, float], - key: jax.random.KeyArray, -) -> Annotated[jax.Array, "HW"]: + key: jr.KeyArray, +) -> HWArray: minval, maxval = angle_range angle = jr.uniform(key=key, shape=(), minval=minval, maxval=maxval) return rotate_2d(image, angle) -def perspective_transform_2d( - image: Annotated[jax.Array, "HW"], - coeffs: jax.Array, -) -> Annotated[jax.Array, "HW"]: +def perspective_transform_2d(image: HWArray, coeffs: jax.Array) -> HWArray: """Apply a perspective transform to an image.""" rows, cols = image.shape @@ -115,12 +105,12 @@ def perspective_transform_2d( def random_perspective_2d( - image: Annotated[jax.Array, "HW"], - key: jax.random.KeyArray, - scale: float = 1.0, -) -> Annotated[jax.Array, "HW"]: + image: HWArray, + key: jr.KeyArray, + scale: float, +) -> HWArray: """Applies a random perspective transform to a channel-first image""" - _, _ = image.shape + assert image.ndim == 2 a = e = 1.0 b = d = 0.0 c = f = 0.0 # no translation @@ -129,12 +119,9 @@ def random_perspective_2d( return perspective_transform_2d(image, coeffs) -def horizontal_translate_2d( - image: Annotated[jax.Array, "HW"], - shift: int, -) -> Annotated[jax.Array, "HW"]: +def horizontal_translate_2d(image: HWArray, shift: int) -> HWArray: """Translate an image horizontally by a pixel value.""" - _, _ = image.shape + assert image.ndim == 2 if shift > 0: return jnp.zeros_like(image).at[:, shift:].set(image[:, :-shift]) if shift < 0: @@ -142,10 +129,7 @@ def horizontal_translate_2d( return image -def vertical_translate_2d( - image: Annotated[jax.Array, "HW"], - shift: int, -) -> Annotated[jax.Array, "HW"]: +def vertical_translate_2d(image: HWArray, shift: int) -> HWArray: """Translate an image vertically by a pixel value.""" _, _ = image.shape if shift > 0: @@ -155,18 +139,13 @@ def vertical_translate_2d( return image -def random_horizontal_translate_2d( - image: Annotated[jax.Array, "HW"], key: jr.KeyArray -) -> Annotated[jax.Array, "HW"]: +def random_horizontal_translate_2d(image: HWArray, key: jr.KeyArray) -> HWArray: _, w = image.shape shift = jr.randint(key, shape=(), minval=-w, maxval=w) return horizontal_translate_2d(image, shift) -def random_vertical_translate_2d( - image: Annotated[jax.Array, "HW"], - key: jr.KeyArray, -) -> Annotated[jax.Array, "HW"]: +def random_vertical_translate_2d(image: HWArray, key: jr.KeyArray) -> HWArray: h, _ = image.shape shift = jr.randint(key, shape=(), minval=-h, maxval=h) return vertical_translate_2d(image, shift) @@ -196,7 +175,7 @@ def __init__(self, angle: float): self.angle = angle @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: angle = jax.lax.stop_gradient(self.angle) return jax.vmap(rotate_2d, in_axes=(0, None))(x, angle) @@ -251,11 +230,7 @@ def __init__(self, angle_range: tuple[float, float] = (0.0, 360.0)): self.angle_range = angle_range @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__( - self, - x: jax.Array, - key: jax.random.KeyArray = jax.random.PRNGKey(0), - ) -> jax.Array: + def __call__(self, x: CHWArray, *, key: jr.KeyArray) -> CHWArray: angle_range = jax.lax.stop_gradient(self.angle_range) return jax.vmap(random_rotate_2d, in_axes=(0, None, None))(x, angle_range, key) @@ -288,7 +263,7 @@ def __init__(self, angle: float): self.angle = angle @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: angle = jax.lax.stop_gradient(self.angle) return jax.vmap(horizontal_shear_2d, in_axes=(0, None))(x, angle) @@ -322,8 +297,9 @@ class RandomHorizontalShear2D(sk.TreeClass): >>> import serket as sk >>> import jax >>> import jax.numpy as jnp + >>> import jax.random as jr >>> x = jnp.arange(1, 26).reshape(1, 5, 5) - >>> print(sk.image.RandomHorizontalShear2D((45, 45))(x)) + >>> print(sk.image.RandomHorizontalShear2D((45, 45))(x, key=jr.PRNGKey(0))) [[[ 0 0 1 2 3] [ 0 6 7 8 9] [11 12 13 14 15] @@ -342,11 +318,7 @@ def __init__(self, angle_range: tuple[float, float] = (0.0, 90.0)): self.angle_range = angle_range @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__( - self, - x: jax.Array, - key: jax.random.KeyArray = jax.random.PRNGKey(0), - ) -> jax.Array: + def __call__(self, x: CHWArray, *, key: jr.KeyArray) -> CHWArray: angle = jax.lax.stop_gradient(self.angle_range) in_axes = (0, None, None) return jax.vmap(random_horizontal_shear_2d, in_axes=in_axes)(x, angle, key) @@ -414,8 +386,9 @@ class RandomVerticalShear2D(sk.TreeClass): >>> import serket as sk >>> import jax >>> import jax.numpy as jnp + >>> import jax.random as jr >>> x = jnp.arange(1, 26).reshape(1, 5, 5) - >>> print(sk.image.RandomVerticalShear2D((45, 45))(x)) + >>> print(sk.image.RandomVerticalShear2D((45, 45))(x, key=jr.PRNGKey(0))) [[[ 0 0 3 9 15] [ 0 2 8 14 20] [ 1 7 13 19 25] @@ -434,11 +407,7 @@ def __init__(self, angle_range: tuple[float, float] = (0.0, 90.0)): self.angle_range = angle_range @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__( - self, - x: jax.Array, - key: jax.random.KeyArray = jax.random.PRNGKey(0), - ) -> jax.Array: + def __call__(self, x: CHWArray, *, key: jr.KeyArray) -> CHWArray: angle = jax.lax.stop_gradient(self.angle_range) in_axes = (0, None, None) return jax.vmap(random_vertical_shear_2d, in_axes=in_axes)(x, angle, key) @@ -551,7 +520,7 @@ def __init__(self, scale: float = 1.0): self.scale = scale @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array: + def __call__(self, x: CHWArray, *, key: jr.KeyArray) -> CHWArray: scale = jax.lax.stop_gradient(self.scale) return jax.vmap(random_perspective_2d, in_axes=(0, None, None))(x, key, scale) @@ -584,7 +553,7 @@ class HorizontalTranslate2D(sk.TreeClass): shift: int = sk.field(on_setattr=[IsInstance(int)]) @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: return jax.vmap(horizontal_translate_2d, in_axes=(0, None))(x, self.shift) @property @@ -647,7 +616,7 @@ class RandomHorizontalTranslate2D(sk.TreeClass): >>> import serket as sk >>> import jax.numpy as jnp >>> x = jnp.arange(1, 26).reshape(1, 5, 5) - >>> print(sk.image.RandomHorizontalTranslate2D()(x)) + >>> print(sk.image.RandomHorizontalTranslate2D()(x, key=jr.PRNGKey(0))) [[[ 4 5 0 0 0] [ 9 10 0 0 0] [14 15 0 0 0] @@ -656,11 +625,7 @@ class RandomHorizontalTranslate2D(sk.TreeClass): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__( - self, - x: jax.Array, - key: jr.KeyArray = jr.PRNGKey(0), - ) -> jax.Array: + def __call__(self, x: CHWArray, *, key: jr.KeyArray) -> CHWArray: return jax.vmap(random_horizontal_translate_2d, in_axes=(0, None))(x, key) @property @@ -699,7 +664,7 @@ class RandomVerticalTranslate2D(sk.TreeClass): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array: + def __call__(self, x: CHWArray, *, key: jr.KeyArray = jr.PRNGKey(0)) -> CHWArray: return jax.vmap(random_vertical_translate_2d, in_axes=(0, None))(x, key) @property @@ -731,7 +696,7 @@ class HorizontalFlip2D(sk.TreeClass): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array, **k) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: return jax.vmap(lambda x: jnp.flip(x, axis=1))(x) @property @@ -763,7 +728,7 @@ class VerticalFlip2D(sk.TreeClass): """ @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, x: jax.Array, **k) -> jax.Array: + def __call__(self, x: CHWArray) -> CHWArray: return jax.vmap(lambda x: jnp.flip(x, axis=0))(x) @property diff --git a/serket/_src/utils.py b/serket/_src/utils.py index a5788f4..f07c430 100644 --- a/serket/_src/utils.py +++ b/serket/_src/utils.py @@ -23,7 +23,7 @@ import jax import jax.numpy as jnp import numpy as np -from typing_extensions import ParamSpec +from typing_extensions import Annotated, ParamSpec import serket as sk @@ -33,6 +33,8 @@ DilationType = Union[int, Sequence[int]] P = ParamSpec("P") T = TypeVar("T") +HWArray = Annotated[jax.Array, "HW"] +CHWArray = Annotated[jax.Array, "CHW"] @ft.lru_cache(maxsize=None) diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 385fcfb..4988795 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -193,7 +193,7 @@ def test_jigsaw(): x = jnp.arange(1, 17).reshape(1, 4, 4) layer = sk.image.JigSaw2D(2) npt.assert_allclose( - layer(x), + layer(x, key=jax.random.PRNGKey(0)), jnp.array([[[9, 10, 3, 4], [13, 14, 7, 8], [11, 12, 1, 2], [15, 16, 5, 6]]]), ) @@ -222,7 +222,7 @@ def test_rotate(): layer = sk.image.RandomRotate2D((90, 90)) - npt.assert_allclose(layer(x), rot) + npt.assert_allclose(layer(x, key=jax.random.PRNGKey(0)), rot) npt.assert_allclose(sk.tree_eval(layer)(x), x) @@ -244,7 +244,7 @@ def test_horizontal_shear(): npt.assert_allclose(layer(x), shear) layer = sk.image.RandomHorizontalShear2D((45, 45)) - npt.assert_allclose(layer(x), shear) + npt.assert_allclose(layer(x, key=jax.random.PRNGKey(0)), shear) npt.assert_allclose(sk.tree_eval(layer)(x), x) @@ -267,7 +267,7 @@ def test_vertical_shear(): npt.assert_allclose(layer(x), shear) layer = sk.image.RandomVerticalShear2D((45, 45)) - npt.assert_allclose(layer(x), shear) + npt.assert_allclose(layer(x, key=jax.random.PRNGKey(0)), shear) npt.assert_allclose(sk.tree_eval(layer)(x), x)