Skip to content

Commit

Permalink
images to image docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 27, 2023
1 parent 2f19451 commit af72b9c
Show file tree
Hide file tree
Showing 23 changed files with 106 additions and 65 deletions.
3 changes: 3 additions & 0 deletions docs/API/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Image filtering
.. autoclass:: Filter2D
.. autoclass:: GaussianBlur2D
.. autoclass:: HistogramEqualization2D
.. autoclass:: HorizontalFlip2D
.. autoclass:: HorizontalShear2D
.. autoclass:: HorizontalTranslate2D
.. autoclass:: JigSaw2D
Expand All @@ -21,5 +22,7 @@ Image filtering
.. autoclass:: RandomVerticalShear2D
.. autoclass:: RandomVerticalTranslate2D
.. autoclass:: Rotate2D
.. autoclass:: Solarize2D
.. autoclass:: VerticalFlip2D
.. autoclass:: VerticalShear2D
.. autoclass:: VerticalTranslate2D
Binary file added docs/_static/adjustcontrast2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/avgblur2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/filter2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/gaussianblur2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/horizontalflip2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/horizontalshear2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/horizontaltranslate2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/jigsaw2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/pixelate2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/pixelshuffle2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/posterize.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/randomcontrast2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/randomperspective2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/rotate2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/solarize2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/verticalflip2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/verticalshear2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/verticaltranslate2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added lenna.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions serket/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
Filter2D,
GaussianBlur2D,
HistogramEqualization2D,
HorizontalFlip2D,
HorizontalShear2D,
HorizontalTranslate2D,
JigSaw2D,
Expand All @@ -109,6 +110,7 @@
RandomVerticalTranslate2D,
Rotate2D,
Solarize2D,
VerticalFlip2D,
VerticalShear2D,
VerticalTranslate2D,
)
Expand Down Expand Up @@ -161,7 +163,6 @@
Crop2D,
Crop3D,
Flatten,
HorizontalFlip2D,
Pad1D,
Pad2D,
Pad3D,
Expand All @@ -178,7 +179,6 @@
Upsample1D,
Upsample2D,
Upsample3D,
VerticalFlip2D,
)

