Skip to content

Commit

Permalink
add dtype to filter layers rename affine to geometric
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 4, 2023
1 parent d56841e commit dff0b2a
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 35 deletions.
18 changes: 9 additions & 9 deletions serket/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .affine import (
from .augment import (
AdjustContrast2D,
JigSaw2D,
PixelShuffle2D,
Posterize2D,
RandomContrast2D,
)
from .filter import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D
from .geometric import (
HorizontalFlip2D,
HorizontalShear2D,
HorizontalTranslate2D,
Expand All @@ -29,14 +37,6 @@
VerticalShear2D,
VerticalTranslate2D,
)
from .augment import (
AdjustContrast2D,
JigSaw2D,
PixelShuffle2D,
Posterize2D,
RandomContrast2D,
)
from .filter import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D

__all__ = (
# image
Expand Down
45 changes: 41 additions & 4 deletions serket/image/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import serket as sk
from serket.nn.convolution import DepthwiseConv2D, DepthwiseFFTConv2D
from serket.nn.initialization import DType
from serket.utils import (
maybe_lazy_call,
maybe_lazy_init,
Expand Down Expand Up @@ -54,6 +55,7 @@ class AvgBlur2D(sk.TreeClass):
Args:
in_features: number of input channels.
kernel_size: size of the convolving kernel.
dtype: data type of the layer. Defaults to ``jnp.float32``.
Example:
>>> import serket as sk
Expand Down Expand Up @@ -92,7 +94,14 @@ class AvgBlur2D(sk.TreeClass):
"""

@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
def __init__(self, in_features: int | None, kernel_size: int | tuple[int, int]):
def __init__(
self,
in_features: int | None,
kernel_size: int,
*,
dtype: DType = jnp.float32,
):
kernel_size = positive_int_cb(kernel_size)
weight = jnp.ones(kernel_size)
weight = weight / jnp.sum(weight)
weight = weight[:, None]
Expand All @@ -104,6 +113,7 @@ def __init__(self, in_features: int | None, kernel_size: int | tuple[int, int]):
padding="same",
weight_init=lambda *_: weight,
bias_init=None,
dtype=dtype,
)

self.conv2 = DepthwiseConv2D(
Expand All @@ -112,6 +122,7 @@ def __init__(self, in_features: int | None, kernel_size: int | tuple[int, int]):
padding="same",
weight_init=lambda *_: jnp.moveaxis(weight, 2, 3),
bias_init=None,
dtype=dtype,
)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates)
Expand All @@ -134,6 +145,7 @@ class GaussianBlur2D(sk.TreeClass):
in_features: number of input features
kernel_size: kernel size
sigma: sigma. Defaults to 1.
dtype: data type of the layer. Defaults to ``jnp.float32``.
Example:
>>> import serket as sk
Expand Down Expand Up @@ -172,7 +184,14 @@ class GaussianBlur2D(sk.TreeClass):
"""

@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0):
def __init__(
self,
in_features: int,
kernel_size: int,
*,
sigma: float = 1.0,
dtype: DType = jnp.float32,
):
kernel_size = positive_int_cb(kernel_size)
self.sigma = sigma

Expand All @@ -189,6 +208,7 @@ def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0):
padding="same",
weight_init=lambda *_: weight,
bias_init=None,
dtype=dtype,
)

self.conv2 = DepthwiseFFTConv2D(
Expand All @@ -197,6 +217,7 @@ def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0):
padding="same",
weight_init=lambda *_: jnp.moveaxis(weight, 2, 3),
bias_init=None,
dtype=dtype,
)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates)
Expand All @@ -218,6 +239,7 @@ class Filter2D(sk.TreeClass):
Args:
in_features: number of input channels.
kernel: kernel array.
dtype: data type of the layer. Defaults to ``jnp.float32``.
Example:
>>> import serket as sk
Expand All @@ -232,7 +254,13 @@ class Filter2D(sk.TreeClass):
"""

@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
def __init__(self, in_features: int, kernel: jax.Array):
def __init__(
self,
in_features: int,
kernel: jax.Array,
*,
dtype: DType = jnp.float32,
):
if not isinstance(kernel, jax.Array) or kernel.ndim != 2:
raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)")

Expand All @@ -246,6 +274,7 @@ def __init__(self, in_features: int, kernel: jax.Array):
padding="same",
weight_init=lambda *_: weight,
bias_init=None,
dtype=dtype,
)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates)
Expand All @@ -267,6 +296,7 @@ class FFTFilter2D(sk.TreeClass):
Args:
in_features: number of input channels
kernel: kernel array
dtype: data type of the layer. Defaults to ``jnp.float32``.
Example:
>>> import serket as sk
Expand All @@ -281,7 +311,13 @@ class FFTFilter2D(sk.TreeClass):
"""

@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
def __init__(self, in_features: int, kernel: jax.Array):
def __init__(
self,
in_features: int,
kernel: jax.Array,
*,
dtype: DType = jnp.float32,
):
if not isinstance(kernel, jax.Array) or kernel.ndim != 2:
raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)")

Expand All @@ -295,6 +331,7 @@ def __init__(self, in_features: int, kernel: jax.Array):
padding="same",
weight_init=lambda *_: weight,
bias_init=None,
dtype=dtype,
)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates)
Expand Down
File renamed without changes.
14 changes: 0 additions & 14 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,6 @@
)
from serket.nn.convolution import Conv1DLocal, Conv2DLocal

# Copyright 2023 Serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def test_fft_conv1d():
x = jnp.array(
Expand Down
16 changes: 8 additions & 8 deletions tests/test_image_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
import pytest

import serket as sk
from serket.image.affine import (
from serket.image.augment import (
AdjustContrast2D,
JigSaw2D,
PixelShuffle2D,
RandomContrast2D,
)
from serket.image.filter import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D
from serket.image.geometric import (
HorizontalFlip2D,
HorizontalShear2D,
HorizontalTranslate2D,
Expand All @@ -34,13 +41,6 @@
VerticalShear2D,
VerticalTranslate2D,
)
from serket.image.augment import (
AdjustContrast2D,
JigSaw2D,
PixelShuffle2D,
RandomContrast2D,
)
from serket.image.filter import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D


def test_AvgBlur2D():
Expand Down

0 comments on commit dff0b2a

Please sign in to comment.