diff --git a/docs/conf.py b/docs/conf.py index 38efadf..fed4b05 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -70,8 +70,9 @@ autodoc_default_options = { "member-order": "bysource", - "special-members": '__call__', + "special-members": "__call__", "exclude-members": "__repr__, __str__, __weakref__", + "inherited-members": True, } # -- Options for HTML output ------------------------------------------------- @@ -151,4 +152,4 @@ # Tell sphinx-autodoc-typehints to generate stub parameter annotations including # types, even if the parameters aren't explicitly documented. -always_document_param_types = True \ No newline at end of file +always_document_param_types = True diff --git a/serket/_src/image/augment.py b/serket/_src/image/augment.py index b7e8b24..260104d 100644 --- a/serket/_src/image/augment.py +++ b/serket/_src/image/augment.py @@ -27,9 +27,12 @@ from serket._src.utils import IsInstance, Range, validate_spatial_nd -def pixel_shuffle_2d(x: jax.Array, upscale_factor: int | tuple[int, int]) -> jax.Array: +def pixel_shuffle_2d( + array: jax.Array, + upscale_factor: int | tuple[int, int], +) -> jax.Array: """Rearrange elements in a tensor.""" - channels = x.shape[0] + channels = array.shape[0] sr, sw = upscale_factor oc = channels // (sr * sw) @@ -37,11 +40,11 @@ def pixel_shuffle_2d(x: jax.Array, upscale_factor: int | tuple[int, int]) -> jax if not (channels % (sr * sw)) == 0: raise ValueError(f"{channels=} not divisible by {sr*sw}.") - ih, iw = x.shape[1], x.shape[2] - x = jnp.reshape(x, (sr, sw, oc, ih, iw)) - x = jnp.transpose(x, (2, 3, 0, 4, 1)) - x = jnp.reshape(x, (oc, ih * sr, iw * sw)) - return x + ih, iw = array.shape[1], array.shape[2] + array = jnp.reshape(array, (sr, sw, oc, ih, iw)) + array = jnp.transpose(array, (2, 3, 0, 4, 1)) + array = jnp.reshape(array, (oc, ih * sr, iw * sw)) + return array class PixelShuffle2D(sk.TreeClass): @@ -83,21 +86,34 @@ def spatial_ndim(self) -> int: return 2 -def adjust_contrast_nd(x: jax.Array, contrast_factor: float): - """Adjusts the contrast of an image by scaling the pixel values by a factor.""" - μ = jnp.mean(x, axis=tuple(range(1, x.ndim)), keepdims=True) - return (contrast_factor * (x - μ) + μ).astype(x.dtype) +def adjust_contrast_nd(array: jax.Array, contrast_factor: float): + """Adjusts the contrast of an image by scaling the pixel values by a factor. + + Args: + array: input array + contrast_factor: contrast factor to adust the contrast by. + + + """ + μ = jnp.mean(array, axis=tuple(range(1, array.ndim)), keepdims=True) + return (contrast_factor * (array - μ) + μ).astype(array.dtype) def random_contrast_nd( - x: jax.Array, + array: jax.Array, contrast_range: tuple[float, float], key: jr.KeyArray = jr.PRNGKey(0), ) -> jax.Array: - """Randomly adjusts the contrast of an image by scaling the pixel values by a factor.""" + """Randomly adjusts the contrast of an image by scaling the pixel values by a factor. + + Args: + array: input array + contrast_range: contrast range to adust the contrast by. accepts a tuple of length 2. + key: random key + """ minval, maxval = contrast_range contrast_factor = jr.uniform(key=key, shape=(), minval=minval, maxval=maxval) - return adjust_contrast_nd(x, contrast_factor) + return adjust_contrast_nd(array, contrast_factor) @sk.autoinit diff --git a/serket/_src/image/filter.py b/serket/_src/image/filter.py index e7ae94d..679ef00 100644 --- a/serket/_src/image/filter.py +++ b/serket/_src/image/filter.py @@ -103,7 +103,16 @@ def calculate_average_kernel( kernel_size: int, dtype: DType, ) -> Annotated[jax.Array, "HW"]: - kernel = jnp.ones((kernel_size)) + """Calculate average kernel. + + Args: + kernel_size: size of the convolving kernel. Accept an int. + dtype: data type of the kernel. + + Returns: + Average kernel. shape is (1, kernel_size). + """ + kernel = jnp.ones((kernel_size), dtype=dtype) kernel = kernel / jnp.sum(kernel) kernel = kernel.astype(dtype) kernel = jnp.expand_dims(kernel, 0) @@ -190,7 +199,17 @@ def calculate_gaussian_kernel( sigma: float, dtype: DType, ) -> Annotated[jax.Array, "HW"]: - x = jnp.arange(kernel_size) - kernel_size // 2 + """Calculate gaussian kernel. + + Args: + kernel_size: size of the convolving kernel. Accept an int. + sigma: sigma of gaussian kernel. + dtype: data type of the kernel. + + Returns: + gaussian kernel. shape is (1, kernel_size). + """ + x = jnp.arange(kernel_size, dtype=dtype) - kernel_size // 2 x = x + 0.5 if kernel_size % 2 == 0 else x kernel = jnp.exp(-(x**2) / (2 * sigma**2)) kernel = kernel / jnp.sum(kernel) @@ -340,6 +359,15 @@ def __call__(self, x: jax.Array) -> jax.Array: def calculate_box_kernel(kernel_size: int, dtype: DType) -> Annotated[jax.Array, "HW"]: + """Calculate box kernel. + + Args: + kernel_size: size of the convolving kernel. Accept an int. + dtype: data type of the kernel. + + Returns: + Box kernel. shape is (1, kernel_size). + """ kernel = jnp.ones((kernel_size)) kernel = kernel.astype(dtype) kernel = jnp.expand_dims(kernel, 0) @@ -425,6 +453,15 @@ def calculate_laplacian_kernel( kernel_size: tuple[int, int], dtype: DType, ) -> Annotated[jax.Array, "HW"]: + """Calculate laplacian kernel. + + Args: + kernel_size: size of the convolving kernel. Accepts tuple of two ints. + dtype: data type of the kernel. + + Returns: + Laplacian kernel. shape is (kernel_size[0], kernel_size[1]). + """ ky, kx = kernel_size kernel = jnp.ones((ky, kx)) kernel = kernel.at[ky // 2, kx // 2].set(1 - jnp.sum(kernel)).astype(dtype) @@ -481,7 +518,7 @@ class FFTLaplacian2D(Laplacian2DBase): """Apply Laplacian filter to a channel-first image using FFT. .. image:: ../_static/laplacian2d.png - + Args: kernel_size: size of the convolving kernel. Accepts int or tuple of two ints. dtype: data type of the layer. Defaults to ``jnp.float32``.