Skip to content

Commit

Permalink
more restructure + docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 26, 2023
1 parent 0d821a0 commit 1c47f5a
Show file tree
Hide file tree
Showing 29 changed files with 922 additions and 840 deletions.
3 changes: 2 additions & 1 deletion docs/API/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
convolution
normalization
image_filtering
misc
reshaping
random_transforms
activations
recurrent
misc


8 changes: 7 additions & 1 deletion docs/API/image_filtering.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,10 @@ Image filtering
.. autoclass:: AvgBlur2D
.. autoclass:: GaussianBlur2D
.. autoclass:: Filter2D
.. autoclass:: FFTFilter2D
.. autoclass:: FFTFilter2D

.. autoclass:: HistogramEqualization2D
.. autoclass:: PixelShuffle2D

.. autoclass:: AdjustContrast2D
.. autoclass:: RandomContrast2D
24 changes: 0 additions & 24 deletions docs/API/misc.rst
Original file line number Diff line number Diff line change
@@ -1,33 +1,9 @@
Misc
---------------------------------
.. currentmodule:: serket.nn

.. autoclass:: FlipLeftRight2D
.. autoclass:: FlipUpDown2D
.. autoclass:: Resize1D
.. autoclass:: Resize2D
.. autoclass:: Resize3D
.. autoclass:: Upsample1D
.. autoclass:: Upsample2D
.. autoclass:: Upsample3D
.. autoclass:: Pad1D
.. autoclass:: Pad2D
.. autoclass:: Pad3D

.. autoclass:: VGG16Block
.. autoclass:: VGG19Block
.. autoclass:: UNetBlock

.. autoclass:: Crop1D
.. autoclass:: Crop2D
.. autoclass:: Crop3D

.. autoclass:: Flatten
.. autoclass:: Unflatten

.. autoclass:: HistogramEqualization2D
.. autoclass:: PixelShuffle2D

.. autoclass:: AdjustContrast2D
.. autoclass:: RandomContrast2D

24 changes: 24 additions & 0 deletions docs/API/reshaping.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Reshaping
---------------------------------
.. currentmodule:: serket.nn


.. autoclass:: Resize1D
.. autoclass:: Resize2D
.. autoclass:: Resize3D

.. autoclass:: Upsample1D
.. autoclass:: Upsample2D
.. autoclass:: Upsample3D

.. autoclass:: Pad1D
.. autoclass:: Pad2D
.. autoclass:: Pad3D

.. autoclass:: Crop1D
.. autoclass:: Crop2D
.. autoclass:: Crop3D

.. autoclass:: Flatten
.. autoclass:: Unflatten

25 changes: 20 additions & 5 deletions serket/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
ThresholdedReLU,
)
from .blocks import UNetBlock, VGG16Block, VGG19Block
from .blur import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D
from .containers import Sequential
from .contrast import AdjustContrast2D, RandomContrast2D
from .convolution import (
Expand All @@ -66,7 +65,6 @@
SeparableConv2D,
SeparableConv3D,
)
from .crop import Crop1D, Crop2D, Crop3D, RandomCrop1D, RandomCrop2D, RandomCrop3D
from .cutout import RandomCutout1D, RandomCutout2D
from .dropout import Dropout, Dropout1D, Dropout2D, Dropout3D
from .fft_convolution import (
Expand All @@ -83,12 +81,11 @@
SeparableFFTConv2D,
SeparableFFTConv3D,
)
from .flatten import Flatten, Unflatten
from .flip import FlipLeftRight2D, FlipUpDown2D
from .fully_connected import FNN, MLP
from .image_filter import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D
from .linear import Bilinear, Embedding, GeneralLinear, Identity, Linear, Multilinear
from .normalization import BatchNorm, GroupNorm, InstanceNorm, LayerNorm
from .padding import Pad1D, Pad2D, Pad3D
from .pooling import (
AdaptiveAvgPool1D,
AdaptiveAvgPool2D,
Expand Down Expand Up @@ -133,7 +130,25 @@
ScanRNN,
SimpleRNNCell,
)
from .resize import Resize1D, Resize2D, Resize3D, Upsample1D, Upsample2D, Upsample3D
from .reshape import (
Crop1D,
Crop2D,
Crop3D,
Flatten,
Pad1D,
Pad2D,
Pad3D,
RandomCrop1D,
RandomCrop2D,
RandomCrop3D,
Resize1D,
Resize2D,
Resize3D,
Unflatten,
Upsample1D,
Upsample2D,
Upsample3D,
)

