Skip to content

Commit

Permalink
standardize geometric functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 10, 2023
1 parent a2363cd commit 760b31a
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 207 deletions.
16 changes: 8 additions & 8 deletions serket/_src/image/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import serket as sk
from serket._src.custom_transform import tree_eval
from serket._src.nn.linear import Identity
from serket._src.utils import IsInstance, Range, validate_spatial_ndim
from serket._src.utils import IsInstance, Range, validate_spatial_nd


def pixel_shuffle_2d(x: jax.Array, upscale_factor: int | tuple[int, int]) -> jax.Array:
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self, upscale_factor: int | tuple[int, int] = 1):

raise ValueError("upscale_factor must be an integer or tuple of length 2")

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
return pixel_shuffle_2d(x, self.upscale_factor)

Expand Down Expand Up @@ -116,7 +116,7 @@ class AdjustContrast2D(sk.TreeClass):

contrast_factor: float = 1.0

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
contrast_factor = jax.lax.stop_gradient(self.contrast_factor)
return adjust_contrast_nd(x, contrast_factor)
Expand Down Expand Up @@ -155,7 +155,7 @@ def __init__(self, contrast_range: tuple[float, float] = (0.5, 1)):

self.contrast_range = contrast_range

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array:
return random_contrast_nd(x, lax.stop_gradient(self.contrast_range), key=key)

Expand Down Expand Up @@ -203,7 +203,7 @@ def __init__(self, scale: int):
raise ValueError(f"{scale=} must be a positive int")
self.scale = scale

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
return pixelate(x, jax.lax.stop_gradient(self.scale))

Expand Down Expand Up @@ -251,7 +251,7 @@ class Solarize2D(sk.TreeClass):
threshold: float
max_val: float = 1.0

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
threshold, max_val = jax.lax.stop_gradient((self.threshold, self.max_val))
return solarize(x, threshold, max_val)
Expand Down Expand Up @@ -320,7 +320,7 @@ class Posterize2D(sk.TreeClass):

bits: int = sk.field(on_setattr=[IsInstance(int), Range(1, 8)])

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
bits = jax.lax.stop_gradient(self.bits)
return jax.vmap(posterize, in_axes=(0, None))(x, bits)
Expand Down Expand Up @@ -410,7 +410,7 @@ class JigSaw2D(sk.TreeClass):

tiles: int = sk.field(on_setattr=[IsInstance(int), Range(1)])

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array:
"""Mixes up tiles of an image.
Expand Down
14 changes: 7 additions & 7 deletions serket/_src/image/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
generate_conv_dim_numbers,
positive_int_cb,
resolve_string_padding,
validate_spatial_ndim,
validate_spatial_nd,
)


Expand Down Expand Up @@ -141,7 +141,7 @@ class AvgBlur2D(AvgBlur2DBase):
[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]]
"""

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
kernel = jax.lax.stop_gradient_p.bind(self.kernel)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel)
Expand Down Expand Up @@ -170,7 +170,7 @@ class FFTAvgBlur2D(AvgBlur2DBase):
[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]]
"""

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
kernel = jax.lax.stop_gradient_p.bind(self.kernel)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel)
Expand Down Expand Up @@ -231,7 +231,7 @@ class GaussianBlur2D(GaussianBlur2DBase):
[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]]
"""

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
kernel = jax.lax.stop_gradient_p.bind(self.kernel)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel)
Expand Down Expand Up @@ -261,7 +261,7 @@ class FFTGaussianBlur2D(GaussianBlur2DBase):
[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]]
"""

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
kernel = jax.lax.stop_gradient_p.bind(self.kernel)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel)
Expand Down Expand Up @@ -300,7 +300,7 @@ def __init__(
raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)")
self.kernel = kernel.astype(dtype)

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
kernel = jax.lax.stop_gradient_p.bind(self.kernel)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel)
Expand Down Expand Up @@ -338,7 +338,7 @@ def __init__(self, kernel: jax.Array, *, dtype: DType = jnp.float32):

self.kernel = kernel.astype(dtype)

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
kernel = jax.lax.stop_gradient_p.bind(self.kernel)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel)
Expand Down
Loading

0 comments on commit 760b31a

Please sign in to comment.