diff --git a/pyproject.toml b/pyproject.toml index a10f10f..eb1a5ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,27 +6,32 @@ build-backend = "setuptools.build_meta" name = "serket" dynamic = ["version"] requires-python = ">=3.8" -license = {text = "Apache-2.0"} +license = { text = "Apache-2.0" } description = "Functional neural network library in JAX" -authors = [{name = "Mahmoud Asem", email = "mahmoudasem00@gmail.com"}] -keywords = ["jax", "neural-networks", "functional-programming", "machine-learning"] -dependencies = ["jax>=0.4.7", "typing-extensions"] +authors = [{ name = "Mahmoud Asem", email = "mahmoudasem00@gmail.com" }] +keywords = [ + "jax", + "neural-networks", + "functional-programming", + "machine-learning", +] +dependencies = ["pytreeclass>=0.4.0"] -classifiers=[ - "Development Status :: 5 - Production/Stable", - "Environment :: Console", - "Intended Audience :: Science/Research", - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Software Development :: Libraries :: Python Modules", +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", ] [tool.setuptools.dynamic] -version = {attr = "serket.__version__" } +version = { attr = "serket.__version__" } [tool.setuptools.packages.find] include = ["serket", "serket.*"] @@ -39,5 +44,5 @@ Source = "https://github.com/ASEM000/Serket" select = ["F", "E", "I"] line-length = 120 ignore = [ - "E731", # do not assign a lambda expression, use a def + "E731", # do not assign a lambda expression, use a def ] diff --git a/serket/experimental/test_lazy_class.py b/serket/experimental/test_lazy_class.py index b86fa18..9d32ad5 100644 --- a/serket/experimental/test_lazy_class.py +++ b/serket/experimental/test_lazy_class.py @@ -17,8 +17,8 @@ import jax import jax.numpy as jnp import pytest -import pytreeclass as pytc +import serket as sk from serket.experimental import lazy_class @@ -30,7 +30,7 @@ def test_lazy_class(): infer_method_name="__call__", # -> `infer_func` is applied to `__call__` method lazy_marker=None, # -> `None` is used to indicate a lazy argument ) - class LazyLinear(pytc.TreeClass): + class LazyLinear(sk.TreeClass): weight: jax.Array bias: jax.Array diff --git a/serket/nn/activation.py b/serket/nn/activation.py index d903d96..e9d9fd5 100644 --- a/serket/nn/activation.py +++ b/serket/nn/activation.py @@ -18,198 +18,198 @@ import jax import jax.numpy as jnp -import pytreeclass as pytc from jax import lax +import serket as sk from serket.nn.utils import IsInstance, Range, ScalarLike -class AdaptiveLeakyReLU(pytc.TreeClass): +class AdaptiveLeakyReLU(sk.TreeClass): """Leaky ReLU activation function Note: https://arxiv.org/pdf/1906.01170.pdf. """ - a: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()]) - v: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()]) + a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()]) + v: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array, **k) -> jax.Array: v = jax.lax.stop_gradient(self.v) return jnp.maximum(0, self.a * x) - v * jnp.maximum(0, -self.a * x) -class AdaptiveReLU(pytc.TreeClass): +class AdaptiveReLU(sk.TreeClass): """ReLU activation function with learnable parameters Note: https://arxiv.org/pdf/1906.01170.pdf. """ - a: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()]) + a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array, **k) -> jax.Array: return jnp.maximum(0, self.a * x) -class AdaptiveSigmoid(pytc.TreeClass): +class AdaptiveSigmoid(sk.TreeClass): """Sigmoid activation function with learnable `a` parameter Note: https://arxiv.org/pdf/1906.01170.pdf. """ - a: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()]) + a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array, **k) -> jax.Array: return 1 / (1 + jnp.exp(-self.a * x)) -class AdaptiveTanh(pytc.TreeClass): +class AdaptiveTanh(sk.TreeClass): """Tanh activation function with learnable parameters Note: https://arxiv.org/pdf/1906.01170.pdf. """ - a: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()]) + a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array, **k) -> jax.Array: a = self.a return (jnp.exp(a * x) - jnp.exp(-a * x)) / (jnp.exp(a * x) + jnp.exp(-a * x)) -class CeLU(pytc.TreeClass): +class CeLU(sk.TreeClass): """Celu activation function""" - alpha: float = pytc.field(default=1.0, callbacks=[ScalarLike()]) + alpha: float = sk.field(default=1.0, callbacks=[ScalarLike()]) def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.celu(x, alpha=lax.stop_gradient(self.alpha)) -class ELU(pytc.TreeClass): +class ELU(sk.TreeClass): """Exponential linear unit""" - alpha: float = pytc.field(default=1.0, callbacks=[ScalarLike()]) + alpha: float = sk.field(default=1.0, callbacks=[ScalarLike()]) def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.elu(x, alpha=lax.stop_gradient(self.alpha)) -class GELU(pytc.TreeClass): +class GELU(sk.TreeClass): """Gaussian error linear unit""" - approximate: bool = pytc.field(default=1.0, callbacks=[IsInstance(bool)]) + approximate: bool = sk.field(default=1.0, callbacks=[IsInstance(bool)]) def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.gelu(x, approximate=self.approximate) -class GLU(pytc.TreeClass): +class GLU(sk.TreeClass): """Gated linear unit""" def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.glu(x) -class HardShrink(pytc.TreeClass): +class HardShrink(sk.TreeClass): """Hard shrink activation function""" - alpha: float = pytc.field(default=0.5, callbacks=[Range(0), ScalarLike()]) + alpha: float = sk.field(default=0.5, callbacks=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array, **k) -> jax.Array: alpha = lax.stop_gradient(self.alpha) return jnp.where(x > alpha, x, jnp.where(x < -alpha, x, 0.0)) -class HardSigmoid(pytc.TreeClass): +class HardSigmoid(sk.TreeClass): """Hard sigmoid activation function""" def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.hard_sigmoid(x) -class HardSwish(pytc.TreeClass): +class HardSwish(sk.TreeClass): """Hard swish activation function""" def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.hard_swish(x) -class HardTanh(pytc.TreeClass): +class HardTanh(sk.TreeClass): """Hard tanh activation function""" def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.hard_tanh(x) -class LogSigmoid(pytc.TreeClass): +class LogSigmoid(sk.TreeClass): """Log sigmoid activation function""" def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.log_sigmoid(x) -class LogSoftmax(pytc.TreeClass): +class LogSoftmax(sk.TreeClass): """Log softmax activation function""" def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.log_softmax(x) -class LeakyReLU(pytc.TreeClass): +class LeakyReLU(sk.TreeClass): """Leaky ReLU activation function""" - negative_slope: float = pytc.field(default=0.01, callbacks=[Range(0), ScalarLike()]) + negative_slope: float = sk.field(default=0.01, callbacks=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.leaky_relu(x, lax.stop_gradient(self.negative_slope)) -class ReLU(pytc.TreeClass): +class ReLU(sk.TreeClass): """ReLU activation function""" def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.relu(x) -class ReLU6(pytc.TreeClass): +class ReLU6(sk.TreeClass): """ReLU6 activation function""" def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.relu6(x) -class SeLU(pytc.TreeClass): +class SeLU(sk.TreeClass): """Scaled Exponential Linear Unit""" def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.selu(x) -class Sigmoid(pytc.TreeClass): +class Sigmoid(sk.TreeClass): """Sigmoid activation function""" def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.sigmoid(x) -class SoftPlus(pytc.TreeClass): +class SoftPlus(sk.TreeClass): """SoftPlus activation function""" def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.softplus(x) -class SoftSign(pytc.TreeClass): +class SoftSign(sk.TreeClass): """SoftSign activation function""" def __call__(self, x: jax.Array, **k) -> jax.Array: return x / (1 + jnp.abs(x)) -class SoftShrink(pytc.TreeClass): +class SoftShrink(sk.TreeClass): """SoftShrink activation function""" - alpha: float = pytc.field(default=0.5, callbacks=[Range(0), ScalarLike()]) + alpha: float = sk.field(default=0.5, callbacks=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array, **k) -> jax.Array: alpha = lax.stop_gradient(self.alpha) @@ -220,61 +220,61 @@ def __call__(self, x: jax.Array, **k) -> jax.Array: ) -class SquarePlus(pytc.TreeClass): +class SquarePlus(sk.TreeClass): """SquarePlus activation function""" def __call__(self, x: jax.Array, **k) -> jax.Array: return 0.5 * (x + jnp.sqrt(x * x + 4)) -class Swish(pytc.TreeClass): +class Swish(sk.TreeClass): """Swish activation function""" def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.swish(x) -class Tanh(pytc.TreeClass): +class Tanh(sk.TreeClass): """Tanh activation function""" def __call__(self, x: jax.Array, **k) -> jax.Array: return jax.nn.tanh(x) -class TanhShrink(pytc.TreeClass): +class TanhShrink(sk.TreeClass): """TanhShrink activation function""" def __call__(self, x: jax.Array, **k) -> jax.Array: return x - jax.nn.tanh(x) -class ThresholdedReLU(pytc.TreeClass): +class ThresholdedReLU(sk.TreeClass): """Thresholded ReLU activation function.""" - theta: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()]) + theta: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array, **k) -> jax.Array: theta = lax.stop_gradient(self.theta) return jnp.where(x > theta, x, 0) -class Mish(pytc.TreeClass): +class Mish(sk.TreeClass): """Mish activation function https://arxiv.org/pdf/1908.08681.pdf.""" def __call__(self, x: jax.Array, **k) -> jax.Array: return x * jax.nn.tanh(jax.nn.softplus(x)) -class PReLU(pytc.TreeClass): +class PReLU(sk.TreeClass): """Parametric ReLU activation function""" - a: float = pytc.field(default=0.25, callbacks=[Range(0), ScalarLike()]) + a: float = sk.field(default=0.25, callbacks=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array, **k) -> jax.Array: return jnp.where(x >= 0, x, x * self.a) -class Snake(pytc.TreeClass): +class Snake(sk.TreeClass): """Snake activation function Args: @@ -284,7 +284,7 @@ class Snake(pytc.TreeClass): https://arxiv.org/pdf/2006.08195.pdf. """ - a: float = pytc.field(callbacks=[Range(0), ScalarLike()], default=1.0) + a: float = sk.field(callbacks=[Range(0), ScalarLike()], default=1.0) def __call__(self, x: jax.Array, **k) -> jax.Array: a = lax.stop_gradient(self.a) @@ -360,7 +360,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array: ] -act_map: dict[str, pytc.TreeClass] = dict(zip(get_args(ActivationLiteral), acts)) +act_map: dict[str, sk.TreeClass] = dict(zip(get_args(ActivationLiteral), acts)) ActivationFunctionType = Callable[[jax.typing.ArrayLike], jax.Array] ActivationType = Union[ActivationLiteral, ActivationFunctionType] diff --git a/serket/nn/blocks/unet.py b/serket/nn/blocks/unet.py index f1bd99a..1af5076 100644 --- a/serket/nn/blocks/unet.py +++ b/serket/nn/blocks/unet.py @@ -20,13 +20,12 @@ import jax import jax.numpy as jnp -import pytreeclass as pytc import serket as sk from serket.nn.utils import positive_int_cb -class ResizeAndCat(pytc.TreeClass): +class ResizeAndCat(sk.TreeClass): def __call__(self, x1: jax.Array, x2: jax.Array) -> jax.Array: """resize a tensor to the same size as another tensor and concatenate x2 to x1 along the channel axis""" x1 = jax.image.resize(x1, shape=x2.shape, method="nearest") @@ -34,7 +33,7 @@ def __call__(self, x1: jax.Array, x2: jax.Array) -> jax.Array: return x1 -class DoubleConvBlock(pytc.TreeClass): +class DoubleConvBlock(sk.TreeClass): def __init__(self, in_features: int, out_features: int): self.conv1 = sk.nn.Conv2D( in_features=in_features, @@ -59,7 +58,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array: return x -class UpscaleBlock(pytc.TreeClass): +class UpscaleBlock(sk.TreeClass): def __init__(self, in_features: int, out_features: int): self.conv = sk.nn.Conv2DTranspose( in_features=in_features, @@ -72,7 +71,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array: return self.conv(x) -class UNetBlock(pytc.TreeClass): +class UNetBlock(sk.TreeClass): """Vanilla UNet Args: @@ -82,10 +81,10 @@ class UNetBlock(pytc.TreeClass): init_features : number of features in the first block. Default is 64 """ - in_features: int = pytc.field(callbacks=[positive_int_cb]) - out_features: int = pytc.field(callbacks=[positive_int_cb]) - blocks: int = pytc.field(callbacks=[positive_int_cb], default=4) - init_features: int = pytc.field(callbacks=[positive_int_cb], default=64) + in_features: int = sk.field(callbacks=[positive_int_cb]) + out_features: int = sk.field(callbacks=[positive_int_cb]) + blocks: int = sk.field(callbacks=[positive_int_cb], default=4) + init_features: int = sk.field(callbacks=[positive_int_cb], default=64) def __post_init__(self): """ diff --git a/serket/nn/blocks/vgg.py b/serket/nn/blocks/vgg.py index 741a1e1..f1661b2 100644 --- a/serket/nn/blocks/vgg.py +++ b/serket/nn/blocks/vgg.py @@ -16,12 +16,11 @@ import jax import jax.random as jr -import pytreeclass as pytc import serket as sk -class VGG16Block(pytc.TreeClass): +class VGG16Block(sk.TreeClass): def __init__( self, in_features: int, @@ -113,7 +112,7 @@ def __call__(self, x: jax.Array, **kwargs) -> jax.Array: return x -class VGG19Block(pytc.TreeClass): +class VGG19Block(sk.TreeClass): def __init__( self, in_feautres: int, diff --git a/serket/nn/blur.py b/serket/nn/blur.py index 0aa2870..b8c2dcd 100644 --- a/serket/nn/blur.py +++ b/serket/nn/blur.py @@ -19,33 +19,34 @@ import jax import jax.numpy as jnp -import pytreeclass as pytc from jax import lax +import serket as sk from serket.nn.convolution import DepthwiseConv2D from serket.nn.fft_convolution import DepthwiseFFTConv2D from serket.nn.utils import positive_int_cb, validate_axis_shape, validate_spatial_ndim -class AvgBlur2D(pytc.TreeClass): - def __init__(self, in_features: int, kernel_size: int | tuple[int, int]): - """Average blur 2D layer - Args: - in_features: number of input channels - kernel_size: size of the convolving kernel +class AvgBlur2D(sk.TreeClass): + """Average blur 2D layer + + Args: + in_features: number of input channels. + kernel_size: size of the convolving kernel. - Example: + Example: >>> import serket as sk >>> import jax.numpy as jnp >>> layer = sk.nn.AvgBlur2D(in_features=1, kernel_size=3) >>> print(layer(jnp.ones((1,5,5)))) [[[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448] - [0.6666667 1. 1. 1. 0.6666667 ] - [0.6666667 1. 1. 1. 0.6666667 ] - [0.6666667 1. 1. 1. 0.6666667 ] - [0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]] + [0.6666667 1. 1. 1. 0.6666667 ] + [0.6666667 1. 1. 1. 0.6666667 ] + [0.6666667 1. 1. 1. 0.6666667 ] + [0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]] + """ - """ + def __init__(self, in_features: int, kernel_size: int | tuple[int, int]): self.in_features = positive_int_cb(in_features) self.kernel_size = positive_int_cb(kernel_size) @@ -82,26 +83,27 @@ def spatial_ndim(self) -> int: return 2 -class GaussianBlur2D(pytc.TreeClass): - def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0): - """Apply Gaussian blur to a channel-first image. +class GaussianBlur2D(sk.TreeClass): + """Apply Gaussian blur to a channel-first image. - Args: - in_features: number of input features - kernel_size: kernel size - sigma: sigma. Defaults to 1. + Args: + in_features: number of input features + kernel_size: kernel size + sigma: sigma. Defaults to 1. - Example: + Example: >>> import serket as sk >>> import jax.numpy as jnp >>> layer = sk.nn.GaussianBlur2D(in_features=1, kernel_size=3) >>> print(layer(jnp.ones((1,5,5)))) [[[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764] - [0.7259314 1. 1. 1. 0.7259314] - [0.7259314 1. 1. 1. 0.7259314] - [0.7259314 1. 1. 1. 0.7259314] - [0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]] - """ + [0.7259314 1. 1. 1. 0.7259314] + [0.7259314 1. 1. 1. 0.7259314] + [0.7259314 1. 1. 1. 0.7259314] + [0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]] + """ + + def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0): self.in_features = positive_int_cb(in_features) self.kernel_size = positive_int_cb(kernel_size) @@ -143,24 +145,26 @@ def spatial_ndim(self) -> int: return 2 -class Filter2D(pytc.TreeClass): - def __init__(self, in_features: int, kernel: jax.Array): - """Apply 2D filter for each channel - Args: - in_features: number of input channels - kernel: kernel array +class Filter2D(sk.TreeClass): + """Apply 2D filter for each channel - Example: + Args: + in_features: number of input channels. + kernel: kernel array. + + Example: >>> import serket as sk >>> import jax.numpy as jnp >>> layer = sk.nn.Filter2D(in_features=1, kernel=jnp.ones((3,3))) >>> print(layer(jnp.ones((1,5,5)))) [[[4. 6. 6. 6. 4.] - [6. 9. 9. 9. 6.] - [6. 9. 9. 9. 6.] - [6. 9. 9. 9. 6.] - [4. 6. 6. 6. 4.]]] - """ + [6. 9. 9. 9. 6.] + [6. 9. 9. 9. 6.] + [6. 9. 9. 9. 6.] + [4. 6. 6. 6. 4.]]] + """ + + def __init__(self, in_features: int, kernel: jax.Array): if not isinstance(kernel, jax.Array) or kernel.ndim != 2: raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)") @@ -186,24 +190,26 @@ def spatial_ndim(self) -> int: return 2 -class FFTFilter2D(pytc.TreeClass): - def __init__(self, in_features: int, kernel: jax.Array): - """Apply 2D filter for each channel using FFT - Args: - in_features: number of input channels - kernel: kernel array +class FFTFilter2D(sk.TreeClass): + """Apply 2D filter for each channel using FFT - Example: + Args: + in_features: number of input channels + kernel: kernel array + + Example: >>> import serket as sk >>> import jax.numpy as jnp >>> layer = sk.nn.FFTFilter2D(in_features=1, kernel=jnp.ones((3,3))) >>> print(layer(jnp.ones((1,5,5)))) [[[4.0000005 6.0000005 6.000001 6.0000005 4.0000005] - [6.0000005 9. 9. 9. 6.0000005] - [6.0000005 9. 9. 9. 6.0000005] - [6.0000005 9. 9. 9. 6.0000005] - [4. 6.0000005 6.0000005 6.0000005 4. ]]] - """ + [6.0000005 9. 9. 9. 6.0000005] + [6.0000005 9. 9. 9. 6.0000005] + [6.0000005 9. 9. 9. 6.0000005] + [4. 6.0000005 6.0000005 6.0000005 4. ]]] + """ + + def __init__(self, in_features: int, kernel: jax.Array): if not isinstance(kernel, jax.Array) or kernel.ndim != 2: raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)") diff --git a/serket/nn/containers.py b/serket/nn/containers.py index d334f6d..4ed018d 100644 --- a/serket/nn/containers.py +++ b/serket/nn/containers.py @@ -19,12 +19,12 @@ import jax import jax.random as jr -import pytreeclass as pytc +import serket as sk from serket.nn.utils import IsInstance -class Sequential(pytc.TreeClass): +class Sequential(sk.TreeClass): """A sequential container for layers. Args: @@ -38,10 +38,14 @@ class Sequential(pytc.TreeClass): >>> layers = sk.nn.Sequential((lambda x: x + 1, lambda x: x * 2)) >>> print(layers(jnp.array([1, 2, 3]), key=jr.PRNGKey(0))) [4 6 8] + + Note: + Layer might be a function or a class with a `__call__` method, additionally + it might have a key argument for random number generation. """ # allow list then cast to tuple avoid mutability issues - layers: tuple[Any, ...] = pytc.field(callbacks=[IsInstance((tuple, list)), tuple]) + layers: tuple[Any, ...] = sk.field(callbacks=[IsInstance((tuple, list)), tuple]) def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array: for key, layer in zip(jr.split(key, len(self.layers)), self.layers): diff --git a/serket/nn/contrast.py b/serket/nn/contrast.py index 9fe8908..0e9238c 100644 --- a/serket/nn/contrast.py +++ b/serket/nn/contrast.py @@ -20,9 +20,9 @@ import jax import jax.numpy as jnp import jax.random as jr -import pytreeclass as pytc from jax import lax +import serket as sk from serket.nn.utils import validate_spatial_ndim @@ -43,10 +43,10 @@ def random_contrast_nd( return adjust_contrast_nd(x, contrast_factor) -class AdjustContrastND(pytc.TreeClass): +class AdjustContrastND(sk.TreeClass): """Adjusts the contrast of an NDimage by scaling the pixel values by a factor. - See: + Note: https://www.tensorflow.org/api_docs/python/tf/image/adjust_contrast https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py """ @@ -65,12 +65,13 @@ def spatial_ndim(self) -> int: class AdjustContrast2D(AdjustContrastND): - def __init__(self, contrast_factor=1.0): - """Adjusts the contrast of an image by scaling the pixel values by a factor. + """Adjusts the contrast of an image by scaling the pixel values by a factor. + + Args: + contrast_factor: contrast factor to adjust the image by. + """ - Args: - contrast_factor: contrast factor to adjust the image by. - """ + def __init__(self, contrast_factor=1.0): super().__init__(contrast_factor=contrast_factor) @property @@ -78,16 +79,17 @@ def spatial_ndim(self) -> int: return 2 -class RandomContrastND(pytc.TreeClass): +class RandomContrastND(sk.TreeClass): + """Randomly adjusts the contrast of an image by scaling the pixel + values by a factor. + + Args: + contrast_range: range of contrast factors to randomly sample from. + """ + contrast_range: tuple def __init__(self, contrast_range=(0.5, 1)): - """Randomly adjusts the contrast of an image by scaling the pixel - values by a factor. - - Args: - contrast_range: range of contrast factors to randomly sample from. - """ if not ( isinstance(contrast_range, tuple) and len(contrast_range) == 2 @@ -121,13 +123,14 @@ def spatial_ndim(self) -> int: class RandomContrast2D(RandomContrastND): - def __init__(self, contrast_range=(0.5, 1)): - """Randomly adjusts the contrast of an image by scaling the pixel - values by a factor. + """Randomly adjusts the contrast of an image by scaling the pixel + values by a factor. - Args: - contrast_range: range of contrast factors to randomly sample from. - """ + Args: + contrast_range: range of contrast factors to randomly sample from. + """ + + def __init__(self, contrast_range=(0.5, 1)): super().__init__(contrast_range=contrast_range) @property diff --git a/serket/nn/convolution.py b/serket/nn/convolution.py index 303e7fd..9aba873 100644 --- a/serket/nn/convolution.py +++ b/serket/nn/convolution.py @@ -22,9 +22,9 @@ import jax import jax.numpy as jnp import jax.random as jr -import pytreeclass as pytc from jax.lax import ConvDimensionNumbers +import serket as sk from serket.nn.initialization import InitType, resolve_init_func from serket.nn.utils import ( DilationType, @@ -46,7 +46,7 @@ def generate_conv_dim_numbers(spatial_ndim): return ConvDimensionNumbers(*((tuple(range(spatial_ndim + 2)),) * 3)) -class ConvND(pytc.TreeClass): +class ConvND(sk.TreeClass): def __init__( self, in_features: int, @@ -62,25 +62,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """Convolutional layer. - - Args: - in_features: number of input features - out_features: number of output features - kernel_size: size of the convolutional kernel - strides: stride of the convolution - padding: padding of the input - input_dilation: dilation of the input - kernel_dilation: dilation of the convolutional kernel - weight_init_func: function to use for initializing the weights - bias_init_func: function to use for initializing the bias - groups: number of groups to use for grouped convolution - key: key to use for initializing the weights - - Note: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - """ - # already checked by the callbacks self.in_features = positive_int_cb(in_features) self.out_features = positive_int_cb(out_features) self.kernel_size = canonicalize( @@ -152,6 +133,60 @@ def spatial_ndim(self) -> int: class Conv1D(ConvND): + """1D Convolutional layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + out_features: number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + input_dilation: dilation of the input. accepts: + * single integer for same dilation in all dimensions. + * sequence of integers for different dilation in each dimension. + kernel_dilation: dilation of the convolutional kernel accepts: + * single integer for same dilation in all dimensions. + * sequence of integers for different dilation in each dimension. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + groups: number of groups to use for grouped convolution. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> layer = sk.nn.Conv1D(in_features=1, out_features=2, kernel_size=3) + >>> # single sample + >>> x = jnp.ones((1, 5)) + >>> print(layer(x).shape) + (2, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5) + + Note: + https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + """ + def __init__( self, in_features: int, @@ -167,33 +202,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """1D Convolutional layer. - - Args: - in_features: number of input features - out_features: number of output features - kernel_size: size of the convolutional kernel - strides: stride of the convolution - padding: padding of the input - input_dilation: dilation of the input - kernel_dilation: dilation of the convolutional kernel - weight_init_func: function to use for initializing the weights - bias_init_func: function to use for initializing the bias - groups: number of groups to use for grouped convolution - key: key to use for initializing the weights - - Example: - >>> import jax.numpy as jnp - >>> import serket as sk - >>> layer = sk.nn.Conv1D(in_features=1, out_features=2, kernel_size=3) - >>> x = jnp.ones((1, 5)) - >>> print(layer(x).shape) - (2, 5) - - Note: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - """ - super().__init__( in_features=in_features, out_features=out_features, @@ -214,6 +222,60 @@ def spatial_ndim(self) -> int: class Conv2D(ConvND): + """2D Convolutional layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + out_features: number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + input_dilation: dilation of the input. accepts: + * single integer for same dilation in all dimensions. + * sequence of integers for different dilation in each dimension. + kernel_dilation: dilation of the convolutional kernel accepts: + * single integer for same dilation in all dimensions. + * sequence of integers for different dilation in each dimension. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + groups: number of groups to use for grouped convolution. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> layer = sk.nn.Conv2D(in_features=1, out_features=2, kernel_size=3) + >>> # single sample + >>> x = jnp.ones((1, 5, 5)) + >>> print(layer(x).shape) + (2, 5, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5, 5) + + Note: + https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + """ + def __init__( self, in_features: int, @@ -229,33 +291,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """2D Convolutional layer. - - Args: - in_features: number of input features - out_features: number of output features - kernel_size: size of the convolutional kernel - strides: stride of the convolution - padding: padding of the input - input_dilation: dilation of the input - kernel_dilation: dilation of the convolutional kernel - weight_init_func: function to use for initializing the weights - bias_init_func: function to use for initializing the bias - groups: number of groups to use for grouped convolution - key: key to use for initializing the weights - - Example: - >>> import jax.numpy as jnp - >>> import serket as sk - >>> layer = sk.nn.Conv2D(in_features=1, out_features=2, kernel_size=3) - >>> x = jnp.ones((1, 5, 5)) - >>> print(layer(x).shape) - (2, 5, 5) - - Note: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - """ - super().__init__( in_features=in_features, out_features=out_features, @@ -276,6 +311,60 @@ def spatial_ndim(self) -> int: class Conv3D(ConvND): + """3D Convolutional layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + out_features: number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + input_dilation: dilation of the input. accepts: + * single integer for same dilation in all dimensions. + * sequence of integers for different dilation in each dimension. + kernel_dilation: dilation of the convolutional kernel accepts: + * single integer for same dilation in all dimensions. + * sequence of integers for different dilation in each dimension. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + groups: number of groups to use for grouped convolution. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> layer = sk.nn.Conv3D(in_features=1, out_features=2, kernel_size=3) + >>> # single sample + >>> x = jnp.ones((1, 5, 5, 5)) + >>> print(layer(x).shape) + (2, 5, 5, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5, 5, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5, 5, 5) + + Note: + https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + """ + def __init__( self, in_features: int, @@ -291,34 +380,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """3D Convolutional layer. - - Args: - in_features: number of input features - out_features: number of output features - kernel_size: size of the convolutional kernel - strides: stride of the convolution - padding: padding of the input - input_dilation: dilation of the input - kernel_dilation: dilation of the convolutional kernel - weight_init_func: function to use for initializing the weights - bias_init_func: function to use for initializing the bias - groups: number of groups to use for grouped convolution - key: key to use for initializing the weights - - Example: - >>> import jax.numpy as jnp - >>> import serket as sk - >>> layer = sk.nn.Conv3D(in_features=1, out_features=2, kernel_size=3) - >>> x = jnp.ones((1, 5, 5, 5)) - >>> print(layer(x).shape) - (2, 5, 5, 5) - - - Note: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - """ - super().__init__( in_features=in_features, out_features=out_features, @@ -338,10 +399,7 @@ def spatial_ndim(self) -> int: return 3 -# ---------------------------------------------------------------------------- # - - -class ConvNDTranspose(pytc.TreeClass): +class ConvNDTranspose(sk.TreeClass): def __init__( self, in_features: int, @@ -357,21 +415,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """Convolutional Transpose Layer - - Args: - in_features : Number of input channels - out_features : Number of output channels - kernel_size : Size of the convolutional kernel - strides : Stride of the convolution - padding : Padding of the input - output_padding : Additional size added to one side of the output shape - kernel_dilation : Dilation of the convolutional kernel - weight_init_func : Weight initialization function - bias_init_func : Bias initialization function - groups : Number of groups - key : PRNG key - """ self.in_features = positive_int_cb(in_features) self.out_features = positive_int_cb(out_features) self.kernel_size = canonicalize( @@ -446,6 +489,59 @@ def spatial_ndim(self) -> int: class Conv1DTranspose(ConvNDTranspose): + """1D Convolution transpose layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + out_features: number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + output_padding: padding of the output after convolution. accepts: + * single integer for same padding in all dimensions. + kernel_dilation: dilation of the convolutional kernel accepts: + * single integer for same dilation in all dimensions. + * sequence of integers for different dilation in each dimension. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + groups: number of groups to use for grouped convolution. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> layer = sk.nn.Conv1DTranspose(1, 2, 3) + >>> # single sample + >>> x = jnp.ones((1, 5)) + >>> print(layer(x).shape) + (2, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5) + + Note: + https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + """ + def __init__( self, in_features: int, @@ -461,21 +557,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """1D Convolutional Transpose Layer. - - Args: - in_features : Number of input channels - out_features : Number of output channels - kernel_size : Size of the convolutional kernel - strides : Stride of the convolution - padding : Padding of the input - output_padding : Additional size added to one side of the output shape - kernel_dilation : Dilation of the convolutional kernel - weight_init_func : Weight initialization function - bias_init_func : Bias initialization function - groups : Number of groups - key : PRNG key - """ super().__init__( in_features=in_features, out_features=out_features, @@ -496,6 +577,59 @@ def spatial_ndim(self) -> int: class Conv2DTranspose(ConvNDTranspose): + """2D Convolution transpose layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + out_features: number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + output_padding: padding of the output after convolution. accepts: + * single integer for same padding in all dimensions. + kernel_dilation: dilation of the convolutional kernel accepts: + * single integer for same dilation in all dimensions. + * sequence of integers for different dilation in each dimension. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + groups: number of groups to use for grouped convolution. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> layer = sk.nn.Conv2DTranspose(1, 2, 3) + >>> # single sample + >>> x = jnp.ones((1, 5, 5)) + >>> print(layer(x).shape) + (2, 5, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5, 5) + + Note: + https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + """ + def __init__( self, in_features: int, @@ -511,21 +645,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """2D Convolutional Transpose Layer. - - Args: - in_features : Number of input channels - out_features : Number of output channels - kernel_size : Size of the convolutional kernel - strides : Stride of the convolution - padding : Padding of the input - output_padding : Additional size added to one side of the output shape - kernel_dilation : Dilation of the convolutional kernel - weight_init_func : Weight initialization function - bias_init_func : Bias initialization function - groups : Number of groups - key : PRNG key - """ super().__init__( in_features=in_features, out_features=out_features, @@ -546,6 +665,59 @@ def spatial_ndim(self) -> int: class Conv3DTranspose(ConvNDTranspose): + """3D Convolution transpose layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + out_features: number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + output_padding: padding of the output after convolution. accepts: + * single integer for same padding in all dimensions. + kernel_dilation: dilation of the convolutional kernel accepts: + * single integer for same dilation in all dimensions. + * sequence of integers for different dilation in each dimension. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + groups: number of groups to use for grouped convolution. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> layer = sk.nn.Conv3DTranspose(1, 2, 3) + >>> # single sample + >>> x = jnp.ones((1, 5, 5, 5)) + >>> print(layer(x).shape) + (2, 5, 5, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5, 5, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5, 5, 5) + + Note: + https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + """ + def __init__( self, in_features: int, @@ -561,21 +733,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """3D Convolutional Transpose Layer. - - Args: - in_features : Number of input channels - out_features : Number of output channels - kernel_size : Size of the convolutional kernel - strides : Stride of the convolution - padding : Padding of the input - output_padding : Additional size added to one side of the output shape - kernel_dilation : Dilation of the convolutional kernel - weight_init_func : Weight initialization function - bias_init_func : Bias initialization function - groups : Number of groups - key : PRNG key - """ super().__init__( in_features=in_features, out_features=out_features, @@ -595,10 +752,7 @@ def spatial_ndim(self) -> int: return 3 -# ---------------------------------------------------------------------------- # - - -class DepthwiseConvND(pytc.TreeClass): +class DepthwiseConvND(sk.TreeClass): def __init__( self, in_features: int, @@ -611,22 +765,6 @@ def __init__( bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """Depthwise Convolutional layer. - - Args: - in_features: number of input features - kernel_size: size of the convolution kernel - depth_multiplier : number of output channels per input channel - strides: stride of the convolution - padding: padding of the input - weight_init_func: function to initialize the weights - bias_init_func: function to initialize the bias - key: random key for weight initialization - - Note: - https://keras.io/api/layers/convolution_layers/depthwise_convolution2d/ - https://github.com/google/flax/blob/main/flax/linen/linear.py - """ self.in_features = positive_int_cb(in_features) self.kernel_size = canonicalize( kernel_size, self.spatial_ndim, name="kernel_size" @@ -683,6 +821,48 @@ def spatial_ndim(self) -> int: class DepthwiseConv1D(DepthwiseConvND): + """1D Depthwise convolution layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + key: key to use for initializing the weights. defaults to `0`. + + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> l1 = sk.nn.DepthwiseConv1D(3, 3, depth_multiplier=2, strides=2) + >>> l1(jnp.ones((3, 32))).shape + (6, 16) + + Note: + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py + """ + def __init__( self, in_features: int, @@ -695,28 +875,6 @@ def __init__( bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """1D Depthwise Convolutional layer. - - Args: - in_features: number of input features - kernel_size: size of the convolution kernel - depth_multiplier : number of output channels per input channel - strides: stride of the convolution - padding: padding of the input - weight_init_func: function to initialize the weights - bias_init_func: function to initialize the bias - key: random key for weight initialization - - Example: - >>> l1 = DepthwiseConv1D(3, 3, depth_multiplier=2, strides=2, padding="SAME") - >>> l1(jnp.ones((3, 32))).shape - (6, 16) - - Note: - https://keras.io/api/layers/convolution_layers/depthwise_convolution2d/ - https://github.com/google/flax/blob/main/flax/linen/linear.py - """ - super().__init__( in_features=in_features, kernel_size=kernel_size, @@ -734,6 +892,48 @@ def spatial_ndim(self) -> int: class DepthwiseConv2D(DepthwiseConvND): + """2D Depthwise convolution layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + key: key to use for initializing the weights. defaults to `0`. + + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> l1 = sk.nn.DepthwiseConv2D(3, 3, depth_multiplier=2, strides=2) + >>> l1(jnp.ones((3, 32, 32))).shape + (6, 16, 16) + + Note: + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py + """ + def __init__( self, in_features: int, @@ -746,28 +946,6 @@ def __init__( bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """2D Depthwise Convolutional layer. - - Args: - in_features: number of input features - kernel_size: size of the convolution kernel - depth_multiplier : number of output channels per input channel - strides: stride of the convolution - padding: padding of the input - weight_init_func: function to initialize the weights - bias_init_func: function to initialize the bias - key: random key for weight initialization - - Example: - >>> l1 = DepthwiseConv2D(3, 3, depth_multiplier=2, strides=2, padding="SAME") - >>> l1(jnp.ones((3, 32, 32))).shape - (6, 16, 16) - - Note: - https://keras.io/api/layers/convolution_layers/depthwise_convolution2d/ - https://github.com/google/flax/blob/main/flax/linen/linear.py - """ - super().__init__( in_features=in_features, kernel_size=kernel_size, @@ -785,6 +963,48 @@ def spatial_ndim(self) -> int: class DepthwiseConv3D(DepthwiseConvND): + """3D Depthwise convolution layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + key: key to use for initializing the weights. defaults to `0`. + + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> l1 = sk.nn.DepthwiseConv3D(3, 3, depth_multiplier=2, strides=2) + >>> l1(jnp.ones((3, 32, 32, 32))).shape + (6, 16, 16, 16) + + Note: + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py + """ + def __init__( self, in_features: int, @@ -797,28 +1017,6 @@ def __init__( bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """3D Depthwise Convolutional layer. - - Args: - in_features: number of input features - kernel_size: size of the convolution kernel - depth_multiplier : number of output channels per input channel - strides: stride of the convolution - padding: padding of the input - weight_init_func: function to initialize the weights - bias_init_func: function to initialize the bias - key: random key for weight initialization - - Example: - >>> l1 = DepthwiseConv3D(3, 3, depth_multiplier=2, strides=2, padding="SAME") - >>> l1(jnp.ones((3, 32, 32, 32))).shape - (6, 16, 16, 16) - - Note: - https://keras.io/api/layers/convolution_layers/depthwise_convolution2d/ - https://github.com/google/flax/blob/main/flax/linen/linear.py - """ - super().__init__( in_features=in_features, kernel_size=kernel_size, @@ -835,10 +1033,56 @@ def spatial_ndim(self) -> int: return 3 -# ---------------------------------------------------------------------------- # - +class SeparableConv1D(sk.TreeClass): + """1D Separable convolution layer. + + Separable convolution is a depthwise convolution followed by a pointwise + convolution. The objective is to reduce the number of parameters in the + convolutional layer. For example, for I input features and O output features, + and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O + parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O + parameters. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + key: key to use for initializing the weights. defaults to `0`. + + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> l1 = sk.nn.SeparableConv1D(3, 3, 3, depth_multiplier=2) + >>> l1(jnp.ones((3, 32))).shape + (3, 32) + + Note: + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py + """ -class SeparableConv1D(pytc.TreeClass): def __init__( self, in_features: int, @@ -853,28 +1097,6 @@ def __init__( pointwise_bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """1D Separable convolutional layer. - - Args: - in_features : Number of input channels. - out_features : Number of output channels. - kernel_size : Size of the convolving kernel. - depth_multiplier : Number of depthwise convolution output channels - for each input channel. - strides : Stride of the convolution. - padding : Padding to apply to the input. - depthwise_weight_init_func : Function to initialize the depthwise - convolution weights. - pointwise_weight_init_func : Function to initialize the pointwise - convolution weights. - pointwise_bias_init_func : Function to initialize the pointwise - convolution bias. - - Note: - https://en.wikipedia.org/wiki/Separable_filter - https://keras.io/api/layers/convolution_layers/separable_convolution2d/ - https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/depthwise_conv.py - """ self.depthwise_conv = DepthwiseConv1D( in_features=in_features, depth_multiplier=depth_multiplier, @@ -907,7 +1129,55 @@ def spatial_ndim(self) -> int: return 1 -class SeparableConv2D(pytc.TreeClass): +class SeparableConv2D(sk.TreeClass): + """2D Separable convolution layer. + + Separable convolution is a depthwise convolution followed by a pointwise + convolution. The objective is to reduce the number of parameters in the + convolutional layer. For example, for I input features and O output features, + and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O + parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O + parameters. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> l1 = sk.nn.SeparableConv2D(3, 3, 3, depth_multiplier=2) + >>> l1(jnp.ones((3, 32, 32))).shape + (3, 32, 32) + + Note: + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py + """ + def __init__( self, in_features: int, @@ -922,28 +1192,6 @@ def __init__( pointwise_bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """2D Separable convolutional layer. - - Args: - in_features : Number of input channels. - out_features : Number of output channels. - kernel_size : Size of the convolving kernel. - depth_multiplier : Number of depthwise convolution output channels - for each input channel. - strides : Stride of the convolution. - padding : Padding to apply to the input. - depthwise_weight_init_func : Function to initialize the depthwise - convolution weights. - pointwise_weight_init_func : Function to initialize the pointwise - convolution weights. - pointwise_bias_init_func : Function to initialize the pointwise - convolution bias. - - Note: - https://en.wikipedia.org/wiki/Separable_filter - https://keras.io/api/layers/convolution_layers/separable_convolution2d/ - https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/depthwise_conv.py - """ self.depthwise_conv = DepthwiseConv2D( in_features=in_features, depth_multiplier=depth_multiplier, @@ -976,7 +1224,55 @@ def spatial_ndim(self) -> int: return 2 -class SeparableConv3D(pytc.TreeClass): +class SeparableConv3D(sk.TreeClass): + """3D Separable convolution layer. + + Separable convolution is a depthwise convolution followed by a pointwise + convolution. The objective is to reduce the number of parameters in the + convolutional layer. For example, for I input features and O output features, + and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O + parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O + parameters. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> l1 = sk.nn.SeparableConv3D(3, 3, 3, depth_multiplier=2) + >>> l1(jnp.ones((3, 32, 32, 32))).shape + (3, 32, 32, 32) + + Note: + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py + """ + def __init__( self, in_features: int, @@ -991,28 +1287,6 @@ def __init__( pointwise_bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """3D Separable convolutional layer. - - Note: - https://en.wikipedia.org/wiki/Separable_filter - https://keras.io/api/layers/convolution_layers/separable_convolution2d/ - https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/depthwise_conv.py - - Args: - in_features : Number of input channels. - out_features : Number of output channels. - kernel_size : Size of the convolving kernel. - depth_multiplier : Number of depthwise convolution output channels - for each input channel. - strides : Stride of the convolution. - padding : Padding to apply to the input. - depthwise_weight_init_func : Function to initialize the depthwise - convolution weights. - pointwise_weight_init_func : Function to initialize the pointwise - convolution weights. - pointwise_bias_init_func : Function to initialize the pointwise - convolution bias. - """ self.depthwise_conv = DepthwiseConv3D( in_features=in_features, depth_multiplier=depth_multiplier, @@ -1045,10 +1319,7 @@ def spatial_ndim(self) -> int: return 3 -# ---------------------------------------------------------------------------- # - - -class ConvNDLocal(pytc.TreeClass): +class ConvNDLocal(sk.TreeClass): def __init__( self, in_features: int, @@ -1064,23 +1335,6 @@ def __init__( bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """Local convolutional layer. - - Args: - in_features: number of input features - out_features: number of output features - kernel_size: size of the convolution kernel - in_size: size of the input - strides: stride of the convolution - padding: padding of the convolution - input_dilation: dilation of the input - kernel_dilation: dilation of the convolution kernel - weight_init_func: weight initialization function - bias_init_func: bias initialization function - key: random number generator key - Note: - https://keras.io/api/layers/locally_connected_layers/ - """ # checked by callbacks self.in_features = positive_int_cb(in_features) self.out_features = positive_int_cb(out_features) @@ -1157,6 +1411,51 @@ def spatial_ndim(self) -> int: class Conv1DLocal(ConvNDLocal): + """1D Local convolutional layer. + + Local convolutional layer is a convolutional layer where the convolution + kernel is applied to a local region of the input. The kernel weights are + *not* shared across the spatial dimensions of the input. + + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + in_size: the size of the spatial dimensions of the input. e.g excluding + the first dimension. accepts a sequence of integer(s). + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> l1 = sk.nn.Conv1DLocal(3, 3, 3, in_size=(32,)) + >>> l1(jnp.ones((3, 32))).shape + (3, 32) + + Note: + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py + """ + def __init__( self, in_features: int, @@ -1171,23 +1470,6 @@ def __init__( bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """1D Local convolutional layer. - - Args: - in_features: number of input features - out_features: number of output features - kernel_size: size of the convolution kernel - in_size: size of the input - strides: stride of the convolution - padding: padding of the convolution - input_dilation: dilation of the input - kernel_dilation: dilation of the convolution kernel - weight_init_func: weight initialization function - bias_init_func: bias initialization function - key: random number generator key - Note: - https://keras.io/api/layers/locally_connected_layers/ - """ super().__init__( in_features=in_features, out_features=out_features, @@ -1207,6 +1489,51 @@ def spatial_ndim(self) -> int: class Conv2DLocal(ConvNDLocal): + """2D Local convolutional layer. + + Local convolutional layer is a convolutional layer where the convolution + kernel is applied to a local region of the input. This means that the kernel + weights are *not* shared across the spatial dimensions of the input. + + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + in_size: the size of the spatial dimensions of the input. e.g excluding + the first dimension. accepts a sequence of integer(s). + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> l1 = sk.nn.Conv2DLocal(3, 3, 3, in_size=(32, 32)) + >>> l1(jnp.ones((3, 32, 32))).shape + (3, 32, 32) + + Note: + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py + """ + def __init__( self, in_features: int, @@ -1221,23 +1548,6 @@ def __init__( bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """2D Local convolutional layer. - - Args: - in_features: number of input features - out_features: number of output features - kernel_size: size of the convolution kernel - in_size: size of the input - strides: stride of the convolution - padding: padding of the convolution - input_dilation: dilation of the input - kernel_dilation: dilation of the convolution kernel - weight_init_func: weight initialization function - bias_init_func: bias initialization function - key: random number generator key - Note: - https://keras.io/api/layers/locally_connected_layers/ - """ super().__init__( in_features=in_features, out_features=out_features, @@ -1257,6 +1567,51 @@ def spatial_ndim(self) -> int: class Conv3DLocal(ConvNDLocal): + """3D Local convolutional layer. + + Local convolutional layer is a convolutional layer where the convolution + kernel is applied to a local region of the input. This means that the kernel + weights are *not* shared across the spatial dimensions of the input. + + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + in_size: the size of the spatial dimensions of the input. e.g excluding + the first dimension. accepts a sequence of integer(s). + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> l1 = sk.nn.Conv3DLocal(3, 3, 3, in_size=(32, 32, 32)) + >>> l1(jnp.ones((3, 32, 32, 32))).shape + (3, 32, 32, 32) + + Note: + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py + """ + def __init__( self, in_features: int, @@ -1271,23 +1626,6 @@ def __init__( bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """3D Local convolutional layer. - - Args: - in_features: number of input features - out_features: number of output features - kernel_size: size of the convolution kernel - in_size: size of the input - strides: stride of the convolution - padding: padding of the convolution - input_dilation: dilation of the input - kernel_dilation: dilation of the convolution kernel - weight_init_func: weight initialization function - bias_init_func: bias initialization function - key: random number generator key - Note: - https://keras.io/api/layers/locally_connected_layers/ - """ super().__init__( in_features=in_features, out_features=out_features, diff --git a/serket/nn/crop.py b/serket/nn/crop.py index 693565a..79bd96a 100644 --- a/serket/nn/crop.py +++ b/serket/nn/crop.py @@ -19,24 +19,21 @@ import jax import jax.random as jr -import pytreeclass as pytc from jax import lax +import serket as sk from serket.nn.utils import canonicalize, validate_spatial_ndim -class CropND(pytc.TreeClass): - def __init__( - self, - size: int | tuple[int, ...], - start: int | tuple[int, ...], - ): - """Applies jax.lax.dynamic_slice_in_dim to the second dimension of the input. +class CropND(sk.TreeClass): + """Applies jax.lax.dynamic_slice_in_dim to the second dimension of the input. - Args: - size: size of the slice - start: start of the slice - """ + Args: + size: size of the slice, accepted values are integers or tuples of integers. + start: start of the slice, accepted values are integers or tuples of integers. + """ + + def __init__(self, size: int | tuple[int, ...], start: int | tuple[int, ...]): self.size = canonicalize(size, self.spatial_ndim, name="size") self.start = canonicalize(start, self.spatial_ndim, name="start") @@ -53,21 +50,22 @@ def spatial_ndim(self) -> int: class Crop1D(CropND): - def __init__( - self, - size: int | tuple[int], - start: int | tuple[int, int], - ): - """Applies jax.lax.dynamic_slice_in_dim to the second dimension of the input. - - Example: - >>> import jax - >>> import jax.numpy as jnp - >>> import serket as sk - >>> x = jnp.arange(1, 6).reshape(1, 5) - >>> print(sk.nn.Crop1D(size=3, start=1)(x)) - [[2 3 4]] - """ + """Applies jax.lax.dynamic_slice_in_dim to the second dimension of the input. + + Args: + size: size of the slice, either a single int or a tuple of int. + start: start of the slice, either a single int or a tuple of int. + + Example: + >>> import jax + >>> import jax.numpy as jnp + >>> import serket as sk + >>> x = jnp.arange(1, 6).reshape(1, 5) + >>> print(sk.nn.Crop1D(size=3, start=1)(x)) + [[2 3 4]] + """ + + def __init__(self, size: int | tuple[int], start: int | tuple[int]): super().__init__(size, start) @property @@ -76,35 +74,33 @@ def spatial_ndim(self) -> int: class Crop2D(CropND): - def __init__( - self, - size: int | tuple[int, int], - start: int | tuple[int, int], - ): - """Applies jax.lax.dynamic_slice_in_dim to the second dimension of the input. - - Args: - size: size of the slice, either a single int or a tuple of two ints - start: start of the slice, either a single int or a tuple of two ints for start along each axis - - Example: - >>> # start = (2, 0) and size = (3, 3) - >>> # i.e. start at index 2 along the first axis and index 0 along the second axis - >>> import jax.numpy as jnp - >>> import serket as sk - >>> x = jnp.arange(1, 26).reshape((1, 5, 5)) - >>> print(x) - [[[ 1 2 3 4 5] - [ 6 7 8 9 10] - [11 12 13 14 15] - [16 17 18 19 20] - [21 22 23 24 25]]] - - >>> print(sk.nn.Crop2D(size=3, start=(2, 0))(x)) - [[[11 12 13] - [16 17 18] - [21 22 23]]] - """ + """Applies jax.lax.dynamic_slice_in_dim to the second dimension of the input. + + Args: + size: size of the slice, either a single int or a tuple of two ints + for size along each axis. + start: start of the slice, either a single int or a tuple of two ints + for start along each axis. + + Example: + >>> # start = (2, 0) and size = (3, 3) + >>> # i.e. start at index 2 along the first axis and index 0 along the second axis + >>> import jax.numpy as jnp + >>> import serket as sk + >>> x = jnp.arange(1, 26).reshape((1, 5, 5)) + >>> print(x) + [[[ 1 2 3 4 5] + [ 6 7 8 9 10] + [11 12 13 14 15] + [16 17 18 19 20] + [21 22 23 24 25]]] + >>> print(sk.nn.Crop2D(size=3, start=(2, 0))(x)) + [[[11 12 13] + [16 17 18] + [21 22 23]]] + """ + + def __init__(self, size: int | tuple[int, int], start: int | tuple[int, int]): super().__init__(size, start) @property @@ -113,17 +109,20 @@ def spatial_ndim(self) -> int: class Crop3D(CropND): + """Applies jax.lax.dynamic_slice_in_dim to the second dimension of the input. + + Args: + size: size of the slice, either a single int or a tuple of three ints + for size along each axis. + start: start of the slice, either a single int or a tuple of three + ints for start along each axis. + """ + def __init__( self, size: int | tuple[int, int, int], start: int | tuple[int, int, int], ): - """Applies jax.lax.dynamic_slice_in_dim to the second dimension of the input. - - Args: - size: size of the slice, either a single int or a tuple of two ints - start: start of the slice, either a single int or a tuple of three ints for start along each axis - """ super().__init__(size, start) @property @@ -131,13 +130,8 @@ def spatial_ndim(self) -> int: return 3 -class RandomCropND(pytc.TreeClass): +class RandomCropND(sk.TreeClass): def __init__(self, size: int | tuple[int, ...]): - """Applies jax.lax.dynamic_slice_in_dim with a random start along each axis - - Args: - size: size of the slice - """ self.size = canonicalize(size, self.spatial_ndim, name="size") @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @@ -161,12 +155,14 @@ def spatial_ndim(self) -> int: class RandomCrop1D(RandomCropND): - def __init__(self, size: int | tuple[int]): - """Applies jax.lax.dynamic_slice_in_dim with a random start along each axis + """Applies jax.lax.dynamic_slice_in_dim with a random start along each axis - Args: - size: size of the slice - """ + Args: + size: size of the slice, either a single int or a tuple of int. accepted + values are either a single int or a tuple of int denoting the size. + """ + + def __init__(self, size: int | tuple[int]): super().__init__(size) @property @@ -175,12 +171,14 @@ def spatial_ndim(self) -> int: class RandomCrop2D(RandomCropND): - def __init__(self, size: int | tuple[int, int]): - """Applies jax.lax.dynamic_slice_in_dim with a random start along each axis + """Applies jax.lax.dynamic_slice_in_dim with a random start along each axis - Args: - size: size of the slice - """ + Args: + size: size of the slice in each axis. accepted values are either a single int + or a tuple of two ints denoting the size along each axis. + """ + + def __init__(self, size: int | tuple[int, int]): super().__init__(size) @property @@ -189,12 +187,14 @@ def spatial_ndim(self) -> int: class RandomCrop3D(RandomCropND): - def __init__(self, size: int | tuple[int, int, int]): - """Applies jax.lax.dynamic_slice_in_dim with a random start along each axis + """Applies jax.lax.dynamic_slice_in_dim with a random start along each axis - Args: - size: size of the slice - """ + Args: + size: size of the slice in each axis. accepted values are either a single int + or a tuple of three ints denoting the size along each axis. + """ + + def __init__(self, size: int | tuple[int, int, int]): super().__init__(size) @property diff --git a/serket/nn/cutout.py b/serket/nn/cutout.py index bb89b3d..e8535b6 100644 --- a/serket/nn/cutout.py +++ b/serket/nn/cutout.py @@ -19,9 +19,9 @@ import jax import jax.numpy as jnp import jax.random as jr -import pytreeclass as pytc from jax import lax +import serket as sk from serket.nn.utils import canonicalize, positive_int_cb, validate_spatial_ndim @@ -110,28 +110,29 @@ def scan_step(x, key): return x -class RandomCutout1D(pytc.TreeClass): +class RandomCutout1D(sk.TreeClass): + """Random Cutouts for spatial 1D array. + + Args: + shape: shape of the cutout. accepts an int or a tuple of int. + cutout_count: number of holes. Defaults to 1. + fill_value: fill_value to fill. Defaults to 0. + + Note: + https://arxiv.org/abs/1708.04552 + https://keras.io/api/keras_cv/layers/preprocessing/random_cutout/ + + Examples: + >>> print(RandomCutout1D(5)(jnp.ones((1, 10)) * 100)) + [[100. 100. 100. 100. 0. 0. 0. 0. 0. 100.]] + """ + def __init__( self, - shape: tuple[int], + shape: int | tuple[int], cutout_count: int = 1, fill_value: int | float = 0, ): - """Random Cutouts for spatial 1D array. - - Args: - shape: shape of the cutout - cutout_count: number of holes. Defaults to 1. - fill_value: fill_value to fill. Defaults to 0. - - See: - https://arxiv.org/abs/1708.04552 - https://keras.io/api/keras_cv/layers/preprocessing/random_cutout/ - - Examples: - >>> RandomCutout1D(5)(jnp.ones((1, 10))*100) - [[100., 100., 100., 100., 0., 0., 0., 0., 0., 100.]] - """ self.shape = canonicalize(shape, ndim=1, name="shape") self.cutout_count = positive_int_cb(cutout_count) self.fill_value = fill_value @@ -146,24 +147,25 @@ def spatial_ndim(self) -> int: return 1 -class RandomCutout2D(pytc.TreeClass): +class RandomCutout2D(sk.TreeClass): + """Random Cutouts for spatial 2D array + + Args: + shape: shape of the cutout. accepts int or a two element tuple. + cutout_count: number of holes. Defaults to 1. + fill_value: fill_value to fill. Defaults to 0. + + Note: + https://arxiv.org/abs/1708.04552 + https://keras.io/api/keras_cv/layers/preprocessing/random_cutout/ + """ + def __init__( self, shape: int | tuple[int, int], cutout_count: int = 1, fill_value: int | float = 0, ): - """Random Cutouts for spatial 2D array - - Args: - shape: shape of the cutout - cutout_count: number of holes. Defaults to 1. - fill_value: fill_value to fill. Defaults to 0. - - See: - https://arxiv.org/abs/1708.04552 - https://keras.io/api/keras_cv/layers/preprocessing/random_cutout/ - """ self.shape = canonicalize(shape, 2, name="shape") self.cutout_count = positive_int_cb(cutout_count) self.fill_value = fill_value diff --git a/serket/nn/dropout.py b/serket/nn/dropout.py index dd3a932..ff40bfd 100644 --- a/serket/nn/dropout.py +++ b/serket/nn/dropout.py @@ -19,13 +19,13 @@ import jax.numpy as jnp import jax.random as jr -import pytreeclass as pytc from jax import lax +import serket as sk from serket.nn.utils import Range, validate_spatial_ndim -class Dropout(pytc.TreeClass): +class Dropout(sk.TreeClass): """Randomly zeroes some of the elements of the input tensor with probability :attr:`p` using samples from a Bernoulli distribution. @@ -35,7 +35,7 @@ class Dropout(pytc.TreeClass): Example: >>> import serket as sk - >>> import pytreeclass as pytc + >>> import jax.numpy as jnp >>> layer = sk.nn.Dropout(0.5) >>> # change `p` to 0.0 to turn off dropout >>> layer = layer.at["p"].set(0.0, is_leaf=pytc.is_frozen) @@ -44,7 +44,7 @@ class Dropout(pytc.TreeClass): Use `p`= 0.0 to turn off dropout. """ - p: float = pytc.field(default=0.5, callbacks=[Range(0, 1)]) + p: float = sk.field(default=0.5, callbacks=[Range(0, 1)]) def __call__(self, x, *, key: jr.KeyArray = jr.PRNGKey(0)): return jnp.where( @@ -54,23 +54,10 @@ def __call__(self, x, *, key: jr.KeyArray = jr.PRNGKey(0)): ) -class DropoutND(pytc.TreeClass): - """Drops full feature maps along the channel axis. +class DropoutND(sk.TreeClass): + """Drops full feature maps along the channel axis.""" - Args: - p: fraction of an elements to be zeroed out - - Note: - https://keras.io/api/layers/regularization_layers/spatial_dropout1d/ - https://arxiv.org/abs/1411.4280 - - Example: - >>> layer = DropoutND(0.5) - >>> layer(jnp.ones((1, 10))) - [[2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]] - """ - - p: float = pytc.field(default=0.5, callbacks=[Range(0, 1)]) + p: float = sk.field(default=0.5, callbacks=[Range(0, 1)]) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x, *, key=jr.PRNGKey(0)): @@ -94,16 +81,18 @@ def __init__(self, p: float = 0.5): """Drops full feature maps along the channel axis. Args: - p: fraction of an elements to be zeroed out + p: fraction of an elements to be zeroed out. + + Example: + >>> import serket as sk + >>> import jax.numpy as jnp + >>> layer = sk.nn.Dropout1D(0.5) + >>> print(layer(jnp.ones((1, 10)))) + [[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]] Note: https://keras.io/api/layers/regularization_layers/spatial_dropout1d/ https://arxiv.org/abs/1411.4280 - - Example: - >>> layer = DropoutND(0.5) - >>> layer(jnp.ones((1, 10))) - [[2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]] """ super().__init__(p=p) @@ -117,18 +106,24 @@ def __init__(self, p: float = 0.5): """Drops full feature maps along the channel axis. Args: - p: fraction of an elements to be zeroed out + p: fraction of an elements to be zeroed out. Note: https://keras.io/api/layers/regularization_layers/spatial_dropout1d/ https://arxiv.org/abs/1411.4280 Example: - >>> layer = DropoutND(0.5) - >>> layer(jnp.ones((1, 10))) - [[2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]] + >>> import serket as sk + >>> import jax.numpy as jnp + >>> layer = sk.nn.Dropout2D(0.5) + >>> print(layer(jnp.ones((1, 5, 5)))) # doctest: +NORMALIZE_WHITESPACE + [[[2. 2. 2. 2. 2.] + [2. 2. 2. 2. 2.] + [2. 2. 2. 2. 2.] + [2. 2. 2. 2. 2.] + [2. 2. 2. 2. 2.]]] """ - super().__init__(p=p, spatial_ndim=2) + super().__init__(p=p) @property def spatial_ndim(self) -> int: @@ -140,18 +135,24 @@ def __init__(self, p: float = 0.5): """Drops full feature maps along the channel axis. Args: - p: fraction of an elements to be zeroed out + p: fraction of an elements to be zeroed out. + + Example: + >>> import serket as sk + >>> import jax.numpy as jnp + >>> layer = sk.nn.Dropout3D(0.5) + >>> print(layer(jnp.ones((1, 2, 2, 2)))) # doctest: +NORMALIZE_WHITESPACE + [[[[2. 2.] + [2. 2.]] + + [[2. 2.] + [2. 2.]]]] Note: https://keras.io/api/layers/regularization_layers/spatial_dropout1d/ https://arxiv.org/abs/1411.4280 - - Example: - >>> layer = DropoutND(0.5) - >>> layer(jnp.ones((1, 10))) - [[2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]] """ - super().__init__(p=p, spatial_ndim=3) + super().__init__(p=p) @property def spatial_ndim(self) -> int: diff --git a/serket/nn/fft_convolution.py b/serket/nn/fft_convolution.py index 6f2883d..c49b12f 100644 --- a/serket/nn/fft_convolution.py +++ b/serket/nn/fft_convolution.py @@ -20,8 +20,8 @@ import jax import jax.numpy as jnp import jax.random as jr -import pytreeclass as pytc +import serket as sk from serket.nn.initialization import InitType, resolve_init_func from serket.nn.utils import ( DilationType, @@ -84,10 +84,10 @@ def _general_intersperse( def _general_pad(x: jax.Array, pad_width: tuple[tuple[int, int], ...]) -> jax.Array: """Pad the input with `pad_width` on each side. Negative value will lead to cropping. Example: - >>> _general_pad(jnp.ones([3,3]),((0,0),(-1,1))) - [[1., 1., 0.], - [1., 1., 0.], - [1., 1., 0.]] + >>> print(_general_pad(jnp.ones([3,3]),((0,0),(-1,1)))) # DOCTEST: +NORMALIZE_WHITESPACE + [[1. 1. 0.] + [1. 1. 0.] + [1. 1. 0.]] """ for axis, (lhs, rhs) in enumerate(pad_width := list(pad_width)): @@ -150,7 +150,7 @@ def fft_conv_general_dilated( return jax.lax.slice(z, start, end, (1, 1, *strides)) -class FFTConvND(pytc.TreeClass): +class FFTConvND(sk.TreeClass): def __init__( self, in_features: int, @@ -165,24 +165,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """FFT Convolutional layer. - - Args: - in_features: number of input features - out_features: number of output features - kernel_size: size of the convolutional kernel - strides: stride of the convolution - padding: padding of the input - kernel_dilation: dilation of the kernel - weight_init_func: function to use for initializing the weights - bias_init_func: function to use for initializing the bias - groups: number of groups to use for grouped convolution - key: key to use for initializing the weights - - See: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - The implementation is tested against https://github.com/fkodom/fft-conv-pytorch - """ self.in_features = positive_int_cb(in_features) self.out_features = positive_int_cb(out_features) self.kernel_size = canonicalize( @@ -245,6 +227,57 @@ def spatial_ndim(self) -> int: class FFTConv1D(FFTConvND): + """1D Convolutional layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + out_features: number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + kernel_dilation: dilation of the convolutional kernel accepts: + * single integer for same dilation in all dimensions. + * sequence of integers for different dilation in each dimension. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + groups: number of groups to use for grouped convolution. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> layer = sk.nn.FFTConv1D(in_features=1, out_features=2, kernel_size=3) + >>> # single sample + >>> x = jnp.ones((1, 5)) + >>> print(layer(x).shape) + (2, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5) + + Note: + https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + """ + def __init__( self, in_features: int, @@ -259,24 +292,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """1D FFT Convolutional layer. - - Args: - in_features: number of input features - out_features: number of output features - kernel_size: size of the convolutional kernel - strides: stride of the convolution - padding: padding of the input - kernel_dilation: dilation of the kernel - weight_init_func: function to use for initializing the weights - bias_init_func: function to use for initializing the bias - groups: number of groups to use for grouped convolution - key: key to use for initializing the weights - - See: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - The implementation is tested against https://github.com/fkodom/fft-conv-pytorch - """ super().__init__( in_features=in_features, out_features=out_features, @@ -296,6 +311,57 @@ def spatial_ndim(self) -> int: class FFTConv2D(FFTConvND): + """2D FFT Convolutional layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + out_features: number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + kernel_dilation: dilation of the convolutional kernel accepts: + * single integer for same dilation in all dimensions. + * sequence of integers for different dilation in each dimension. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + groups: number of groups to use for grouped convolution. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> layer = sk.nn.FFTConv2D(in_features=1, out_features=2, kernel_size=3) + >>> # single sample + >>> x = jnp.ones((1, 5, 5)) + >>> print(layer(x).shape) + (2, 5, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5, 5) + + Note: + https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + """ + def __init__( self, in_features: int, @@ -310,24 +376,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """2D FFT Convolutional layer. - - Args: - in_features: number of input features - out_features: number of output features - kernel_size: size of the convolutional kernel - strides: stride of the convolution - padding: padding of the input - kernel_dilation: dilation of the kernel - weight_init_func: function to use for initializing the weights - bias_init_func: function to use for initializing the bias - groups: number of groups to use for grouped convolution - key: key to use for initializing the weights - - See: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - The implementation is tested against https://github.com/fkodom/fft-conv-pytorch - """ super().__init__( in_features=in_features, out_features=out_features, @@ -347,6 +395,57 @@ def spatial_ndim(self) -> int: class FFTConv3D(FFTConvND): + """3D FFT Convolutional layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + out_features: number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + kernel_dilation: dilation of the convolutional kernel accepts: + * single integer for same dilation in all dimensions. + * sequence of integers for different dilation in each dimension. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + groups: number of groups to use for grouped convolution. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> layer = sk.nn.FFTConv3D(in_features=1, out_features=2, kernel_size=3) + >>> # single sample + >>> x = jnp.ones((1, 5, 5, 5)) + >>> print(layer(x).shape) + (2, 5, 5, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5, 5, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5, 5, 5) + + Note: + https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + """ + def __init__( self, in_features: int, @@ -361,24 +460,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """3D FFT Convolutional layer. - - Args: - in_features: number of input features - out_features: number of output features - kernel_size: size of the convolutional kernel - strides: stride of the convolution - padding: padding of the input - kernel_dilation: dilation of the kernel - weight_init_func: function to use for initializing the weights - bias_init_func: function to use for initializing the bias - groups: number of groups to use for grouped convolution - key: key to use for initializing the weights - - See: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - The implementation is tested against https://github.com/fkodom/fft-conv-pytorch - """ super().__init__( in_features=in_features, out_features=out_features, @@ -397,10 +478,7 @@ def spatial_ndim(self) -> int: return 3 -# ---------------------------------------------------------------------------- # - - -class FFTConvNDTranspose(pytc.TreeClass): +class FFTConvNDTranspose(sk.TreeClass): def __init__( self, in_features: int, @@ -416,21 +494,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """Convolutional Transpose Layer - - Args: - in_features : Number of input channels - out_features : Number of output channels - kernel_size : Size of the convolutional kernel - strides : Stride of the convolution - padding : Padding of the input - output_padding : Additional size added to one side of the output shape - kernel_dilation : Dilation of the kernel - weight_init_func : Weight initialization function - bias_init_func : Bias initialization function - groups : Number of groups - key : PRNG key - """ self.in_features = positive_int_cb(in_features) self.out_features = positive_int_cb(out_features) self.kernel_size = canonicalize( @@ -509,6 +572,59 @@ def spatial_ndim(self) -> int: class FFTConv1DTranspose(FFTConvNDTranspose): + """1D FFT Convolution transpose layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + out_features: number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + output_padding: padding of the output after convolution. accepts: + * single integer for same padding in all dimensions. + kernel_dilation: dilation of the convolutional kernel accepts: + * single integer for same dilation in all dimensions. + * sequence of integers for different dilation in each dimension. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + groups: number of groups to use for grouped convolution. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> layer = sk.nn.FFTConv1DTranspose(1, 2, 3) + >>> # single sample + >>> x = jnp.ones((1, 5)) + >>> print(layer(x).shape) + (2, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5) + + Note: + https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + """ + def __init__( self, in_features: int, @@ -524,21 +640,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """1D FFT Convolutional Transpose Layer. - - Args: - in_features : Number of input channels - out_features : Number of output channels - kernel_size : Size of the convolutional kernel - strides : Stride of the convolution - padding : Padding of the input - output_padding : Additional size added to one side of the output shape - kernel_dilation : Dilation of the kernel - weight_init_func : Weight initialization function - bias_init_func : Bias initialization function - groups : Number of groups - key : PRNG key - """ super().__init__( in_features=in_features, out_features=out_features, @@ -559,6 +660,59 @@ def spatial_ndim(self) -> int: class FFTConv2DTranspose(FFTConvNDTranspose): + """2D FFT Convolution transpose layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + out_features: number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + output_padding: padding of the output after convolution. accepts: + * single integer for same padding in all dimensions. + kernel_dilation: dilation of the convolutional kernel accepts: + * single integer for same dilation in all dimensions. + * sequence of integers for different dilation in each dimension. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + groups: number of groups to use for grouped convolution. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> layer = sk.nn.FFTConv2DTranspose(1, 2, 3) + >>> # single sample + >>> x = jnp.ones((1, 5, 5)) + >>> print(layer(x).shape) + (2, 5, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5, 5) + + Note: + https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + """ + def __init__( self, in_features: int, @@ -574,21 +728,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """2D FFT Convolutional Transpose Layer. - - Args: - in_features : Number of input channels - out_features : Number of output channels - kernel_size : Size of the convolutional kernel - strides : Stride of the convolution - padding : Padding of the input - output_padding : Additional size added to one side of the output shape - kernel_dilation : Dilation of the kernel - weight_init_func : Weight initialization function - bias_init_func : Bias initialization function - groups : Number of groups - key : PRNG key - """ super().__init__( in_features=in_features, out_features=out_features, @@ -609,6 +748,59 @@ def spatial_ndim(self) -> int: class FFTConv3DTranspose(FFTConvNDTranspose): + """3D FFT Convolution transpose layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + out_features: number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + output_padding: padding of the output after convolution. accepts: + * single integer for same padding in all dimensions. + kernel_dilation: dilation of the convolutional kernel accepts: + * single integer for same dilation in all dimensions. + * sequence of integers for different dilation in each dimension. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + groups: number of groups to use for grouped convolution. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> layer = sk.nn.FFTConv3DTranspose(1, 2, 3) + >>> # single sample + >>> x = jnp.ones((1, 5, 5, 5)) + >>> print(layer(x).shape) + (2, 5, 5, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5, 5, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5, 5, 5) + + Note: + https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + """ + def __init__( self, in_features: int, @@ -624,21 +816,6 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - """3D FFT Convolutional Transpose Layer. - - Args: - in_features : Number of input channels - out_features : Number of output channels - kernel_size : Size of the convolutional kernel - strides : Stride of the convolution - padding : Padding of the input - output_padding : Additional size added to one side of the output shape - kernel_dilation : Dilation of the kernel - weight_init_func : Weight initialization function - bias_init_func : Bias initialization function - groups : Number of groups - key : PRNG key - """ super().__init__( in_features=in_features, out_features=out_features, @@ -658,10 +835,7 @@ def spatial_ndim(self) -> int: return 3 -# ----------------------------------------------------------------------------- # - - -class DepthwiseFFTConvND(pytc.TreeClass): +class DepthwiseFFTConvND(sk.TreeClass): def __init__( self, in_features: int, @@ -674,27 +848,6 @@ def __init__( bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """Depthwise Convolutional layer. - - Args: - in_features: number of input features - kernel_size: size of the convolution kernel - depth_multiplier : number of output channels per input channel - strides: stride of the convolution - padding: padding of the input - weight_init_func: function to initialize the weights - bias_init_func: function to initialize the bias - key: random key for weight initialization - - Examples:---- - >>> l1 = DepthwiseConvND(3, 3, depth_multiplier=2, strides=2, padding="SAME") - >>> l1(jnp.ones((3, 32, 32))).shape - (3, 16, 16, 6) - - Note: - https://keras.io/api/layers/convolution_layers/depthwise_convolution2d/ - https://github.com/google/flax/blob/main/flax/linen/linear.py - """ self.in_features = positive_int_cb(in_features) self.kernel_size = canonicalize( kernel_size, self.spatial_ndim, name="kernel_size" @@ -749,6 +902,48 @@ def spatial_ndim(self) -> int: class DepthwiseFFTConv1D(DepthwiseFFTConvND): + """1D Depthwise FFT convolution layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + key: key to use for initializing the weights. defaults to `0`. + + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> l1 = sk.nn.DepthwiseFFTConv1D(3, 3, depth_multiplier=2, strides=2) + >>> l1(jnp.ones((3, 32))).shape + (6, 16) + + Note: + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py + """ + def __init__( self, in_features: int, @@ -761,22 +956,6 @@ def __init__( bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """1D Depthwise FFT Convolutional layer. - - Args: - in_features: number of input features - kernel_size: size of the convolution kernel - depth_multiplier : number of output channels per input channel - strides: stride of the convolution - padding: padding of the input - weight_init_func: function to initialize the weights - bias_init_func: function to initialize the bias - key: random key for weight initialization - - Note: - https://keras.io/api/layers/convolution_layers/depthwise_convolution2d/ - https://github.com/google/flax/blob/main/flax/linen/linear.py - """ super().__init__( in_features=in_features, kernel_size=kernel_size, @@ -794,6 +973,48 @@ def spatial_ndim(self) -> int: class DepthwiseFFTConv2D(DepthwiseFFTConvND): + """2D Depthwise convolution layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + key: key to use for initializing the weights. defaults to `0`. + + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> l1 = sk.nn.DepthwiseFFTConv2D(3, 3, depth_multiplier=2, strides=2) + >>> l1(jnp.ones((3, 32, 32))).shape + (6, 16, 16) + + Note: + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py + """ + def __init__( self, in_features: int, @@ -806,22 +1027,6 @@ def __init__( bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """2D Depthwise FFT Convolutional layer. - - Args: - in_features: number of input features - kernel_size: size of the convolution kernel - depth_multiplier : number of output channels per input channel - strides: stride of the convolution - padding: padding of the input - weight_init_func: function to initialize the weights - bias_init_func: function to initialize the bias - key: random key for weight initialization - - Note: - https://keras.io/api/layers/convolution_layers/depthwise_convolution2d/ - https://github.com/google/flax/blob/main/flax/linen/linear.py - """ super().__init__( in_features=in_features, kernel_size=kernel_size, @@ -839,6 +1044,48 @@ def spatial_ndim(self) -> int: class DepthwiseFFTConv3D(DepthwiseFFTConvND): + """3D Depthwise FFT convolution layer. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + key: key to use for initializing the weights. defaults to `0`. + + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> l1 = sk.nn.DepthwiseFFTConv3D(3, 3, depth_multiplier=2, strides=2) + >>> l1(jnp.ones((3, 32, 32, 32))).shape + (6, 16, 16, 16) + + Note: + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py + """ + def __init__( self, in_features: int, @@ -851,23 +1098,6 @@ def __init__( bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """3D Depthwise FFT Convolutional layer. - - Args: - in_features: number of input features - kernel_size: size of the convolution kernel - depth_multiplier : number of output channels per input channel - strides: stride of the convolution - padding: padding of the input - weight_init_func: function to initialize the weights - bias_init_func: function to initialize the bias - spatial_ndim: number of spatial dimensions - key: random key for weight initialization - - Note: - https://keras.io/api/layers/convolution_layers/depthwise_convolution2d/ - https://github.com/google/flax/blob/main/flax/linen/linear.py - """ super().__init__( in_features=in_features, kernel_size=kernel_size, @@ -884,10 +1114,56 @@ def spatial_ndim(self) -> int: return 3 -# ---------------------------------------------------------------------------- # +class SeparableFFTConv1D(sk.TreeClass): + """1D Separable FFT convolution layer. + Separable convolution is a depthwise convolution followed by a pointwise + convolution. The objective is to reduce the number of parameters in the + convolutional layer. For example, for I input features and O output features, + and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O + parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O + parameters. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + key: key to use for initializing the weights. defaults to `0`. + + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> l1 = sk.nn.SeparableFFTConv1D(3, 3, 3, depth_multiplier=2) + >>> l1(jnp.ones((3, 32))).shape + (3, 32) + + Note: + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py + """ -class SeparableFFTConv1D(pytc.TreeClass): def __init__( self, in_features: int, @@ -902,28 +1178,6 @@ def __init__( pointwise_bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """Separable 1D FFT Convolutional layer. - Args: - in_features : Number of input channels. - out_features : Number of output channels. - kernel_size : Size of the convolving kernel. - depth_multiplier : Number of depthwise convolution output channels - for each input channel. - strides : Stride of the convolution. - padding : Padding to apply to the input. - depthwise_weight_init_func : Function to initialize the depthwise - convolution weights. - pointwise_weight_init_func : Function to initialize the pointwise - convolution weights. - pointwise_bias_init_func : Function to initialize the pointwise - convolution bias. - - Note: - https://en.wikipedia.org/wiki/Separable_filter - https://keras.io/api/layers/convolution_layers/separable_convolution2d/ - https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/depthwise_conv.py - - """ self.in_features = in_features self.depth_multiplier = canonicalize( depth_multiplier, @@ -965,7 +1219,55 @@ def spatial_ndim(self) -> int: return 1 -class SeparableFFTConv2D(pytc.TreeClass): +class SeparableFFTConv2D(sk.TreeClass): + """2D Separable FFT convolution layer. + + Separable convolution is a depthwise convolution followed by a pointwise + convolution. The objective is to reduce the number of parameters in the + convolutional layer. For example, for I input features and O output features, + and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O + parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O + parameters. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> l1 = sk.nn.SeparableFFTConv2D(3, 3, 3, depth_multiplier=2) + >>> l1(jnp.ones((3, 32, 32))).shape + (3, 32, 32) + + Note: + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py + """ + def __init__( self, in_features: int, @@ -980,27 +1282,6 @@ def __init__( pointwise_bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """Separable 2D FFT Convolutional layer. - - Args: - in_features : Number of input channels. - out_features : Number of output channels. - kernel_size : Size of the convolving kernel. - depth_multiplier : Number of depthwise convolution output channels - for each input channel. - strides : Stride of the convolution. - padding : Padding to apply to the input. - depthwise_weight_init_func : Function to initialize the depthwise - convolution weights. - pointwise_weight_init_func : Function to initialize the pointwise - convolution weights. - pointwise_bias_init_func : Function to initialize the pointwise - convolution bias. - Note: - https://en.wikipedia.org/wiki/Separable_filter - https://keras.io/api/layers/convolution_layers/separable_convolution2d/ - https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/depthwise_conv.py - """ self.in_features = in_features self.depth_multiplier = canonicalize( depth_multiplier, @@ -1042,7 +1323,55 @@ def spatial_ndim(self) -> int: return 2 -class SeparableFFTConv3D(pytc.TreeClass): +class SeparableFFTConv3D(sk.TreeClass): + """3D Separable FFT convolution layer. + + Separable convolution is a depthwise convolution followed by a pointwise + convolution. The objective is to reduce the number of parameters in the + convolutional layer. For example, for I input features and O output features, + and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O + parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O + parameters. + + Args: + in_features: number of input feature maps, for 1D convolution this is the + length of the input, for 2D convolution this is the number of input + channels, for 3D convolution this is the number of input channels. + kernel_size: size of the convolutional kernel. accepts: + * single integer for same kernel size in all dimensions. + * sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. + strides: stride of the convolution. accepts: + * single integer for same stride in all dimensions. + * sequence of integers for different strides in each dimension. + padding: padding of the input before convolution. accepts: + * single integer for same padding in all dimensions. + * tuple of integers for different padding in each dimension. + * tuple of a tuple of two integers for before and after padding in + each dimension. + * "same"/"SAME" for padding such that the output has the same shape + as the input. + * "valid"/"VALID" for no padding. + weight_init_func: function to use for initializing the weights. defaults + to `glorot uniform`. + bias_init_func: function to use for initializing the bias. defaults to + `zeros`. set to `None` to not use a bias. + key: key to use for initializing the weights. defaults to `0`. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> l1 = sk.nn.SeparableFFTConv3D(3, 3, 3, depth_multiplier=2) + >>> l1(jnp.ones((3, 32, 32, 32))).shape + (3, 32, 32, 32) + + Note: + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py + """ + def __init__( self, in_features: int, @@ -1057,25 +1386,6 @@ def __init__( pointwise_bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """Separable 3D FFT Convolutional layer. - - Note: - See: - https://en.wikipedia.org/wiki/Separable_filter - https://keras.io/api/layers/convolution_layers/separable_convolution2d/ - https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/depthwise_conv.py - - Args: - in_features : Number of input channels. - out_features : Number of output channels. - kernel_size : Size of the convolving kernel. - depth_multiplier : Number of depthwise convolution output channels for each input channel. - strides : Stride of the convolution. - padding : Padding to apply to the input. - depthwise_weight_init_func : Function to initialize the depthwise convolution weights. - pointwise_weight_init_func : Function to initialize the pointwise convolution weights. - pointwise_bias_init_func : Function to initialize the pointwise convolution bias. - """ self.in_features = in_features self.depth_multiplier = canonicalize( depth_multiplier, diff --git a/serket/nn/flatten.py b/serket/nn/flatten.py index b90a021..e9c037a 100644 --- a/serket/nn/flatten.py +++ b/serket/nn/flatten.py @@ -16,16 +16,18 @@ import jax import jax.numpy as jnp -import pytreeclass as pytc +import serket as sk from serket.nn.utils import IsInstance -class Flatten(pytc.TreeClass): - """ +class Flatten(sk.TreeClass): + """Flatten an array from dim `start_dim` to `end_dim` (inclusive). + Args: start_dim: the first dim to flatten end_dim: the last dim to flatten (inclusive) + Returns: a function that flattens a jnp.ndarray @@ -49,8 +51,8 @@ class Flatten(pytc.TreeClass): https://pytorch.org/docs/stable/generated/torch.nn.Flatten.html?highlight=flatten#torch.nn.Flatten """ - start_dim: int = pytc.field(default=0, callbacks=[IsInstance(int)]) - end_dim: int = pytc.field(default=-1, callbacks=[IsInstance(int)]) + start_dim: int = sk.field(default=0, callbacks=[IsInstance(int)]) + end_dim: int = sk.field(default=-1, callbacks=[IsInstance(int)]) def __call__(self, x: jax.Array) -> jax.Array: start_dim = self.start_dim + (0 if self.start_dim >= 0 else x.ndim) @@ -58,12 +60,13 @@ def __call__(self, x: jax.Array) -> jax.Array: return jax.lax.collapse(x, start_dim, end_dim) -class Unflatten(pytc.TreeClass): - dim: int = pytc.field(default=0, callbacks=[IsInstance(int)]) - shape: tuple = pytc.field(default=None, callbacks=[IsInstance(tuple)]) +class Unflatten(sk.TreeClass): + """Unflatten an array. + + Args: + dim: the dim to unflatten. + shape: the shape to unflatten to. accepts a tuple of ints. - """ - Example: >>> Unflatten(0, (1,2,3,4,5))(jnp.ones([120])).shape (1, 2, 3, 4, 5) @@ -74,6 +77,9 @@ class Unflatten(pytc.TreeClass): https://pytorch.org/docs/stable/generated/torch.nn.Unflatten.html?highlight=unflatten """ + dim: int = sk.field(default=0, callbacks=[IsInstance(int)]) + shape: tuple = sk.field(default=None, callbacks=[IsInstance(tuple)]) + def __call__(self, x: jax.Array, **k) -> jax.Array: shape = list(x.shape) shape = [*shape[: self.dim], *self.shape, *shape[self.dim + 1 :]] diff --git a/serket/nn/flip.py b/serket/nn/flip.py index 7140e08..1fdb015 100644 --- a/serket/nn/flip.py +++ b/serket/nn/flip.py @@ -18,30 +18,29 @@ import jax import jax.numpy as jnp -import pytreeclass as pytc +import serket as sk from serket.nn.utils import validate_spatial_ndim -class FlipLeftRight2D(pytc.TreeClass): - def __init__(self): - """Flip channels left to right. +class FlipLeftRight2D(sk.TreeClass): + """Flip channels left to right. - Note: - https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py + Note: + https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py - Examples: - >>> x = jnp.arange(1,10).reshape(1,3, 3) - >>> x - [[[1 2 3] - [4 5 6] - [7 8 9]]] + Examples: + >>> x = jnp.arange(1,10).reshape(1,3, 3) + >>> print(x) + [[[1 2 3] + [4 5 6] + [7 8 9]]] - >>> FlipLeftRight2D()(x) - [[[3 2 1] - [6 5 4] - [9 8 7]]] - """ + >>> print(FlipLeftRight2D()(x)) + [[[3 2 1] + [6 5 4] + [9 8 7]]] + """ @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array, **k) -> jax.Array: @@ -53,25 +52,24 @@ def spatial_ndim(self) -> int: return 2 -class FlipUpDown2D(pytc.TreeClass): - def __init__(self): - """Flip channels up to down. +class FlipUpDown2D(sk.TreeClass): + """Flip channels up to down. - Note: - https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py + Note: + https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py - Examples: - >>> x = jnp.arange(1,10).reshape(1,3, 3) - >>> x - [[[1 2 3] - [4 5 6] - [7 8 9]]] + Examples: + >>> x = jnp.arange(1,10).reshape(1,3, 3) + >>> print(x) + [[[1 2 3] + [4 5 6] + [7 8 9]]] - >>> FlipUpDown2D()(x) - [[[7 8 9] - [4 5 6] - [1 2 3]]] - """ + >>> print(FlipUpDown2D()(x)) + [[[7 8 9] + [4 5 6] + [1 2 3]]] + """ @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array, **k) -> jax.Array: diff --git a/serket/nn/fully_connected.py b/serket/nn/fully_connected.py index cb72d34..e481279 100644 --- a/serket/nn/fully_connected.py +++ b/serket/nn/fully_connected.py @@ -18,8 +18,8 @@ import jax import jax.random as jr -import pytreeclass as pytc +import serket as sk from serket.nn.activation import ActivationType, resolve_activation from serket.nn.initialization import InitType from serket.nn.linear import Linear @@ -27,7 +27,33 @@ PyTree = Any -class FNN(pytc.TreeClass): +class FNN(sk.TreeClass): + """Fully connected neural network + Args: + layers: Sequence of layer sizes + act_func: a single Activation function to be applied between layers or + `len(layers)-2` Sequence of activation functions applied between + layers. + weight_init_func: Weight initializer function. + bias_init_func: Bias initializer function. Defaults to lambda key, + shape: jnp.ones(shape). + key: Random key for weight and bias initialization. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> fnn = sk.nn.FNN([10, 5, 2]) + >>> fnn(jnp.ones((3, 10))).shape + (3, 2) + + Note: + - layers argument yields len(layers) - 1 linear layers with required + `len(layers)-2` activation functions, for example, `layers=[10, 5, 2]` + yields 2 linear layers with weight shapes (10, 5) and (5, 2) + and single activation function is applied between them. + - `FNN` uses python `for` loop to apply layers and activation functions. + """ + def __init__( self, layers: Sequence[int], @@ -37,30 +63,6 @@ def __init__( bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """Fully connected neural network - Args: - layers: Sequence of layer sizes - act_func: a single Activation function to be applied between layers or - `len(layers)-2` Sequence of activation functions applied between - layers. - weight_init_func: Weight initializer function. - bias_init_func: Bias initializer function. Defaults to lambda key, - shape: jnp.ones(shape). - key: Random key for weight and bias initialization. - - Example: - >>> fnn = FNN([10, 5, 2]) - >>> fnn(jnp.ones((3, 10))).shape - (3, 2) - - Note: - - layers argument yields len(layers) - 1 linear layers with required - `len(layers)-2` activation functions, for example, `layers=[10, 5, 2]` - yields 2 linear layers with weight shapes (10, 5) and (5, 2) - and single activation function is applied between them. - - `FNN` uses python `for` loop to apply layers and activation functions. - """ - keys = jr.split(key, len(layers) - 1) num_hidden_layers = len(layers) - 2 @@ -104,7 +106,30 @@ def __call__(self, x: jax.Array, **k) -> jax.Array: return self._single_call(x, **k) -class MLP(pytc.TreeClass): +class MLP(sk.TreeClass): + """Multi-layer perceptron. + + Args: + in_features: Number of input features. + out_features: Number of output features. + hidden_size: Number of hidden units in each hidden layer. + num_hidden_layers: Number of hidden layers including the output layer. + act_func: Activation function. + weight_init_func: Weight initialization function. + bias_init_func: Bias initialization function. + key: Random number generator key. + + Note: + - MLP with `in_features`=1, `out_features`=2, `hidden_size`=4, + `num_hidden_layers`=2 is equivalent to `[1, 4, 4, 2]` which has one + input layer (1, 4), one intermediate layer (4, 4), and one output + layer (4, 2) = `num_hidden_layers` + 1 + - `MLP` exploits same input/out size for intermediate layers to use + `jax.lax.scan`, which offers better compilation speed for large + number of layers and producing a smaller `jaxpr` but could be + slower than equivalent `FNN` for small number of layers. + """ + def __init__( self, in_features: int, @@ -117,28 +142,6 @@ def __init__( bias_init_func: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - """Multi-layer perceptron. - - Args: - in_features: Number of input features. - out_features: Number of output features. - hidden_size: Number of hidden units in each hidden layer. - num_hidden_layers: Number of hidden layers including the output layer. - act_func: Activation function. - weight_init_func: Weight initialization function. - bias_init_func: Bias initialization function. - key: Random number generator key. - - Note: - - MLP with `in_features`=1, `out_features`=2, `hidden_size`=4, - `num_hidden_layers`=2 is equivalent to `[1, 4, 4, 2]` which has one - input layer (1, 4), one intermediate layer (4, 4), and one output - layer (4, 2) = `num_hidden_layers` + 1 - - `MLP` exploits same input/out size for intermediate layers to use - `jax.lax.scan`, which offers better compilation speed for large - number of layers and producing a smaller `jaxpr` but could be - slower than equivalent `FNN` for small number of layers. - """ if hidden_size < 1: raise ValueError(f"hidden_size must be positive, got {hidden_size}") diff --git a/serket/nn/linear.py b/serket/nn/linear.py index f5ef47a..5aeea2f 100644 --- a/serket/nn/linear.py +++ b/serket/nn/linear.py @@ -19,8 +19,8 @@ import jax import jax.numpy as jnp import jax.random as jr -import pytreeclass as pytc +import serket as sk from serket.nn.initialization import InitType, resolve_init_func from serket.nn.utils import IsInstance, positive_int_cb @@ -73,7 +73,28 @@ def _general_linear_einsum_string(*axes: tuple[int, ...]) -> str: return f"{input_string},{weight_string}->{result_string}" -class Multilinear(pytc.TreeClass): +class Multilinear(sk.TreeClass): + """Linear layer with arbitrary number of inputs applied to last axis of each input + + Args: + in_features: number of input features for each input + out_features: number of output features + weight_init_func: function to initialize the weights + bias_init_func: function to initialize the bias + key: key for the random number generator + + Example: + >>> # Bilinear layer + >>> layer = Multilinear((5,6), 7) + >>> layer(jnp.ones((1,5)), jnp.ones((1,6))).shape + (1, 7) + + >>> # Trilinear layer + >>> layer = Multilinear((5,6,7), 8) + >>> layer(jnp.ones((1,5)), jnp.ones((1,6)), jnp.ones((1,7))).shape + (1, 8) + """ + def __init__( self, in_features: int | tuple[int, ...] | None, @@ -83,26 +104,6 @@ def __init__( bias_init_func: InitType = "ones", key: jr.KeyArray = jr.PRNGKey(0), ): - """Linear layer with arbitrary number of inputs applied to last axis of each input - - Args: - in_features: number of input features for each input - out_features: number of output features - weight_init_func: function to initialize the weights - bias_init_func: function to initialize the bias - key: key for the random number generator - - Example: - >>> # Bilinear layer - >>> layer = Multilinear((5,6), 7) - >>> layer(jnp.ones((1,5)), jnp.ones((1,6))).shape - (1, 7) - - >>> # Trilinear layer - >>> layer = Multilinear((5,6,7), 8) - >>> layer(jnp.ones((1,5)), jnp.ones((1,6)), jnp.ones((1,7))).shape - (1, 8) - """ if not isinstance(in_features, (tuple, int)): raise ValueError(f"Expected tuple or int for {in_features=}.") @@ -162,6 +163,22 @@ def __init__( class Bilinear(Multilinear): + """Bilinear layer + + Args: + in1_features: number of input features for the first input + in2_features: number of input features for the second input + out_features: number of output features + weight_init_func: function to initialize the weights + bias_init_func: function to initialize the bias + key: key for the random number generator + + Example: + >>> layer = Bilinear(5, 6, 7) + >>> layer(jnp.ones((1,5)), jnp.ones((1,6))).shape + (1, 7) + """ + def __init__( self, in1_features: int, @@ -172,21 +189,6 @@ def __init__( bias_init_func: InitType = "ones", key: jr.KeyArray = jr.PRNGKey(0), ): - """Bilinear layer - - Args: - in1_features: number of input features for the first input - in2_features: number of input features for the second input - out_features: number of output features - weight_init_func: function to initialize the weights - bias_init_func: function to initialize the bias - key: key for the random number generator - - Example: - >>> layer = Bilinear(5, 6, 7) - >>> layer(jnp.ones((1,5)), jnp.ones((1,6))).shape - (1, 7) - """ super().__init__( (in1_features, in2_features), out_features, @@ -196,7 +198,27 @@ def __init__( ) -class GeneralLinear(pytc.TreeClass): +class GeneralLinear(sk.TreeClass): + """Apply a Linear Layer to input at in_axes + + Args: + in_features: number of input features corresponding to in_axes + out_features: number of output features + in_axes: axes to apply the linear layer to + weight_init_func: weight initialization function + bias_init_func: bias initialization function + key: random key + + Example: + >>> x = jnp.ones([1, 2, 3, 4]) + >>> layer = GeneralLinear(in_features=(1, 2), in_axes=(0, 1), out_features=5) + >>> assert layer(x).shape == (3, 4, 5) + + Note: + This layer is similar to to flax linen's DenseGeneral, the difference is that + this layer uses einsum to apply the linear layer to the specified axes. + """ + def __init__( self, in_features: tuple[int, ...], @@ -207,26 +229,6 @@ def __init__( bias_init_func: InitType = "ones", key: jr.KeyArray = jr.PRNGKey(0), ): - """Apply a Linear Layer to input at in_axes - - Args: - in_features: number of input features corresponding to in_axes - out_features: number of output features - in_axes: axes to apply the linear layer to - weight_init_func: weight initialization function - bias_init_func: bias initialization function - key: random key - - Example: - >>> x = jnp.ones([1, 2, 3, 4]) - >>> layer = GeneralLinear(in_features=(1, 2), in_axes=(0, 1), out_features=5) - >>> assert layer(x).shape == (3, 4, 5) - - Note: - This layer is similar to to flax linen's DenseGeneral, the difference is that - this layer uses einsum to apply the linear layer to the specified axes. - """ - self.in_features = IsInstance(tuple)(in_features) self.out_features = out_features self.in_axes = IsInstance(tuple)(in_axes) @@ -255,35 +257,36 @@ def __call__(self, x: jax.Array, **k) -> jax.Array: return x -class Identity(pytc.TreeClass): +class Identity(sk.TreeClass): """Identity layer""" def __call__(self, x: jax.Array, **k) -> jax.Array: return x -class Embedding(pytc.TreeClass): +class Embedding(sk.TreeClass): + """Defines an embedding layer. + + Args: + in_features: vocabulary size. + out_features: embedding size. + key: random key to initialize the weights. + + Example: + >>> import serket as sk + >>> # 10 words in the vocabulary, each word is represented by a 3 dimensional vector + >>> table = sk.nn.Embedding(10,3) + >>> # take the last word in the vocab + >>> table(jnp.array([9])) + Array([[0.43810904, 0.35078037, 0.13254273]], dtype=float32) + """ + def __init__( self, in_features: int, out_features: int, key: jr.KeyArray = jr.PRNGKey(0), ): - """Defines an embedding layer. - - Args: - in_features: vocabulary size. - out_features: embedding size. - key: random key to initialize the weights. - - Example: - >>> import serket as sk - >>> # 10 words in the vocabulary, each word is represented by a 3 dimensional vector - >>> table = sk.nn.Embedding(10,3) - >>> # take the last word in the vocab - >>> table(jnp.array([9])) - Array([[0.43810904, 0.35078037, 0.13254273]], dtype=float32) - """ self.in_features = positive_int_cb(in_features) self.out_features = positive_int_cb(out_features) self.weight = jr.uniform(key, (self.in_features, self.out_features)) @@ -304,28 +307,29 @@ def __call__(self, x: jax.Array, **k) -> jax.Array: return jnp.take(self.weight, x, axis=0) -class MergeLinear(pytc.TreeClass): - def __init__(self, *layers: tuple[Linear, ...]): - """Merge multiple linear layers with the same `out_features`. +class MergeLinear(sk.TreeClass): + """Merge multiple linear layers with the same `out_features`. - Args: - layers: linear layers to merge - - Example: - >>> import serket as sk - >>> import numpy.testing as npt - >>> layer1 = sk.nn.Linear(5, 6) # 5 input features, 6 output features - >>> layer2 = sk.nn.Linear(7, 6) # 7 input features, 6 output features - >>> merged_layer = sk.nn.MergeLinear(layer1, layer2) # 12 input features, 6 output features - >>> x1 = jnp.ones([1, 5]) # 1 sample, 5 features - >>> x2 = jnp.ones([1, 7]) # 1 sample, 7 features - >>> y = merged_layer(x1, x2) # one matrix multiplication - >>> z = layer1(x1) + layer2(x2) # two matrix multiplications - >>> npt.assert_allclose(y, z, atol=1e-6) - - Note: - Use this layer to reduce the matrix multiplication operations in the forward pass. - """ + Args: + layers: linear layers to merge + + Example: + >>> import serket as sk + >>> import numpy.testing as npt + >>> layer1 = sk.nn.Linear(5, 6) # 5 input features, 6 output features + >>> layer2 = sk.nn.Linear(7, 6) # 7 input features, 6 output features + >>> merged_layer = sk.nn.MergeLinear(layer1, layer2) # 12 input features, 6 output features + >>> x1 = jnp.ones([1, 5]) # 1 sample, 5 features + >>> x2 = jnp.ones([1, 7]) # 1 sample, 7 features + >>> y = merged_layer(x1, x2) # one matrix multiplication + >>> z = layer1(x1) + layer2(x2) # two matrix multiplications + >>> npt.assert_allclose(y, z, atol=1e-6) + + Note: + Use this layer to reduce the matrix multiplication operations in the forward pass. + """ + + def __init__(self, *layers: tuple[Linear, ...]): out_dim0 = layers[0].out_features if not all(isinstance(layer, Linear) for layer in layers): raise TypeError("All layers must be instances of Linear.") diff --git a/serket/nn/normalization.py b/serket/nn/normalization.py index 99bc049..69cecd3 100644 --- a/serket/nn/normalization.py +++ b/serket/nn/normalization.py @@ -14,11 +14,14 @@ from __future__ import annotations +from typing import NamedTuple + import jax import jax.numpy as jnp -import pytreeclass as pytc +from jax.custom_batching import custom_vmap -from serket.nn.utils import Range, ScalarLike, positive_int_cb +import serket as sk +from serket.nn.utils import IsInstance, Range, ScalarLike, positive_int_cb def layer_norm( @@ -84,8 +87,18 @@ def group_norm( return x̂ -class LayerNorm(pytc.TreeClass): - eps: float = pytc.field(callbacks=[Range(0), ScalarLike()]) +class LayerNorm(sk.TreeClass): + """Layer Normalization + See: https://nn.labml.ai/normalization/layer_norm/index.html + transform the input by scaling and shifting to have zero mean and unit variance. + + Args: + normalized_shape: the shape of the input to be normalized. + eps: a value added to the denominator for numerical stability. + affine: a boolean value that when set to True, this module has learnable affine parameters. + """ + + eps: float = sk.field(callbacks=[Range(0), ScalarLike()]) def __init__( self, @@ -94,15 +107,6 @@ def __init__( eps: float = 1e-5, affine: bool = True, ): - """Layer Normalization - See: https://nn.labml.ai/normalization/layer_norm/index.html - transform the input by scaling and shifting to have zero mean and unit variance. - - Args: - normalized_shape: the shape of the input to be normalized. - eps: a value added to the denominator for numerical stability. - affine: a boolean value that when set to True, this module has learnable affine parameters. - """ self.normalized_shape = ( normalized_shape if isinstance(normalized_shape, tuple) @@ -125,8 +129,19 @@ def __call__(self, x: jax.Array, **kwargs) -> jax.Array: ) -class GroupNorm(pytc.TreeClass): - eps: float = pytc.field(callbacks=[Range(0), ScalarLike()]) +class GroupNorm(sk.TreeClass): + """Group Normalization + See: https://nn.labml.ai/normalization/group_norm/index.html + transform the input by scaling and shifting to have zero mean and unit variance. + + Args: + in_features : the shape of the input to be normalized. + groups : number of groups to separate the channels into. + eps : a value added to the denominator for numerical stability. + affine : a boolean value that when set to True, this module has learnable affine parameters. + """ + + eps: float = sk.field(callbacks=[Range(0), ScalarLike()]) def __init__( self, @@ -136,17 +151,6 @@ def __init__( eps: float = 1e-5, affine: bool = True, ): - """Group Normalization - See: https://nn.labml.ai/normalization/group_norm/index.html - transform the input by scaling and shifting to have zero mean and unit variance. - - Args: - in_features : the shape of the input to be normalized. - groups : number of groups to separate the channels into. - eps : a value added to the denominator for numerical stability. - affine : a boolean value that when set to True, this module has learnable affine parameters. - """ - # checked by callbacks self.in_features = positive_int_cb(in_features) self.groups = positive_int_cb(groups) self.affine = affine @@ -172,6 +176,16 @@ def __call__(self, x: jax.Array, **k) -> jax.Array: class InstanceNorm(GroupNorm): + """Instance Normalization + See: https://nn.labml.ai/normalization/instance_norm/index.html + transform the input by scaling and shifting to have zero mean and unit variance. + + Args: + in_features : the shape of the input to be normalized. + eps : a value added to the denominator for numerical stability. + affine : a boolean value that when set to True, this module has learnable affine parameters. + """ + def __init__( self, in_features: int, @@ -179,18 +193,73 @@ def __init__( eps: float = 1e-5, affine: bool = True, ): - """Instance Normalization - See: https://nn.labml.ai/normalization/instance_norm/index.html - transform the input by scaling and shifting to have zero mean and unit variance. - - Args: - in_features : the shape of the input to be normalized. - eps : a value added to the denominator for numerical stability. - affine : a boolean value that when set to True, this module has learnable affine parameters. - """ super().__init__( in_features=in_features, groups=in_features, eps=eps, affine=affine, ) + + +class BatchNormState(NamedTuple): + running_mean: jax.Array + running_var: jax.Array + + +@custom_vmap +def batchnorm( + x: jax.Array, + state: tuple[jax.Array, jax.Array], + *, + momentum: float = 0.1, + eps: float = 1e-5, + gamma: jax.Array | None = None, + beta: jax.Array | None = None, + track_running_stats: bool = False, +): + del momentum, eps, gamma, beta, track_running_stats + return x, state + + +@batchnorm.def_vmap +def _( + axis_size, + in_batched, + x: jax.Array, + state: tuple[jax.Array, jax.Array], + *, + momentum: float = 0.1, + eps: float = 1e-5, + track_running_stats: bool = True, +): + run_mean, run_var = state + + axes = [0] + list(range(2, x.ndim)) + + batch_mean, batch_var = jnp.mean(x, axis=axes), jnp.var(x, axis=axes) + + run_mean = jnp.where( + track_running_stats, + (1 - momentum) * run_mean + momentum * batch_mean, + batch_mean, + ) + + run_var = jnp.where( + track_running_stats, + (1 - momentum) * run_var + momentum * batch_var * (axis_size / (axis_size - 1)), + batch_var, + ) + x_normalized = (x - batch_mean) * jax.lax.rsqrt(batch_var + eps) + return (x_normalized, (run_mean, run_var)), (True, (True, True)) + + +class BatchNorm(sk.TreeClass): + in_features: int = sk.field(callbacks=[IsInstance(int), Range(1)]) + momentum: float = sk.field(callbacks=[Range(0, 1), ScalarLike()]) + eps: float = sk.field(callbacks=[Range(0), ScalarLike()]) + track_running_stats: bool = sk.field(callbacks=[IsInstance(bool)]) + + def __post_init__(self): + self.state = BatchNormState( + jnp.zeros(self.in_features), jnp.ones(self.in_features) + ) diff --git a/serket/nn/padding.py b/serket/nn/padding.py index 223130e..24b9217 100644 --- a/serket/nn/padding.py +++ b/serket/nn/padding.py @@ -19,24 +19,13 @@ import jax import jax.numpy as jnp -import pytreeclass as pytc +import serket as sk from serket.nn.utils import delayed_canonicalize_padding, validate_spatial_ndim -class PadND(pytc.TreeClass): +class PadND(sk.TreeClass): def __init__(self, padding: int | tuple[int, int], value: float = 0.0): - """ - Args: - padding: padding to apply to each side of the input. - value: value to pad with. Defaults to 0.0. - - Note: - https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.pad.html - https://www.tensorflow.org/api_docs/python/tf/keras/layers/ZeroPadding1D - https://www.tensorflow.org/api_docs/python/tf/keras/layers/ZeroPadding2D - https://www.tensorflow.org/api_docs/python/tf/keras/layers/ZeroPadding3D - """ self.padding = delayed_canonicalize_padding( in_dim=None, padding=padding, @@ -60,20 +49,21 @@ def spatial_ndim(self) -> int: class Pad1D(PadND): - def __init__(self, padding: int | tuple[int, int], value: float = 0.0): - """ - Pad a 1D tensor. + """ + Pad a 1D tensor. - Args: - padding: padding to apply to each side of the input. - value: value to pad with. Defaults to 0.0. + Args: + padding: padding to apply to each side of the input. + value: value to pad with. Defaults to 0.0. - Note: - https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.pad.html - https://www.tensorflow.org/api_docs/python/tf/keras/layers/ZeroPadding1D - https://www.tensorflow.org/api_docs/python/tf/keras/layers/ZeroPadding2D - https://www.tensorflow.org/api_docs/python/tf/keras/layers/ZeroPadding3D - """ + Note: + https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.pad.html + https://www.tensorflow.org/api_docs/python/tf/keras/layers/ZeroPadding1D + https://www.tensorflow.org/api_docs/python/tf/keras/layers/ZeroPadding2D + https://www.tensorflow.org/api_docs/python/tf/keras/layers/ZeroPadding3D + """ + + def __init__(self, padding: int | tuple[int, int], value: float = 0.0): super().__init__(padding=padding, value=value) @property diff --git a/serket/nn/pooling.py b/serket/nn/pooling.py index 741a72e..740d586 100644 --- a/serket/nn/pooling.py +++ b/serket/nn/pooling.py @@ -21,8 +21,8 @@ import jax import jax.numpy as jnp import kernex as kex -import pytreeclass as pytc +import serket as sk from serket.nn.utils import ( KernelSizeType, PaddingType, @@ -33,10 +33,10 @@ ) # Based on colab hardware benchmarks `kernex` seems to -# be faster on CPU and on par with JAX on GPU. +# be faster on CPU and on par with JAX on GPU for low number of channels. -class GeneralPoolND(pytc.TreeClass): +class GeneralPoolND(sk.TreeClass): def __init__( self, kernel_size: KernelSizeType, @@ -114,7 +114,7 @@ def __init__( ) -class GlobalPoolND(pytc.TreeClass): +class GlobalPoolND(sk.TreeClass): def __init__(self, keepdims: bool = True, operation: Callable = jnp.mean): """Apply global pooling to the input with function `func` applied to the kernel. @@ -137,7 +137,7 @@ def spatial_ndim(self) -> int: ... -class AdaptivePoolND(pytc.TreeClass): +class AdaptivePoolND(sk.TreeClass): output_size: tuple[int, ...] def __init__(self, output_size: tuple[int, ...], *, func: Callable = None): @@ -184,6 +184,14 @@ def spatial_ndim(self) -> int: class MaxPool1D(GeneralPoolND): + """1D Max Pooling layer + + Args: + kernel_size: size of the kernel + strides: strides of the kernel + padding: padding of the kernel (valid, same) or tuple of ints + """ + def __init__( self, kernel_size: int, @@ -191,12 +199,6 @@ def __init__( *, padding: str = "valid", ): - """1D Max Pooling layer - Args: - kernel_size: size of the kernel - strides: strides of the kernel - padding: padding of the kernel (valid, same) or tuple of ints - """ super().__init__( kernel_size=kernel_size, strides=strides, @@ -210,6 +212,13 @@ def spatial_ndim(self) -> int: class MaxPool2D(GeneralPoolND): + """2D Max Pooling layer + Args: + kernel_size: size of the kernel + strides: strides of the kernel + padding: padding of the kernel (valid, same) or tuple of ints + """ + def __init__( self, kernel_size: int, @@ -217,12 +226,6 @@ def __init__( *, padding: str = "valid", ): - """2D Max Pooling layer - Args: - kernel_size: size of the kernel - strides: strides of the kernel - padding: padding of the kernel (valid, same) or tuple of ints - """ super().__init__( kernel_size=kernel_size, strides=strides, @@ -236,6 +239,13 @@ def spatial_ndim(self) -> int: class MaxPool3D(GeneralPoolND): + """3D Max Pooling layer + Args: + kernel_size: size of the kernel + strides: strides of the kernel + padding: padding of the kernel (valid, same) or tuple of ints + """ + def __init__( self, kernel_size: int, @@ -243,12 +253,6 @@ def __init__( *, padding: str = "valid", ): - """3D Max Pooling layer - Args: - kernel_size: size of the kernel - strides: strides of the kernel - padding: padding of the kernel (valid, same) or tuple of ints - """ super().__init__( kernel_size=kernel_size, strides=strides, @@ -262,6 +266,13 @@ def spatial_ndim(self) -> int: class AvgPool1D(GeneralPoolND): + """1D Average Pooling layer + Args: + kernel_size: size of the kernel + strides: strides of the kernel + padding: padding of the kernel (valid, same) or tuple of ints + """ + def __init__( self, kernel_size: int, @@ -269,12 +280,6 @@ def __init__( *, padding: str = "valid", ): - """1D Average Pooling layer - Args: - kernel_size: size of the kernel - strides: strides of the kernel - padding: padding of the kernel (valid, same) or tuple of ints - """ super().__init__( kernel_size=kernel_size, strides=strides, @@ -288,6 +293,13 @@ def spatial_ndim(self) -> int: class AvgPool2D(GeneralPoolND): + """2D Average Pooling layer + Args: + kernel_size: size of the kernel + strides: strides of the kernel + padding: padding of the kernel (valid, same) or tuple of ints + """ + def __init__( self, kernel_size: int, @@ -295,12 +307,6 @@ def __init__( *, padding: str = "valid", ): - """2D Average Pooling layer - Args: - kernel_size: size of the kernel - strides: strides of the kernel - padding: padding of the kernel (valid, same) or tuple of ints - """ super().__init__( kernel_size=kernel_size, strides=strides, @@ -314,6 +320,13 @@ def spatial_ndim(self) -> int: class AvgPool3D(GeneralPoolND): + """3D Average Pooling layer + Args: + kernel_size: size of the kernel + strides: strides of the kernel + padding: padding of the kernel (valid, same) or tuple of ints + """ + def __init__( self, kernel_size: int, @@ -321,12 +334,6 @@ def __init__( *, padding: str = "valid", ): - """3D Average Pooling layer - Args: - kernel_size: size of the kernel - strides: strides of the kernel - padding: padding of the kernel (valid, same) or tuple of ints - """ super().__init__( kernel_size=kernel_size, strides=strides, @@ -340,6 +347,15 @@ def spatial_ndim(self) -> int: class LPPool1D(LPPoolND): + """1D Lp pooling to the input. + + Args: + norm_type: norm type + kernel_size: size of the kernel + strides: strides of the kernel + padding: padding of the kernel + """ + def __init__( self, norm_type: float, @@ -348,14 +364,6 @@ def __init__( *, padding: PaddingType = "valid", ): - """1D Lp pooling to the input. - - Args: - norm_type: norm type - kernel_size: size of the kernel - strides: strides of the kernel - padding: padding of the kernel - """ super().__init__( norm_type=norm_type, kernel_size=kernel_size, @@ -369,6 +377,15 @@ def spatial_ndim(self) -> int: class LPPool2D(LPPoolND): + """2D Lp pooling to the input. + + Args: + norm_type: norm type + kernel_size: size of the kernel + strides: strides of the kernel + padding: padding of the kernel + """ + def __init__( self, norm_type: float, @@ -377,14 +394,6 @@ def __init__( *, padding: tuple[tuple[int, int], ...] | str = "valid", ): - """2D Lp pooling to the input. - - Args: - norm_type: norm type - kernel_size: size of the kernel - strides: strides of the kernel - padding: padding of the kernel - """ super().__init__( norm_type=norm_type, kernel_size=kernel_size, @@ -398,6 +407,15 @@ def spatial_ndim(self) -> int: class LPPool3D(LPPoolND): + """3D Lp pooling to the input. + + Args: + norm_type: norm type + kernel_size: size of the kernel + strides: strides of the kernel + padding: padding of the kernel + """ + def __init__( self, norm_type: float, @@ -406,14 +424,6 @@ def __init__( *, padding: PaddingType = "valid", ): - """3D Lp pooling to the input. - - Args: - norm_type: norm type - kernel_size: size of the kernel - strides: strides of the kernel - padding: padding of the kernel - """ super().__init__( norm_type=norm_type, kernel_size=kernel_size, @@ -424,11 +434,12 @@ def __init__( class GlobalAvgPool1D(GlobalPoolND): + """1D Global Average Pooling layer + Args: + keepdims: whether to keep the dimensions or not + """ + def __init__(self, keepdims: bool = True): - """1D Global Average Pooling layer - Args: - keepdims: whether to keep the dimensions or not - """ super().__init__(operation=jnp.mean, keepdims=keepdims) @property @@ -437,11 +448,12 @@ def spatial_ndim(self) -> int: class GlobalAvgPool2D(GlobalPoolND): + """2D Global Average Pooling layer + Args: + keepdims: whether to keep the dimensions or not + """ + def __init__(self, keepdims: bool = True): - """2D Global Average Pooling layer - Args: - keepdims: whether to keep the dimensions or not - """ super().__init__(operation=jnp.mean, keepdims=keepdims) @property @@ -450,11 +462,12 @@ def spatial_ndim(self) -> int: class GlobalAvgPool3D(GlobalPoolND): + """3D Global Average Pooling layer + Args: + keepdims: whether to keep the dimensions or not + """ + def __init__(self, keepdims: bool = True): - """3D Global Average Pooling layer - Args: - keepdims: whether to keep the dimensions or not - """ super().__init__(operation=jnp.mean, keepdims=keepdims) @property @@ -463,11 +476,12 @@ def spatial_ndim(self) -> int: class GlobalMaxPool1D(GlobalPoolND): + """1D Global Max Pooling layer + Args: + keepdims: whether to keep the dimensions or not + """ + def __init__(self, keepdims: bool = True): - """1D Global Max Pooling layer - Args: - keepdims: whether to keep the dimensions or not - """ super().__init__(operation=jnp.max, keepdims=keepdims) @property @@ -476,11 +490,12 @@ def spatial_ndim(self) -> int: class GlobalMaxPool2D(GlobalPoolND): + """2D Global Max Pooling layer + Args: + keepdims: whether to keep the dimensions or not + """ + def __init__(self, keepdims: bool = True): - """2D Global Max Pooling layer - Args: - keepdims: whether to keep the dimensions or not - """ super().__init__(operation=jnp.max, keepdims=keepdims) @property @@ -489,11 +504,12 @@ def spatial_ndim(self) -> int: class GlobalMaxPool3D(GlobalPoolND): + """3D Global Max Pooling layer + Args: + keepdims: whether to keep the dimensions or not + """ + def __init__(self, keepdims: bool = True): - """3D Global Max Pooling layer - Args: - keepdims: whether to keep the dimensions or not - """ super().__init__(operation=jnp.max, keepdims=keepdims) @property @@ -502,11 +518,12 @@ def spatial_ndim(self) -> int: class AdaptiveAvgPool1D(AdaptivePoolND): + """1D Adaptive Average Pooling layer + Args: + output_size: size of the output + """ + def __init__(self, output_size: tuple[int, ...]): - """1D Adaptive Average Pooling layer - Args: - output_size: size of the output - """ super().__init__(output_size=output_size, func=jnp.mean) @property @@ -515,11 +532,12 @@ def spatial_ndim(self) -> int: class AdaptiveAvgPool2D(AdaptivePoolND): + """2D Adaptive Average Pooling layer + Args: + output_size: size of the output + """ + def __init__(self, output_size: tuple[int, ...]): - """2D Adaptive Average Pooling layer - Args: - output_size: size of the output - """ super().__init__(output_size=output_size, func=jnp.mean) @property @@ -528,11 +546,12 @@ def spatial_ndim(self) -> int: class AdaptiveAvgPool3D(AdaptivePoolND): + """3D Adaptive Average Pooling layer + Args: + output_size: size of the output + """ + def __init__(self, output_size: tuple[int, ...]): - """3D Adaptive Average Pooling layer - Args: - output_size: size of the output - """ super().__init__(output_size=output_size, func=jnp.mean) @property @@ -541,11 +560,12 @@ def spatial_ndim(self) -> int: class AdaptiveMaxPool1D(AdaptivePoolND): + """1D Adaptive Max Pooling layer + Args: + output_size: size of the output + """ + def __init__(self, output_size: tuple[int, ...]): - """1D Adaptive Max Pooling layer - Args: - output_size: size of the output - """ super().__init__(output_size=output_size, func=jnp.max) @property @@ -554,11 +574,12 @@ def spatial_ndim(self) -> int: class AdaptiveMaxPool2D(AdaptivePoolND): + """2D Adaptive Max Pooling layer + Args: + output_size: size of the output + """ + def __init__(self, output_size: tuple[int, ...]): - """2D Adaptive Max Pooling layer - Args: - output_size: size of the output - """ super().__init__(output_size=output_size, func=jnp.max) @property @@ -567,11 +588,12 @@ def spatial_ndim(self) -> int: class AdaptiveMaxPool3D(AdaptivePoolND): + """3D Adaptive Max Pooling layer + Args: + output_size: size of the output + """ + def __init__(self, output_size: tuple[int, ...]): - """3D Adaptive Max Pooling layer - Args: - output_size: size of the output - """ super().__init__(output_size=output_size, func=jnp.max) @property diff --git a/serket/nn/preprocessing.py b/serket/nn/preprocessing.py index 020cde4..38ab07c 100644 --- a/serket/nn/preprocessing.py +++ b/serket/nn/preprocessing.py @@ -18,23 +18,25 @@ import jax import jax.numpy as jnp -import pytreeclass as pytc +import serket as sk from serket.nn.utils import positive_int_cb, validate_spatial_ndim -class HistogramEqualization2D(pytc.TreeClass): +class HistogramEqualization2D(sk.TreeClass): + """Apply histogram equalization to 2D spatial array channel wise + + Args: + bins: number of bins. Defaults to 256. + + Note: + https://en.wikipedia.org/wiki/Histogram_equalization + http://www.janeriksolem.net/histogram-equalization-with-python-and.html + https://scikit-image.org/docs/stable/api/skimage.exposure.html#skimage.exposure.equalize_hist + https://stackoverflow.com/questions/28518684/histogram-equalization-of-grayscale-images-with-numpy + """ + def __init__(self, bins: int = 256): - """Apply histogram equalization to 2D spatial array channel wise - Args: - bins: number of bins. Defaults to 256. - - Note: - https://en.wikipedia.org/wiki/Histogram_equalization - http://www.janeriksolem.net/histogram-equalization-with-python-and.html - https://scikit-image.org/docs/stable/api/skimage.exposure.html#skimage.exposure.equalize_hist - https://stackoverflow.com/questions/28518684/histogram-equalization-of-grayscale-images-with-numpy - """ self.bins = positive_int_cb(bins) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @@ -51,7 +53,17 @@ def spatial_ndim(self) -> int: return 2 -class PixelShuffle2D(pytc.TreeClass): +class PixelShuffle2D(sk.TreeClass): + """Rearrange elements in a tensor. + + Args: + upscale_factor: factor to increase spatial resolution by. accepts a + single integer or a tuple of length 2. defaults to 1. + + Note: + https://arxiv.org/abs/1609.05158 + """ + def __init__(self, upscale_factor: int | tuple[int, int] = 1): if isinstance(upscale_factor, int): if upscale_factor < 1: diff --git a/serket/nn/random_transform.py b/serket/nn/random_transform.py index b045f69..0209e69 100644 --- a/serket/nn/random_transform.py +++ b/serket/nn/random_transform.py @@ -18,16 +18,16 @@ import jax import jax.random as jr -import pytreeclass as pytc from jax.lax import stop_gradient +import serket as sk from serket.nn.crop import RandomCrop2D from serket.nn.padding import Pad2D from serket.nn.resize import Resize2D from serket.nn.utils import Range -class RandomApply(pytc.TreeClass): +class RandomApply(sk.TreeClass): """ Randomly applies a layer with probability p. @@ -36,6 +36,8 @@ class RandomApply(pytc.TreeClass): p: probability of applying the layer Example: + >>> import serket as sk + >>> import jax.numpy as jnp >>> layer = RandomApply(sk.nn.MaxPool2D(kernel_size=2, strides=2), p=0.0) >>> layer(jnp.ones((1, 10, 10))).shape (1, 10, 10) @@ -50,7 +52,7 @@ class RandomApply(pytc.TreeClass): """ layer: Any - p: float = pytc.field(default=0.5, callbacks=[Range(0, 1)]) + p: float = sk.field(default=0.5, callbacks=[Range(0, 1)]) def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)): if not jr.bernoulli(key, stop_gradient(self.p)): @@ -58,7 +60,7 @@ def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)): return self.layer(x) -class RandomZoom2D(pytc.TreeClass): +class RandomZoom2D(sk.TreeClass): def __init__( self, height_factor: tuple[float, float] = (0.0, 1.0), @@ -85,17 +87,20 @@ def __init__( def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array: keys = jr.split(key, 4) + height_factor = jax.lax.stop_gradient(self.height_factor) + width_factor = jax.lax.stop_gradient(self.width_factor) + height_factor = jr.uniform( keys[0], shape=(), - minval=self.height_factor[0], - maxval=self.height_factor[1], + minval=height_factor[0], + maxval=height_factor[1], ) width_factor = jr.uniform( keys[1], shape=(), - minval=self.width_factor[0], - maxval=self.width_factor[1], + minval=width_factor[0], + maxval=width_factor[1], ) R, C = x.shape[1:3] # R = rows, C = cols diff --git a/serket/nn/recurrent.py b/serket/nn/recurrent.py index 4e32764..f662405 100644 --- a/serket/nn/recurrent.py +++ b/serket/nn/recurrent.py @@ -21,7 +21,6 @@ import jax import jax.numpy as jnp import jax.random as jr -import pytreeclass as pytc import serket as sk from serket.nn.activation import ActivationType, resolve_activation @@ -42,11 +41,11 @@ # Non Spatial RNN -class RNNState(pytc.TreeClass): +class RNNState(sk.TreeClass): hidden_state: jax.Array -class RNNCell(pytc.TreeClass): +class RNNCell(sk.TreeClass): """Abstract class for RNN cells. Subclasses must implement: @@ -163,6 +162,26 @@ class DenseState(RNNState): class DenseCell(RNNCell): + """No hidden state cell that applies a dense(Linear+activation) layer to the input + + Args: + in_features: the number of input features + hidden_features: the number of hidden features + weight_init_func: the function to use to initialize the weights + bias_init_func: the function to use to initialize the bias + act_func: the activation function to use for the hidden state update, + use `None` for no activation + key: the key to use to initialize the weights + + Example: + >>> cell = DenseCell(10, 20) # 10-dimensional input, 20-dimensional hidden state + >>> dummy_state = cell.init_state() # 20-dimensional hidden state + >>> x = jnp.ones((10,)) # 10 features + >>> result = cell(x, dummy_state) + >>> result.hidden_state.shape # 20 features + (20,) + """ + def __init__( self, in_features: int, @@ -173,26 +192,6 @@ def __init__( act_func: ActivationType = jax.nn.tanh, key: jr.KeyArray = jr.PRNGKey(0), ): - """No hidden state cell that applies a dense(Linear+activation) layer to the input - - Args: - in_features: the number of input features - hidden_features: the number of hidden features - weight_init_func: the function to use to initialize the weights - bias_init_func: the function to use to initialize the bias - act_func: the activation function to use for the hidden state update, - use `None` for no activation - key: the key to use to initialize the weights - - Example: - >>> cell = DenseCell(10, 20) # 10-dimensional input, 20-dimensional hidden state - >>> dummy_state = cell.init_state() # 20-dimensional hidden state - >>> x = jnp.ones((10,)) # 10 features - >>> result = cell(x, dummy_state) - >>> result.hidden_state.shape # 20 features - (20,) - """ - self.in_features = positive_int_cb(in_features) self.hidden_features = positive_int_cb(hidden_features) self.act_func = resolve_activation(act_func) @@ -229,6 +228,23 @@ class LSTMState(RNNState): class LSTMCell(RNNCell): + """LSTM cell that defines the update rule for the hidden state and cell state + + Args: + in_features: the number of input features + hidden_features: the number of hidden features + weight_init_func: the function to use to initialize the weights + bias_init_func: the function to use to initialize the bias + recurrent_weight_init_func: the function to use to initialize the recurrent weights + act_func: the activation function to use for the hidden state update + recurrent_act_func: the activation function to use for the cell state update + key: the key to use to initialize the weights + + Note: + https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTMCell + https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/recurrent.py + """ + def __init__( self, in_features: int, @@ -241,21 +257,6 @@ def __init__( recurrent_act_func: ActivationType | None = "sigmoid", key: jr.KeyArray = jr.PRNGKey(0), ): - """LSTM cell that defines the update rule for the hidden state and cell state - Args: - in_features: the number of input features - hidden_features: the number of hidden features - weight_init_func: the function to use to initialize the weights - bias_init_func: the function to use to initialize the bias - recurrent_weight_init_func: the function to use to initialize the recurrent weights - act_func: the activation function to use for the hidden state update - recurrent_act_func: the activation function to use for the cell state update - key: the key to use to initialize the weights - - Note: - https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTMCell - https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/recurrent.py - """ k1, k2 = jr.split(key, 2) self.in_features = positive_int_cb(in_features) @@ -313,6 +314,22 @@ class GRUState(RNNState): class GRUCell(RNNCell): + """GRU cell that defines the update rule for the hidden state and cell state + + Args: + in_features: the number of input features + hidden_features: the number of hidden features + weight_init_func: the function to use to initialize the weights + bias_init_func: the function to use to initialize the bias + recurrent_weight_init_func: the function to use to initialize the recurrent weights + act_func: the activation function to use for the hidden state update + recurrent_act_func: the activation function to use for the cell state update + key: the key to use to initialize the weights + + Note: + https://keras.io/api/layers/recurrent_layers/gru/ + """ + def __init__( self, in_features: int, @@ -325,20 +342,6 @@ def __init__( recurrent_act_func: ActivationType | None = "sigmoid", key: jr.KeyArray = jr.PRNGKey(0), ): - """GRU cell that defines the update rule for the hidden state and cell state - Args: - in_features: the number of input features - hidden_features: the number of hidden features - weight_init_func: the function to use to initialize the weights - bias_init_func: the function to use to initialize the bias - recurrent_weight_init_func: the function to use to initialize the recurrent weights - act_func: the activation function to use for the hidden state update - recurrent_act_func: the activation function to use for the cell state update - key: the key to use to initialize the weights - - See: - https://keras.io/api/layers/recurrent_layers/gru/ - """ k1, k2 = jr.split(key, 2) self.in_features = positive_int_cb(in_features) @@ -395,6 +398,27 @@ class ConvLSTMNDState(RNNState): class ConvLSTMNDCell(RNNCell): + """Convolution LSTM cell that defines the update rule for the hidden state and cell state + + Args: + in_features: Number of input features + hidden_features: Number of output features + kernel_size: Size of the convolutional kernel + strides: Stride of the convolution + padding: Padding of the convolution + input_dilation: Dilation of the input + kernel_dilation: Dilation of the convolutional kernel + weight_init_func: Weight initialization function + bias_init_func: Bias initialization function + recurrent_weight_init_func: Recurrent weight initialization function + act_func: Activation function + recurrent_act_func: Recurrent activation function + key: PRNG key + + Note: + https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D + """ + def __init__( self, in_features: int, @@ -413,24 +437,6 @@ def __init__( key: jr.KeyArray = jr.PRNGKey(0), conv_layer: Any = None, ): - """Convolution LSTM cell that defines the update rule for the hidden state and cell state - Args: - in_features: Number of input features - hidden_features: Number of output features - kernel_size: Size of the convolutional kernel - strides: Stride of the convolution - padding: Padding of the convolution - input_dilation: Dilation of the input - kernel_dilation: Dilation of the convolutional kernel - weight_init_func: Weight initialization function - bias_init_func: Bias initialization function - recurrent_weight_init_func: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function - key: PRNG key - - See: https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D - """ k1, k2 = jr.split(key, 2) self.in_features = positive_int_cb(in_features) @@ -489,6 +495,27 @@ def init_state(self, spatial_dim: tuple[int, ...]) -> ConvLSTMNDState: class ConvLSTM1DCell(ConvLSTMNDCell): + """1D Convolution LSTM cell that defines the update rule for the hidden state and cell state + Args: + in_features: Number of input features + hidden_features: Number of output features + kernel_size: Size of the convolutional kernel + strides: Stride of the convolution + padding: Padding of the convolution + input_dilation: Dilation of the input + kernel_dilation: Dilation of the convolutional kernel + weight_init_func: Weight initialization function + bias_init_func: Bias initialization function + recurrent_weight_init_func: Recurrent weight initialization function + act_func: Activation function + recurrent_act_func: Recurrent activation function + key: PRNG key + spatial_ndim: Number of spatial dimensions. + + Note: + https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D + """ + def __init__( self, in_features: int, @@ -506,26 +533,6 @@ def __init__( recurrent_act_func: ActivationType | None = "hard_sigmoid", key: jr.KeyArray = jr.PRNGKey(0), ): - """1D Convolution LSTM cell that defines the update rule for the hidden state and cell state - Args: - in_features: Number of input features - hidden_features: Number of output features - kernel_size: Size of the convolutional kernel - strides: Stride of the convolution - padding: Padding of the convolution - input_dilation: Dilation of the input - kernel_dilation: Dilation of the convolutional kernel - weight_init_func: Weight initialization function - bias_init_func: Bias initialization function - recurrent_weight_init_func: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function - key: PRNG key - spatial_ndim: Number of spatial dimensions. - - Note: - https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D - """ super().__init__( in_features=in_features, hidden_features=hidden_features, @@ -549,6 +556,27 @@ def spatial_ndim(self) -> int: class ConvLSTM2DCell(ConvLSTMNDCell): + """2D Convolution LSTM cell that defines the update rule for the hidden state and cell state + Args: + in_features: Number of input features + hidden_features: Number of output features + kernel_size: Size of the convolutional kernel + strides: Stride of the convolution + padding: Padding of the convolution + input_dilation: Dilation of the input + kernel_dilation: Dilation of the convolutional kernel + weight_init_func: Weight initialization function + bias_init_func: Bias initialization function + recurrent_weight_init_func: Recurrent weight initialization function + act_func: Activation function + recurrent_act_func: Recurrent activation function + key: PRNG key + spatial_ndim: Number of spatial dimensions. + + Note: + https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D + """ + def __init__( self, in_features: int, @@ -566,26 +594,6 @@ def __init__( recurrent_act_func: ActivationType | None = "hard_sigmoid", key: jr.KeyArray = jr.PRNGKey(0), ): - """2D Convolution LSTM cell that defines the update rule for the hidden state and cell state - Args: - in_features: Number of input features - hidden_features: Number of output features - kernel_size: Size of the convolutional kernel - strides: Stride of the convolution - padding: Padding of the convolution - input_dilation: Dilation of the input - kernel_dilation: Dilation of the convolutional kernel - weight_init_func: Weight initialization function - bias_init_func: Bias initialization function - recurrent_weight_init_func: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function - key: PRNG key - spatial_ndim: Number of spatial dimensions. - - Note: - https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D - """ super().__init__( in_features=in_features, hidden_features=hidden_features, @@ -609,6 +617,27 @@ def spatial_ndim(self) -> int: class ConvLSTM3DCell(ConvLSTMNDCell): + """3D Convolution LSTM cell that defines the update rule for the hidden state and cell state + Args: + in_features: Number of input features + hidden_features: Number of output features + kernel_size: Size of the convolutional kernel + strides: Stride of the convolution + padding: Padding of the convolution + input_dilation: Dilation of the input + kernel_dilation: Dilation of the convolutional kernel + weight_init_func: Weight initialization function + bias_init_func: Bias initialization function + recurrent_weight_init_func: Recurrent weight initialization function + act_func: Activation function + recurrent_act_func: Recurrent activation function + key: PRNG key + spatial_ndim: Number of spatial dimensions. + + Note: + https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D + """ + def __init__( self, in_features: int, @@ -626,26 +655,6 @@ def __init__( recurrent_act_func: ActivationType | None = "hard_sigmoid", key: jr.KeyArray = jr.PRNGKey(0), ): - """3D Convolution LSTM cell that defines the update rule for the hidden state and cell state - Args: - in_features: Number of input features - hidden_features: Number of output features - kernel_size: Size of the convolutional kernel - strides: Stride of the convolution - padding: Padding of the convolution - input_dilation: Dilation of the input - kernel_dilation: Dilation of the convolutional kernel - weight_init_func: Weight initialization function - bias_init_func: Bias initialization function - recurrent_weight_init_func: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function - key: PRNG key - spatial_ndim: Number of spatial dimensions. - - Note: - https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D - """ super().__init__( in_features=in_features, hidden_features=hidden_features, @@ -673,6 +682,25 @@ class ConvGRUNDState(RNNState): class ConvGRUNDCell(RNNCell): + """Convolution GRU cell that defines the update rule for the hidden state and cell state + Args: + in_features: Number of input features + hidden_features: Number of output features + kernel_size: Size of the convolutional kernel + strides: Stride of the convolution + padding: Padding of the convolution + input_dilation: Dilation of the input + kernel_dilation: Dilation of the convolutional kernel + weight_init_func: Weight initialization function + bias_init_func: Bias initialization function + recurrent_weight_init_func: Recurrent weight initialization function + act_func: Activation function + recurrent_act_func: Recurrent activation function + key: PRNG key + spatial_ndim: Number of spatial dimensions. + + """ + def __init__( self, in_features: int, @@ -691,24 +719,6 @@ def __init__( key: jr.KeyArray = jr.PRNGKey(0), conv_layer: Any = None, ): - """Convolution GRU cell that defines the update rule for the hidden state and cell state - Args: - in_features: Number of input features - hidden_features: Number of output features - kernel_size: Size of the convolutional kernel - strides: Stride of the convolution - padding: Padding of the convolution - input_dilation: Dilation of the input - kernel_dilation: Dilation of the convolutional kernel - weight_init_func: Weight initialization function - bias_init_func: Bias initialization function - recurrent_weight_init_func: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function - key: PRNG key - spatial_ndim: Number of spatial dimensions. - - """ k1, k2 = jr.split(key, 2) self.in_features = positive_int_cb(in_features) @@ -764,6 +774,25 @@ def init_state(self, spatial_dim: tuple[int, ...]) -> ConvGRUNDState: class ConvGRU1DCell(ConvGRUNDCell): + """1D Convolution GRU cell that defines the update rule for the hidden state and cell state + Args: + in_features: Number of input features + hidden_features: Number of output features + kernel_size: Size of the convolutional kernel + strides: Stride of the convolution + padding: Padding of the convolution + input_dilation: Dilation of the input + kernel_dilation: Dilation of the convolutional kernel + weight_init_func: Weight initialization function + bias_init_func: Bias initialization function + recurrent_weight_init_func: Recurrent weight initialization function + act_func: Activation function + recurrent_act_func: Recurrent activation function + key: PRNG key + spatial_ndim: Number of spatial dimensions. + + """ + def __init__( self, in_features: int, @@ -781,24 +810,6 @@ def __init__( recurrent_act_func: ActivationType | None = "sigmoid", key: jr.KeyArray = jr.PRNGKey(0), ): - """1D Convolution GRU cell that defines the update rule for the hidden state and cell state - Args: - in_features: Number of input features - hidden_features: Number of output features - kernel_size: Size of the convolutional kernel - strides: Stride of the convolution - padding: Padding of the convolution - input_dilation: Dilation of the input - kernel_dilation: Dilation of the convolutional kernel - weight_init_func: Weight initialization function - bias_init_func: Bias initialization function - recurrent_weight_init_func: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function - key: PRNG key - spatial_ndim: Number of spatial dimensions. - - """ super().__init__( in_features=in_features, hidden_features=hidden_features, @@ -822,6 +833,25 @@ def spatial_ndim(self) -> int: class ConvGRU2DCell(ConvGRUNDCell): + """2D Convolution GRU cell that defines the update rule for the hidden state and cell state + Args: + in_features: Number of input features + hidden_features: Number of output features + kernel_size: Size of the convolutional kernel + strides: Stride of the convolution + padding: Padding of the convolution + input_dilation: Dilation of the input + kernel_dilation: Dilation of the convolutional kernel + weight_init_func: Weight initialization function + bias_init_func: Bias initialization function + recurrent_weight_init_func: Recurrent weight initialization function + act_func: Activation function + recurrent_act_func: Recurrent activation function + key: PRNG key + spatial_ndim: Number of spatial dimensions. + + """ + def __init__( self, in_features: int, @@ -839,24 +869,6 @@ def __init__( recurrent_act_func: ActivationType | None = "sigmoid", key: jr.KeyArray = jr.PRNGKey(0), ): - """2D Convolution GRU cell that defines the update rule for the hidden state and cell state - Args: - in_features: Number of input features - hidden_features: Number of output features - kernel_size: Size of the convolutional kernel - strides: Stride of the convolution - padding: Padding of the convolution - input_dilation: Dilation of the input - kernel_dilation: Dilation of the convolutional kernel - weight_init_func: Weight initialization function - bias_init_func: Bias initialization function - recurrent_weight_init_func: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function - key: PRNG key - spatial_ndim: Number of spatial dimensions. - - """ super().__init__( in_features=in_features, hidden_features=hidden_features, @@ -880,6 +892,25 @@ def spatial_ndim(self) -> int: class ConvGRU3DCell(ConvGRUNDCell): + """3D Convolution GRU cell that defines the update rule for the hidden state and cell state + Args: + in_features: Number of input features + hidden_features: Number of output features + kernel_size: Size of the convolutional kernel + strides: Stride of the convolution + padding: Padding of the convolution + input_dilation: Dilation of the input + kernel_dilation: Dilation of the convolutional kernel + weight_init_func: Weight initialization function + bias_init_func: Bias initialization function + recurrent_weight_init_func: Recurrent weight initialization function + act_func: Activation function + recurrent_act_func: Recurrent activation function + key: PRNG key + spatial_ndim: Number of spatial dimensions. + + """ + def __init__( self, in_features: int, @@ -897,24 +928,6 @@ def __init__( recurrent_act_func: ActivationType | None = "sigmoid", key: jr.KeyArray = jr.PRNGKey(0), ): - """3D Convolution GRU cell that defines the update rule for the hidden state and cell state - Args: - in_features: Number of input features - hidden_features: Number of output features - kernel_size: Size of the convolutional kernel - strides: Stride of the convolution - padding: Padding of the convolution - input_dilation: Dilation of the input - kernel_dilation: Dilation of the convolutional kernel - weight_init_func: Weight initialization function - bias_init_func: Bias initialization function - recurrent_weight_init_func: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function - key: PRNG key - spatial_ndim: Number of spatial dimensions. - - """ super().__init__( in_features=in_features, hidden_features=hidden_features, @@ -941,27 +954,31 @@ def spatial_ndim(self) -> int: # Scanning API -class ScanRNN(pytc.TreeClass): +class ScanRNN(sk.TreeClass): + """Scans RNN cell over a sequence. + + Args: + cell: the RNN cell to use. + backward_cell: the RNN cell to use for bidirectional scanning. + return_sequences: whether to return the hidden state for each timestep. + + Example: + >>> cell = SimpleRNNCell(10, 20) # 10-dimensional input, 20-dimensional hidden state + >>> rnn = ScanRNN(cell) + >>> x = jnp.ones((5, 10)) # 5 timesteps, 10 features + >>> result = rnn(x) # 20 features + """ + + # cell: RNN + def __init__( self, cell: RNNCell, backward_cell: RNNCell | None = None, *, return_sequences: bool = False, + return_state: bool = False, ): - """Scans RNN cell over a sequence. - - Args: - cell: the RNN cell to use - backward_cell: the RNN cell to use for bidirectional scanning. - return_sequences: whether to return the hidden state for each timestep - - Example: - >>> cell = SimpleRNNCell(10, 20) # 10-dimensional input, 20-dimensional hidden state - >>> rnn = ScanRNN(cell) - >>> x = jnp.ones((5, 10)) # 5 timesteps, 10 features - >>> result = rnn(x) # 20 features - """ if not isinstance(cell, RNNCell): raise TypeError(f"Expected {cell=} to be an instance of RNNCell.") @@ -975,7 +992,7 @@ def __init__( def __call__( self, x: jax.Array, - state: RNNCell | None = None, + state: RNNState | None = None, backward_state: RNNState | None = None, **k, ) -> jax.Array: diff --git a/serket/nn/resize.py b/serket/nn/resize.py index 2538179..21f4ec0 100644 --- a/serket/nn/resize.py +++ b/serket/nn/resize.py @@ -19,14 +19,14 @@ from typing import Literal import jax -import pytreeclass as pytc +import serket as sk from serket.nn.utils import canonicalize, validate_spatial_ndim MethodKind = Literal["nearest", "linear", "cubic", "lanczos3", "lanczos5"] -class ResizeND(pytc.TreeClass): +class ResizeND(sk.TreeClass): """ Resize an image to a given size using a given interpolation method. @@ -72,7 +72,7 @@ def spatial_ndim(self) -> int: ... -class UpsampleND(pytc.TreeClass): +class UpsampleND(sk.TreeClass): def __init__( self, scale: int | tuple[int, ...] = 1, @@ -101,28 +101,29 @@ def spatial_ndim(self) -> int: class Resize1D(ResizeND): + """Resize a 1D input to a given size using a given interpolation method. + + Args: + size: the size of the output. if size is None, the output size is + calculated as input size * scale + method: the method of interpolation. Defaults to "nearest". choices are: + - "nearest": Nearest neighbor interpolation. The values of antialias + and precision are ignored. + - "linear", "bilinear", "trilinear", "triangle": Linear interpolation. + If antialias is True, uses a triangular filter when downsampling. + - "cubic", "bicubic", "tricubic": Cubic interpolation, using the Keys + cubic kernel. + - "lanczos3": Lanczos resampling, using a kernel of radius 3. + - "lanczos5": Lanczos resampling, using a kernel of radius 5. + antialias: whether to use antialiasing. Defaults to True. + """ + def __init__( self, size: int | tuple[int, ...], method: MethodKind = "nearest", antialias=True, ): - """Resize a 1D input to a given size using a given interpolation method. - - Args: - size: the size of the output. if size is None, the output size is - calculated as input size * scale - method: the method of interpolation. Defaults to "nearest". choices are: - - "nearest": Nearest neighbor interpolation. The values of antialias - and precision are ignored. - - "linear", "bilinear", "trilinear", "triangle": Linear interpolation. - If antialias is True, uses a triangular filter when downsampling. - - "cubic", "bicubic", "tricubic": Cubic interpolation, using the Keys - cubic kernel. - - "lanczos3": Lanczos resampling, using a kernel of radius 3. - - "lanczos5": Lanczos resampling, using a kernel of radius 5. - antialias: whether to use antialiasing. Defaults to True. - """ super().__init__(size=size, method=method, antialias=antialias) @property @@ -131,28 +132,29 @@ def spatial_ndim(self) -> int: class Resize2D(ResizeND): + """Resize a 2D input to a given size using a given interpolation method. + + Args: + size: the size of the output. if size is None, the output size is + calculated as input size * scale + method: the method of interpolation. Defaults to "nearest". choices are: + - "nearest": Nearest neighbor interpolation. The values of antialias + and precision are ignored. + - "linear", "bilinear", "trilinear", "triangle": Linear interpolation. + If antialias is True, uses a triangular filter when downsampling. + - "cubic", "bicubic", "tricubic": Cubic interpolation, using the Keys + cubic kernel. + - "lanczos3": Lanczos resampling, using a kernel of radius 3. + - "lanczos5": Lanczos resampling, using a kernel of radius 5. + antialias: whether to use antialiasing. Defaults to True. + """ + def __init__( self, size: int | tuple[int, ...], method: MethodKind = "nearest", antialias=True, ): - """Resize a 2D input to a given size using a given interpolation method. - - Args: - size: the size of the output. if size is None, the output size is - calculated as input size * scale - method: the method of interpolation. Defaults to "nearest". choices are: - - "nearest": Nearest neighbor interpolation. The values of antialias - and precision are ignored. - - "linear", "bilinear", "trilinear", "triangle": Linear interpolation. - If antialias is True, uses a triangular filter when downsampling. - - "cubic", "bicubic", "tricubic": Cubic interpolation, using the Keys - cubic kernel. - - "lanczos3": Lanczos resampling, using a kernel of radius 3. - - "lanczos5": Lanczos resampling, using a kernel of radius 5. - antialias: whether to use antialiasing. Defaults to True. - """ super().__init__(size=size, method=method, antialias=antialias) @property @@ -161,28 +163,29 @@ def spatial_ndim(self) -> int: class Resize3D(ResizeND): + """Resize a 3D input to a given size using a given interpolation method. + + Args: + size: the size of the output. if size is None, the output size is + calculated as input size * scale + method: the method of interpolation. Defaults to "nearest". choices are: + - "nearest": Nearest neighbor interpolation. The values of antialias + and precision are ignored. + - "linear", "bilinear", "trilinear", "triangle": Linear interpolation. + If antialias is True, uses a triangular filter when downsampling. + - "cubic", "bicubic", "tricubic": Cubic interpolation, using the Keys + cubic kernel. + - "lanczos3": Lanczos resampling, using a kernel of radius 3. + - "lanczos5": Lanczos resampling, using a kernel of radius 5. + antialias: whether to use antialiasing. Defaults to True. + """ + def __init__( self, size: int | tuple[int, ...], method: MethodKind = "nearest", antialias=True, ): - """Resize a 3D input to a given size using a given interpolation method. - - Args: - size: the size of the output. if size is None, the output size is - calculated as input size * scale - method: the method of interpolation. Defaults to "nearest". choices are: - - "nearest": Nearest neighbor interpolation. The values of antialias - and precision are ignored. - - "linear", "bilinear", "trilinear", "triangle": Linear interpolation. - If antialias is True, uses a triangular filter when downsampling. - - "cubic", "bicubic", "tricubic": Cubic interpolation, using the Keys - cubic kernel. - - "lanczos3": Lanczos resampling, using a kernel of radius 3. - - "lanczos5": Lanczos resampling, using a kernel of radius 5. - antialias: whether to use antialiasing. Defaults to True. - """ super().__init__(size=size, method=method, antialias=antialias) @property @@ -191,13 +194,14 @@ def spatial_ndim(self) -> int: class Upsample1D(UpsampleND): - def __init__(self, scale: int, method: str = "nearest"): - """Upsample a 1D input to a given size using a given interpolation method. + """Upsample a 1D input to a given size using a given interpolation method. + + Args: + scale: the scale of the output. + method: the method of interpolation. Defaults to "nearest". + """ - Args: - scale: the scale of the output. - method: the method of interpolation. Defaults to "nearest". - """ + def __init__(self, scale: int, method: str = "nearest"): super().__init__(scale=scale, method=method) @property @@ -206,13 +210,14 @@ def spatial_ndim(self) -> int: class Upsample2D(UpsampleND): - def __init__(self, scale: int | tuple[int, int], method: str = "nearest"): - """Upsample a 2D input to a given size using a given interpolation method. + """Upsample a 2D input to a given size using a given interpolation method. - Args: - scale: the scale of the output. - method: the method of interpolation. Defaults to "nearest". - """ + Args: + scale: the scale of the output. + method: the method of interpolation. Defaults to "nearest". + """ + + def __init__(self, scale: int | tuple[int, int], method: str = "nearest"): super().__init__(scale=scale, method=method) @property @@ -221,13 +226,14 @@ def spatial_ndim(self) -> int: class Upsample3D(UpsampleND): - def __init__(self, scale: int | tuple[int, int, int], method: str = "nearest"): - """Upsample a 1D input to a given size using a given interpolation method. + """Upsample a 1D input to a given size using a given interpolation method. - Args: - scale: the scale of the output. - method: the method of interpolation. Defaults to "nearest". - """ + Args: + scale: the scale of the output. + method: the method of interpolation. Defaults to "nearest". + """ + + def __init__(self, scale: int | tuple[int, int, int], method: str = "nearest"): super().__init__(scale=scale, method=method) @property diff --git a/serket/nn/utils.py b/serket/nn/utils.py index 3ad423b..0ba2a1f 100644 --- a/serket/nn/utils.py +++ b/serket/nn/utils.py @@ -20,7 +20,8 @@ import jax import jax.numpy as jnp import numpy as np -import pytreeclass as pytc + +import serket as sk KernelSizeType = Union[int, Sequence[int]] StridesType = Union[int, Sequence[int]] @@ -178,7 +179,7 @@ def canonicalize(value, ndim, *, name: str | None = None): raise ValueError(f"Expected int or tuple , got {value=}.") -class Range(pytc.TreeClass): +class Range(sk.TreeClass): """Check if the input is in the range [min_val, max_val].""" min_val: float = -float("inf") @@ -190,7 +191,7 @@ def __call__(self, value: Any): raise ValueError(f"Not in range[{self.min_val}, {self.max_val}] got {value=}.") -class IsInstance(pytc.TreeClass): +class IsInstance(sk.TreeClass): """Check if the input is an instance of expected_type.""" predicted_type: type | Sequence[type] @@ -201,7 +202,7 @@ def __call__(self, value: Any): raise TypeError(f"Expected {self.predicted_type}, got {type(value).__name__}") -class ScalarLike(pytc.TreeClass): +class ScalarLike(sk.TreeClass): """Check if the input is a scalar""" def __call__(self, value: Any): diff --git a/tests/test_dropout.py b/tests/test_dropout.py index e3ba43f..ee3e5ae 100644 --- a/tests/test_dropout.py +++ b/tests/test_dropout.py @@ -15,8 +15,8 @@ import jax.numpy as jnp import numpy.testing as npt import pytest -import pytreeclass as pytc +import serket as sk from serket.nn import Dropout @@ -26,7 +26,7 @@ def test_dropout(): layer = Dropout(1.0) npt.assert_allclose(layer(x), jnp.array([0.0, 0.0, 0.0, 0.0, 0.0])) - layer = layer.at["p"].set(0.0, is_leaf=pytc.is_frozen) + layer = layer.at["p"].set(0.0, is_leaf=sk.is_frozen) npt.assert_allclose(layer(x), x) with pytest.raises(ValueError): diff --git a/tests/test_linear.py b/tests/test_linear.py index 91ee3d8..a01b137 100644 --- a/tests/test_linear.py +++ b/tests/test_linear.py @@ -17,8 +17,8 @@ import jax.tree_util as jtu import numpy.testing as npt import pytest -import pytreeclass as pytc +import serket as sk from serket.nn import ( FNN, Bilinear, @@ -46,7 +46,7 @@ def test_linear(): @jax.value_and_grad def loss_func(NN, x, y): - NN = NN.at[...].apply(pytc.unfreeze, is_leaf=pytc.is_frozen) + NN = sk.tree_unmask(NN) return jnp.mean((NN(x) - y) ** 2) @jax.jit @@ -54,19 +54,17 @@ def update(NN, x, y): value, grad = loss_func(NN, x, y) return value, jtu.tree_map(lambda x, g: x - 1e-3 * g, NN, grad) - NN = FNN( + nn = FNN( [1, 128, 128, 1], act_func="relu", weight_init_func="he_normal", bias_init_func="ones", ) - # NN = jtu.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, NN) - NN = NN.at[pytc.bcmap(pytc.is_nondiff)(NN)].apply(pytc.freeze) + nn = sk.tree_mask(nn) - # print(pytc.tree_diagram(NN)) for _ in range(20_000): - value, NN = update(NN, x, y) + value, nn = update(nn, x, y) npt.assert_allclose(jnp.array(4.933563e-05), value, atol=1e-3)