Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 11, 2023
1 parent 2e870a7 commit 7a6ee8e
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 19 deletions.
5 changes: 3 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@

autodoc_default_options = {
"member-order": "bysource",
"special-members": '__call__',
"special-members": "__call__",
"exclude-members": "__repr__, __str__, __weakref__",
"inherited-members": True,
}

# -- Options for HTML output -------------------------------------------------
Expand Down Expand Up @@ -151,4 +152,4 @@

# Tell sphinx-autodoc-typehints to generate stub parameter annotations including
# types, even if the parameters aren't explicitly documented.
always_document_param_types = True
always_document_param_types = True
44 changes: 30 additions & 14 deletions serket/_src/image/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,24 @@
from serket._src.utils import IsInstance, Range, validate_spatial_nd


def pixel_shuffle_2d(x: jax.Array, upscale_factor: int | tuple[int, int]) -> jax.Array:
def pixel_shuffle_2d(
array: jax.Array,
upscale_factor: int | tuple[int, int],
) -> jax.Array:
"""Rearrange elements in a tensor."""
channels = x.shape[0]
channels = array.shape[0]

sr, sw = upscale_factor
oc = channels // (sr * sw)

if not (channels % (sr * sw)) == 0:
raise ValueError(f"{channels=} not divisible by {sr*sw}.")

ih, iw = x.shape[1], x.shape[2]
x = jnp.reshape(x, (sr, sw, oc, ih, iw))
x = jnp.transpose(x, (2, 3, 0, 4, 1))
x = jnp.reshape(x, (oc, ih * sr, iw * sw))
return x
ih, iw = array.shape[1], array.shape[2]
array = jnp.reshape(array, (sr, sw, oc, ih, iw))
array = jnp.transpose(array, (2, 3, 0, 4, 1))
array = jnp.reshape(array, (oc, ih * sr, iw * sw))
return array


class PixelShuffle2D(sk.TreeClass):
Expand Down Expand Up @@ -83,21 +86,34 @@ def spatial_ndim(self) -> int:
return 2


def adjust_contrast_nd(x: jax.Array, contrast_factor: float):
"""Adjusts the contrast of an image by scaling the pixel values by a factor."""
μ = jnp.mean(x, axis=tuple(range(1, x.ndim)), keepdims=True)
return (contrast_factor * (x - μ) + μ).astype(x.dtype)
def adjust_contrast_nd(array: jax.Array, contrast_factor: float):
"""Adjusts the contrast of an image by scaling the pixel values by a factor.
Args:
array: input array
contrast_factor: contrast factor to adust the contrast by.
"""
μ = jnp.mean(array, axis=tuple(range(1, array.ndim)), keepdims=True)
return (contrast_factor * (array - μ) + μ).astype(array.dtype)


def random_contrast_nd(
x: jax.Array,
array: jax.Array,
contrast_range: tuple[float, float],
key: jr.KeyArray = jr.PRNGKey(0),
) -> jax.Array:
"""Randomly adjusts the contrast of an image by scaling the pixel values by a factor."""
"""Randomly adjusts the contrast of an image by scaling the pixel values by a factor.
Args:
array: input array
contrast_range: contrast range to adust the contrast by. accepts a tuple of length 2.
key: random key
"""
minval, maxval = contrast_range
contrast_factor = jr.uniform(key=key, shape=(), minval=minval, maxval=maxval)
return adjust_contrast_nd(x, contrast_factor)
return adjust_contrast_nd(array, contrast_factor)


@sk.autoinit
Expand Down
43 changes: 40 additions & 3 deletions serket/_src/image/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,16 @@ def calculate_average_kernel(
kernel_size: int,
dtype: DType,
) -> Annotated[jax.Array, "HW"]:
kernel = jnp.ones((kernel_size))
"""Calculate average kernel.
Args:
kernel_size: size of the convolving kernel. Accept an int.
dtype: data type of the kernel.
Returns:
Average kernel. shape is (1, kernel_size).
"""
kernel = jnp.ones((kernel_size), dtype=dtype)
kernel = kernel / jnp.sum(kernel)
kernel = kernel.astype(dtype)
kernel = jnp.expand_dims(kernel, 0)
Expand Down Expand Up @@ -190,7 +199,17 @@ def calculate_gaussian_kernel(
sigma: float,
dtype: DType,
) -> Annotated[jax.Array, "HW"]:
x = jnp.arange(kernel_size) - kernel_size // 2
"""Calculate gaussian kernel.
Args:
kernel_size: size of the convolving kernel. Accept an int.
sigma: sigma of gaussian kernel.
dtype: data type of the kernel.
Returns:
gaussian kernel. shape is (1, kernel_size).
"""
x = jnp.arange(kernel_size, dtype=dtype) - kernel_size // 2
x = x + 0.5 if kernel_size % 2 == 0 else x
kernel = jnp.exp(-(x**2) / (2 * sigma**2))
kernel = kernel / jnp.sum(kernel)
Expand Down Expand Up @@ -340,6 +359,15 @@ def __call__(self, x: jax.Array) -> jax.Array:


def calculate_box_kernel(kernel_size: int, dtype: DType) -> Annotated[jax.Array, "HW"]:
"""Calculate box kernel.
Args:
kernel_size: size of the convolving kernel. Accept an int.
dtype: data type of the kernel.
Returns:
Box kernel. shape is (1, kernel_size).
"""
kernel = jnp.ones((kernel_size))
kernel = kernel.astype(dtype)
kernel = jnp.expand_dims(kernel, 0)
Expand Down Expand Up @@ -425,6 +453,15 @@ def calculate_laplacian_kernel(
kernel_size: tuple[int, int],
dtype: DType,
) -> Annotated[jax.Array, "HW"]:
"""Calculate laplacian kernel.
Args:
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
dtype: data type of the kernel.
Returns:
Laplacian kernel. shape is (kernel_size[0], kernel_size[1]).
"""
ky, kx = kernel_size
kernel = jnp.ones((ky, kx))
kernel = kernel.at[ky // 2, kx // 2].set(1 - jnp.sum(kernel)).astype(dtype)
Expand Down Expand Up @@ -481,7 +518,7 @@ class FFTLaplacian2D(Laplacian2DBase):
"""Apply Laplacian filter to a channel-first image using FFT.
.. image:: ../_static/laplacian2d.png
Args:
kernel_size: size of the convolving kernel. Accepts int or tuple of two ints.
dtype: data type of the layer. Defaults to ``jnp.float32``.
Expand Down

0 comments on commit 7a6ee8e

Please sign in to comment.