__all__ = (
Expand Down Expand Up @@ -279,6 +279,7 @@
"Filter2D",
"GaussianBlur2D",
"HistogramEqualization2D",
"HorizontalFlip2D",
"HorizontalShear2D",
"HorizontalTranslate2D",
"JigSaw2D",
Expand All @@ -294,6 +295,7 @@
"RandomVerticalTranslate2D",
"Rotate2D",
"Solarize2D",
"VerticalFlip2D",
"VerticalShear2D",
"VerticalTranslate2D",
# kmeans
Expand Down Expand Up @@ -343,7 +345,6 @@
"Crop2D",
"Crop3D",
"Flatten",
"HorizontalFlip2D",
"Pad1D",
"Pad2D",
"Pad3D",
Expand All @@ -360,7 +361,6 @@
"Upsample1D",
"Upsample2D",
"Upsample3D",
"VerticalFlip2D",
# block
"blocks",
)
100 changes: 99 additions & 1 deletion serket/nn/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def infer_in_features(_, x, *__, **___) -> int:
class AvgBlur2D(sk.TreeClass):
"""Average blur 2D layer
.. image:: ../_static/avgblur2d.png
Args:
in_features: number of input channels.
kernel_size: size of the convolving kernel.
Expand Down Expand Up @@ -132,6 +134,8 @@ def spatial_ndim(self) -> int:
class GaussianBlur2D(sk.TreeClass):
"""Apply Gaussian blur to a channel-first image.
.. image:: ../_static/gaussianblur2d.png
Args:
in_features: number of input features
kernel_size: kernel size
Expand Down Expand Up @@ -215,6 +219,8 @@ def spatial_ndim(self) -> int:
class Filter2D(sk.TreeClass):
"""Apply 2D filter for each channel
.. image:: ../_static/filter2d.png
Args:
in_features: number of input channels.
kernel: kernel array.
Expand Down Expand Up @@ -262,6 +268,8 @@ def spatial_ndim(self) -> int:
class FFTFilter2D(sk.TreeClass):
"""Apply 2D filter for each channel using FFT
.. image:: ../_static/filter2d.png
Args:
in_features: number of input channels
kernel: kernel array
Expand Down Expand Up @@ -408,6 +416,11 @@ def random_contrast_nd(
class AdjustContrast2D(sk.TreeClass):
"""Adjusts the contrast of an 2D input by scaling the pixel values by a factor.
.. image:: ../_static/adjustcontrast2d.png
Args:
contrast_factor: contrast factor to adust the contrast by. Defaults to 1.0.
Reference:
- https://www.tensorflow.org/api_docs/python/tf/image/adjust_contrast
- https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py
Expand All @@ -417,7 +430,8 @@ class AdjustContrast2D(sk.TreeClass):

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
return lax.stop_gradient(adjust_contrast_nd(x, self.contrast_factor))
contrast_factor = jax.lax.stop_gradient(self.contrast_factor)
return adjust_contrast_nd(x, contrast_factor)

@property
def spatial_ndim(self) -> int:
Expand Down Expand Up @@ -523,6 +537,8 @@ def random_rotate(
class Rotate2D(sk.TreeClass):
"""Rotate a 2D image by an angle in dgrees in CCW direction
.. image:: ../_static/rotate2d.png
Args:
angle: angle to rotate in degrees counter-clockwise direction.
Expand Down Expand Up @@ -596,6 +612,8 @@ def spatial_ndim(self) -> int:
class HorizontalShear2D(sk.TreeClass):
"""Shear an image horizontally
.. image:: ../_static/horizontalshear2d.png
Args:
angle: angle to rotate in degrees counter-clockwise direction.
Expand Down Expand Up @@ -668,6 +686,8 @@ def spatial_ndim(self) -> int:
class VerticalShear2D(sk.TreeClass):
"""Shear an image vertically
.. image:: ../_static/verticalshear2d.png
Args:
angle: angle to rotate in degrees counter-clockwise direction.
Expand Down Expand Up @@ -751,6 +771,8 @@ def pixelate(image: jax.Array, scale: int = 16) -> jax.Array:
class Pixelate2D(sk.TreeClass):
"""Pixelate an image by upsizing and downsizing an image
.. image:: ../_static/pixelate2d.png
Args:
scale: the scale to which the image will be downsized before being upsized
to the original shape. for example, ``scale=2`` means the image will
Expand Down Expand Up @@ -818,6 +840,8 @@ def random_perspective(
class RandomPerspective2D(sk.TreeClass):
"""Applies a random perspective transform to a channel-first image.
.. image:: ../_static/randomperspective2d.png
Args:
scale: the scale of the random perspective transform. Higher scale will
lead to higher degree of perspective transform. default to 1.0. 0.0
Expand Down Expand Up @@ -922,6 +946,8 @@ def solarize(
class Solarize2D(sk.TreeClass):
"""Inverts all values above a given threshold.
.. image:: ../_static/solarize2d.png
Args:
threshold: The threshold value above which to invert.
max_val: The maximum value of the image. e.g. 255 for uint8 images.
Expand Down Expand Up @@ -975,6 +1001,8 @@ def posterize(image: jax.Array, bits: int) -> jax.Array:
class Posterize2D(sk.TreeClass):
"""Reduce the number of bits for each color channel.
.. image:: ../_static/nn/posterize2d.png
Args:
bits: The number of bits to keep for each channel (1-8).
Expand Down Expand Up @@ -1061,6 +1089,8 @@ def random_vertical_translate(image: jax.Array, key: jr.KeyArray) -> jax.Array:
class HorizontalTranslate2D(sk.TreeClass):
"""Translate an image horizontally by a pixel value.
.. image:: ../_static/horizontaltranslate2d.png
Args:
shift: The number of pixels to shift the image by.
Expand Down Expand Up @@ -1091,6 +1121,8 @@ def spatial_ndim(self) -> int:
class VerticalTranslate2D(sk.TreeClass):
"""Translate an image vertically by a pixel value.
.. image:: ../_static/verticaltranslate2d.png
Args:
shift: The number of pixels to shift the image by.
Expand Down Expand Up @@ -1213,6 +1245,8 @@ def jigsaw(
class JigSaw2D(sk.TreeClass):
"""Mixes up tiles of an image.
.. image:: ../_static/jigsaw2d.png
Args:
tiles: number of tiles per side
Expand Down Expand Up @@ -1267,6 +1301,70 @@ def spatial_ndim(self) -> int:
return 2


class HorizontalFlip2D(sk.TreeClass):
"""Flip channels left to right.
.. image:: ../_static/horizontalflip2d.png
Examples:
>>> import jax.numpy as jnp
>>> import serket as sk
>>> x = jnp.arange(1,10).reshape(1,3, 3)
>>> print(x)
[[[1 2 3]
[4 5 6]
[7 8 9]]]
>>> print(sk.nn.HorizontalFlip2D()(x))
[[[3 2 1]
[6 5 4]
[9 8 7]]]
Reference:
- https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py
"""

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.vmap(lambda x: jnp.flip(x, axis=1))(x)

@property
def spatial_ndim(self) -> int:
return 2


class VerticalFlip2D(sk.TreeClass):
"""Flip channels up to down.
.. image:: ../_static/verticalflip2d.png
Examples:
>>> import jax.numpy as jnp
>>> import serket as sk
>>> x = jnp.arange(1,10).reshape(1,3, 3)
>>> print(x)
[[[1 2 3]
[4 5 6]
[7 8 9]]]
>>> print(sk.nn.VerticalFlip2D()(x))
[[[7 8 9]
[4 5 6]
[1 2 3]]]
Reference:
- https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py
"""

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.vmap(lambda x: jnp.flip(x, axis=0))(x)

@property
def spatial_ndim(self) -> int:
return 2


@tree_eval.def_eval(RandomContrast2D)
@tree_eval.def_eval(RandomRotate2D)
@tree_eval.def_eval(RandomHorizontalShear2D)
Expand Down
60 changes: 0 additions & 60 deletions serket/nn/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,66 +507,6 @@ def spatial_ndim(self) -> int:
return 3


class HorizontalFlip2D(sk.TreeClass):
"""Flip channels left to right.
Examples:
>>> import jax.numpy as jnp
>>> import serket as sk
>>> x = jnp.arange(1,10).reshape(1,3, 3)
>>> print(x)
[[[1 2 3]
[4 5 6]
[7 8 9]]]
>>> print(sk.nn.HorizontalFlip2D()(x))
[[[3 2 1]
[6 5 4]
[9 8 7]]]
Reference:
- https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py
"""

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.vmap(lambda x: jnp.flip(x, axis=1))(x)

@property
def spatial_ndim(self) -> int:
return 2


class VerticalFlip2D(sk.TreeClass):
"""Flip channels up to down.
Examples:
>>> import jax.numpy as jnp
>>> import serket as sk
>>> x = jnp.arange(1,10).reshape(1,3, 3)
>>> print(x)
[[[1 2 3]
[4 5 6]
[7 8 9]]]
>>> print(sk.nn.VerticalFlip2D()(x))
[[[7 8 9]
[4 5 6]
[1 2 3]]]
Reference:
- https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py
"""

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.vmap(lambda x: jnp.flip(x, axis=0))(x)

@property
def spatial_ndim(self) -> int:
return 2


def zoom_axis(
x: jax.Array,
factor: float,
Expand Down

0 comments on commit af72b9c

Please sign in to comment.