Skip to content

Commit

Permalink
nits
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Dec 5, 2023
1 parent 3efc30b commit d6bcaf2
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 261 deletions.
48 changes: 24 additions & 24 deletions serket/_src/image/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,9 @@ class AdjustContrast2D(sk.TreeClass):
factor: float = 1.0

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: CHWArray) -> CHWArray:
def __call__(self, image: CHWArray) -> CHWArray:
factor = jax.lax.stop_gradient(self.factor)
return jax.vmap(adjust_contrast_2d, in_axes=(0, None))(x, factor)
return jax.vmap(adjust_contrast_2d, in_axes=(0, None))(image, factor)

spatial_ndim: int = 2

Expand Down Expand Up @@ -308,10 +308,10 @@ def __init__(self, range: tuple[float, float] = (0.5, 1)):
self.range = range

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, array: CHWArray, *, key: jax.Array) -> CHWArray:
def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
range = jax.lax.stop_gradient(self.range)
in_axes = (None, 0, None)
return jax.vmap(random_contrast_2d, in_axes=in_axes)(key, array, range)
return jax.vmap(random_contrast_2d, in_axes=in_axes)(key, image, range)

spatial_ndim: int = 2

Expand Down Expand Up @@ -339,9 +339,9 @@ class AdjustBrightness2D(sk.TreeClass):
factor: float = sk.field(on_setattr=[IsInstance(float), Range(0, 1)])

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, array: CHWArray) -> CHWArray:
def __call__(self, image: CHWArray) -> CHWArray:
factor = jax.lax.stop_gradient(self.factor)
return jax.vmap(adjust_brightness_2d, in_axes=(0, None))(array, factor)
return jax.vmap(adjust_brightness_2d, in_axes=(0, None))(image, factor)

spatial_ndim: int = 2

Expand Down Expand Up @@ -398,8 +398,8 @@ def __init__(self, scale: int):
self.scale = scale

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, array: CHWArray) -> CHWArray:
return jax.vmap(pixelate_2d, in_axes=(0, None))(array, self.scale)
def __call__(self, image: CHWArray) -> CHWArray:
return jax.vmap(pixelate_2d, in_axes=(0, None))(image, self.scale)

spatial_ndim: int = 2

Expand Down Expand Up @@ -435,10 +435,10 @@ class Solarize2D(sk.TreeClass):
max_val: float = 1.0

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, array: CHWArray) -> CHWArray:
def __call__(self, image: CHWArray) -> CHWArray:
threshold, max_val = jax.lax.stop_gradient((self.threshold, self.max_val))
in_axes = (0, None, None)
args = (array, threshold, max_val)
args = (image, threshold, max_val)
return jax.vmap(solarize_2d, in_axes=in_axes)(*args)

spatial_ndim: int = 2
Expand Down Expand Up @@ -489,8 +489,8 @@ class Posterize2D(sk.TreeClass):
bits: int = sk.field(on_setattr=[IsInstance(int), Range(1, 8)])

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, array: CHWArray) -> CHWArray:
return jax.vmap(posterize_2d, in_axes=(0, None))(array, self.bits)
def __call__(self, image: CHWArray) -> CHWArray:
return jax.vmap(posterize_2d, in_axes=(0, None))(image, self.bits)

spatial_ndim: int = 2

Expand Down Expand Up @@ -542,15 +542,15 @@ class RandomJigSaw2D(sk.TreeClass):
tiles: int = sk.field(on_setattr=[IsInstance(int), Range(1)])

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, array: CHWArray, *, key: jax.Array) -> CHWArray:
def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
"""Mixes up tiles of an image.
Args:
x: channel-first image (CHW)
key: random key
"""
in_axes = (None, 0, None)
args = (key, array, self.tiles)
args = (key, image, self.tiles)
return jax.vmap(random_jigsaw_2d, in_axes=in_axes)(*args)

spatial_ndim: int = 2
Expand Down Expand Up @@ -615,10 +615,10 @@ def __init__(self, cutoff: float = 0.5, gain: float = 10, inv: bool = False):
self.inv = inv

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, array: CHWArray) -> CHWArray:
def __call__(self, image: CHWArray) -> CHWArray:
in_axes = (0, None, None, None)
cutoff, gain = jax.lax.stop_gradient((self.cutoff, self.gain))
args = (array, cutoff, gain, self.inv)
args = (image, cutoff, gain, self.inv)
return jax.vmap(adjust_sigmoid_2d, in_axes=in_axes)(*args)

spatial_ndim: int = 2
Expand All @@ -638,9 +638,9 @@ def __init__(self, factor: float):
self.factor = factor

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, array: CHWArray) -> CHWArray:
def __call__(self, image: CHWArray) -> CHWArray:
factor = jax.lax.stop_gradient(self.factor)
return adjust_hue_3d(array, factor)
return adjust_hue_3d(image, factor)

spatial_ndim: int = 2

Expand All @@ -661,9 +661,9 @@ def __init__(self, range: tuple[float, float]):
self.range = range

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, array: CHWArray, *, key: jax.Array) -> CHWArray:
def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
range = jax.lax.stop_gradient(self.range)
return random_hue_3d(key, array, range)
return random_hue_3d(key, image, range)

spatial_ndim: int = 2

Expand All @@ -682,9 +682,9 @@ def __init__(self, factor: float):
self.factor = factor

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, array: CHWArray) -> CHWArray:
def __call__(self, image: CHWArray) -> CHWArray:
factor = jax.lax.stop_gradient(self.factor)
return adust_saturation_3d(array, factor)
return adust_saturation_3d(image, factor)

spatial_ndim: int = 2

Expand All @@ -705,9 +705,9 @@ def __init__(self, range: tuple[float, float]):
self.range = range

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: CHWArray, *, key: jax.Array) -> CHWArray:
def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
range = jax.lax.stop_gradient(self.range)
return random_saturation_3d(key, x, range)
return random_saturation_3d(key, image, range)

spatial_ndim: int = 2

Expand Down
Loading

0 comments on commit d6bcaf2

Please sign in to comment.