diff --git a/serket/__init__.py b/serket/__init__.py index dbbf05a6..91e86ff5 100644 --- a/serket/__init__.py +++ b/serket/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/serket/_src/__init__.py b/serket/_src/__init__.py index dbf0b046..650d4be4 100644 --- a/serket/_src/__init__.py +++ b/serket/_src/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/serket/_src/cluster/__init__.py b/serket/_src/cluster/__init__.py index dbf0b046..650d4be4 100644 --- a/serket/_src/cluster/__init__.py +++ b/serket/_src/cluster/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/serket/_src/cluster/kmeans.py b/serket/_src/cluster/kmeans.py index 880b7c36..06c9970b 100644 --- a/serket/_src/cluster/kmeans.py +++ b/serket/_src/cluster/kmeans.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ import serket as sk from serket._src.custom_transform import tree_eval, tree_state -from serket._src.utils import IsInstance, Range +from serket._src.utils.validate import IsInstance, Range """K-means utility functions.""" diff --git a/serket/_src/containers.py b/serket/_src/containers.py index 44bb3aee..2a885c68 100644 --- a/serket/_src/containers.py +++ b/serket/_src/containers.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,9 +20,9 @@ import jax import jax.random as jr -import serket as sk +from serket import TreeClass, tree_summary from serket._src.custom_transform import tree_eval -from serket._src.utils import single_dispatch +from serket._src.utils.dispatch import single_dispatch @single_dispatch(argnum=0) @@ -53,7 +53,7 @@ def _(key: jax.Array, layers: Sequence[Callable[..., Any]], array: Any): return array -class Sequential(sk.TreeClass): +class Sequential(TreeClass): """A sequential container for layers. Args: @@ -103,7 +103,7 @@ def __reversed__(self): return reversed(self.layers) -@sk.tree_summary.def_type(Sequential) +@tree_summary.def_type(Sequential) def _(node): types = [type(x).__name__ for x in node] return f"{type(node).__name__}[{','.join(types)}]" @@ -121,7 +121,7 @@ def random_choice(key: jax.Array, layers: tuple[Callable[..., Any], ...], array: return jax.lax.switch(index, layers, array) -class RandomChoice(sk.TreeClass): +class RandomChoice(TreeClass): """Randomly selects one of the given layers/functions. Args: diff --git a/serket/_src/custom_transform.py b/serket/_src/custom_transform.py index c22f2131..61983fef 100644 --- a/serket/_src/custom_transform.py +++ b/serket/_src/custom_transform.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ import jax import serket as sk -from serket._src.utils import single_dispatch +from serket._src.utils.dispatch import single_dispatch T = TypeVar("T") diff --git a/serket/_src/image/__init__.py b/serket/_src/image/__init__.py index dbf0b046..650d4be4 100644 --- a/serket/_src/image/__init__.py +++ b/serket/_src/image/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/serket/_src/image/augment.py b/serket/_src/image/augment.py index 62c833f2..cc13a889 100644 --- a/serket/_src/image/augment.py +++ b/serket/_src/image/augment.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,17 +21,12 @@ import jax.numpy as jnp import jax.random as jr -import serket as sk +from serket import TreeClass, autoinit, field from serket._src.custom_transform import tree_eval from serket._src.image.color import hsv_to_rgb, rgb_to_hsv from serket._src.nn.linear import Identity -from serket._src.utils import ( - CHWArray, - HWArray, - IsInstance, - Range, - validate_spatial_ndim, -) +from serket._src.utils.typing import CHWArray, HWArray +from serket._src.utils.validate import IsInstance, Range, validate_spatial_ndim def pixel_shuffle_3d(array: CHWArray, upscale_factor: tuple[int, int]) -> CHWArray: @@ -73,7 +68,7 @@ def adjust_contrast_2d(image: HWArray, factor: float): """Adjusts the contrast of an image by scaling the pixel values by a factor. Args: - array: input array \in [0, 1] with shape (height, width) + array: input array in [0, 1] with shape (height, width) factor: contrast factor to adust the contrast by. """ _, _ = image.shape @@ -97,7 +92,7 @@ def adjust_brightness_2d(image: HWArray, factor: float) -> HWArray: """Adjusts the brightness of an image by adding a value to the pixel values. Args: - array: input array \in [0, 1] with shape (height, width) + array: input array in [0, 1] with shape (height, width) factor: brightness factor to adust the brightness by. """ _, _ = image.shape @@ -240,7 +235,7 @@ def fourier_domain_adapt_2d(image: HWArray, styler: HWArray, beta: float): return image_out.astype(dtype) -class PixelShuffle2D(sk.TreeClass): +class PixelShuffle2D(TreeClass): """Rearrange elements in a tensor. .. image:: ../_static/pixelshuffle2d.png @@ -265,8 +260,8 @@ def __call__(self, array: CHWArray) -> CHWArray: spatial_ndim: int = 2 -@sk.autoinit -class AdjustContrast2D(sk.TreeClass): +@autoinit +class AdjustContrast2D(TreeClass): """Adjusts the contrast of an 2D input by scaling the pixel values by a factor. .. image:: ../_static/adjustcontrast2d.png @@ -289,7 +284,7 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -class RandomContrast2D(sk.TreeClass): +class RandomContrast2D(TreeClass): """Randomly adjusts the contrast of an 1D input by scaling the pixel values by a factor. Args: @@ -322,8 +317,8 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: spatial_ndim: int = 2 -@sk.autoinit -class AdjustBrightness2D(sk.TreeClass): +@autoinit +class AdjustBrightness2D(TreeClass): """Adjusts the brightness of an 2D input by adding a value to the pixel values. .. image:: ../_static/adjustbrightness2d.png @@ -342,7 +337,7 @@ class AdjustBrightness2D(sk.TreeClass): [1. 1. 1. 1. ]]] """ - factor: float = sk.field(on_setattr=[IsInstance(float), Range(0, 1)]) + factor: float = field(on_setattr=[IsInstance(float), Range(0, 1)]) @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: @@ -352,8 +347,8 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -@sk.autoinit -class RandomBrightness2D(sk.TreeClass): +@autoinit +class RandomBrightness2D(TreeClass): """Randomly adjusts the brightness of an 2D input by adding a value to the pixel values. Args: @@ -364,7 +359,7 @@ class RandomBrightness2D(sk.TreeClass): evaluation. """ - range: tuple[float, float] = sk.field(on_setattr=[IsInstance(tuple)]) + range: tuple[float, float] = field(on_setattr=[IsInstance(tuple)]) @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, array: CHWArray, *, key: jax.Array) -> CHWArray: @@ -375,7 +370,7 @@ def __call__(self, array: CHWArray, *, key: jax.Array) -> CHWArray: spatial_ndim: int = 2 -class Pixelate2D(sk.TreeClass): +class Pixelate2D(TreeClass): """Pixelate an image by upsizing and downsizing an image .. image:: ../_static/pixelate2d.png @@ -410,8 +405,8 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -@sk.autoinit -class Solarize2D(sk.TreeClass): +@autoinit +class Solarize2D(TreeClass): """Inverts all values above a given threshold. .. image:: ../_static/solarize2d.png @@ -450,8 +445,8 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -@sk.autoinit -class Posterize2D(sk.TreeClass): +@autoinit +class Posterize2D(TreeClass): """Reduce the number of bits for each color channel. .. image:: ../_static/posterize2d.png @@ -492,7 +487,7 @@ class Posterize2D(sk.TreeClass): - https://github.com/python-pillow/Pillow/blob/main/src/PIL/ImageOps.py#L547 """ - bits: int = sk.field(on_setattr=[IsInstance(int), Range(1, 8)]) + bits: int = field(on_setattr=[IsInstance(int), Range(1, 8)]) @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: @@ -501,8 +496,8 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -@sk.autoinit -class RandomJigSaw2D(sk.TreeClass): +@autoinit +class RandomJigSaw2D(TreeClass): """Mixes up tiles of an image. .. image:: ../_static/jigsaw2d.png @@ -545,7 +540,7 @@ class RandomJigSaw2D(sk.TreeClass): - https://imgaug.readthedocs.io/en/latest/source/overview/geometric.html#jigsaw """ - tiles: int = sk.field(on_setattr=[IsInstance(int), Range(1)]) + tiles: int = field(on_setattr=[IsInstance(int), Range(1)]) @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: @@ -562,7 +557,7 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: spatial_ndim: int = 2 -class AdjustLog2D(sk.TreeClass): +class AdjustLog2D(TreeClass): """Adjust log correction on the input 2D image of range [0, 1]. Args: @@ -594,7 +589,7 @@ def __call__(self, array: CHWArray) -> CHWArray: spatial_ndim: int = 2 -class AdjustSigmoid2D(sk.TreeClass): +class AdjustSigmoid2D(TreeClass): """Adjust sigmoid correction on the input 2D image of range [0, 1]. @@ -630,7 +625,7 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -class AdjustHue2D(sk.TreeClass): +class AdjustHue2D(TreeClass): """Adjust hue of an RGB image. .. image:: ../_static/adjusthue2d.png @@ -651,7 +646,7 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -class RandomHue2D(sk.TreeClass): +class RandomHue2D(TreeClass): """Randomly adjust hue of an RGB image. Args: @@ -674,7 +669,7 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: spatial_ndim: int = 2 -class AdjustSaturation2D(sk.TreeClass): +class AdjustSaturation2D(TreeClass): """Adjust saturation of an RGB image. .. image:: ../_static/adjustsaturation2d.png @@ -695,7 +690,7 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -class RandomSaturation2D(sk.TreeClass): +class RandomSaturation2D(TreeClass): """Randomly adjust saturation of an RGB image. Args: @@ -718,7 +713,7 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: spatial_ndim: int = 2 -class FourierDomainAdapt2D(sk.TreeClass): +class FourierDomainAdapt2D(TreeClass): """Domain adaptation via style transfer .. image:: ../_static/fourierdomainadapt2d.png diff --git a/serket/_src/image/color.py b/serket/_src/image/color.py index 92416822..d6c62fc3 100644 --- a/serket/_src/image/color.py +++ b/serket/_src/image/color.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,8 +21,9 @@ import jax import jax.numpy as jnp -import serket as sk -from serket._src.utils import CHWArray, validate_spatial_ndim +from serket import TreeClass +from serket._src.utils.typing import CHWArray +from serket._src.utils.validate import validate_spatial_ndim def rgb_to_grayscale(image: CHWArray, weights: jax.Array | None = None) -> CHWArray: @@ -50,7 +51,7 @@ def grayscale_to_rgb(image: CHWArray) -> CHWArray: return jnp.concatenate([image, image, image], axis=0) -class RGBToGrayscale2D(sk.TreeClass): +class RGBToGrayscale2D(TreeClass): """Converts a channel-first RGB image to grayscale. .. image:: ../_static/rgbtograyscale2d.png @@ -135,7 +136,7 @@ def hsv_to_rgb(image: CHWArray) -> CHWArray: return out -class GrayscaleToRGB2D(sk.TreeClass): +class GrayscaleToRGB2D(TreeClass): """Converts a grayscale image to RGB. Example: @@ -155,7 +156,7 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -class RGBToHSV2D(sk.TreeClass): +class RGBToHSV2D(TreeClass): """Converts an RGB image to HSV. .. image:: ../_static/rgbtohsv2d.png @@ -180,7 +181,7 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -class HSVToRGB2D(sk.TreeClass): +class HSVToRGB2D(TreeClass): """Converts an HSV image to RGB. Example: diff --git a/serket/_src/image/filter.py b/serket/_src/image/filter.py index a7e2974e..d8539bdb 100644 --- a/serket/_src/image/filter.py +++ b/serket/_src/image/filter.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,20 +22,20 @@ import jax.random as jr from jax.scipy.ndimage import map_coordinates -import serket as sk +from serket import TreeClass from serket._src.image.geometric import rotate_2d -from serket._src.nn.convolution import fft_conv_general_dilated -from serket._src.nn.initialization import DType -from serket._src.utils import ( - CHWArray, - HWArray, - canonicalize, - delayed_canonicalize_padding, +from serket._src.nn.convolution import ( + fft_conv_general_dilated, generate_conv_dim_numbers, - kernel_map, +) +from serket._src.utils.convert import canonicalize +from serket._src.utils.mapping import kernel_map +from serket._src.utils.padding import ( + delayed_canonicalize_padding, resolve_string_padding, - validate_spatial_ndim, ) +from serket._src.utils.typing import CHWArray, DType, HWArray +from serket._src.utils.validate import validate_spatial_ndim def filter_2d( @@ -762,7 +762,7 @@ def fft_blur_pool_2d( return fft_filter_2d(image, kernel, strides) -class BaseAvgBlur2D(sk.TreeClass): +class BaseAvgBlur2D(TreeClass): def __init__(self, kernel_size: int | tuple[int, int]): self.kernel_size = canonicalize(kernel_size, ndim=2, name="kernel_size") @@ -824,7 +824,7 @@ class FFTAvgBlur2D(BaseAvgBlur2D): filter_op = staticmethod(fft_avg_blur_2d) -class BaseGaussianBlur2D(sk.TreeClass): +class BaseGaussianBlur2D(TreeClass): def __init__( self, kernel_size: int | tuple[int, int], @@ -945,7 +945,7 @@ class FFTUnsharpMask2D(BaseGaussianBlur2D): filter_op = staticmethod(fft_unsharp_mask_2d) -class BoxBlur2DBase(sk.TreeClass): +class BoxBlur2DBase(TreeClass): def __init__(self, kernel_size: int | tuple[int, int]): self.kernel_size = canonicalize(kernel_size, ndim=2, name="kernel_size") @@ -1007,7 +1007,7 @@ class FFTBoxBlur2D(BoxBlur2DBase): filter_op = staticmethod(fft_box_blur_2d) -class Laplacian2DBase(sk.TreeClass): +class Laplacian2DBase(TreeClass): def __init__(self, kernel_size: int | tuple[int, int]): self.kernel_size = canonicalize(kernel_size, ndim=2, name="kernel_size") @@ -1075,7 +1075,7 @@ class FFTLaplacian2D(Laplacian2DBase): filter_op = staticmethod(fft_laplacian_2d) -class MotionBlur2DBase(sk.TreeClass): +class MotionBlur2DBase(TreeClass): def __init__( self, kernel_size: int, @@ -1148,7 +1148,7 @@ class FFTMotionBlur2D(MotionBlur2DBase): filter_op = staticmethod(fft_motion_blur_2d) -class MedianBlur2D(sk.TreeClass): +class MedianBlur2D(TreeClass): """Apply median filter to a channel-first image. .. image:: ../_static/medianblur2d.png @@ -1186,7 +1186,7 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -class Sobel2DBase(sk.TreeClass): +class Sobel2DBase(TreeClass): @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: return jax.vmap(self.filter_op)(image) @@ -1240,7 +1240,7 @@ class FFTSobel2D(Sobel2DBase): filter_op = staticmethod(fft_sobel_2d) -class ElasticTransform2DBase(sk.TreeClass): +class ElasticTransform2DBase(TreeClass): def __init__( self, kernel_size: int | tuple[int, int], @@ -1321,7 +1321,7 @@ class FFTElasticTransform2D(ElasticTransform2DBase): filter_op = staticmethod(fft_elastic_transform_2d) -class BilateralBlur2D(sk.TreeClass): +class BilateralBlur2D(TreeClass): """Apply bilateral blur to a channel-first image. .. image:: ../_static/bilateralblur2d.png @@ -1364,7 +1364,7 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -class JointBilateralBlur2D(sk.TreeClass): +class JointBilateralBlur2D(TreeClass): """Apply joint bilateral blur to a channel-first image. .. image:: ../_static/jointbilateralblur2d.png @@ -1414,7 +1414,7 @@ def __call__(self, image: CHWArray, guide: CHWArray) -> CHWArray: spatial_ndim: int = 2 -class BlurPool2DBase(sk.TreeClass): +class BlurPool2DBase(TreeClass): def __init__( self, kernel_size: int | tuple[int, int], diff --git a/serket/_src/image/geometric.py b/serket/_src/image/geometric.py index ccb1e9ab..38b521e7 100644 --- a/serket/_src/image/geometric.py +++ b/serket/_src/image/geometric.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,16 +21,11 @@ import jax.random as jr from jax.scipy.ndimage import map_coordinates -import serket as sk +from serket import TreeClass, autoinit, field from serket._src.custom_transform import tree_eval from serket._src.nn.linear import Identity -from serket._src.utils import ( - CHWArray, - HWArray, - IsInstance, - Range, - validate_spatial_ndim, -) +from serket._src.utils.typing import CHWArray, HWArray +from serket._src.utils.validate import IsInstance, Range, validate_spatial_ndim def affine_2d(array: HWArray, matrix: HWArray) -> HWArray: @@ -189,7 +184,7 @@ def random_wave_transform_2d( return wave_transform_2d(image, length, amplitude) -class Rotate2D(sk.TreeClass): +class Rotate2D(TreeClass): """Rotate_2d a 2D image by an angle in dgrees in CCW direction .. image:: ../_static/rotate2d.png @@ -220,7 +215,7 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -class RandomRotate2D(sk.TreeClass): +class RandomRotate2D(TreeClass): """Rotate_2d a 2D image by an angle in dgrees in CCW direction Args: @@ -273,7 +268,7 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: spatial_ndim: int = 2 -class HorizontalShear2D(sk.TreeClass): +class HorizontalShear2D(TreeClass): """Shear an image horizontally .. image:: ../_static/horizontalshear2d.png @@ -304,7 +299,7 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -class RandomHorizontalShear2D(sk.TreeClass): +class RandomHorizontalShear2D(TreeClass): """Shear an image horizontally with random angle choice. Args: @@ -358,7 +353,7 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: spatial_ndim: int = 2 -class VerticalShear2D(sk.TreeClass): +class VerticalShear2D(TreeClass): """Shear an image vertically .. image:: ../_static/verticalshear2d.png @@ -389,7 +384,7 @@ def __call__(self, image: jax.Array) -> jax.Array: spatial_ndim: int = 2 -class RandomVerticalShear2D(sk.TreeClass): +class RandomVerticalShear2D(TreeClass): """Shear an image vertically with random angle choice. Args: @@ -443,7 +438,7 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: spatial_ndim: int = 2 -class RandomPerspective2D(sk.TreeClass): +class RandomPerspective2D(TreeClass): """Applies a random perspective transform to a channel-first image. .. image:: ../_static/randomperspective2d.png @@ -481,8 +476,8 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: spatial_ndim: int = 2 -@sk.autoinit -class HorizontalTranslate2D(sk.TreeClass): +@autoinit +class HorizontalTranslate2D(TreeClass): """Translate an image horizontally by a pixel value. .. image:: ../_static/horizontaltranslate2d.png @@ -502,7 +497,7 @@ class HorizontalTranslate2D(sk.TreeClass): [ 0 0 21 22 23]]] """ - shift: int = sk.field(on_setattr=[IsInstance(int)]) + shift: int = field(on_setattr=[IsInstance(int)]) @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: @@ -511,8 +506,8 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -@sk.autoinit -class VerticalTranslate2D(sk.TreeClass): +@autoinit +class VerticalTranslate2D(TreeClass): """Translate an image vertically by a pixel value. .. image:: ../_static/verticaltranslate2d.png @@ -532,7 +527,7 @@ class VerticalTranslate2D(sk.TreeClass): [11 12 13 14 15]]] """ - shift: int = sk.field(on_setattr=[IsInstance(int)]) + shift: int = field(on_setattr=[IsInstance(int)]) @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: @@ -541,8 +536,8 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -@sk.autoinit -class RandomHorizontalTranslate2D(sk.TreeClass): +@autoinit +class RandomHorizontalTranslate2D(TreeClass): """Translate an image horizontally by a random pixel value. Note: @@ -579,7 +574,7 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: spatial_ndim: int = 2 -class RandomVerticalTranslate2D(sk.TreeClass): +class RandomVerticalTranslate2D(TreeClass): """Translate an image vertically by a random pixel value. Note: @@ -618,7 +613,7 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: spatial_ndim: int = 2 -class HorizontalFlip2D(sk.TreeClass): +class HorizontalFlip2D(TreeClass): """Flip channels left to right. .. image:: ../_static/horizontalflip2d.png @@ -648,8 +643,8 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -@sk.autoinit -class RandomHorizontalFlip2D(sk.TreeClass): +@autoinit +class RandomHorizontalFlip2D(TreeClass): """Flip channels left to right with a probability of `rate`. .. image:: ../_static/horizontalflip2d.png @@ -674,7 +669,7 @@ class RandomHorizontalFlip2D(sk.TreeClass): [25 24 23 22 21]]] """ - rate: float = sk.field(on_setattr=[IsInstance(float), Range(0.0, 1.0)]) + rate: float = field(on_setattr=[IsInstance(float), Range(0.0, 1.0)]) @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: @@ -685,7 +680,7 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: spatial_ndim: int = 2 -class VerticalFlip2D(sk.TreeClass): +class VerticalFlip2D(TreeClass): """Flip channels up to down. .. image:: ../_static/verticalflip2d.png @@ -715,8 +710,8 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -@sk.autoinit -class RandomVerticalFlip2D(sk.TreeClass): +@autoinit +class RandomVerticalFlip2D(TreeClass): """Flip channels up to down with a probability of `rate`. .. image:: ../_static/verticalflip2d.png @@ -741,7 +736,7 @@ class RandomVerticalFlip2D(sk.TreeClass): [ 1 2 3 4 5]]] """ - rate: float = sk.field(on_setattr=[IsInstance(float), Range(0.0, 1.0)]) + rate: float = field(on_setattr=[IsInstance(float), Range(0.0, 1.0)]) @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: @@ -752,7 +747,7 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray: spatial_ndim: int = 2 -class WaveTransform2D(sk.TreeClass): +class WaveTransform2D(TreeClass): """Apply a wave transform to an image. .. image:: ../_static/wavetransform2d.png @@ -775,7 +770,7 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -class RandomWaveTransform2D(sk.TreeClass): +class RandomWaveTransform2D(TreeClass): """Apply a random wave transform to an image. .. image:: ../_static/wavetransform2d.png diff --git a/serket/_src/nn/__init__.py b/serket/_src/nn/__init__.py index dbf0b046..650d4be4 100644 --- a/serket/_src/nn/__init__.py +++ b/serket/_src/nn/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/serket/_src/nn/activation.py b/serket/_src/nn/activation.py index 2138c4cf..7101f26e 100644 --- a/serket/_src/nn/activation.py +++ b/serket/_src/nn/activation.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,17 +22,18 @@ import jax.numpy as jnp from jax import lax -import serket as sk -from serket._src.utils import IsInstance, Range, ScalarLike, single_dispatch +from serket import TreeClass, autoinit, field +from serket._src.utils.dispatch import single_dispatch +from serket._src.utils.validate import IsInstance, Range, ScalarLike T = TypeVar("T") -@sk.autoinit -class CeLU(sk.TreeClass): +@autoinit +class CeLU(TreeClass): """Celu activation function""" - alpha: float = sk.field( + alpha: float = field( default=1.0, on_setattr=[ScalarLike()], on_getattr=[lax.stop_gradient_p.bind], @@ -42,11 +43,11 @@ def __call__(self, input: jax.Array) -> jax.Array: return jax.nn.celu(input, alpha=self.alpha) -@sk.autoinit -class ELU(sk.TreeClass): +@autoinit +class ELU(TreeClass): """Exponential linear unit""" - alpha: float = sk.field( + alpha: float = field( default=1.0, on_setattr=[ScalarLike()], on_getattr=[lax.stop_gradient_p.bind], @@ -56,18 +57,18 @@ def __call__(self, input: jax.Array) -> jax.Array: return jax.nn.elu(input, alpha=self.alpha) -@sk.autoinit -class GELU(sk.TreeClass): +@autoinit +class GELU(TreeClass): """Gaussian error linear unit""" - approximate: bool = sk.field(default=False, on_setattr=[IsInstance(bool)]) + approximate: bool = field(default=False, on_setattr=[IsInstance(bool)]) def __call__(self, input: jax.Array) -> jax.Array: return jax.nn.gelu(input, approximate=self.approximate) -@sk.autoinit -class GLU(sk.TreeClass): +@autoinit +class GLU(TreeClass): """Gated linear unit""" def __call__(self, input: jax.Array) -> jax.Array: @@ -79,11 +80,11 @@ def hard_shrink(input: jax.typing.ArrayLike, alpha: float = 0.5) -> jax.Array: return jnp.where(input > alpha, input, jnp.where(input < -alpha, input, 0.0)) -@sk.autoinit -class HardShrink(sk.TreeClass): +@autoinit +class HardShrink(TreeClass): """Hard shrink activation function""" - alpha: float = sk.field( + alpha: float = field( default=0.5, on_setattr=[Range(0), ScalarLike()], on_getattr=[lax.stop_gradient_p.bind], @@ -93,46 +94,46 @@ def __call__(self, input: jax.Array) -> jax.Array: return hard_shrink(input, self.alpha) -class HardSigmoid(sk.TreeClass): +class HardSigmoid(TreeClass): """Hard sigmoid activation function""" def __call__(self, input: jax.Array) -> jax.Array: return jax.nn.hard_sigmoid(input) -class HardSwish(sk.TreeClass): +class HardSwish(TreeClass): """Hard swish activation function""" def __call__(self, input: jax.Array) -> jax.Array: return jax.nn.hard_swish(input) -class HardTanh(sk.TreeClass): +class HardTanh(TreeClass): """Hard tanh activation function""" def __call__(self, input: jax.Array) -> jax.Array: return jax.nn.hard_tanh(input) -class LogSigmoid(sk.TreeClass): +class LogSigmoid(TreeClass): """Log sigmoid activation function""" def __call__(self, input: jax.Array) -> jax.Array: return jax.nn.log_sigmoid(input) -class LogSoftmax(sk.TreeClass): +class LogSoftmax(TreeClass): """Log softmax activation function""" def __call__(self, input: jax.Array) -> jax.Array: return jax.nn.log_softmax(input) -@sk.autoinit -class LeakyReLU(sk.TreeClass): +@autoinit +class LeakyReLU(TreeClass): """Leaky ReLU activation function""" - negative_slope: float = sk.field( + negative_slope: float = field( default=0.01, on_setattr=[Range(0), ScalarLike()], on_getattr=[lax.stop_gradient_p.bind], @@ -142,35 +143,35 @@ def __call__(self, input: jax.Array) -> jax.Array: return jax.nn.leaky_relu(input, self.negative_slope) -class ReLU(sk.TreeClass): +class ReLU(TreeClass): """ReLU activation function""" def __call__(self, input: jax.Array) -> jax.Array: return jax.nn.relu(input) -class ReLU6(sk.TreeClass): +class ReLU6(TreeClass): """ReLU6 activation function""" def __call__(self, input: jax.Array) -> jax.Array: return jax.nn.relu6(input) -class SeLU(sk.TreeClass): +class SeLU(TreeClass): """Scaled Exponential Linear Unit""" def __call__(self, input: jax.Array) -> jax.Array: return jax.nn.selu(input) -class Sigmoid(sk.TreeClass): +class Sigmoid(TreeClass): """Sigmoid activation function""" def __call__(self, input: jax.Array) -> jax.Array: return jax.nn.sigmoid(input) -class SoftPlus(sk.TreeClass): +class SoftPlus(TreeClass): """SoftPlus activation function""" def __call__(self, input: jax.Array) -> jax.Array: @@ -182,7 +183,7 @@ def softsign(x: jax.typing.ArrayLike) -> jax.Array: return x / (1 + jnp.abs(x)) -class SoftSign(sk.TreeClass): +class SoftSign(TreeClass): """SoftSign activation function""" def __call__(self, input: jax.Array) -> jax.Array: @@ -198,11 +199,11 @@ def softshrink(input: jax.typing.ArrayLike, alpha: float = 0.5) -> jax.Array: ) -@sk.autoinit -class SoftShrink(sk.TreeClass): +@autoinit +class SoftShrink(TreeClass): """SoftShrink activation function""" - alpha: float = sk.field( + alpha: float = field( default=0.5, on_setattr=[Range(0), ScalarLike()], on_getattr=[lax.stop_gradient_p.bind], @@ -217,21 +218,21 @@ def squareplus(input: jax.typing.ArrayLike) -> jax.Array: return 0.5 * (input + jnp.sqrt(input * input + 4)) -class SquarePlus(sk.TreeClass): +class SquarePlus(TreeClass): """SquarePlus activation function""" def __call__(self, input: jax.Array) -> jax.Array: return squareplus(input) -class Swish(sk.TreeClass): +class Swish(TreeClass): """Swish activation function""" def __call__(self, input: jax.Array) -> jax.Array: return jax.nn.swish(input) -class Tanh(sk.TreeClass): +class Tanh(TreeClass): """Tanh activation function""" def __call__(self, input: jax.Array) -> jax.Array: @@ -243,7 +244,7 @@ def tanh_shrink(input: jax.typing.ArrayLike) -> jax.Array: return input - jnp.tanh(input) -class TanhShrink(sk.TreeClass): +class TanhShrink(TreeClass): """TanhShrink activation function""" def __call__(self, input: jax.Array) -> jax.Array: @@ -259,11 +260,11 @@ def thresholded_relu(input: jax.typing.ArrayLike, theta: float = 1.0) -> jax.Arr return jnp.where(input > theta, input, 0) -@sk.autoinit -class ThresholdedReLU(sk.TreeClass): +@autoinit +class ThresholdedReLU(TreeClass): """Thresholded ReLU activation function.""" - theta: float = sk.field( + theta: float = field( default=1.0, on_setattr=[Range(0), ScalarLike()], on_getattr=[lax.stop_gradient_p.bind], @@ -278,7 +279,7 @@ def mish(input: jax.typing.ArrayLike) -> jax.Array: return input * jax.nn.tanh(jax.nn.softplus(input)) -class Mish(sk.TreeClass): +class Mish(TreeClass): """Mish activation function https://arxiv.org/pdf/1908.08681.pdf.""" def __call__(self, input: jax.Array) -> jax.Array: @@ -290,11 +291,11 @@ def prelu(input: jax.typing.ArrayLike, a: float = 0.25) -> jax.Array: return jnp.where(input >= 0, input, input * a) -@sk.autoinit -class PReLU(sk.TreeClass): +@autoinit +class PReLU(TreeClass): """Parametric ReLU activation function""" - a: float = sk.field(default=0.25, on_setattr=[Range(0), ScalarLike()]) + a: float = field(default=0.25, on_setattr=[Range(0), ScalarLike()]) def __call__(self, input: jax.Array) -> jax.Array: return prelu(input, self.a) diff --git a/serket/_src/nn/attention.py b/serket/_src/nn/attention.py index e0b1361c..e5b31131 100644 --- a/serket/_src/nn/attention.py +++ b/serket/_src/nn/attention.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,8 +23,8 @@ from typing_extensions import Annotated import serket as sk -from serket._src.nn.initialization import DType, InitType -from serket._src.utils import maybe_lazy_call, maybe_lazy_init +from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init +from serket._src.utils.typing import DType, InitType """Defines attention layers.""" diff --git a/serket/_src/nn/convolution.py b/serket/_src/nn/convolution.py index cf5d37a5..32c0f21a 100644 --- a/serket/_src/nn/convolution.py +++ b/serket/_src/nn/convolution.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,25 +28,45 @@ from typing_extensions import Annotated import serket as sk -from serket._src.nn.initialization import DType, InitType, resolve_init -from serket._src.utils import ( +from serket._src.nn.initialization import resolve_init +from serket._src.utils.convert import canonicalize +from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init +from serket._src.utils.padding import ( + calculate_transpose_padding, + delayed_canonicalize_padding, +) +from serket._src.utils.typing import ( DilationType, + DType, + InitType, KernelSizeType, PaddingType, StridesType, - calculate_convolution_output_shape, - calculate_transpose_padding, - canonicalize, - delayed_canonicalize_padding, - generate_conv_dim_numbers, - maybe_lazy_call, - maybe_lazy_init, - positive_int_cb, + Weight, +) +from serket._src.utils.validate import ( validate_in_features_shape, + validate_pos_int, validate_spatial_ndim, ) -Weight = Annotated[jax.Array, "OI..."] + +def calculate_convolution_output_shape( + shape: tuple[int, ...], + kernel_size: tuple[int, ...], + padding: tuple[tuple[int, int], ...], + strides: tuple[int, ...], +): + """Compute the shape of the output of a convolutional layer.""" + return tuple( + (xi + (li + ri) - ki) // si + 1 + for xi, ki, si, (li, ri) in zip(shape, kernel_size, strides, padding) + ) + + +@ft.lru_cache(maxsize=None) +def generate_conv_dim_numbers(spatial_ndim) -> jax.lax.ConvDimensionNumbers: + return jax.lax.ConvDimensionNumbers(*((tuple(range(spatial_ndim + 2)),) * 3)) def fft_conv_general_dilated( @@ -562,15 +582,15 @@ def __init__( groups: int = 1, dtype: DType = jnp.float32, ): - self.in_features = positive_int_cb(in_features) - self.out_features = positive_int_cb(out_features) + self.in_features = validate_pos_int(in_features) + self.out_features = validate_pos_int(out_features) self.kernel_size = canonicalize(kernel_size, self.spatial_ndim, "kernel_size") self.strides = canonicalize(strides, self.spatial_ndim, "strides") self.padding = padding self.dilation = canonicalize(dilation, self.spatial_ndim, "dilation") self.weight_init = weight_init self.bias_init = bias_init - self.groups = positive_int_cb(groups) + self.groups = validate_pos_int(groups) if self.out_features % self.groups != 0: raise ValueError(f"{(out_features % groups == 0)=}") @@ -1163,8 +1183,8 @@ def __init__( groups: int = 1, dtype: DType = jnp.float32, ): - self.in_features = positive_int_cb(in_features) - self.out_features = positive_int_cb(out_features) + self.in_features = validate_pos_int(in_features) + self.out_features = validate_pos_int(out_features) self.kernel_size = canonicalize(kernel_size, self.spatial_ndim, "kernel_size") self.strides = canonicalize(strides, self.spatial_ndim, "strides") self.padding = padding # delayed canonicalization @@ -1172,12 +1192,12 @@ def __init__( self.dilation = canonicalize(dilation, self.spatial_ndim, "dilation") self.weight_init = weight_init self.bias_init = bias_init - self.groups = positive_int_cb(groups) + self.groups = validate_pos_int(groups) if self.out_features % self.groups != 0: raise ValueError(f"{(self.out_features % self.groups ==0)=}") - in_features = positive_int_cb(self.in_features) + in_features = validate_pos_int(self.in_features) weight_shape = (out_features, in_features // groups, *self.kernel_size) self.weight = resolve_init(self.weight_init)(key, weight_shape, dtype) @@ -1787,9 +1807,9 @@ def __init__( bias_init: InitType = "zeros", dtype: DType = jnp.float32, ): - self.in_features = positive_int_cb(in_features) + self.in_features = validate_pos_int(in_features) self.kernel_size = canonicalize(kernel_size, self.spatial_ndim, "kernel_size") - self.depth_multiplier = positive_int_cb(depth_multiplier) + self.depth_multiplier = validate_pos_int(depth_multiplier) self.strides = canonicalize(strides, self.spatial_ndim, "strides") self.padding = padding # delayed canonicalization self.weight_init = weight_init @@ -2302,9 +2322,9 @@ def __init__( pointwise_bias_init: InitType = "zeros", dtype: DType = jnp.float32, ): - self.in_features = positive_int_cb(in_features) + self.in_features = validate_pos_int(in_features) self.kernel_size = canonicalize(kernel_size, self.spatial_ndim, "kernel_size") - self.depth_multiplier = positive_int_cb(depth_multiplier) + self.depth_multiplier = validate_pos_int(depth_multiplier) self.strides = canonicalize(strides, self.spatial_ndim, "strides") self.padding = padding # delayed canonicalization self.depthwise_weight_init = depthwise_weight_init @@ -2902,8 +2922,8 @@ def __init__( key: jax.Array, dtype: DType = jnp.float32, ): - self.in_features = positive_int_cb(in_features) - self.out_features = positive_int_cb(out_features) + self.in_features = validate_pos_int(in_features) + self.out_features = validate_pos_int(out_features) self.modes: tuple[int, ...] = canonicalize(modes, self.spatial_ndim, "modes") weight_shape = (1, out_features, in_features, *self.modes) scale = 1 / (in_features * out_features) @@ -3148,8 +3168,8 @@ def __init__( bias_init: InitType = "zeros", dtype: DType = jnp.float32, ): - self.in_features = positive_int_cb(in_features) - self.out_features = positive_int_cb(out_features) + self.in_features = validate_pos_int(in_features) + self.out_features = validate_pos_int(out_features) self.kernel_size = canonicalize(kernel_size, self.spatial_ndim, "kernel_size") self.in_size = canonicalize(in_size, self.spatial_ndim, name="in_size") self.strides = canonicalize(strides, self.spatial_ndim, "strides") diff --git a/serket/_src/nn/dropout.py b/serket/_src/nn/dropout.py index 923e0755..1310d343 100644 --- a/serket/_src/nn/dropout.py +++ b/serket/_src/nn/dropout.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,12 +24,12 @@ import serket as sk from serket._src.custom_transform import tree_eval -from serket._src.utils import ( +from serket._src.utils.convert import canonicalize +from serket._src.utils.mapping import kernel_map +from serket._src.utils.validate import ( IsInstance, Range, - canonicalize, - kernel_map, - positive_int_cb, + validate_pos_int, validate_spatial_ndim, ) @@ -179,8 +179,7 @@ def __call__(self, input: jax.Array, *, key: jax.Array) -> jax.Array: @property @abc.abstractmethod - def spatial_ndim(self): - ... + def spatial_ndim(self): ... class Dropout1D(DropoutND): @@ -299,7 +298,7 @@ def __init__( fill_value: int | float = 0, ): self.shape = canonicalize(shape, ndim=self.spatial_ndim, name="shape") - self.cutout_count = positive_int_cb(cutout_count) + self.cutout_count = validate_pos_int(cutout_count) self.fill_value = fill_value @ft.partial(validate_spatial_ndim, argnum=0) diff --git a/serket/_src/nn/initialization.py b/serket/_src/nn/initialization.py index 8411dac7..d8da63ab 100644 --- a/serket/_src/nn/initialization.py +++ b/serket/_src/nn/initialization.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,36 +14,14 @@ from __future__ import annotations from collections.abc import Callable as ABCCallable -from typing import Any, Callable, Literal, Tuple, Union, get_args +from typing import Callable, get_args import jax import jax.nn.initializers as ji import jax.tree_util as jtu -import numpy as np - -from serket._src.utils import single_dispatch - -InitLiteral = Literal[ - "he_normal", - "he_uniform", - "glorot_normal", - "glorot_uniform", - "lecun_normal", - "lecun_uniform", - "normal", - "uniform", - "ones", - "zeros", - "xavier_normal", - "xavier_uniform", - "orthogonal", -] - -Shape = Tuple[int, ...] -DType = Union[np.dtype, str, Any] -InitFuncType = Callable[[jax.Array, Shape, DType], jax.Array] -InitType = Union[InitLiteral, InitFuncType] +from serket._src.utils.dispatch import single_dispatch +from serket._src.utils.typing import InitFuncType, InitLiteral, InitType inits: list[InitFuncType] = [ ji.he_normal(), diff --git a/serket/_src/nn/linear.py b/serket/_src/nn/linear.py index a13136c2..e70a6b45 100644 --- a/serket/_src/nn/linear.py +++ b/serket/_src/nn/linear.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,8 +27,11 @@ ActivationType, resolve_activation, ) -from serket._src.nn.initialization import DType, InitType, resolve_init -from serket._src.utils import maybe_lazy_call, maybe_lazy_init, positive_int_cb, tuplify +from serket._src.nn.initialization import resolve_init +from serket._src.utils.convert import tuplify +from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init +from serket._src.utils.typing import DType, InitType +from serket._src.utils.validate import validate_pos_int T = TypeVar("T") PyTree = Any @@ -245,8 +248,8 @@ class Embedding(sk.TreeClass): """ def __init__(self, in_features: int, out_features: int, key: jax.Array): - self.in_features = positive_int_cb(in_features) - self.out_features = positive_int_cb(out_features) + self.in_features = validate_pos_int(in_features) + self.out_features = validate_pos_int(out_features) self.weight = jr.normal(key, (self.out_features, self.in_features)) def __call__(self, input: jax.Array) -> jax.Array: diff --git a/serket/_src/nn/normalization.py b/serket/_src/nn/normalization.py index e1415e82..58f36544 100644 --- a/serket/_src/nn/normalization.py +++ b/serket/_src/nn/normalization.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,15 +23,15 @@ import serket as sk from serket._src.custom_transform import tree_eval, tree_state -from serket._src.nn.initialization import DType, InitType, resolve_init -from serket._src.utils import ( +from serket._src.nn.initialization import resolve_init +from serket._src.utils.convert import tuplify +from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init +from serket._src.utils.typing import DType, InitType +from serket._src.utils.validate import ( Range, ScalarLike, - maybe_lazy_call, - maybe_lazy_init, - positive_int_cb, - tuplify, validate_in_features_shape, + validate_pos_int, ) @@ -297,8 +297,8 @@ def __init__( bias_init: InitType = "zeros", dtype: DType = jnp.float32, ): - self.in_features = positive_int_cb(in_features) - self.groups = positive_int_cb(groups) + self.in_features = validate_pos_int(in_features) + self.groups = validate_pos_int(groups) self.eps = eps # needs more info for checking @@ -384,7 +384,7 @@ def __init__( bias_init: InitType = "zeros", dtype: DType = jnp.float32, ): - self.in_features = positive_int_cb(in_features) + self.in_features = validate_pos_int(in_features) self.eps = eps self.weight = resolve_init(weight_init)(key, (in_features,), dtype) self.bias = resolve_init(bias_init)(key, (in_features,), dtype) diff --git a/serket/_src/nn/pooling.py b/serket/_src/nn/pooling.py index 4fa1cc1f..8eff03fa 100644 --- a/serket/_src/nn/pooling.py +++ b/serket/_src/nn/pooling.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,15 +23,11 @@ from typing_extensions import Annotated import serket as sk -from serket._src.utils import ( - KernelSizeType, - PaddingType, - StridesType, - canonicalize, - delayed_canonicalize_padding, - kernel_map, - validate_spatial_ndim, -) +from serket._src.utils.convert import canonicalize +from serket._src.utils.mapping import kernel_map +from serket._src.utils.padding import delayed_canonicalize_padding +from serket._src.utils.typing import KernelSizeType, PaddingType, StridesType +from serket._src.utils.validate import validate_spatial_ndim def pool_nd( diff --git a/serket/_src/nn/recurrent.py b/serket/_src/nn/recurrent.py index 03fbf59a..28bf7faa 100644 --- a/serket/_src/nn/recurrent.py +++ b/serket/_src/nn/recurrent.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ from typing_extensions import ParamSpec import serket as sk +from serket import TreeClass, autoinit from serket._src.custom_transform import tree_state from serket._src.nn.activation import ActivationType, resolve_activation from serket._src.nn.convolution import ( @@ -34,16 +35,18 @@ FFTConv2D, FFTConv3D, ) -from serket._src.nn.initialization import DType, InitType -from serket._src.utils import ( +from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init +from serket._src.utils.typing import ( DilationType, + DType, + InitType, KernelSizeType, PaddingType, StridesType, - maybe_lazy_call, - maybe_lazy_init, - positive_int_cb, +) +from serket._src.utils.validate import ( validate_in_features_shape, + validate_pos_int, validate_spatial_ndim, ) @@ -71,15 +74,15 @@ def infer_in_features(_, input: jax.Array, *__, **___) -> int: updates = dict(in_features=infer_in_features) -@sk.autoinit -class RNNState(sk.TreeClass): +@autoinit +class RNNState(TreeClass): hidden_state: jax.Array class SimpleRNNState(RNNState): ... -class SimpleRNNCell(sk.TreeClass): +class SimpleRNNCell(TreeClass): """Vanilla RNN cell that defines the update rule for the hidden state Args: @@ -144,8 +147,8 @@ def __init__( ): k1, k2 = jr.split(key, 2) - self.in_features = positive_int_cb(in_features) - self.hidden_features = positive_int_cb(hidden_features) + self.in_features = validate_pos_int(in_features) + self.hidden_features = validate_pos_int(hidden_features) self.act = resolve_activation(act) i2h = sk.nn.Linear( @@ -198,7 +201,7 @@ def __call__( class DenseState(RNNState): ... -class LinearCell(sk.TreeClass): +class LinearCell(TreeClass): """No hidden state cell that applies a dense(Linear+activation) layer to the input Args: @@ -257,8 +260,8 @@ def __init__( key: jax.Array, dtype: DType = jnp.float32, ): - self.in_features = positive_int_cb(in_features) - self.hidden_features = positive_int_cb(hidden_features) + self.in_features = validate_pos_int(in_features) + self.hidden_features = validate_pos_int(hidden_features) self.act = resolve_activation(act) self.in_to_hidden = sk.nn.Linear( @@ -288,12 +291,12 @@ def __call__( spatial_ndim: int = 0 -@sk.autoinit +@autoinit class LSTMState(RNNState): cell_state: jax.Array -class LSTMCell(sk.TreeClass): +class LSTMCell(TreeClass): """LSTM cell that defines the update rule for the hidden state and cell state Args: @@ -361,8 +364,8 @@ def __init__( ): k1, k2 = jr.split(key, 2) - self.in_features = positive_int_cb(in_features) - self.hidden_features = positive_int_cb(hidden_features) + self.in_features = validate_pos_int(in_features) + self.hidden_features = validate_pos_int(hidden_features) self.act = resolve_activation(act) self.recurrent_act = resolve_activation(recurrent_act) @@ -424,7 +427,7 @@ def __call__( class GRUState(RNNState): ... -class GRUCell(sk.TreeClass): +class GRUCell(TreeClass): """GRU cell that defines the update rule for the hidden state and cell state Args: @@ -491,8 +494,8 @@ def __init__( ): k1, k2 = jr.split(key, 2) - self.in_features = positive_int_cb(in_features) - self.hidden_features = positive_int_cb(hidden_features) + self.in_features = validate_pos_int(in_features) + self.hidden_features = validate_pos_int(hidden_features) self.act = resolve_activation(act) self.recurrent_act = resolve_activation(recurrent_act) @@ -539,12 +542,12 @@ def __call__( spatial_ndim: int = 0 -@sk.autoinit +@autoinit class ConvLSTMNDState(RNNState): cell_state: jax.Array -class ConvLSTMNDCell(sk.TreeClass): +class ConvLSTMNDCell(TreeClass): @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) def __init__( self, @@ -565,8 +568,8 @@ def __init__( ): k1, k2 = jr.split(key, 2) - self.in_features = positive_int_cb(in_features) - self.hidden_features = positive_int_cb(hidden_features) + self.in_features = validate_pos_int(in_features) + self.hidden_features = validate_pos_int(hidden_features) self.act = resolve_activation(act) self.recurrent_act = resolve_activation(recurrent_act) @@ -970,7 +973,7 @@ class FFTConvLSTM3DCell(ConvLSTMNDCell): class ConvGRUNDState(RNNState): ... -class ConvGRUNDCell(sk.TreeClass): +class ConvGRUNDCell(TreeClass): @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) def __init__( self, @@ -991,8 +994,8 @@ def __init__( ): k1, k2 = jr.split(key, 2) - self.in_features = positive_int_cb(in_features) - self.hidden_features = positive_int_cb(hidden_features) + self.in_features = validate_pos_int(in_features) + self.hidden_features = validate_pos_int(hidden_features) self.act = resolve_activation(act) self.recurrent_act = resolve_activation(recurrent_act) diff --git a/serket/_src/nn/reshape.py b/serket/_src/nn/reshape.py index 9ea06c72..eb68749b 100644 --- a/serket/_src/nn/reshape.py +++ b/serket/_src/nn/reshape.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ import abc import functools as ft -from typing import Literal import jax import jax.random as jr @@ -24,17 +23,16 @@ import serket as sk from serket._src.custom_transform import tree_eval from serket._src.nn.linear import Identity -from serket._src.utils import ( +from serket._src.utils.convert import canonicalize +from serket._src.utils.mapping import kernel_map +from serket._src.utils.padding import delayed_canonicalize_padding +from serket._src.utils.typing import ( KernelSizeType, + MethodKind, PaddingType, StridesType, - canonicalize, - delayed_canonicalize_padding, - kernel_map, - validate_spatial_ndim, ) - -MethodKind = Literal["nearest", "linear", "cubic", "lanczos3", "lanczos5"] +from serket._src.utils.validate import validate_spatial_ndim def random_crop_nd( diff --git a/serket/_src/utils.py b/serket/_src/utils.py deleted file mode 100644 index 26e30a4b..00000000 --- a/serket/_src/utils.py +++ /dev/null @@ -1,723 +0,0 @@ -# Copyright 2023 serket authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import functools as ft -import inspect -import operator as op -from types import MethodType -from typing import Any, Callable, Sequence, Tuple, TypeVar, Union - -import jax -import jax.numpy as jnp -import numpy as np -from jax.util import safe_zip -from typing_extensions import Annotated, Literal, ParamSpec - -import serket as sk - -KernelSizeType = Union[int, Sequence[int]] -StridesType = Union[int, Sequence[int]] -PaddingType = Union[str, int, Sequence[int], Sequence[Tuple[int, int]]] -DilationType = Union[int, Sequence[int]] -P = ParamSpec("P") -T = TypeVar("T") -HWArray = Annotated[jax.Array, "HW"] -CHWArray = Annotated[jax.Array, "CHW"] -PaddingLiteral = Literal[ - "constant", - "edge", - "linear_ramp", - "maximum", - "mean", - "median", - "minimum", - "reflect", - "symmetric", - "wrap", -] -PaddingMode = Union[PaddingLiteral, Union[int, float], Callable] - - -@ft.lru_cache(maxsize=None) -def generate_conv_dim_numbers(spatial_ndim) -> jax.lax.ConvDimensionNumbers: - return jax.lax.ConvDimensionNumbers(*((tuple(range(spatial_ndim + 2)),) * 3)) - - -@ft.lru_cache(maxsize=128) -def calculate_transpose_padding( - padding, - kernel_size, - input_dilation, - extra_padding, -): - """Transpose padding to get the padding for the transpose convolution. - - Args: - padding: padding to transpose - kernel_size: kernel size to use for transposing padding - input_dilation: input dilation to use for transposing padding - extra_padding: extra padding to use for transposing padding - """ - return tuple( - ((ki - 1) * di - pl, (ki - 1) * di - pr + ep) - for (pl, pr), ki, ep, di in zip( - padding, kernel_size, extra_padding, input_dilation - ) - ) - - -def calculate_convolution_output_shape( - shape: tuple[int, ...], - kernel_size: tuple[int, ...], - padding: tuple[tuple[int, int], ...], - strides: tuple[int, ...], -): - """Compute the shape of the output of a convolutional layer.""" - return tuple( - (xi + (li + ri) - ki) // si + 1 - for xi, ki, si, (li, ri) in zip(shape, kernel_size, strides, padding) - ) - - -def same_padding_along_dim( - in_dim: int, - kernel_size: int, - stride: int, -) -> tuple[int, int]: - # https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2 - # di: input dimension - # ki: kernel size - # si: stride - if in_dim % stride == 0: - pad = max(kernel_size - stride, 0) - else: - pad = max(kernel_size - (in_dim % stride), 0) - - return (pad // 2, pad - pad // 2) - - -def resolve_tuple_padding( - in_dim: tuple[int, ...], - padding: PaddingType, - kernel_size: KernelSizeType, - strides: StridesType, -) -> tuple[tuple[int, int], ...]: - del in_dim, strides - if len(padding) != len(kernel_size): - raise ValueError(f"Length mismatch {len(kernel_size)=}!={len(padding)=}.") - - resolved_padding = [[]] * len(kernel_size) - - for i, item in enumerate(padding): - if isinstance(item, int): - resolved_padding[i] = (item, item) # ex: padding = (1, 2, 3) - - elif isinstance(item, tuple): - if len(item) != 2: - raise ValueError(f"Expected tuple of length 2, got {len(item)=}") - resolved_padding[i] = item - - return tuple(resolved_padding) - - -def resolve_int_padding( - in_dim: tuple[int, ...], - padding: PaddingType, - kernel_size: KernelSizeType, - strides: StridesType, -): - del in_dim, strides - return ((padding, padding),) * len(kernel_size) - - -def resolve_string_padding(in_dim, padding, kernel_size, strides): - if padding.lower() == "same": - return tuple( - same_padding_along_dim(di, ki, si) - for di, ki, si in zip(in_dim, kernel_size, strides) - ) - - if padding.lower() == "valid": - return ((0, 0),) * len(kernel_size) - - raise ValueError(f'string argument must be in ["same","valid"].Found {padding}') - - -@ft.lru_cache(maxsize=128) -def delayed_canonicalize_padding( - in_dim: tuple[int, ...], - padding: PaddingType, - kernel_size: KernelSizeType, - strides: StridesType, -): - # in case of `str` padding, we need to know the input dimension - # to calculate the padding thus we need to delay the canonicalization - # until the call - - if isinstance(padding, int): - return resolve_int_padding(in_dim, padding, kernel_size, strides) - - if isinstance(padding, str): - return resolve_string_padding(in_dim, padding, kernel_size, strides) - - if isinstance(padding, tuple): - return resolve_tuple_padding(in_dim, padding, kernel_size, strides) - - raise ValueError( - "Expected padding to be of:\n" - "* int, for same padding along all dimensions\n" - "* str, for `same` or `valid` padding along all dimensions\n" - "* tuple of int, for individual padding along each dimension\n" - "* tuple of tuple of int, for padding before and after each dimension\n" - f"Got {padding=}." - ) - - -def canonicalize(value, ndim, name: str | None = None): - if isinstance(value, (int, float)): - return (value,) * ndim - if isinstance(value, jax.Array): - return jnp.repeat(value, ndim) - if isinstance(value, tuple): - if len(value) != ndim: - raise ValueError(f"{len(value)=} != {ndim=} for {name=} and {value=}.") - return tuple(value) - - raise ValueError(f"Expected int or tuple , got {value=}.") - - -@sk.autoinit -class Range(sk.TreeClass): - min_val: float = -float("inf") - max_val: float = float("inf") - min_inclusive: bool = True - max_inclusive: bool = True - - def __call__(self, value: Any): - lop, ls = (op.ge, "[") if self.min_inclusive else (op.gt, "(") - rop, rs = (op.le, "]") if self.max_inclusive else (op.lt, ")") - - if lop(value, self.min_val) and rop(value, self.max_val): - return value - - raise ValueError(f"Not in {ls}{self.min_val}, {self.max_val}{rs} got {value=}.") - - -@sk.autoinit -class IsInstance(sk.TreeClass): - klass: type | Sequence[type] - - def __call__(self, value: Any): - if isinstance(value, self.klass): - return value - raise TypeError(f"Expected {self.klass}, got {type(value).__name__}") - - -class ScalarLike(sk.TreeClass): - """Check if the input is a scalar""" - - def __call__(self, value: Any): - if isinstance(value, (float, complex)): - return value - if ( - isinstance(value, (jax.Array, np.ndarray)) - and np.issubdtype(value.dtype, np.inexact) - and value.shape == () - ): - return value - raise ValueError(f"Expected inexact type got {value=}") - - -def positive_int_cb(value): - """Return if value is a positive integer, otherwise raise an error.""" - if not isinstance(value, int): - raise ValueError(f"value must be an integer, got {type(value).__name__}") - if value <= 0: - raise ValueError(f"{value=} must be positive.") - return value - - -def recursive_getattr(obj, attr: Sequence[str]): - return ( - getattr(obj, attr[0]) - if len(attr) == 1 - else recursive_getattr(getattr(obj, attr[0]), attr[1:]) - ) - - -def validate_spatial_ndim(func: Callable[P, T], argnum: int = 0) -> Callable[P, T]: - """Decorator to validate spatial input shape.""" - - @ft.wraps(func) - def wrapper(self, *args, **kwargs): - input = args[argnum] - spatial_ndim = self.spatial_ndim - - if input.ndim != spatial_ndim + 1: - spatial = ", ".join(("rows", "cols", "depths")[:spatial_ndim]) - name = type(self).__name__ - raise ValueError( - f"Dimesion mismatch error in inputs of {name}\n" - f"Input should satisfy:\n" - f" - {(spatial_ndim + 1) = } dimension, but got {input.ndim = }.\n" - f" - shape of (in_features, {spatial}), but got {input.shape = }.\n" - + ( - # maybe the user apply the layer on a batched input - "The input should be unbatched (no batch dimension).\n" - "To apply on batched input, use `jax.vmap(...)(input)`." - if input.ndim == spatial_ndim + 2 - else "" - ) - ) - return func(self, *args, **kwargs) - - return wrapper - - -def validate_in_features_shape(func: Callable[P, T], axis: int) -> Callable[P, T]: - """Decorator to validate input features.""" - - def check_axis_shape(input, in_features: int, axis: int) -> None: - if input.shape[axis] != in_features: - raise ValueError(f"Specified {in_features=}, got {input.shape[axis]=}.") - return input - - @ft.wraps(func) - def wrapper(self, array, *a, **k): - check_axis_shape(array, self.in_features, axis) - return func(self, array, *a, **k) - - return wrapper - - -@ft.lru_cache(maxsize=128) -def get_params(func: MethodType) -> tuple[inspect.Parameter, ...]: - """Get the arguments of func.""" - return tuple(inspect.signature(func).parameters.values()) - - -def tuplify(value: T) -> T | tuple[T]: - return value if isinstance(value, tuple) else (value,) - - -def kernel_map( - func: dict, - shape: tuple[int, ...], - kernel_size: tuple[int, ...], - strides: tuple[int, ...], - padding: tuple[tuple[int, int], ...], - padding_mode: PaddingMode = "constant", -) -> Callable: - """Minimal implementation of kmap from kernex""" - # Copyright 2023 Kernex authors - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. - # You may obtain a copy of the License at - # - # https://www.apache.org/licenses/LICENSE-2.0 - # - # Unless required by applicable law or agreed to in writing, software - # distributed under the License is distributed on an "AS IS" BASIS, - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - # See the License for the specific language governing permissions and - # limitations under the License. - - # copied here to avoid requiring kernex as a dependency - # does not support most of the kernex features - if isinstance(padding_mode, (int, float)): - pad_kwargs = dict(mode="constant", constant_values=padding_mode) - elif isinstance(padding_mode, (str, Callable)): - pad_kwargs = dict(mode=padding_mode) - - gather_kwargs = dict( - mode="promise_in_bounds", - indices_are_sorted=True, - unique_indices=True, - ) - - def calculate_kernel_map_output_shape( - shape: tuple[int, ...], - kernel_size: tuple[int, ...], - strides: tuple[int, ...], - border: tuple[tuple[int, int], ...], - ) -> tuple[int, ...]: - return tuple( - (xi + (li + ri) - ki) // si + 1 - for xi, ki, si, (li, ri) in safe_zip(shape, kernel_size, strides, border) - ) - - @ft.partial(jax.profiler.annotate_function, name="general_arange") - def general_arange(di: int, ki: int, si: int, x0: int, xf: int) -> jax.Array: - # this function is used to calculate the windows indices for a given dimension - start, end = -x0 + ((ki - 1) // 2), di + xf - (ki // 2) - size = end - start - - res = ( - jax.lax.broadcasted_iota(dtype=jnp.int32, shape=(size, ki), dimension=0) - + jax.lax.broadcasted_iota(dtype=jnp.int32, shape=(ki, size), dimension=0).T - + (start) - - ((ki - 1) // 2) - ) - - # res[::si] is slightly slower. - return (res) if si == 1 else (res)[::si] - - @ft.lru_cache(maxsize=128) - def recursive_vmap(*, ndim: int): - def nvmap(n): - in_axes = [None] * ndim - in_axes[-n] = 0 - return ( - jax.vmap(lambda *x: x, in_axes=in_axes) - if n == 1 - else jax.vmap(nvmap(n - 1), in_axes=in_axes) - ) - - return nvmap(ndim) - - @ft.partial(jax.profiler.annotate_function, name="general_product") - def general_product(*args: jax.Array): - return recursive_vmap(ndim=len(args))(*args) - - def generate_views( - shape: tuple[int, ...], - kernel_size: tuple[int, ...], - strides: tuple[int, ...], - border: tuple[tuple[int, int], ...], - ) -> tuple[jax.Array, ...]: - dim_range = tuple( - general_arange(di, ki, si, x0, xf) - for (di, ki, si, (x0, xf)) in zip(shape, kernel_size, strides, border) - ) - matrix = general_product(*dim_range) - return tuple(map(lambda xi, wi: xi.reshape(-1, wi), matrix, kernel_size)) - - def absolute_wrapper(*a, **k): - def map_func(view: tuple[jax.Array, ...], array: jax.Array): - patch = array.at[ix_(*view)].get(**gather_kwargs) - return func(patch, *a, **k) - - return map_func - - def ix_(*args): - """modified version of jnp.ix_""" - n = len(args) - output = [] - for i, a in enumerate(args): - shape = [1] * n - shape[i] = a.shape[0] - output.append(jax.lax.broadcast_in_dim(a, shape, (i,))) - return tuple(output) - - pad_width = tuple([0, max(0, pi[0]) + max(0, pi[1])] for pi in padding) - args = (shape, kernel_size, strides, padding) - views = generate_views(*args) - output_shape = calculate_kernel_map_output_shape(*args) - - def single_call_wrapper(array: jax.Array, *a, **k): - padded_array = jnp.pad(array, pad_width, **pad_kwargs) - reduced_func = absolute_wrapper(*a, **k) - - def map_func(view): - return reduced_func(view, padded_array) - - result = jax.vmap(map_func)(views) - return result.reshape(*output_shape, *result.shape[1:]) - - return single_call_wrapper - - -def single_dispatch(argnum: int = 0): - """Single dispatch with argnum""" - - def decorator(func): - dispatcher = ft.singledispatch(func) - - @ft.wraps(func) - def wrapper(*args, **kwargs): - try: - klass = type(args[argnum]) - except IndexError: - argname = get_params(func)[argnum].name - klass = type(kwargs[argname]) - return dispatcher.dispatch(klass)(*args, **kwargs) - - wrapper.def_type = dispatcher.register - wrapper.registry = dispatcher.registry - ft.update_wrapper(wrapper, func) - return wrapper - - return decorator - - -"""This module provides decorators to handle lazy layers in a functional way. - - -Creating a _lazy_ ``Linear`` layer example: - -In this example, we create a _lazy_ ``Linear`` that initializes the weights -and biases based on the input. The ``__init__`` method is decorated with -``maybe_lazy_init`` with the condition that if ``in_features`` is ``None`` -then the layer is lazy. and the ``__call__`` method is decorated with -``maybe_lazy_call`` with the condition that if the instance ``in_features`` -is ``None`` then the layer is lazy. addditionally, we define an update function -that infers the ``in_features`` from the input shape and updates the -``in_features`` attribute to then -re-initialize the layer with the inferred ``in_features``. - -One benefit of this approach is that we can use the layer as a lazy layer -or a materialized layer without changing the code. This is useful to -translate code from both explicit and implicit shaped layer found in -libraries like ``pytorch`` and ``tensorflow``. - -As quick sketch how this work is in the following example: - ->>> import jax ->>> class Lazy: -... def __init__(self, dim_size: int | None): -... # let dim size be the array size -... # and if we dont have the array size -... # we can set it to None to be inferred later -... self.dim_size = dim_size -... def __call__(self, x): -... return x * self.dim_size ->>> def maybe_lazy_init(func): -... def wrapper(self, dim_size): -... if input is not None: -... return func(self, dim_size) -... # we do not execute the init function -... # because its lazy -... return None -... return wrapper ->>> def maybe_lazy_call(func): -... def wrapper(self, x): -... if self.dim_size is not None: -... return func(self, x) -... # the input is lazy , so we do infer the dim size -... # here. because `TreeClass` is immutable we need to -... # return a new instance of the class with the updated -... # dim size, but here we are just updating the dim size -... # of the current instance that is not immutable -... self.dim_size = x.size -... return func(self, x) -... return wrapper ->>> # now lets decorate our lazy class ->>> Lazy.__init__ = maybe_lazy_init(Lazy.__init__) ->>> Lazy.__call__ = maybe_lazy_call(Lazy.__call__) ->>> print(Lazy(2)(jax.numpy.ones([2]))) ->>> print(Lazy(None)(jax.numpy.ones([2]))) - - -Now lets create a lazy ``Linear`` layer using ``serket``: - ->>> import functools as ft ->>> import serket as sk ->>> import jax.numpy as jnp ->>> from serket._src.utils import maybe_lazy_call, maybe_lazy_init ->>> def is_lazy_init(self, in_features, out_features): -... # we need to define how to tell if the layer is lazy -... # based on the inputs -... return in_features is None # or anything else really ->>> def is_lazy_call(self, x): -... # we need to define how to tell if the layer is lazy -... # at the call time -... # replicating the lazy init condition -... return getattr(self, "in_features", False) is None ->>> def infer_in_features(self, x): -... # we need to define how to infer the in_features -... # based on the inputs at call time -... # for linear layers, we can infer the in_features as the last dimension -... return x.shape[-1] ->>> # lastly we need to assign this function to a dictionary that has the name ->>> # of the feature we want to infer ->>> updates = dict(in_features=infer_in_features) ->>> class SimpleLinear(sk.TreeClass): -... @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) -... def __init__(self, in_features, out_features): -... self.in_features = in_features -... self.out_features = out_features -... self.weight = jnp.ones((in_features, out_features)) # dummy weight -... self.bias = jnp.zeros((out_features,)) # dummy bias ->>> @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) -... def __call__(self, x): -... return x ->>> simple_lazy = SimpleLinear(None, 1) ->>> x = jnp.ones([10, 2]) # last dimension is the in_features of the layer ->>> print(repr(simple_lazy)) -SimpleLinear(in_features=None, out_features=1) ->>> _, material = sk.value_and_tree(lambda layer: layer(x))(simple_lazy) ->>> print(repr(material)) -SimpleLinear( - in_features=2, - out_features=1, - weight=f32[2,1](μ=1.00, σ=0.00, ∈[1.00,1.00]), - bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00]) -) -""" - - -# Maybe expose this as a public API -# Handling lazy layers - - -def maybe_lazy_init( - func: Callable[P, T], - is_lazy: Callable[..., bool], -) -> Callable[P, T]: - """Sets input arguments to instance attribute if lazy initialization is ``True``. - - The key idea is to store the input arguments to the instance attribute to - be used later when the instance is re-initialized using ``maybe_lazy_call`` - decorator. ``maybe_lazy_call`` assumes that the input arguments are stored - in the instance attribute and can be retrieved using ``vars(instance)``. - Meaning that upon re-initialization, ``obj.__init__(**vars(obj))`` will - re-initialize the instance with the same input arguments. - - Args: - func: The ``__init__`` method of a class. - is_lazy: A function that returns ``True`` if lazy initialization is ``True``. - the function accepts the same arguments as ``func``. - - Returns: - The decorated ``__init__`` method. - """ - - def inner(instance, *a, **k): - if not is_lazy(instance, *a, **k): - # continue with the original initialization - return func(instance, *a, **k) - - # store the input arguments to the instance - # until the instance is re-initialized (materialized) - # then use the stored arguments to re-initialize the instance - kwargs: dict[str, Any] = dict() - - for index, param in enumerate(get_params(func)[1:]): - # skip the self argument - if param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: - # fetch from the positional arguments - # or the keyword arguments or the default value if exists - if len(a) > index: - # fetch from the positional arguments if available - kwargs[param.name] = a[index] - elif param.name in k: - # fetch from the keyword arguments - kwargs[param.name] = k[param.name] - elif param.default is not inspect.Parameter.empty: - # fetch from the default value if exists - kwargs[param.name] = param.default - - elif param.kind is inspect.Parameter.KEYWORD_ONLY: - # fetch from the keyword arguments - # or the default value - if param.name in k: - # fetch from the keyword arguments if exists - kwargs[param.name] = k[param.name] - elif param.default is not inspect.Parameter.empty: - # fetch from the default value if exists - kwargs[param.name] = param.default - else: - # dont support positional only arguments, etc. - # not to complicate things - raise NotImplementedError(f"{param.kind=}") - - for key, value in kwargs.items(): - # set the attribute to the instance - # these will be reused to re-initialize the instance - # after the first call - setattr(instance, key, value) - - # halt the initialization of the instance - # and move to the next call - return None - - return ft.wraps(func)(inner) - - -LAZY_CALL_ERROR = """\ -Cannot call ``{func_name}`` directly on a lazy layer. -use ``value_and_tree(lambda layer: layer{func_name}(...))(layer)`` instead to return a tuple of: - - Layer output. - - Materialized layer. - -Example: - >>> layer = {class_name}(...) - >>> layer(input) # this will raise an error - ... - - Instead use the following pattern: - - >>> output, material = value_and_tree(lambda layer: layer{func_name}(input))(layer) - >>> material(input) - ... -""" - - -def maybe_lazy_call( - func: Callable[P, T], - is_lazy: Callable[..., bool], - updates: dict[str, Callable[..., Any]], -) -> Callable[P, T]: - """Reinitialize the instance if it is lazy. - - Accompanying decorator for ``maybe_lazy_init``. - - Args: - func: The method to decorate that accepts the arguments needed to re-initialize - the instance. - is_lazy: A function that returns ``True`` if lazy initialization is ``True``. - the function accepts the same arguments as ``func``. - updates: A dictionary of updates to the instance attributes. this dictionary - maps the attribute name to a function that accepts the attribute value - and returns the updated value. the function accepts the same arguments - as ``func``. - """ - - @ft.wraps(func) - def inner(instance, *a, **k): - if not is_lazy(instance, *a, **k): - return func(instance, *a, **k) - - # the instance variables are the input arguments - # to the ``__init__`` method - kwargs = dict(vars(instance)) - - for key, update in updates.items(): - kwargs[key] = update(instance, *a, **k) - - try: - for key in kwargs: - # clear the instance information (i.e. the initial input arguments) - # use ``delattr`` to raise an error if the instance is immutable - # which is marking the instance as lazy and immutable - delattr(instance, key) - except AttributeError: - # the instance is lazy and immutable - func_name = func.__name__ - func_name = "" if func_name == "__call__" else f".{func_name}" - class_name = type(instance).__name__ - kwargs = dict(func_name=func_name, class_name=class_name) - raise RuntimeError(LAZY_CALL_ERROR.format(**kwargs)) - - # re-initialize the instance with the resolved arguments - # this will only works under `value_and_tree` that allows - # the instance to be mutable with it's context after being copied first - type(instance).__init__(instance, **kwargs) - # call the decorated function - return func(instance, *a, **k) - - return inner diff --git a/serket/_src/utils/__init__.py b/serket/_src/utils/__init__.py new file mode 100644 index 00000000..650d4be4 --- /dev/null +++ b/serket/_src/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 serket authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/serket/_src/utils/convert.py b/serket/_src/utils/convert.py new file mode 100644 index 00000000..59847f2c --- /dev/null +++ b/serket/_src/utils/convert.py @@ -0,0 +1,38 @@ +# Copyright 2024 serket authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Sequence, TypeVar + +import jax +import jax.numpy as jnp + +T = TypeVar("T") + + +def canonicalize(value, ndim, name: str | None = None): + if isinstance(value, (int, float)): + return (value,) * ndim + if isinstance(value, jax.Array): + return jnp.repeat(value, ndim) + if isinstance(value, Sequence): + if len(value) != ndim: + raise ValueError(f"{len(value)=} != {ndim=} for {name=} and {value=}.") + return value + raise TypeError(f"Expected int or tuple for {name}, got {value=}.") + + +def tuplify(value: T) -> T | tuple[T]: + return tuple(value) if isinstance(value, Sequence) else (value,) diff --git a/serket/_src/utils/dispatch.py b/serket/_src/utils/dispatch.py new file mode 100644 index 00000000..4149a618 --- /dev/null +++ b/serket/_src/utils/dispatch.py @@ -0,0 +1,42 @@ +# Copyright 2024 serket authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import functools as ft + +from serket._src.utils.inspect import get_params + + +def single_dispatch(argnum: int = 0): + """Single dispatch with argnum""" + + def decorator(func): + dispatcher = ft.singledispatch(func) + + @ft.wraps(func) + def wrapper(*args, **kwargs): + try: + klass = type(args[argnum]) + except IndexError: + argname = get_params(func)[argnum].name + klass = type(kwargs[argname]) + return dispatcher.dispatch(klass)(*args, **kwargs) + + wrapper.def_type = dispatcher.register + wrapper.registry = dispatcher.registry + ft.update_wrapper(wrapper, func) + return wrapper + + return decorator diff --git a/serket/_src/utils/inspect.py b/serket/_src/utils/inspect.py new file mode 100644 index 00000000..65568f9f --- /dev/null +++ b/serket/_src/utils/inspect.py @@ -0,0 +1,25 @@ +# Copyright 2024 serket authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import functools as ft +import inspect +from types import MethodType + + +@ft.lru_cache(maxsize=128) +def get_params(func: MethodType) -> tuple[inspect.Parameter, ...]: + """Get the arguments of func.""" + return tuple(inspect.signature(func).parameters.values()) diff --git a/serket/_src/utils/lazy.py b/serket/_src/utils/lazy.py new file mode 100644 index 00000000..d2ccf1f5 --- /dev/null +++ b/serket/_src/utils/lazy.py @@ -0,0 +1,187 @@ +# Copyright 2024 serket authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module provides decorators to handle lazy layers in a functional way. + +For instance: +>>> net = Linear(None, 10, key=...) +is a lazy layer because the input features are passed as ``None``. The layer +is not materialized yet. To materialize the layer, you can use ``value_and_tree`` +to materialize the layer and get the output along with the materialized layer: +>>> output, material = value_and_tree(lambda layer: layer(input))(net) + +- Contrary to other framework, `serket` is eager-first lazy-second. This means + that the lazy is opt-in and not the default behavior. + +- The key idea is to store the input arguments to the instance attribute to + be used later when the instance is re-initialized using ``maybe_lazy_call`` + decorator. ``maybe_lazy_call`` assumes that the input arguments are stored + in the instance attribute and can be retrieved using ``vars(instance)``. + Meaning that upon re-initialization, ``obj.__init__(**vars(obj))`` will + re-initialize the instance with the same input arguments. + +- Because the instance is immutable, the process of re-initialization is + performed under ``value_and_tree`` that allows the instance to be mutable + with it's context after being copied first. +""" + +from __future__ import annotations + +import functools as ft +import inspect +from typing import Any, Callable, TypeVar + +from typing_extensions import ParamSpec + +from serket._src.utils.inspect import get_params + +P = ParamSpec("P") +T = TypeVar("T") + + +def handle_pos_or_kw(param: inspect.Parameter, index: int, args, kwargs): + if len(args) > index: + return args[index] + if param.name in kwargs: + return kwargs[param.name] + if param.default is not inspect.Parameter.empty: + return param.default + raise TypeError(f"{param.name} is required") + + +def handle_kw_only(param: inspect.Parameter, index, args, kwargs): + del index, args + if param.name in kwargs: + return kwargs[param.name] + if param.default is not inspect.Parameter.empty: + return param.default + raise TypeError(f"{param.name} is required") + + +ParamHandler = Callable[[inspect.Parameter, int, tuple[Any, ...], dict[str, Any]], Any] +rules: dict[Any, ParamHandler] = {} +rules[inspect.Parameter.POSITIONAL_OR_KEYWORD] = handle_pos_or_kw +rules[inspect.Parameter.KEYWORD_ONLY] = handle_kw_only + + +def maybe_lazy_init( + func: Callable[P, T], + is_lazy: Callable[..., bool], +) -> Callable[P, T]: + """Sets input arguments to instance attribute if lazy initialization is ``True``. + + The key idea is to store the input arguments to the instance attribute to + be used later when the instance is re-initialized using ``maybe_lazy_call`` + decorator. ``maybe_lazy_call`` assumes that the input arguments are stored + in the instance attribute and can be retrieved using ``vars(instance)``. + Meaning that upon re-initialization, ``obj.__init__(**vars(obj))`` will + re-initialize the instance with the same input arguments. + + Args: + func: The ``__init__`` method of a class. + is_lazy: A function that returns ``True`` if lazy initialization is ``True``. + the function accepts the same arguments as ``func``. + + Returns: + The decorated ``__init__`` method. + """ + + @ft.wraps(func) + def inner(instance, *a, **k): + if not is_lazy(instance, *a, **k): + return func(instance, *a, **k) + + # store the input arguments to the instance + # until the instance is re-initialized (materialized) + # then use the stored arguments to re-initialize the instance + for i, p in enumerate(get_params(func)[1:]): + setattr(instance, p.name, rules[p.kind](p, i, a, k)) + + return None + + return inner + + +LAZY_ERROR = """\ +Cannot call ``{fname}`` directly on a lazy layer. +use ``value_and_tree(lambda layer: layer{fname}(...))(layer)`` instead to return a tuple of: + - Layer output. + - Materialized layer. + +Example: + >>> layer = {cname}(...) + >>> layer(input) # this will raise an error + ... + + Instead use the following pattern: + + >>> output, material = value_and_tree(lambda layer: layer{fname}(input))(layer) + >>> material(input) + ... +""" + + +def maybe_lazy_call( + func: Callable[P, T], + is_lazy: Callable[..., bool], + updates: dict[str, Callable[..., Any]], +) -> Callable[P, T]: + """Reinitialize the instance if it is lazy. + + Accompanying decorator for ``maybe_lazy_init``. + + Args: + func: The method to decorate that accepts the arguments needed to re-initialize + the instance. + is_lazy: A function that returns ``True`` if lazy initialization is ``True``. + the function accepts the same arguments as ``func``. + updates: A dictionary of updates to the instance attributes. this dictionary + maps the attribute name to a function that accepts the attribute value + and returns the updated value. the function accepts the same arguments + as ``func``. + """ + + @ft.wraps(func) + def inner(instance, *a, **k): + if not is_lazy(instance, *a, **k): + return func(instance, *a, **k) + + # the instance variables are the input arguments + # to the ``__init__`` method + partial_mapping = dict(vars(instance)) + + for key, update in updates.items(): + partial_mapping[key] = update(instance, *a, **k) + + try: + for key in partial_mapping: + # clear the instance information (i.e. the initial input arguments) + # use ``delattr`` to raise an error if the instance is immutable + # which is marking the instance as lazy and immutable + delattr(instance, key) + except AttributeError: + # the instance is lazy and immutable + fname = "" if (fname := func.__name__) == "__call__" else f".{fname}" + cname = type(instance).__name__ + raise RuntimeError(LAZY_ERROR.format(fname=fname, cname=cname)) + + # re-initialize the instance with the resolved arguments + # this will only works under `value_and_tree` that allows + # the instance to be mutable with it's context after being copied first + init = getattr(type(instance), "__init__") + init(instance, **partial_mapping) + # call the decorated function + return func(instance, *a, **k) + + return inner diff --git a/serket/_src/utils/mapping.py b/serket/_src/utils/mapping.py new file mode 100644 index 00000000..c3237644 --- /dev/null +++ b/serket/_src/utils/mapping.py @@ -0,0 +1,150 @@ +# Copyright 2024 serket authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools as ft +from typing import Callable + +import jax +import jax.numpy as jnp +from jax.util import safe_zip + +from serket._src.utils.typing import PaddingMode + + +def kernel_map( + func: dict, + shape: tuple[int, ...], + kernel_size: tuple[int, ...], + strides: tuple[int, ...], + padding: tuple[tuple[int, int], ...], + padding_mode: PaddingMode = "constant", +) -> Callable: + """Minimal implementation of kmap from kernex""" + # Copyright 2023 Kernex authors + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + + # copied here to avoid requiring kernex as a dependency + # does not support most of the kernex features + if isinstance(padding_mode, (int, float)): + pad_kwargs = dict(mode="constant", constant_values=padding_mode) + elif isinstance(padding_mode, (str, Callable)): + pad_kwargs = dict(mode=padding_mode) + + gather_kwargs = dict( + mode="promise_in_bounds", + indices_are_sorted=True, + unique_indices=True, + ) + + def calculate_kernel_map_output_shape( + shape: tuple[int, ...], + kernel_size: tuple[int, ...], + strides: tuple[int, ...], + border: tuple[tuple[int, int], ...], + ) -> tuple[int, ...]: + return tuple( + (xi + (li + ri) - ki) // si + 1 + for xi, ki, si, (li, ri) in safe_zip(shape, kernel_size, strides, border) + ) + + @ft.partial(jax.profiler.annotate_function, name="general_arange") + def general_arange(di: int, ki: int, si: int, x0: int, xf: int) -> jax.Array: + # this function is used to calculate the windows indices for a given dimension + start, end = -x0 + ((ki - 1) // 2), di + xf - (ki // 2) + size = end - start + + res = ( + jax.lax.broadcasted_iota(dtype=jnp.int32, shape=(size, ki), dimension=0) + + jax.lax.broadcasted_iota(dtype=jnp.int32, shape=(ki, size), dimension=0).T + + (start) + - ((ki - 1) // 2) + ) + + # res[::si] is slightly slower. + return (res) if si == 1 else (res)[::si] + + @ft.lru_cache(maxsize=128) + def recursive_vmap(*, ndim: int): + def nvmap(n): + in_axes = [None] * ndim + in_axes[-n] = 0 + return ( + jax.vmap(lambda *x: x, in_axes=in_axes) + if n == 1 + else jax.vmap(nvmap(n - 1), in_axes=in_axes) + ) + + return nvmap(ndim) + + @ft.partial(jax.profiler.annotate_function, name="general_product") + def general_product(*args: jax.Array): + return recursive_vmap(ndim=len(args))(*args) + + def generate_views( + shape: tuple[int, ...], + kernel_size: tuple[int, ...], + strides: tuple[int, ...], + border: tuple[tuple[int, int], ...], + ) -> tuple[jax.Array, ...]: + dim_range = tuple( + general_arange(di, ki, si, x0, xf) + for (di, ki, si, (x0, xf)) in zip(shape, kernel_size, strides, border) + ) + matrix = general_product(*dim_range) + return tuple(map(lambda xi, wi: xi.reshape(-1, wi), matrix, kernel_size)) + + def absolute_wrapper(*a, **k): + def map_func(view: tuple[jax.Array, ...], array: jax.Array): + patch = array.at[ix_(*view)].get(**gather_kwargs) + return func(patch, *a, **k) + + return map_func + + def ix_(*args): + """modified version of jnp.ix_""" + n = len(args) + output = [] + for i, a in enumerate(args): + shape = [1] * n + shape[i] = a.shape[0] + output.append(jax.lax.broadcast_in_dim(a, shape, (i,))) + return tuple(output) + + pad_width = tuple([0, max(0, pi[0]) + max(0, pi[1])] for pi in padding) + args = (shape, kernel_size, strides, padding) + views = generate_views(*args) + output_shape = calculate_kernel_map_output_shape(*args) + + def single_call_wrapper(array: jax.Array, *a, **k): + padded_array = jnp.pad(array, pad_width, **pad_kwargs) + reduced_func = absolute_wrapper(*a, **k) + + def map_func(view): + return reduced_func(view, padded_array) + + result = jax.vmap(map_func)(views) + return result.reshape(*output_shape, *result.shape[1:]) + + return single_call_wrapper diff --git a/serket/_src/utils/padding.py b/serket/_src/utils/padding.py new file mode 100644 index 00000000..c5426141 --- /dev/null +++ b/serket/_src/utils/padding.py @@ -0,0 +1,136 @@ +# Copyright 2024 serket authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import functools as ft + +from serket._src.utils.typing import KernelSizeType, PaddingType, StridesType + + +def same_padding_along_dim( + in_dim: int, + kernel_size: int, + stride: int, +) -> tuple[int, int]: + # https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2 + # di: input dimension + # ki: kernel size + # si: stride + if in_dim % stride == 0: + pad = max(kernel_size - stride, 0) + else: + pad = max(kernel_size - (in_dim % stride), 0) + + return (pad // 2, pad - pad // 2) + + +def resolve_tuple_padding( + in_dim: tuple[int, ...], + padding: PaddingType, + kernel_size: KernelSizeType, + strides: StridesType, +) -> tuple[tuple[int, int], ...]: + del in_dim, strides + if len(padding) != len(kernel_size): + raise ValueError(f"Length mismatch {len(kernel_size)=}!={len(padding)=}.") + + resolved_padding = [[]] * len(kernel_size) + + for i, item in enumerate(padding): + if isinstance(item, int): + resolved_padding[i] = (item, item) # ex: padding = (1, 2, 3) + + elif isinstance(item, tuple): + if len(item) != 2: + raise ValueError(f"Expected tuple of length 2, got {len(item)=}") + resolved_padding[i] = item + + return tuple(resolved_padding) + + +def resolve_int_padding( + in_dim: tuple[int, ...], + padding: PaddingType, + kernel_size: KernelSizeType, + strides: StridesType, +): + del in_dim, strides + return ((padding, padding),) * len(kernel_size) + + +def resolve_string_padding(in_dim, padding, kernel_size, strides): + if padding.lower() == "same": + return tuple( + same_padding_along_dim(di, ki, si) + for di, ki, si in zip(in_dim, kernel_size, strides) + ) + + if padding.lower() == "valid": + return ((0, 0),) * len(kernel_size) + + raise ValueError(f'string argument must be in ["same","valid"].Found {padding}') + + +@ft.lru_cache(maxsize=128) +def delayed_canonicalize_padding( + in_dim: tuple[int, ...], + padding: PaddingType, + kernel_size: KernelSizeType, + strides: StridesType, +): + # in case of `str` padding, we need to know the input dimension + # to calculate the padding thus we need to delay the canonicalization + # until the call + + if isinstance(padding, int): + return resolve_int_padding(in_dim, padding, kernel_size, strides) + + if isinstance(padding, str): + return resolve_string_padding(in_dim, padding, kernel_size, strides) + + if isinstance(padding, tuple): + return resolve_tuple_padding(in_dim, padding, kernel_size, strides) + + raise ValueError( + "Expected padding to be of:\n" + "* int, for same padding along all dimensions\n" + "* str, for `same` or `valid` padding along all dimensions\n" + "* tuple of int, for individual padding along each dimension\n" + "* tuple of tuple of int, for padding before and after each dimension\n" + f"Got {padding=}." + ) + + +@ft.lru_cache(maxsize=128) +def calculate_transpose_padding( + padding, + kernel_size, + input_dilation, + extra_padding, +): + """Transpose padding to get the padding for the transpose convolution. + + Args: + padding: padding to transpose + kernel_size: kernel size to use for transposing padding + input_dilation: input dilation to use for transposing padding + extra_padding: extra padding to use for transposing padding + """ + return tuple( + ((ki - 1) * di - pl, (ki - 1) * di - pr + ep) + for (pl, pr), ki, ep, di in zip( + padding, kernel_size, extra_padding, input_dilation + ) + ) diff --git a/serket/_src/utils/typing.py b/serket/_src/utils/typing.py new file mode 100644 index 00000000..1c78b744 --- /dev/null +++ b/serket/_src/utils/typing.py @@ -0,0 +1,67 @@ +# Copyright 2024 serket authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Annotated, Any, Callable, Literal, Sequence, Tuple, TypeVar, Union + +import jax +import numpy as np +from typing_extensions import ParamSpec + +KernelSizeType = Union[int, Sequence[int]] +StridesType = Union[int, Sequence[int]] +PaddingType = Union[str, int, Sequence[int], Sequence[Tuple[int, int]]] +DilationType = Union[int, Sequence[int]] +P = ParamSpec("P") +T = TypeVar("T") +HWArray = Annotated[jax.Array, "HW"] +CHWArray = Annotated[jax.Array, "CHW"] +PaddingLiteral = Literal[ + "constant", + "edge", + "linear_ramp", + "maximum", + "mean", + "median", + "minimum", + "reflect", + "symmetric", + "wrap", +] +PaddingMode = Union[PaddingLiteral, Union[int, float], Callable] + + +InitLiteral = Literal[ + "he_normal", + "he_uniform", + "glorot_normal", + "glorot_uniform", + "lecun_normal", + "lecun_uniform", + "normal", + "uniform", + "ones", + "zeros", + "xavier_normal", + "xavier_uniform", + "orthogonal", +] + +Shape = Tuple[int, ...] +DType = Union[np.dtype, str, Any] +InitFuncType = Callable[[jax.Array, Shape, DType], jax.Array] +InitType = Union[InitLiteral, InitFuncType] +MethodKind = Literal["nearest", "linear", "cubic", "lanczos3", "lanczos5"] +Weight = Annotated[jax.Array, "OI..."] diff --git a/serket/_src/utils/validate.py b/serket/_src/utils/validate.py new file mode 100644 index 00000000..e96e886a --- /dev/null +++ b/serket/_src/utils/validate.py @@ -0,0 +1,124 @@ +# Copyright 2024 serket authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import functools as ft +import operator as op +from typing import Any, Callable, Sequence, TypeVar + +import jax +import numpy as np +from typing_extensions import ParamSpec + +from serket import TreeClass, autoinit + +P = ParamSpec("P") +T = TypeVar("T") + + +@autoinit +class Range(TreeClass): + min_val: float = -float("inf") + max_val: float = float("inf") + min_inclusive: bool = True + max_inclusive: bool = True + + def __call__(self, value: Any): + lop, ls = (op.ge, "[") if self.min_inclusive else (op.gt, "(") + rop, rs = (op.le, "]") if self.max_inclusive else (op.lt, ")") + + if lop(value, self.min_val) and rop(value, self.max_val): + return value + + raise ValueError(f"Not in {ls}{self.min_val}, {self.max_val}{rs} got {value=}.") + + +@autoinit +class IsInstance(TreeClass): + klass: type | Sequence[type] + + def __call__(self, value: Any): + if isinstance(value, self.klass): + return value + raise TypeError(f"Expected {self.klass}, got {type(value).__name__}") + + +class ScalarLike(TreeClass): + """Check if the input is a scalar""" + + def __call__(self, value: Any): + if isinstance(value, (float, complex)): + return value + if ( + isinstance(value, (jax.Array, np.ndarray)) + and np.issubdtype(value.dtype, np.inexact) + and value.shape == () + ): + return value + raise ValueError(f"Expected inexact type got {value=}") + + +def validate_pos_int(value): + """Return if value is a positive integer, otherwise raise an error.""" + if not isinstance(value, int): + raise ValueError(f"value must be an integer, got {type(value).__name__}") + if value <= 0: + raise ValueError(f"{value=} must be positive.") + return value + + +def validate_spatial_ndim(func: Callable[P, T], argnum: int = 0) -> Callable[P, T]: + """Decorator to validate spatial input shape.""" + + @ft.wraps(func) + def wrapper(self, *args, **kwargs): + input = args[argnum] + spatial_ndim = self.spatial_ndim + + if input.ndim != spatial_ndim + 1: + spatial = ", ".join(("rows", "cols", "depths")[:spatial_ndim]) + name = type(self).__name__ + raise ValueError( + f"Dimesion mismatch error in inputs of {name}\n" + f"Input should satisfy:\n" + f" - {(spatial_ndim + 1) = } dimension, but got {input.ndim = }.\n" + f" - shape of (in_features, {spatial}), but got {input.shape = }.\n" + + ( + # maybe the user apply the layer on a batched input + "The input should be unbatched (no batch dimension).\n" + "To apply on batched input, use `jax.vmap(...)(input)`." + if input.ndim == spatial_ndim + 2 + else "" + ) + ) + return func(self, *args, **kwargs) + + return wrapper + + +def validate_in_features_shape(func: Callable[P, T], axis: int) -> Callable[P, T]: + """Decorator to validate input features.""" + + def check_axis_shape(input, in_features: int, axis: int) -> None: + if input.shape[axis] != in_features: + raise ValueError(f"Specified {in_features=}, got {input.shape[axis]=}.") + return input + + @ft.wraps(func) + def wrapper(self, array, *a, **k): + check_axis_shape(array, self.in_features, axis) + return func(self, array, *a, **k) + + return wrapper diff --git a/serket/cluster/__init__.py b/serket/cluster/__init__.py index 885ffe31..a7676684 100644 --- a/serket/cluster/__init__.py +++ b/serket/cluster/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/serket/image/__init__.py b/serket/image/__init__.py index 9ece8df8..67698db0 100644 --- a/serket/image/__init__.py +++ b/serket/image/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/serket/nn/__init__.py b/serket/nn/__init__.py index 525c17a6..5c7b9336 100644 --- a/serket/nn/__init__.py +++ b/serket/nn/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/__init__.py b/tests/__init__.py index dbf0b046..650d4be4 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_activation.py b/tests/test_activation.py index 47ccd568..ae27d963 100644 --- a/tests/test_activation.py +++ b/tests/test_activation.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_attention.py b/tests/test_attention.py index 7fce8228..ccf044d2 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_clustering.py b/tests/test_clustering.py index a2ead621..bd791f29 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,11 +19,13 @@ import jax.random as jr import numpy.testing as npt import pytest + import serket as sk # Suppress FutureWarning warnings.simplefilter(action="ignore", category=FutureWarning) + @pytest.mark.skip(reason="flaky test") def test_kmeans(): from sklearn.cluster import KMeans diff --git a/tests/test_containers.py b/tests/test_containers.py index 5054e108..13a8691a 100644 --- a/tests/test_containers.py +++ b/tests/test_containers.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_conv.py b/tests/test_conv.py index 1fc89ba5..91165d1e 100644 --- a/tests/test_conv.py +++ b/tests/test_conv.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 962707d6..584b82a5 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_dropout.py b/tests/test_dropout.py index 9b00d174..cfa8c034 100644 --- a/tests/test_dropout.py +++ b/tests/test_dropout.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 5f45793e..38d93087 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_linear.py b/tests/test_linear.py index e9a0b612..5cf0c6d8 100644 --- a/tests/test_linear.py +++ b/tests/test_linear.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 6e2f26f4..38fa1e45 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_pooling.py b/tests/test_pooling.py index 7af3c017..22927f12 100644 --- a/tests/test_pooling.py +++ b/tests/test_pooling.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_reshape.py b/tests/test_reshape.py index 0f55ab35..27c47eed 100644 --- a/tests/test_reshape.py +++ b/tests/test_reshape.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_rnn.py b/tests/test_rnn.py index 52605671..4817fb9d 100644 --- a/tests/test_rnn.py +++ b/tests/test_rnn.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_sequential.py b/tests/test_sequential.py index 8f108158..fe0af272 100644 --- a/tests/test_sequential.py +++ b/tests/test_sequential.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_utils.py b/tests/test_utils.py index 2ec2e740..9504330b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -# Copyright 2023 serket authors +# Copyright 2024 serket authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,15 +21,13 @@ import serket as sk from serket._src.nn.initialization import resolve_init -from serket._src.utils import ( - IsInstance, - ScalarLike, - canonicalize, +from serket._src.utils.convert import canonicalize +from serket._src.utils.padding import ( delayed_canonicalize_padding, - positive_int_cb, resolve_string_padding, resolve_tuple_padding, ) +from serket._src.utils.validate import IsInstance, ScalarLike, validate_pos_int @pytest.mark.parametrize( @@ -129,9 +127,9 @@ def test_scalar_like_error(): npt.assert_allclose(ScalarLike()(jax.numpy.array(1.0)), jax.numpy.array(1.0)) -def test_positive_int_cb_error(): +def test_validate_pos_int_error(): with pytest.raises(ValueError): - positive_int_cb(1.0) + validate_pos_int(1.0) def test_lazy_call():