__all__ = (
"blocks",
Expand Down
13 changes: 8 additions & 5 deletions serket/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
class AdaptiveLeakyReLU(sk.TreeClass):
"""Leaky ReLU activation function
Note:
Reference:
https://arxiv.org/pdf/1906.01170.pdf.
"""

Expand All @@ -43,7 +43,8 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
@sk.autoinit
class AdaptiveReLU(sk.TreeClass):
"""ReLU activation function with learnable parameters
Note:
Reference:
https://arxiv.org/pdf/1906.01170.pdf.
"""

Expand All @@ -56,7 +57,8 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
@sk.autoinit
class AdaptiveSigmoid(sk.TreeClass):
"""Sigmoid activation function with learnable `a` parameter
Note:
Reference:
https://arxiv.org/pdf/1906.01170.pdf.
"""

Expand All @@ -69,7 +71,8 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
@sk.autoinit
class AdaptiveTanh(sk.TreeClass):
"""Tanh activation function with learnable parameters
Note:
Reference:
https://arxiv.org/pdf/1906.01170.pdf.
"""

Expand Down Expand Up @@ -299,7 +302,7 @@ class Snake(sk.TreeClass):
Args:
a: scalar (frequency) parameter of the activation function, default is 1.0.
Note:
Reference:
https://arxiv.org/pdf/2006.08195.pdf.
"""

Expand Down
32 changes: 8 additions & 24 deletions serket/nn/contrast.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@ def random_contrast_nd(

@sk.autoinit
class AdjustContrastND(sk.TreeClass):
"""Adjusts the contrast of an NDimage by scaling the pixel values by a factor.
Note:
https://www.tensorflow.org/api_docs/python/tf/image/adjust_contrast
https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py
"""

contrast_factor: float = 1.0

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
Expand All @@ -66,28 +59,19 @@ def spatial_ndim(self) -> int:


class AdjustContrast2D(AdjustContrastND):
"""Adjusts the contrast of an image by scaling the pixel values by a factor.
"""Adjusts the contrast of an 2D input by scaling the pixel values by a factor.
Args:
contrast_factor: contrast factor to adjust the image by.
Reference:
- https://www.tensorflow.org/api_docs/python/tf/image/adjust_contrast
- https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py
"""

@property
def spatial_ndim(self) -> int:
return 2


@sk.autoinit
class RandomContrastND(sk.TreeClass):
"""Randomly adjusts the contrast of an image by scaling the pixel
values by a factor.
Args:
contrast_range: range of contrast factors to randomly sample from.
"""

contrast_range: tuple

def __init__(self, contrast_range=(0.5, 1)):
if not (
isinstance(contrast_range, tuple)
Expand Down Expand Up @@ -123,11 +107,11 @@ def spatial_ndim(self) -> int:


class RandomContrast2D(RandomContrastND):
"""Randomly adjusts the contrast of an image by scaling the pixel
values by a factor.
"""Randomly adjusts the contrast of an 1D input by scaling the pixel values by a factor.
Args:
contrast_range: range of contrast factors to randomly sample from.
Reference:
- https://www.tensorflow.org/api_docs/python/tf/image/adjust_contrast
- https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py
"""

@property
Expand Down
42 changes: 21 additions & 21 deletions serket/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ class Conv1D(ConvND):
>>> print(jax.vmap(layer)(x).shape)
(2, 2, 5)
Note:
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
Reference:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
"""

@property
Expand Down Expand Up @@ -261,8 +261,8 @@ class Conv2D(ConvND):
>>> print(jax.vmap(layer)(x).shape)
(2, 2, 5, 5)
Note:
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
Reference:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
"""

@property
Expand Down Expand Up @@ -332,8 +332,8 @@ class Conv3D(ConvND):
>>> print(jax.vmap(layer)(x).shape)
(2, 2, 5, 5, 5)
Note:
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
Reference:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
"""

@property
Expand Down Expand Up @@ -485,8 +485,8 @@ class Conv1DTranspose(ConvNDTranspose):
>>> print(jax.vmap(layer)(x).shape)
(2, 2, 5)
Note:
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
Reference:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
"""

@property
Expand Down Expand Up @@ -555,8 +555,8 @@ class Conv2DTranspose(ConvNDTranspose):
>>> print(jax.vmap(layer)(x).shape)
(2, 2, 5, 5)
Note:
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
Reference:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
"""

@property
Expand Down Expand Up @@ -625,8 +625,8 @@ class Conv3DTranspose(ConvNDTranspose):
>>> print(jax.vmap(layer)(x).shape)
(2, 2, 5, 5, 5)
Note:
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
Reference:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
"""

@property
Expand Down Expand Up @@ -745,7 +745,7 @@ class DepthwiseConv1D(DepthwiseConvND):
>>> l1(jnp.ones((3, 32))).shape
(6, 16)
Note:
Reference:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
- https://github.com/google/flax/blob/main/flax/linen/linear.py
"""
Expand Down Expand Up @@ -799,7 +799,7 @@ class DepthwiseConv2D(DepthwiseConvND):
>>> l1(jnp.ones((3, 32, 32))).shape
(6, 16, 16)
Note:
Reference:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
- https://github.com/google/flax/blob/main/flax/linen/linear.py
"""
Expand Down Expand Up @@ -854,7 +854,7 @@ class DepthwiseConv3D(DepthwiseConvND):
>>> l1(jnp.ones((3, 32, 32, 32))).shape
(6, 16, 16, 16)
Note:
Reference:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
- https://github.com/google/flax/blob/main/flax/linen/linear.py
"""
Expand Down Expand Up @@ -977,7 +977,7 @@ class SeparableConv1D(SeparableConvND):
>>> l1(jnp.ones((3, 32))).shape
(3, 32)
Note:
Reference:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
- https://github.com/google/flax/blob/main/flax/linen/linear.py
"""
Expand Down Expand Up @@ -1049,7 +1049,7 @@ class SeparableConv2D(SeparableConvND):
>>> l1(jnp.ones((3, 32, 32))).shape
(3, 32, 32)
Note:
Reference:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
- https://github.com/google/flax/blob/main/flax/linen/linear.py
"""
Expand Down Expand Up @@ -1121,7 +1121,7 @@ class SeparableConv3D(SeparableConvND):
>>> l1(jnp.ones((3, 32, 32, 32))).shape
(3, 32, 32, 32)
Note:
Reference:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
- https://github.com/google/flax/blob/main/flax/linen/linear.py
"""
Expand Down Expand Up @@ -1277,7 +1277,7 @@ class Conv1DLocal(ConvNDLocal):
>>> l1(jnp.ones((3, 32))).shape
(3, 32)
Note:
Reference:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
- https://github.com/google/flax/blob/main/flax/linen/linear.py
"""
Expand Down Expand Up @@ -1338,7 +1338,7 @@ class Conv2DLocal(ConvNDLocal):
>>> l1(jnp.ones((3, 32, 32))).shape
(3, 32, 32)
Note:
Reference:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
- https://github.com/google/flax/blob/main/flax/linen/linear.py
"""
Expand Down Expand Up @@ -1399,7 +1399,7 @@ class Conv3DLocal(ConvNDLocal):
>>> l1(jnp.ones((3, 32, 32, 32))).shape
(3, 32, 32, 32)
Note:
Reference:
- https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html
- https://github.com/google/flax/blob/main/flax/linen/linear.py
"""
Expand Down
Loading

0 comments on commit 1c47f5a

Please sign in to comment.