From 9a4711ae7471ea658cb0bcad5f26485da73ee243 Mon Sep 17 00:00:00 2001 From: Mahmoud Asem <48389287+ASEM000@users.noreply.github.com> Date: Sat, 9 Dec 2023 20:03:55 +0900 Subject: [PATCH] attention updates (#91) --- serket/_src/image/augment.py | 42 +++++++----- serket/_src/image/color.py | 10 +-- serket/_src/image/filter.py | 44 ++++++------ serket/_src/image/geometric.py | 42 +++++++----- serket/_src/nn/attention.py | 115 +++++++++++++++++++------------- serket/_src/nn/convolution.py | 28 ++++---- serket/_src/nn/dropout.py | 16 ++--- serket/_src/nn/linear.py | 13 ++-- serket/_src/nn/normalization.py | 4 +- serket/_src/nn/pooling.py | 16 ++--- serket/_src/nn/recurrent.py | 28 ++++---- serket/_src/nn/reshape.py | 22 +++--- serket/_src/utils.py | 47 ++++++------- tests/test_utils.py | 34 ---------- 14 files changed, 226 insertions(+), 235 deletions(-) diff --git a/serket/_src/image/augment.py b/serket/_src/image/augment.py index 5cdcf949..a2a6b677 100644 --- a/serket/_src/image/augment.py +++ b/serket/_src/image/augment.py @@ -25,7 +25,13 @@ from serket._src.custom_transform import tree_eval from serket._src.image.color import hsv_to_rgb_3d, rgb_to_hsv_3d from serket._src.nn.linear import Identity -from serket._src.utils import CHWArray, HWArray, IsInstance, Range, validate_spatial_nd +from serket._src.utils import ( + CHWArray, + HWArray, + IsInstance, + Range, + validate_spatial_ndim, +) def pixel_shuffle_3d(array: CHWArray, upscale_factor: tuple[int, int]) -> CHWArray: @@ -252,7 +258,7 @@ def __init__(self, upscale_factor: int = 1): self.upscale_factor = (upscale_factor, upscale_factor) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, array: CHWArray) -> CHWArray: return pixel_shuffle_3d(array, self.upscale_factor) @@ -275,7 +281,7 @@ class AdjustContrast2D(sk.TreeClass): factor: float = 1.0 - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: factor = jax.lax.stop_gradient(self.factor) return jax.vmap(adjust_contrast_2d, in_axes=(0, None))(image, factor) @@ -307,7 +313,7 @@ def __init__(self, range: tuple[float, float] = (0.5, 1)): self.range = range - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: range = jax.lax.stop_gradient(self.range) in_axes = (None, 0, None) @@ -338,7 +344,7 @@ class AdjustBrightness2D(sk.TreeClass): factor: float = sk.field(on_setattr=[IsInstance(float), Range(0, 1)]) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: factor = jax.lax.stop_gradient(self.factor) return jax.vmap(adjust_brightness_2d, in_axes=(0, None))(image, factor) @@ -360,7 +366,7 @@ class RandomBrightness2D(sk.TreeClass): range: tuple[float, float] = sk.field(on_setattr=[IsInstance(tuple)]) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, array: CHWArray, *, key: jax.Array) -> CHWArray: range = jax.lax.stop_gradient(self.range) in_axes = (None, 0, None) @@ -397,7 +403,7 @@ def __init__(self, scale: int): raise ValueError(f"{scale=} must be a positive int") self.scale = scale - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: return jax.vmap(pixelate_2d, in_axes=(0, None))(image, self.scale) @@ -434,7 +440,7 @@ class Solarize2D(sk.TreeClass): threshold: float max_val: float = 1.0 - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: threshold, max_val = jax.lax.stop_gradient((self.threshold, self.max_val)) in_axes = (0, None, None) @@ -488,7 +494,7 @@ class Posterize2D(sk.TreeClass): bits: int = sk.field(on_setattr=[IsInstance(int), Range(1, 8)]) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: return jax.vmap(posterize_2d, in_axes=(0, None))(image, self.bits) @@ -541,7 +547,7 @@ class RandomJigSaw2D(sk.TreeClass): tiles: int = sk.field(on_setattr=[IsInstance(int), Range(1)]) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: """Mixes up tiles of an image. @@ -579,7 +585,7 @@ def __init__(self, gain: float = 1, inv: bool = False): self.gain = gain self.inv = inv - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, array: CHWArray) -> CHWArray: in_axes = (0, None, None) gain = jax.lax.stop_gradient(self.gain) @@ -614,7 +620,7 @@ def __init__(self, cutoff: float = 0.5, gain: float = 10, inv: bool = False): self.gain = gain self.inv = inv - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None, None, None) cutoff, gain = jax.lax.stop_gradient((self.cutoff, self.gain)) @@ -637,7 +643,7 @@ class AdjustHue2D(sk.TreeClass): def __init__(self, factor: float): self.factor = factor - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: factor = jax.lax.stop_gradient(self.factor) return adjust_hue_3d(image, factor) @@ -660,7 +666,7 @@ class RandomHue2D(sk.TreeClass): def __init__(self, range: tuple[float, float]): self.range = range - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: range = jax.lax.stop_gradient(self.range) return random_hue_3d(key, image, range) @@ -681,7 +687,7 @@ class AdjustSaturation2D(sk.TreeClass): def __init__(self, factor: float): self.factor = factor - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: factor = jax.lax.stop_gradient(self.factor) return adust_saturation_3d(image, factor) @@ -704,7 +710,7 @@ class RandomSaturation2D(sk.TreeClass): def __init__(self, range: tuple[float, float]): self.range = range - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: range = jax.lax.stop_gradient(self.range) return random_saturation_3d(key, image, range) @@ -743,8 +749,8 @@ class FourierDomainAdapt2D(sk.TreeClass): def __init__(self, beta: float): self.beta = beta - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim", argnum=0) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim", argnum=1) + @ft.partial(validate_spatial_ndim, argnum=0) + @ft.partial(validate_spatial_ndim, argnum=1) def __call__(self, image: CHWArray, target: CHWArray) -> CHWArray: """Fourier Domain Adaptation diff --git a/serket/_src/image/color.py b/serket/_src/image/color.py index 2db7c40b..7cec4f5a 100644 --- a/serket/_src/image/color.py +++ b/serket/_src/image/color.py @@ -22,7 +22,7 @@ import jax.numpy as jnp import serket as sk -from serket._src.utils import CHWArray, validate_spatial_nd +from serket._src.utils import CHWArray, validate_spatial_ndim def rgb_to_grayscale(image: CHWArray, weights: jax.Array | None = None) -> CHWArray: @@ -71,7 +71,7 @@ class RGBToGrayscale2D(sk.TreeClass): def __init__(self, weights: jax.Array | None = None): self.weights = weights - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: return rgb_to_grayscale(image, self.weights) @@ -141,7 +141,7 @@ class GrayscaleToRGB2D(sk.TreeClass): (3, 5, 5) """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: return grayscale_to_rgb(image) @@ -166,7 +166,7 @@ class RGBToHSV2D(sk.TreeClass): - https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: return rgb_to_hsv_3d(image) @@ -189,7 +189,7 @@ class HSVToRGB2D(sk.TreeClass): - https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: return hsv_to_rgb_3d(image) diff --git a/serket/_src/image/filter.py b/serket/_src/image/filter.py index 82568e81..e3a62e17 100644 --- a/serket/_src/image/filter.py +++ b/serket/_src/image/filter.py @@ -33,7 +33,7 @@ generate_conv_dim_numbers, kernel_map, resolve_string_padding, - validate_spatial_nd, + validate_spatial_ndim, ) # For filters that have fft implementation, the pattern is to inherit from @@ -793,7 +793,7 @@ class AvgBlur2D(BaseAvgBlur2D): [0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None) args = (image, self.kernel_size) @@ -821,7 +821,7 @@ class FFTAvgBlur2D(BaseAvgBlur2D): [0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None) args = (image, self.kernel_size) @@ -863,7 +863,7 @@ class GaussianBlur2D(BaseGaussianBlur2D): [0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None, None) sigma = jax.lax.stop_gradient(self.sigma) @@ -893,7 +893,7 @@ class FFTGaussianBlur2D(BaseGaussianBlur2D): [0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None, None) sigma = jax.lax.stop_gradient(self.sigma) @@ -923,7 +923,7 @@ class UnsharpMask2D(BaseGaussianBlur2D): [1.4730237 1.2740686 1.2740686 1.2740686 1.4730237]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None, None) sigma = jax.lax.stop_gradient(self.sigma) @@ -953,7 +953,7 @@ class FFTUnsharpMask2D(BaseGaussianBlur2D): [1.4730237 1.2740686 1.2740686 1.2740686 1.4730237]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None, None) sigma = jax.lax.stop_gradient(self.sigma) @@ -989,7 +989,7 @@ class BoxBlur2D(BoxBlur2DBase): [0.40000004 0.53333336 0.6666667 0.53333336 0.40000004]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None) args = (image, self.kernel_size) @@ -1017,7 +1017,7 @@ class FFTBoxBlur2D(BoxBlur2DBase): [0.40000004 0.53333336 0.6666667 0.53333336 0.40000004]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None) args = (image, self.kernel_size) @@ -1055,7 +1055,7 @@ class Laplacian2D(Laplacian2DBase): The laplacian considers all the neighbors of a pixel. """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None) args = (image, self.kernel_size) @@ -1086,7 +1086,7 @@ class FFTLaplacian2D(Laplacian2DBase): The laplacian considers all the neighbors of a pixel. """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None) args = (image, self.kernel_size) @@ -1130,7 +1130,7 @@ class MotionBlur2D(MotionBlur2DBase): [ 6.472714 10.020969 10.770187 9.100007 ]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None, None, None) angle, direction = jax.lax.stop_gradient((self.angle, self.direction)) @@ -1160,7 +1160,7 @@ class FFTMotionBlur2D(MotionBlur2DBase): [ 6.472714 10.020969 10.770187 9.100007 ]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, array: CHWArray) -> CHWArray: in_axes = (0, None, None, None) angle, direction = jax.lax.stop_gradient((self.angle, self.direction)) @@ -1197,7 +1197,7 @@ class MedianBlur2D(sk.TreeClass): def __init__(self, kernel_size: int | tuple[int, int]): self.kernel_size = canonicalize(kernel_size, ndim=2, name="kernel_size") - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None) args = (image, self.kernel_size) @@ -1231,7 +1231,7 @@ class Sobel2D(Sobel2DBase): [78.24321 , 68.26419 , 72.249565, 76.23647 , 89.27486 ]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: return jax.vmap(sobel_2d)(image) @@ -1254,7 +1254,7 @@ class FFTSobel2D(Sobel2DBase): [78.24321 , 68.26419 , 72.249565, 76.23647 , 89.27486 ]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: return jax.vmap(fft_sobel_2d)(image) @@ -1300,7 +1300,7 @@ class ElasticTransform2D(ElasticTransform2DBase): [21. 21.659977 21.43855 21.138866 22.583244 ]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: in_axes = (None, 0, None, None, None) args = (image, self.kernel_size, self.sigma, self.alpha) @@ -1334,7 +1334,7 @@ class FFTElasticTransform2D(ElasticTransform2DBase): [21. 21.659977 21.43855 21.138866 22.583244 ]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: in_axes = (None, 0, None, None, None) args = (image, self.kernel_size, self.sigma, self.alpha) @@ -1375,7 +1375,7 @@ def __init__( self.sigma_space = canonicalize(sigma_space, ndim=2, name="sigma_space") self.sigma_color = sigma_color - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None, None, None) args = (self.kernel_size, self.sigma_space, self.sigma_color) @@ -1419,7 +1419,7 @@ def __init__( self.sigma_space = canonicalize(sigma_space, ndim=2, name="sigma_space") self.sigma_color = sigma_color - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, guide: CHWArray) -> CHWArray: """Apply joint bilateral blur to a channel-first image. @@ -1466,7 +1466,7 @@ class BlurPool2D(BlurPool2DBase): [11.0625 16. 12.9375]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None, None) args = (image, self.kernel_size, self.strides) @@ -1493,7 +1493,7 @@ class FFTBlurPool2D(BlurPool2DBase): [11.0625 16. 12.9375]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None, None) args = (image, self.kernel_size, self.strides) diff --git a/serket/_src/image/geometric.py b/serket/_src/image/geometric.py index 20db7acf..ccb1e9ab 100644 --- a/serket/_src/image/geometric.py +++ b/serket/_src/image/geometric.py @@ -24,7 +24,13 @@ import serket as sk from serket._src.custom_transform import tree_eval from serket._src.nn.linear import Identity -from serket._src.utils import CHWArray, HWArray, IsInstance, Range, validate_spatial_nd +from serket._src.utils import ( + CHWArray, + HWArray, + IsInstance, + Range, + validate_spatial_ndim, +) def affine_2d(array: HWArray, matrix: HWArray) -> HWArray: @@ -206,7 +212,7 @@ class Rotate2D(sk.TreeClass): def __init__(self, angle: float): self.angle = angle - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: angle = jax.lax.stop_gradient(self.angle) return jax.vmap(rotate_2d, in_axes=(0, None))(image, angle) @@ -259,7 +265,7 @@ def __init__(self, range: tuple[float, float] = (0.0, 360.0)): self.range = range - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: range = jax.lax.stop_gradient(self.range) return jax.vmap(random_rotate_2d, in_axes=(None, 0, None))(key, image, range) @@ -290,7 +296,7 @@ class HorizontalShear2D(sk.TreeClass): def __init__(self, angle: float): self.angle = angle - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: angle = jax.lax.stop_gradient(self.angle) return jax.vmap(horizontal_shear_2d, in_axes=(0, None))(image, angle) @@ -343,7 +349,7 @@ def __init__(self, range: tuple[float, float] = (0.0, 90.0)): raise ValueError(f"`{range=}` must be a tuple of 2 floats") self.range = range - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: angle = jax.lax.stop_gradient(self.range) in_axes = (None, 0, None) @@ -375,7 +381,7 @@ class VerticalShear2D(sk.TreeClass): def __init__(self, angle: float): self.angle = angle - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: jax.Array) -> jax.Array: angle = jax.lax.stop_gradient(self.angle) return jax.vmap(vertical_shear_2d, in_axes=(0, None))(image, angle) @@ -428,7 +434,7 @@ def __init__(self, range: tuple[float, float] = (0.0, 90.0)): raise ValueError(f"`{range=}` must be a tuple of 2 floats") self.range = range - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: angle = jax.lax.stop_gradient(self.range) in_axes = (None, 0, None) @@ -466,7 +472,7 @@ class RandomPerspective2D(sk.TreeClass): def __init__(self, scale: float = 1.0): self.scale = scale - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: scale = jax.lax.stop_gradient(self.scale) args = (key, image, scale) @@ -498,7 +504,7 @@ class HorizontalTranslate2D(sk.TreeClass): shift: int = sk.field(on_setattr=[IsInstance(int)]) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: return jax.vmap(horizontal_translate_2d, in_axes=(0, None))(image, self.shift) @@ -528,7 +534,7 @@ class VerticalTranslate2D(sk.TreeClass): shift: int = sk.field(on_setattr=[IsInstance(int)]) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: return jax.vmap(vertical_translate_2d, in_axes=(0, None))(image, self.shift) @@ -566,7 +572,7 @@ class RandomHorizontalTranslate2D(sk.TreeClass): [24 25 0 0 0]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: return jax.vmap(random_horizontal_translate_2d, in_axes=(None, 0))(key, image) @@ -605,7 +611,7 @@ class RandomVerticalTranslate2D(sk.TreeClass): [ 0 0 0 0 0]]] """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: return jax.vmap(random_vertical_translate_2d, in_axes=(None, 0))(key, image) @@ -635,7 +641,7 @@ class HorizontalFlip2D(sk.TreeClass): - https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: return jax.vmap(lambda x: jnp.flip(x, axis=1))(image) @@ -670,7 +676,7 @@ class RandomHorizontalFlip2D(sk.TreeClass): rate: float = sk.field(on_setattr=[IsInstance(float), Range(0.0, 1.0)]) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: rate = jax.lax.stop_gradient(self.rate) prop = jax.random.bernoulli(key, rate) @@ -702,7 +708,7 @@ class VerticalFlip2D(sk.TreeClass): - https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py """ - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: return jax.vmap(lambda x: jnp.flip(x, axis=0))(image) @@ -737,7 +743,7 @@ class RandomVerticalFlip2D(sk.TreeClass): rate: float = sk.field(on_setattr=[IsInstance(float), Range(0.0, 1.0)]) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: rate = jax.lax.stop_gradient(self.rate) prop = jax.random.bernoulli(key, rate) @@ -760,7 +766,7 @@ def __init__(self, length: int, amplitude: float): self.length = length self.amplitude = amplitude - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: in_axes = (0, None, None) length, amplitude = jax.lax.stop_gradient((self.length, self.amplitude)) @@ -791,7 +797,7 @@ def __init__( self.length_range = length_range self.amplitude_range = amplitude_range - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: in_axes = (None, 0, None, None) L, A = jax.lax.stop_gradient((self.length_range, self.amplitude_range)) diff --git a/serket/_src/nn/attention.py b/serket/_src/nn/attention.py index 309c08cf..f58ea70d 100644 --- a/serket/_src/nn/attention.py +++ b/serket/_src/nn/attention.py @@ -20,19 +20,21 @@ import jax.numpy as jnp import jax.random as jr from typing_extensions import Annotated - +from typing import Callable import serket as sk -from serket._src.nn.initialization import InitType +from serket._src.nn.initialization import DType, InitType from serket._src.utils import maybe_lazy_call, maybe_lazy_init """Defines attention layers.""" def split_heads(input: jax.Array, num_heads: int) -> jax.Array: + """Splits the last dimension of the input into multiple heads.""" return input.reshape(*input.shape[:-1], num_heads, -1) def merge_heads(input: jax.Array) -> jax.Array: + """Merges the last two dimensions of the input.""" return input.reshape(*input.shape[:-2], -1) @@ -45,34 +47,34 @@ def is_lazy_init(_, num_heads, q_features, *__, **___) -> bool: attention_updates = dict( - q_features=lambda _, q_array, *__, **___: q_array.shape[-1], - k_features=lambda _, __, k_array, *___, **____: k_array.shape[-1], - v_features=lambda _, __, ___, v_array, *____, **_____: v_array.shape[-1], + q_features=lambda _1, q_input, *_2, **_3: q_input.shape[-1], + k_features=lambda _1, _2, k_input, *_3, **_4: k_input.shape[-1], + v_features=lambda _1, _2, _3, v_input, *_4, **__5: v_input.shape[-1], ) -def calculate_attention( +def dot_product_attention( q_heads: jax.Array, k_heads: jax.Array, v_heads: jax.Array, - mask: jax.Array, num_heads: int, - drop_layer: sk.nn.Dropout, - key: jax.Array, + mask: jax.Array | None, + drop_func: Callable[[jax.Array], jax.Array], ) -> jax.Array: """Applies multi-head attention to the given inputs. Args: - q_array: Query input. [..., q_length, q_features] - k_array: Key input. [..., k_length, k_features] - mask: Mask input. [..., num_heads, q_length, kv_length] + q_input: Query input. [..., q_length, q_features] + k_input: Key input. [..., k_length, k_features] + v_input: Value input. [..., v_length, v_features] + mask: Mask input. [..., num_heads, q_length, kv_length]. Use ``None`` + for no masking. num_heads: Number of attention heads. - drop_layer: Dropout layer. - key: Key for the random number generator. + drop_func: Dropout function. Takes a single input and returns a single output. + Use ``lambda input: input`` for no dropout. Reference: - - https://github.com/keras-team/keras/blob/v2.13.1/keras/layers/attention/multi_head_attention.py - - https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/attention.py + - https://keras.io/api/layers/attention_layers/multi_head_attention/ - https://flax.readthedocs.io/en/latest/_modules/flax/linen/attention.html """ k_depth = k_heads.shape[-1] @@ -86,13 +88,12 @@ def calculate_attention( logits = jnp.einsum("...qhd,...khd->...hqk", q_heads, k_heads) logits /= jnp.sqrt(k_depth // num_heads) - # handle mask min_num = jnp.finfo(logits.dtype).min - logits = jnp.where(mask, logits, min_num) if mask is not None else logits + logits = logits if mask is None else jnp.where(mask, logits, min_num) attention_weight = jax.nn.softmax(logits) attention = jnp.einsum("...hqk,...khd->...qhd", attention_weight, v_heads) - return merge_heads(drop_layer(attention, key=key)) + return merge_heads(drop_func(attention)) class MultiHeadAttention(sk.TreeClass): @@ -115,15 +116,19 @@ class MultiHeadAttention(sk.TreeClass): q_weight_init: Initializer for the query weight. Defaults to ``glorot_uniform``. q_bias_init: Initializer for the query bias. Defaults to zeros. use ``None`` to disable bias. + q_dtype: Data type for the query. Defaults to ``jnp.float32``. k_weight_init: Initializer for the key weight. Defaults to ``glorot_uniform``. k_bias_init: Initializer for the key bias. Defaults to zeros. use ``None`` to disable bias. + k_dtype: Data type for the key. Defaults to ``jnp.float32``. v_weight_init: Initializer for the value weight. Defaults to ``glorot_uniform``. v_bias_init: Initializer for the value bias. Defaults to zeros. use ``None`` to disable bias. + v_dtype: Data type for the value. Defaults to ``jnp.float32``. out_weight_init: Initializer for the output weight. Defaults to ``glorot_uniform``. out_bias_init: Initializer for the output bias. Defaults to zeros. use ``None`` to disable bias. + out_dtype: Data type for the output. Defaults to ``jnp.float32``. drop_rate: Dropout rate. defaults to 0.0. drop_broadcast: Whether to broadcast the dropout mask across the batch dimension and the heads dimension. Defaults to False. @@ -139,18 +144,18 @@ class MultiHeadAttention(sk.TreeClass): >>> v_features = 6 >>> q_length = 4 >>> kv_length = 2 - >>> mask = jr.uniform(jr.PRNGKey(2), (batch, num_heads, q_length, kv_length)) + >>> mask = jr.uniform(jr.PRNGKey(0), (batch, num_heads, q_length, kv_length)) >>> mask = (mask > 0.5).astype(jnp.float32) - >>> q = jr.uniform(jr.PRNGKey(0), (batch, q_length, q_features)) - >>> k = jr.uniform(jr.PRNGKey(1), (batch, kv_length, k_features)) - >>> v = jr.uniform(jr.PRNGKey(2), (batch, kv_length, v_features)) + >>> q = jr.uniform(jr.PRNGKey(1), (batch, q_length, q_features)) + >>> k = jr.uniform(jr.PRNGKey(2), (batch, kv_length, k_features)) + >>> v = jr.uniform(jr.PRNGKey(3), (batch, kv_length, v_features)) >>> layer = sk.nn.MultiHeadAttention( ... num_heads, ... q_features, ... k_features, ... v_features, ... drop_rate=0.0, - ... key=jr.PRNGKey(0), + ... key=jr.PRNGKey(4), ... ) >>> print(layer(q, k, v, mask=mask, key=jr.PRNGKey(0)).shape) (3, 4, 4) @@ -182,14 +187,14 @@ class MultiHeadAttention(sk.TreeClass): >>> q = jr.uniform(jr.PRNGKey(0), (3, 2, 6)) >>> k = jr.uniform(jr.PRNGKey(1), (3, 2, 6)) >>> v = jr.uniform(jr.PRNGKey(2), (3, 2, 6)) - >>> lazy_layer = sk.nn.MultiHeadAttention(2, None, key=jr.PRNGKey(0)) - >>> _, material_layer = lazy_layer.at["__call__"](q, k, v, key=jr.PRNGKey(0)) - >>> material_layer(q, k, v, key=jr.PRNGKey(0)).shape + >>> key = jr.PRNGKey(0) + >>> lazy_layer = sk.nn.MultiHeadAttention(2, None, key=key) + >>> _, material_layer = lazy_layer.at["__call__"](q, k, v, key=key) + >>> material_layer(q, k, v, key=key).shape (3, 2, 6) Reference: - - https://github.com/keras-team/keras/blob/v2.13.1/keras/layers/attention/multi_head_attention.py - - https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/attention.py + - https://keras.io/api/layers/attention_layers/multi_head_attention/ - https://flax.readthedocs.io/en/latest/_modules/flax/linen/attention.html - https://arxiv.org/abs/1706.03762 """ @@ -206,12 +211,16 @@ def __init__( key: jax.Array, q_weight_init: InitType = "glorot_uniform", q_bias_init: InitType = "zeros", + q_dtype: DType = jnp.float32, k_weight_init: InitType = "glorot_uniform", k_bias_init: InitType = "zeros", + k_dtype: DType = jnp.float32, v_weight_init: InitType = "glorot_uniform", v_bias_init: InitType = "zeros", + v_dtype: DType = jnp.float32, out_weight_init: InitType = "glorot_uniform", out_bias_init: InitType = "zeros", + out_dtype: DType = jnp.float32, drop_rate: float = 0.0, drop_broadcast: bool = False, ): @@ -235,14 +244,16 @@ def __init__( qkey, kkey, vkey, okey = jr.split(key, 4) self.num_heads = num_heads - drop_axes = (-1, -2) if drop_broadcast else None - self.dropout = sk.nn.Dropout(drop_rate, drop_axes) + # while dropout == 0.0 is a no-op, still instantiate a dropout layer + # because .at[drop_rate] can be used to change the dropout rate later on. + self.dropout = sk.nn.Dropout(drop_rate, (-1, -2) if drop_broadcast else None) self.q_projection = sk.nn.Linear( in_features=q_features, out_features=head_features * num_heads, weight_init=q_weight_init, bias_init=q_bias_init, + dtype=q_dtype, key=qkey, ) @@ -251,6 +262,7 @@ def __init__( out_features=head_features * num_heads, weight_init=k_weight_init, bias_init=k_bias_init, + dtype=k_dtype, key=kkey, ) @@ -259,6 +271,7 @@ def __init__( out_features=head_features * num_heads, weight_init=v_weight_init, bias_init=v_bias_init, + dtype=v_dtype, key=vkey, ) @@ -267,44 +280,54 @@ def __init__( out_features=out_features, weight_init=out_weight_init, bias_init=out_bias_init, + dtype=out_dtype, key=okey, ) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=attention_updates) def __call__( self, - q_array: Annotated[jax.Array, "..., q_length, q_features"], - k_array: Annotated[jax.Array, "..., kv_length, k_features"], - v_array: Annotated[jax.Array, "..., kv_length, v_features"], + q_input: Annotated[jax.Array, "..., q_length, q_features"], + k_input: Annotated[jax.Array, "..., kv_length, k_features"], + v_input: Annotated[jax.Array, "..., kv_length, v_features"], mask: Annotated[jax.Array, "..., num_heads, q_length, kv_length"] | None = None, *, - key: jax.Array, + key: jax.Array | None = None, ) -> Annotated[jax.Array, "..., q_length, out_features"]: """Applies multi-head attention to the given inputs. Args: - q_array: Query input. [..., q_length, q_features] - k_array: Key input. [..., kv_length, k_features] - v_array: Value input. [..., kv_length, v_features] - mask: Mask input. [..., num_heads, q_length, kv_length] + q_input: Query input. [..., q_length, q_features] + k_input: Key input. [..., kv_length, k_features] + v_input: Value input. [..., kv_length, v_features] + mask: Mask input. [..., num_heads, q_length, kv_length] Defaults to ``None``. + for no masking. key: Key for the random number generator used for dropout. + Defaults to ``None`` for no dropout. """ # [..., q_length, q_features] -> [..., q_length, head_features*num_heads] - q_heads = self.q_projection(q_array) + q_heads = self.q_projection(q_input) # [..., k_length, k_features] -> [..., k_length, head_features*num_heads] - k_heads = self.k_projection(k_array) + k_heads = self.k_projection(k_input) # [..., v_length, v_features] -> [..., v_length, head_features*num_heads] - v_heads = self.v_projection(v_array) + v_heads = self.v_projection(v_input) - attention = calculate_attention( + attention = type(self).attention_op( q_heads=q_heads, k_heads=k_heads, v_heads=v_heads, - mask=mask, num_heads=self.num_heads, - drop_layer=self.dropout, - key=key, + mask=mask, + # note that if `tree_eval` is used, self.dropout is converted to an + # identity function, so the `key` argument is ignored. + # one pro of this approach is that `Identity` will be displayed in + # the repr of the layer to make it clear that dropout is disabled. + # another pro is that no need to thread the ``training`` flag through + # the layer. + drop_func=lambda input: self.dropout(input, key=key), ) return self.out_projection(attention) + + attention_op = dot_product_attention diff --git a/serket/_src/nn/convolution.py b/serket/_src/nn/convolution.py index 0e60976f..e9ef60cb 100644 --- a/serket/_src/nn/convolution.py +++ b/serket/_src/nn/convolution.py @@ -42,8 +42,8 @@ maybe_lazy_call, maybe_lazy_init, positive_int_cb, - validate_axis_shape, - validate_spatial_nd, + validate_in_features_shape, + validate_spatial_ndim, ) Weight = Annotated[jax.Array, "OI..."] @@ -589,8 +589,8 @@ def __init__( self.bias = resolve_init(self.bias_init)(key, bias_shape, dtype) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + @ft.partial(validate_spatial_ndim, argnum=0) + @ft.partial(validate_in_features_shape, axis=0) def __call__(self, input: jax.Array, mask: Weight | None = None) -> jax.Array: """Apply the layer. @@ -1242,8 +1242,8 @@ def __init__( self.bias = resolve_init(self.bias_init)(key, bias_shape, dtype) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + @ft.partial(validate_spatial_ndim, argnum=0) + @ft.partial(validate_in_features_shape, axis=0) def __call__(self, input: jax.Array, mask: Weight | None = None) -> jax.Array: """Apply the layer. @@ -1909,8 +1909,8 @@ def __init__( self.bias = resolve_init(self.bias_init)(key, bias_shape, dtype) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + @ft.partial(validate_spatial_ndim, argnum=0) + @ft.partial(validate_in_features_shape, axis=0) def __call__(self, input: jax.Array, mask: Weight | None = None) -> jax.Array: """Apply the layer. @@ -2481,8 +2481,8 @@ def __init__( self.pointwise_bias = resolve_init(self.pointwise_bias_init)(*args) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + @ft.partial(validate_spatial_ndim, argnum=0) + @ft.partial(validate_in_features_shape, axis=0) def __call__( self, input: jax.Array, @@ -3093,8 +3093,8 @@ def __init__( self.weight_i = scale * jr.normal(k2, weight_shape).astype(dtype) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + @ft.partial(validate_spatial_ndim, argnum=0) + @ft.partial(validate_in_features_shape, axis=0) def __call__(self, input: jax.Array) -> jax.Array: return type(self).conv_op( input=input, @@ -3380,8 +3380,8 @@ def __init__( self.bias = resolve_init(self.bias_init)(key, bias_shape, dtype) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + @ft.partial(validate_spatial_ndim, argnum=0) + @ft.partial(validate_in_features_shape, axis=0) def __call__(self, input: jax.Array, mask: Weight | None = None) -> jax.Array: """Apply the layer. diff --git a/serket/_src/nn/dropout.py b/serket/_src/nn/dropout.py index 27b1379c..2204f40b 100644 --- a/serket/_src/nn/dropout.py +++ b/serket/_src/nn/dropout.py @@ -30,7 +30,7 @@ canonicalize, kernel_map, positive_int_cb, - validate_spatial_nd, + validate_spatial_ndim, ) @@ -167,8 +167,8 @@ class DropoutND(sk.TreeClass): on_getattr=[jax.lax.stop_gradient_p.bind], ) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - def __call__(self, input: jax.Array, *, key): + @ft.partial(validate_spatial_ndim, argnum=0) + def __call__(self, input: jax.Array, *, key: jax.Array) -> jax.Array: """Drop some elements of the input array. Args: @@ -302,7 +302,7 @@ def __init__( self.cutout_count = positive_int_cb(cutout_count) self.fill_value = fill_value - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array, *, key: jax.Array) -> jax.Array: """Drop some elements of the input array. @@ -311,11 +311,9 @@ def __call__(self, input: jax.Array, *, key: jax.Array) -> jax.Array: key: random number generator key """ fill_value = jax.lax.stop_gradient(self.fill_value) - - def cutout(x): - return random_cutout_nd(key, x, self.shape, self.cutout_count, fill_value) - - return jax.vmap(cutout)(input) + in_axes = (None, 0, None, None, None) + args = (key, input, self.shape, self.cutout_count, fill_value) + return jax.vmap(random_cutout_nd, in_axes=in_axes)(*args) @property @abc.abstractmethod diff --git a/serket/_src/nn/linear.py b/serket/_src/nn/linear.py index 94886673..ca257d4c 100644 --- a/serket/_src/nn/linear.py +++ b/serket/_src/nn/linear.py @@ -55,18 +55,17 @@ def general_linear( in_axis: tuple[int, ...], out_axis: int, ) -> jax.Array: + in_axis_shape = tuple(input.shape[i] for i in in_axis) + features_shape = weight.shape[1:] + assert in_axis_shape == features_shape, f"{in_axis_shape=} != {features_shape=}" + in_axis = sorted([axis if axis >= 0 else axis + input.ndim for axis in in_axis]) lhs = "".join(str(axis) for axis in range(input.ndim)) # 0, 1, 2, 3 rhs = "F" + "".join(str(axis) for axis in in_axis) # F, 1, 2, 3 out = "".join(str(axis) for axis in range(input.ndim) if axis not in in_axis) out_axis = out_axis if out_axis >= 0 else out_axis + len(out) + 1 out = out[:out_axis] + "F" + out[out_axis:] - - try: - einsum = f"{lhs},{rhs}->{out}" - result = jnp.einsum(einsum, input, weight) - except ValueError as error: - raise ValueError(f"{einsum=}\n{input.shape=}\n{weight.shape=}\n{error=}") + result = jnp.einsum(f"{lhs},{rhs}->{out}", input, weight) if bias is None: return result @@ -183,6 +182,7 @@ def __init__( @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) def __call__(self, input: jax.Array) -> jax.Array: + """Apply a linear transformation to the input.""" return general_linear( input=input, weight=self.weight, @@ -248,6 +248,7 @@ def scan_linear( act: ActivationFunctionType, ) -> jax.Array: # reduce the ``jaxpr`` size by using ``scan`` + # for the intermediate layers in MLP. can lower the compilation time if bias is None: def scan_func(input: jax.Array, weight: Batched[jax.Array]): diff --git a/serket/_src/nn/normalization.py b/serket/_src/nn/normalization.py index 8d4b7b92..647557dc 100644 --- a/serket/_src/nn/normalization.py +++ b/serket/_src/nn/normalization.py @@ -30,7 +30,7 @@ maybe_lazy_call, maybe_lazy_init, positive_int_cb, - validate_axis_shape, + validate_in_features_shape, ) @@ -265,7 +265,7 @@ def __init__( self.bias = resolve_init(bias_init)(key, (in_features,), dtype) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + @ft.partial(validate_in_features_shape, axis=0) def __call__(self, input: jax.Array) -> jax.Array: return group_norm( input=input, diff --git a/serket/_src/nn/pooling.py b/serket/_src/nn/pooling.py index 199e6c01..36648931 100644 --- a/serket/_src/nn/pooling.py +++ b/serket/_src/nn/pooling.py @@ -30,7 +30,7 @@ canonicalize, delayed_canonicalize_padding, kernel_map, - validate_spatial_nd, + validate_spatial_ndim, ) @@ -152,7 +152,7 @@ def __init__( self.strides = canonicalize(strides, self.spatial_ndim, name="strides") self.padding = padding - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array) -> jax.Array: padding = delayed_canonicalize_padding( in_dim=input.shape, @@ -240,7 +240,7 @@ def __init__( self.strides = canonicalize(strides, self.spatial_ndim, name="strides") self.padding = padding - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array) -> jax.Array: padding = delayed_canonicalize_padding( in_dim=input.shape, @@ -312,7 +312,7 @@ def __init__( self.strides = canonicalize(strides, self.spatial_ndim, name="strides") self.padding = padding - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array) -> jax.Array: padding = delayed_canonicalize_padding( in_dim=input.shape, @@ -374,7 +374,7 @@ class GlobalAvgPoolND(sk.TreeClass): def __init__(self, keepdims: bool = True): self.keepdims = keepdims - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array) -> jax.Array: axes = tuple(range(1, self.spatial_ndim + 1)) # reduce spatial dimensions return jnp.mean(input, axis=axes, keepdims=self.keepdims) @@ -419,7 +419,7 @@ class GlobalMaxPoolND(sk.TreeClass): def __init__(self, keepdims: bool = True): self.keepdims = keepdims - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array) -> jax.Array: axes = tuple(range(1, self.spatial_ndim + 1)) # reduce spatial dimensions return jnp.max(input, axis=axes, keepdims=self.keepdims) @@ -468,7 +468,7 @@ def __init__(self, output_size: tuple[int, ...]): name="output_size", ) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array) -> jax.Array: return adaptive_avg_pool_nd(input, self.output_size) @@ -516,7 +516,7 @@ def __init__(self, output_size: tuple[int, ...]): name="output_size", ) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array) -> jax.Array: return adaptive_max_pool_nd(input, self.output_size) diff --git a/serket/_src/nn/recurrent.py b/serket/_src/nn/recurrent.py index c6ae4e17..6985ec6d 100644 --- a/serket/_src/nn/recurrent.py +++ b/serket/_src/nn/recurrent.py @@ -43,8 +43,8 @@ maybe_lazy_call, maybe_lazy_init, positive_int_cb, - validate_axis_shape, - validate_spatial_nd, + validate_in_features_shape, + validate_spatial_ndim, ) P = ParamSpec("P") @@ -178,8 +178,8 @@ def __init__( ) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + @ft.partial(validate_spatial_ndim, argnum=0) + @ft.partial(validate_in_features_shape, axis=0) def __call__( self, input: jax.Array, @@ -274,8 +274,8 @@ def __init__( ) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + @ft.partial(validate_spatial_ndim, argnum=0) + @ft.partial(validate_in_features_shape, axis=0) def __call__( self, input: jax.Array, @@ -397,8 +397,8 @@ def __init__( ) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + @ft.partial(validate_spatial_ndim, argnum=0) + @ft.partial(validate_in_features_shape, axis=0) def __call__( self, input: jax.Array, @@ -520,8 +520,8 @@ def __init__( ) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + @ft.partial(validate_spatial_ndim, argnum=0) + @ft.partial(validate_in_features_shape, axis=0) def __call__( self, input: jax.Array, @@ -600,8 +600,8 @@ def __init__( ) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + @ft.partial(validate_spatial_ndim, argnum=0) + @ft.partial(validate_in_features_shape, axis=0) def __call__( self, input: jax.Array, @@ -1031,8 +1031,8 @@ def __init__( ) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + @ft.partial(validate_spatial_ndim, argnum=0) + @ft.partial(validate_in_features_shape, axis=0) def __call__( self, input: jax.Array, diff --git a/serket/_src/nn/reshape.py b/serket/_src/nn/reshape.py index 49b9096a..c2906d1c 100644 --- a/serket/_src/nn/reshape.py +++ b/serket/_src/nn/reshape.py @@ -29,7 +29,7 @@ IsInstance, canonicalize, delayed_canonicalize_padding, - validate_spatial_nd, + validate_spatial_ndim, ) MethodKind = Literal["nearest", "linear", "cubic", "lanczos3", "lanczos5"] @@ -150,7 +150,7 @@ def __init__( self.method = method self.antialias = antialias - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array, **k) -> jax.Array: in_axes = (0, None, None, None) args = (input, self.size, self.method, self.antialias) @@ -175,7 +175,7 @@ def __init__( self.scale = canonicalize(scale, self.spatial_ndim, name="scale") self.method = method - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array) -> jax.Array: resized_shape = tuple(s * input.shape[i + 1] for i, s in enumerate(self.scale)) in_axes = (0, None, None) @@ -292,7 +292,7 @@ def __init__(self, size: int | tuple[int, ...], start: int | tuple[int, ...]): self.size = canonicalize(size, self.spatial_ndim, name="size") self.start = canonicalize(start, self.spatial_ndim, name="start") - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array, **k) -> jax.Array: in_axes = (0, None, None) args = (input, self.start, self.size) @@ -374,7 +374,7 @@ def __init__(self, padding: int | tuple[int, int], value: float = 0.0): self.padding = delayed_canonicalize_padding(None, padding, kernel_size, None) self.value = value - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array) -> jax.Array: value = jax.lax.stop_gradient(self.value) pad = ft.partial(jnp.pad, pad_width=self.padding, constant_values=value) @@ -610,7 +610,7 @@ class RandomCropND(sk.TreeClass): def __init__(self, size: int | tuple[int, ...]): self.size = canonicalize(size, self.spatial_ndim, name="size") - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array, *, key: jax.Array) -> jax.Array: crop_size = [input.shape[0], *self.size] return random_crop_nd(key, input, crop_size=crop_size) @@ -659,7 +659,7 @@ class ZoomND(sk.TreeClass): def __init__(self, factor: float | tuple[float, ...]): self.factor = canonicalize(factor, self.spatial_ndim, name="factor") - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array) -> jax.Array: factor = jax.lax.stop_gradient(self.factor) return jax.vmap(zoom_nd, in_axes=(0, None))(input, factor) @@ -747,7 +747,7 @@ def __init__(self, length_range: tuple[int, int] = (0.0, 1.0)): self.length_range = length_range - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array, *, key: jax.Array) -> jax.Array: k1, k2 = jr.split(key, 2) low, high = jax.lax.stop_gradient(self.length_range) @@ -787,7 +787,7 @@ def __init__( self.height_range = height_range self.width_range = width_range - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array, *, key: jax.Array) -> jax.Array: k1, k2, k3 = jr.split(key, 3) factors = (self.height_range, self.width_range) @@ -834,7 +834,7 @@ def __init__( self.width_range = width_range self.depth_range = depth_range - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array, *, key: jax.Array) -> jax.Array: k1, k2, k3, k4 = jr.split(key, 4) factors = (self.height_range, self.width_range, self.depth_range) @@ -852,7 +852,7 @@ class CenterCropND(sk.TreeClass): def __init__(self, size: int | tuple[int, ...]): self.size = canonicalize(size, self.spatial_ndim, name="size") - @ft.partial(validate_spatial_nd, attribute_name="spatial_ndim") + @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, input: jax.Array) -> jax.Array: return jax.vmap(ft.partial(center_crop_nd, sizes=self.size))(input) diff --git a/serket/_src/utils.py b/serket/_src/utils.py index e0ec6fc4..1b48f097 100644 --- a/serket/_src/utils.py +++ b/serket/_src/utils.py @@ -262,32 +262,27 @@ def recursive_getattr(obj, attr: Sequence[str]): ) -def validate_spatial_nd( - func: Callable[P, T], - attribute_name: str, - argnum: int = 0, -) -> Callable[P, T]: +def validate_spatial_ndim(func: Callable[P, T], argnum: int = 0) -> Callable[P, T]: """Decorator to validate spatial input shape.""" - attribute_list: Sequence[str] = attribute_name.split(".") @ft.wraps(func) def wrapper(self, *args, **kwargs): - array = args[argnum] - spatial_ndim = recursive_getattr(self, attribute_list) + input = args[argnum] + spatial_ndim = self.spatial_ndim - if array.ndim != spatial_ndim + 1: + if input.ndim != spatial_ndim + 1: spatial = ", ".join(("rows", "cols", "depths")[:spatial_ndim]) name = type(self).__name__ raise ValueError( f"Dimesion mismatch error in inputs of {name}\n" f"Input should satisfy:\n" - f" - {(spatial_ndim + 1) = } dimension, but got {array.ndim = }.\n" - f" - shape of (in_features, {spatial}), but got {array.shape = }.\n" + f" - {(spatial_ndim + 1) = } dimension, but got {input.ndim = }.\n" + f" - shape of (in_features, {spatial}), but got {input.shape = }.\n" + ( # maybe the user apply the layer on a batched input "The input should be unbatched (no batch dimension).\n" "To apply on batched input, use `jax.vmap(...)(input)`." - if array.ndim == spatial_ndim + 2 + if input.ndim == spatial_ndim + 2 else "" ) ) @@ -296,28 +291,24 @@ def wrapper(self, *args, **kwargs): return wrapper -def validate_axis_shape( - func: Callable[P, T], - *, - attribute_name: str, - axis: int = 0, -) -> Callable[P, T]: +def validate_in_features_shape(func: Callable[P, T], axis: int) -> Callable[P, T]: """Decorator to validate input features.""" - attribute_list = attribute_name.split(".") - def check_axis_shape(x, in_features: int, axis: int) -> None: - if x.shape[axis] != in_features: - raise ValueError(f"Specified {in_features=}, got {x.shape[axis]=}.") - return x + def check_axis_shape(input, in_features: int, axis: int) -> None: + if input.shape[axis] != in_features: + raise ValueError(f"Specified {in_features=}, got {input.shape[axis]=}.") + return input @ft.wraps(func) def wrapper(self, array, *a, **k): - check_axis_shape(array, recursive_getattr(self, attribute_list), axis) + check_axis_shape(array, self.in_features, axis) return func(self, array, *a, **k) return wrapper + + @ft.lru_cache(maxsize=128) def get_params(func: MethodType) -> tuple[inspect.Parameter, ...]: """Get the arguments of func.""" @@ -455,19 +446,19 @@ def inner(instance, *a, **k): LAZY_CALL_ERROR = """\ Cannot call ``{func_name}`` directly on a lazy layer. -use ``layer.at['{func_name}'](...)`` instead to return a tuple of: +use ``layer.at["{func_name}"](...)`` instead to return a tuple of: - Layer output. - Materialized layer. Example: >>> layer = {class_name}(...) - >>> layer(x) # this will raise an error + >>> layer(input) # this will raise an error ... Instead use the following pattern: - >>> layer_output, material_layer = layer.at['{func_name}'](x) - >>> material_layer(x) + >>> layer_output, material_layer = layer.at["{func_name}"](input) + >>> material_layer(input) ... """ diff --git a/tests/test_utils.py b/tests/test_utils.py index 64c615c2..2ec2e740 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools as ft -from typing import Any import jax import jax.random as jr @@ -31,8 +29,6 @@ positive_int_cb, resolve_string_padding, resolve_tuple_padding, - validate_axis_shape, - validate_spatial_nd, ) @@ -138,36 +134,6 @@ def test_positive_int_cb_error(): positive_int_cb(1.0) -def test_validate_spatial_nd_error(): - with pytest.raises(ValueError): - - class T: - @ft.partial(validate_spatial_nd, attribute_name="ndim") - def __call__(self, x) -> Any: - return x - - @property - def ndim(self): - return 1 - - T()(jax.numpy.ones([5])) - - -def test_validate_axis_shape_error(): - with pytest.raises(ValueError): - - class T: - @ft.partial(validate_axis_shape, attribute_name="in_dim") - def __call__(self, x) -> Any: - return x - - @property - def in_dim(self): - return 1 - - T()(jax.numpy.ones([5, 5])) - - def test_lazy_call(): layer = sk.nn.Linear(None, 1, key=jax.random.PRNGKey(0))