Skip to content

Commit

Permalink
add CenterCrop2D
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 12, 2023
1 parent 21a09c3 commit 4a2a441
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/API/geometric.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ Geometric API
---------------------------------
.. currentmodule:: serket.image

.. autoclass:: CenterCrop2D
.. autoclass:: HorizontalFlip2D
.. autoclass:: HorizontalShear2D
.. autoclass:: HorizontalTranslate2D
Expand Down
Binary file added docs/_static/centercrop2d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Install from github::
.. toctree::
:caption: Introduction
:caption: 👋 Introduction
:maxdepth: 1

notebooks/mental_model
Expand All @@ -76,9 +76,9 @@ Install from github::
:maxdepth: 1

train_examples
notebooks/train_eval
notebooks/evaluation
notebooks/lazy_initialization
notebooks/train_mp
notebooks/mixed_precision
notebooks/checkpointing
notebooks/regularization

Expand All @@ -88,7 +88,7 @@ Install from github::


.. toctree::
:caption: API Documentation
:caption: 📃 API Documentation
:maxdepth: 1

API/common
Expand Down
File renamed without changes.
File renamed without changes.
63 changes: 62 additions & 1 deletion serket/_src/image/geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import serket as sk
from serket._src.custom_transform import tree_eval
from serket._src.nn.linear import Identity
from serket._src.utils import IsInstance, validate_spatial_nd
from serket._src.utils import IsInstance, canonicalize, validate_spatial_nd


def affine_2d(
Expand Down Expand Up @@ -196,6 +196,23 @@ def random_vertical_translate_2d(
return vertical_translate_2d(image, shift)


def center_crop_2d(image: Annotated[jax.Array, "HW"], height: int, width: int):
"""Crops an image with the given size keeping the same center of the original.
Args:
image: 2D image.
height: target height to crop the image to. accepts an int
width: target width to crop the image to. accepts an int
"""
_, _ = image.shape
h, w = image.shape
center_h, center_w = h // 2, w // 2
left = max(center_w - width // 2, 0)
top = max(center_h - height // 2, 0)
return jax.lax.dynamic_slice(image, (top, left), (height, width))


class Rotate2D(sk.TreeClass):
"""Rotate_2d a 2D image by an angle in dgrees in CCW direction
Expand Down Expand Up @@ -873,6 +890,50 @@ def spatial_ndim(self) -> int:
return 2


class CenterCrop2D(sk.TreeClass):
"""Crop the center of a channel-first image.
.. image:: ../_static/centercrop2d.png
Args:
size: The size of the output image. accepts a single int or a tuple of two ints.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 145).reshape(1, 12, 12)
>>> print(x)
[[[ 1 2 3 4 5 6 7 8 9 10 11 12]
[ 13 14 15 16 17 18 19 20 21 22 23 24]
[ 25 26 27 28 29 30 31 32 33 34 35 36]
[ 37 38 39 40 41 42 43 44 45 46 47 48]
[ 49 50 51 52 53 54 55 56 57 58 59 60]
[ 61 62 63 64 65 66 67 68 69 70 71 72]
[ 73 74 75 76 77 78 79 80 81 82 83 84]
[ 85 86 87 88 89 90 91 92 93 94 95 96]
[ 97 98 99 100 101 102 103 104 105 106 107 108]
[109 110 111 112 113 114 115 116 117 118 119 120]
[121 122 123 124 125 126 127 128 129 130 131 132]
[133 134 135 136 137 138 139 140 141 142 143 144]]]
>>> print(sk.image.CenterCrop2D(4)(x))
[[[53 54 55 56]
[65 66 67 68]
[77 78 79 80]
[89 90 91 92]]]
"""

def __init__(self, size: int | tuple[int, int]):
self.size = canonicalize(size, ndim=2, name="size")

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
return jax.vmap(center_crop_2d, in_axes=(0, None, None))(x, *self.size)

@property
def spatial_ndim(self) -> int:
return 2


@tree_eval.def_eval(RandomRotate2D)
@tree_eval.def_eval(RandomHorizontalShear2D)
@tree_eval.def_eval(RandomVerticalShear2D)
Expand Down
2 changes: 2 additions & 0 deletions serket/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
UnsharpMask2D,
)
from serket._src.image.geometric import (
CenterCrop2D,
HorizontalFlip2D,
HorizontalShear2D,
HorizontalTranslate2D,
Expand Down Expand Up @@ -72,6 +73,7 @@
"Laplacian2D",
"UnsharpMask2D",
# geometric
"CenterCrop2D",
"HorizontalFlip2D",
"HorizontalShear2D",
"HorizontalTranslate2D",
Expand Down

0 comments on commit 4a2a441

Please sign in to comment.