diff --git a/serket/image/__init__.py b/serket/image/__init__.py index 89fb687..29ed2f3 100644 --- a/serket/image/__init__.py +++ b/serket/image/__init__.py @@ -12,7 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .affine import ( +from .augment import ( + AdjustContrast2D, + JigSaw2D, + PixelShuffle2D, + Posterize2D, + RandomContrast2D, +) +from .filter import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D +from .geometric import ( HorizontalFlip2D, HorizontalShear2D, HorizontalTranslate2D, @@ -29,14 +37,6 @@ VerticalShear2D, VerticalTranslate2D, ) -from .augment import ( - AdjustContrast2D, - JigSaw2D, - PixelShuffle2D, - Posterize2D, - RandomContrast2D, -) -from .filter import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D __all__ = ( # image diff --git a/serket/image/filter.py b/serket/image/filter.py index 64b64bd..bbefefc 100644 --- a/serket/image/filter.py +++ b/serket/image/filter.py @@ -22,6 +22,7 @@ import serket as sk from serket.nn.convolution import DepthwiseConv2D, DepthwiseFFTConv2D +from serket.nn.initialization import DType from serket.utils import ( maybe_lazy_call, maybe_lazy_init, @@ -54,6 +55,7 @@ class AvgBlur2D(sk.TreeClass): Args: in_features: number of input channels. kernel_size: size of the convolving kernel. + dtype: data type of the layer. Defaults to ``jnp.float32``. Example: >>> import serket as sk @@ -92,7 +94,14 @@ class AvgBlur2D(sk.TreeClass): """ @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) - def __init__(self, in_features: int | None, kernel_size: int | tuple[int, int]): + def __init__( + self, + in_features: int | None, + kernel_size: int, + *, + dtype: DType = jnp.float32, + ): + kernel_size = positive_int_cb(kernel_size) weight = jnp.ones(kernel_size) weight = weight / jnp.sum(weight) weight = weight[:, None] @@ -104,6 +113,7 @@ def __init__(self, in_features: int | None, kernel_size: int | tuple[int, int]): padding="same", weight_init=lambda *_: weight, bias_init=None, + dtype=dtype, ) self.conv2 = DepthwiseConv2D( @@ -112,6 +122,7 @@ def __init__(self, in_features: int | None, kernel_size: int | tuple[int, int]): padding="same", weight_init=lambda *_: jnp.moveaxis(weight, 2, 3), bias_init=None, + dtype=dtype, ) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates) @@ -134,6 +145,7 @@ class GaussianBlur2D(sk.TreeClass): in_features: number of input features kernel_size: kernel size sigma: sigma. Defaults to 1. + dtype: data type of the layer. Defaults to ``jnp.float32``. Example: >>> import serket as sk @@ -172,7 +184,14 @@ class GaussianBlur2D(sk.TreeClass): """ @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) - def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0): + def __init__( + self, + in_features: int, + kernel_size: int, + *, + sigma: float = 1.0, + dtype: DType = jnp.float32, + ): kernel_size = positive_int_cb(kernel_size) self.sigma = sigma @@ -189,6 +208,7 @@ def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0): padding="same", weight_init=lambda *_: weight, bias_init=None, + dtype=dtype, ) self.conv2 = DepthwiseFFTConv2D( @@ -197,6 +217,7 @@ def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0): padding="same", weight_init=lambda *_: jnp.moveaxis(weight, 2, 3), bias_init=None, + dtype=dtype, ) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates) @@ -218,6 +239,7 @@ class Filter2D(sk.TreeClass): Args: in_features: number of input channels. kernel: kernel array. + dtype: data type of the layer. Defaults to ``jnp.float32``. Example: >>> import serket as sk @@ -232,7 +254,13 @@ class Filter2D(sk.TreeClass): """ @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) - def __init__(self, in_features: int, kernel: jax.Array): + def __init__( + self, + in_features: int, + kernel: jax.Array, + *, + dtype: DType = jnp.float32, + ): if not isinstance(kernel, jax.Array) or kernel.ndim != 2: raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)") @@ -246,6 +274,7 @@ def __init__(self, in_features: int, kernel: jax.Array): padding="same", weight_init=lambda *_: weight, bias_init=None, + dtype=dtype, ) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates) @@ -267,6 +296,7 @@ class FFTFilter2D(sk.TreeClass): Args: in_features: number of input channels kernel: kernel array + dtype: data type of the layer. Defaults to ``jnp.float32``. Example: >>> import serket as sk @@ -281,7 +311,13 @@ class FFTFilter2D(sk.TreeClass): """ @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) - def __init__(self, in_features: int, kernel: jax.Array): + def __init__( + self, + in_features: int, + kernel: jax.Array, + *, + dtype: DType = jnp.float32, + ): if not isinstance(kernel, jax.Array) or kernel.ndim != 2: raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)") @@ -295,6 +331,7 @@ def __init__(self, in_features: int, kernel: jax.Array): padding="same", weight_init=lambda *_: weight, bias_init=None, + dtype=dtype, ) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates) diff --git a/serket/image/affine.py b/serket/image/geometric.py similarity index 100% rename from serket/image/affine.py rename to serket/image/geometric.py diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 937c268..9a27e7b 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -44,20 +44,6 @@ ) from serket.nn.convolution import Conv1DLocal, Conv2DLocal -# Copyright 2023 Serket authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - def test_fft_conv1d(): x = jnp.array( diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 9b172b9..741ab34 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -20,7 +20,14 @@ import pytest import serket as sk -from serket.image.affine import ( +from serket.image.augment import ( + AdjustContrast2D, + JigSaw2D, + PixelShuffle2D, + RandomContrast2D, +) +from serket.image.filter import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D +from serket.image.geometric import ( HorizontalFlip2D, HorizontalShear2D, HorizontalTranslate2D, @@ -34,13 +41,6 @@ VerticalShear2D, VerticalTranslate2D, ) -from serket.image.augment import ( - AdjustContrast2D, - JigSaw2D, - PixelShuffle2D, - RandomContrast2D, -) -from serket.image.filter import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D def test_AvgBlur2D():