Skip to content

Commit

Permalink
attention updates (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Dec 9, 2023
1 parent dff6a25 commit 9a4711a
Show file tree
Hide file tree
Showing 14 changed files with 226 additions and 235 deletions.
42 changes: 24 additions & 18 deletions serket/_src/image/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
from serket._src.custom_transform import tree_eval
from serket._src.image.color import hsv_to_rgb_3d, rgb_to_hsv_3d
from serket._src.nn.linear import Identity
from serket._src.utils import CHWArray, HWArray, IsInstance, Range, validate_spatial_nd
from serket._src.utils import (
CHWArray,
HWArray,
IsInstance,
Range,
validate_spatial_ndim,
)


def pixel_shuffle_3d(array: CHWArray, upscale_factor: tuple[int, int]) -> CHWArray:
Expand Down Expand Up @@ -252,7 +258,7 @@ def __init__(self, upscale_factor: int = 1):

self.upscale_factor = (upscale_factor, upscale_factor)

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, array: CHWArray) -> CHWArray:
return pixel_shuffle_3d(array, self.upscale_factor)

Expand All @@ -275,7 +281,7 @@ class AdjustContrast2D(sk.TreeClass):

factor: float = 1.0

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
factor = jax.lax.stop_gradient(self.factor)
return jax.vmap(adjust_contrast_2d, in_axes=(0, None))(image, factor)
Expand Down Expand Up @@ -307,7 +313,7 @@ def __init__(self, range: tuple[float, float] = (0.5, 1)):

self.range = range

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
range = jax.lax.stop_gradient(self.range)
in_axes = (None, 0, None)
Expand Down Expand Up @@ -338,7 +344,7 @@ class AdjustBrightness2D(sk.TreeClass):

factor: float = sk.field(on_setattr=[IsInstance(float), Range(0, 1)])

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
factor = jax.lax.stop_gradient(self.factor)
return jax.vmap(adjust_brightness_2d, in_axes=(0, None))(image, factor)
Expand All @@ -360,7 +366,7 @@ class RandomBrightness2D(sk.TreeClass):

range: tuple[float, float] = sk.field(on_setattr=[IsInstance(tuple)])

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, array: CHWArray, *, key: jax.Array) -> CHWArray:
range = jax.lax.stop_gradient(self.range)
in_axes = (None, 0, None)
Expand Down Expand Up @@ -397,7 +403,7 @@ def __init__(self, scale: int):
raise ValueError(f"{scale=} must be a positive int")
self.scale = scale

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return jax.vmap(pixelate_2d, in_axes=(0, None))(image, self.scale)

Expand Down Expand Up @@ -434,7 +440,7 @@ class Solarize2D(sk.TreeClass):
threshold: float
max_val: float = 1.0

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
threshold, max_val = jax.lax.stop_gradient((self.threshold, self.max_val))
in_axes = (0, None, None)
Expand Down Expand Up @@ -488,7 +494,7 @@ class Posterize2D(sk.TreeClass):

bits: int = sk.field(on_setattr=[IsInstance(int), Range(1, 8)])

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return jax.vmap(posterize_2d, in_axes=(0, None))(image, self.bits)

Expand Down Expand Up @@ -541,7 +547,7 @@ class RandomJigSaw2D(sk.TreeClass):

tiles: int = sk.field(on_setattr=[IsInstance(int), Range(1)])

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
"""Mixes up tiles of an image.
Expand Down Expand Up @@ -579,7 +585,7 @@ def __init__(self, gain: float = 1, inv: bool = False):
self.gain = gain
self.inv = inv

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, array: CHWArray) -> CHWArray:
in_axes = (0, None, None)
gain = jax.lax.stop_gradient(self.gain)
Expand Down Expand Up @@ -614,7 +620,7 @@ def __init__(self, cutoff: float = 0.5, gain: float = 10, inv: bool = False):
self.gain = gain
self.inv = inv

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
in_axes = (0, None, None, None)
cutoff, gain = jax.lax.stop_gradient((self.cutoff, self.gain))
Expand All @@ -637,7 +643,7 @@ class AdjustHue2D(sk.TreeClass):
def __init__(self, factor: float):
self.factor = factor

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
factor = jax.lax.stop_gradient(self.factor)
return adjust_hue_3d(image, factor)
Expand All @@ -660,7 +666,7 @@ class RandomHue2D(sk.TreeClass):
def __init__(self, range: tuple[float, float]):
self.range = range

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
range = jax.lax.stop_gradient(self.range)
return random_hue_3d(key, image, range)
Expand All @@ -681,7 +687,7 @@ class AdjustSaturation2D(sk.TreeClass):
def __init__(self, factor: float):
self.factor = factor

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
factor = jax.lax.stop_gradient(self.factor)
return adust_saturation_3d(image, factor)
Expand All @@ -704,7 +710,7 @@ class RandomSaturation2D(sk.TreeClass):
def __init__(self, range: tuple[float, float]):
self.range = range

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
range = jax.lax.stop_gradient(self.range)
return random_saturation_3d(key, image, range)
Expand Down Expand Up @@ -743,8 +749,8 @@ class FourierDomainAdapt2D(sk.TreeClass):
def __init__(self, beta: float):
self.beta = beta

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim", argnum=0)
@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim", argnum=1)
@ft.partial(validate_spatial_ndim, argnum=0)
@ft.partial(validate_spatial_ndim, argnum=1)
def __call__(self, image: CHWArray, target: CHWArray) -> CHWArray:
"""Fourier Domain Adaptation
Expand Down
10 changes: 5 additions & 5 deletions serket/_src/image/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import jax.numpy as jnp

import serket as sk
from serket._src.utils import CHWArray, validate_spatial_nd
from serket._src.utils import CHWArray, validate_spatial_ndim


def rgb_to_grayscale(image: CHWArray, weights: jax.Array | None = None) -> CHWArray:
Expand Down Expand Up @@ -71,7 +71,7 @@ class RGBToGrayscale2D(sk.TreeClass):
def __init__(self, weights: jax.Array | None = None):
self.weights = weights

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return rgb_to_grayscale(image, self.weights)

Expand Down Expand Up @@ -141,7 +141,7 @@ class GrayscaleToRGB2D(sk.TreeClass):
(3, 5, 5)
"""

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return grayscale_to_rgb(image)

Expand All @@ -166,7 +166,7 @@ class RGBToHSV2D(sk.TreeClass):
- https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html
"""

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return rgb_to_hsv_3d(image)

Expand All @@ -189,7 +189,7 @@ class HSVToRGB2D(sk.TreeClass):
- https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html
"""

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return hsv_to_rgb_3d(image)

Expand Down
Loading

0 comments on commit 9a4711a

Please sign in to comment.