Skip to content

Commit

Permalink
Utils refactor (#103)
Browse files Browse the repository at this point in the history
* split utils to smaller files

* refactor lazy

* lazy note

* fix `handle_pos_or_kw`

* Update lazy.py
  • Loading branch information
ASEM000 committed Apr 8, 2024
1 parent d3b2389 commit e1aa2e6
Show file tree
Hide file tree
Showing 51 changed files with 1,084 additions and 1,036 deletions.
2 changes: 1 addition & 1 deletion serket/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 serket authors
# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion serket/_src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 serket authors
# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion serket/_src/cluster/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 serket authors
# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
4 changes: 2 additions & 2 deletions serket/_src/cluster/kmeans.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 serket authors
# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -23,7 +23,7 @@

import serket as sk
from serket._src.custom_transform import tree_eval, tree_state
from serket._src.utils import IsInstance, Range
from serket._src.utils.validate import IsInstance, Range

"""K-means utility functions."""

Expand Down
12 changes: 6 additions & 6 deletions serket/_src/containers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 serket authors
# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,9 +20,9 @@
import jax
import jax.random as jr

import serket as sk
from serket import TreeClass, tree_summary
from serket._src.custom_transform import tree_eval
from serket._src.utils import single_dispatch
from serket._src.utils.dispatch import single_dispatch


@single_dispatch(argnum=0)
Expand Down Expand Up @@ -53,7 +53,7 @@ def _(key: jax.Array, layers: Sequence[Callable[..., Any]], array: Any):
return array


class Sequential(sk.TreeClass):
class Sequential(TreeClass):
"""A sequential container for layers.
Args:
Expand Down Expand Up @@ -103,7 +103,7 @@ def __reversed__(self):
return reversed(self.layers)


@sk.tree_summary.def_type(Sequential)
@tree_summary.def_type(Sequential)
def _(node):
types = [type(x).__name__ for x in node]
return f"{type(node).__name__}[{','.join(types)}]"
Expand All @@ -121,7 +121,7 @@ def random_choice(key: jax.Array, layers: tuple[Callable[..., Any], ...], array:
return jax.lax.switch(index, layers, array)


class RandomChoice(sk.TreeClass):
class RandomChoice(TreeClass):
"""Randomly selects one of the given layers/functions.
Args:
Expand Down
4 changes: 2 additions & 2 deletions serket/_src/custom_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 serket authors
# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,7 +22,7 @@
import jax

import serket as sk
from serket._src.utils import single_dispatch
from serket._src.utils.dispatch import single_dispatch

T = TypeVar("T")

Expand Down
2 changes: 1 addition & 1 deletion serket/_src/image/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 serket authors
# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
69 changes: 32 additions & 37 deletions serket/_src/image/augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 serket authors
# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -21,17 +21,12 @@
import jax.numpy as jnp
import jax.random as jr

import serket as sk
from serket import TreeClass, autoinit, field
from serket._src.custom_transform import tree_eval
from serket._src.image.color import hsv_to_rgb, rgb_to_hsv
from serket._src.nn.linear import Identity
from serket._src.utils import (
CHWArray,
HWArray,
IsInstance,
Range,
validate_spatial_ndim,
)
from serket._src.utils.typing import CHWArray, HWArray
from serket._src.utils.validate import IsInstance, Range, validate_spatial_ndim


def pixel_shuffle_3d(array: CHWArray, upscale_factor: tuple[int, int]) -> CHWArray:
Expand Down Expand Up @@ -73,7 +68,7 @@ def adjust_contrast_2d(image: HWArray, factor: float):
"""Adjusts the contrast of an image by scaling the pixel values by a factor.
Args:
array: input array \in [0, 1] with shape (height, width)
array: input array in [0, 1] with shape (height, width)
factor: contrast factor to adust the contrast by.
"""
_, _ = image.shape
Expand All @@ -97,7 +92,7 @@ def adjust_brightness_2d(image: HWArray, factor: float) -> HWArray:
"""Adjusts the brightness of an image by adding a value to the pixel values.
Args:
array: input array \in [0, 1] with shape (height, width)
array: input array in [0, 1] with shape (height, width)
factor: brightness factor to adust the brightness by.
"""
_, _ = image.shape
Expand Down Expand Up @@ -240,7 +235,7 @@ def fourier_domain_adapt_2d(image: HWArray, styler: HWArray, beta: float):
return image_out.astype(dtype)


class PixelShuffle2D(sk.TreeClass):
class PixelShuffle2D(TreeClass):
"""Rearrange elements in a tensor.
.. image:: ../_static/pixelshuffle2d.png
Expand All @@ -265,8 +260,8 @@ def __call__(self, array: CHWArray) -> CHWArray:
spatial_ndim: int = 2


@sk.autoinit
class AdjustContrast2D(sk.TreeClass):
@autoinit
class AdjustContrast2D(TreeClass):
"""Adjusts the contrast of an 2D input by scaling the pixel values by a factor.
.. image:: ../_static/adjustcontrast2d.png
Expand All @@ -289,7 +284,7 @@ def __call__(self, image: CHWArray) -> CHWArray:
spatial_ndim: int = 2


class RandomContrast2D(sk.TreeClass):
class RandomContrast2D(TreeClass):
"""Randomly adjusts the contrast of an 1D input by scaling the pixel values by a factor.
Args:
Expand Down Expand Up @@ -322,8 +317,8 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
spatial_ndim: int = 2


@sk.autoinit
class AdjustBrightness2D(sk.TreeClass):
@autoinit
class AdjustBrightness2D(TreeClass):
"""Adjusts the brightness of an 2D input by adding a value to the pixel values.
.. image:: ../_static/adjustbrightness2d.png
Expand All @@ -342,7 +337,7 @@ class AdjustBrightness2D(sk.TreeClass):
[1. 1. 1. 1. ]]]
"""

factor: float = sk.field(on_setattr=[IsInstance(float), Range(0, 1)])
factor: float = field(on_setattr=[IsInstance(float), Range(0, 1)])

@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
Expand All @@ -352,8 +347,8 @@ def __call__(self, image: CHWArray) -> CHWArray:
spatial_ndim: int = 2


@sk.autoinit
class RandomBrightness2D(sk.TreeClass):
@autoinit
class RandomBrightness2D(TreeClass):
"""Randomly adjusts the brightness of an 2D input by adding a value to the pixel values.
Args:
Expand All @@ -364,7 +359,7 @@ class RandomBrightness2D(sk.TreeClass):
evaluation.
"""

range: tuple[float, float] = sk.field(on_setattr=[IsInstance(tuple)])
range: tuple[float, float] = field(on_setattr=[IsInstance(tuple)])

@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, array: CHWArray, *, key: jax.Array) -> CHWArray:
Expand All @@ -375,7 +370,7 @@ def __call__(self, array: CHWArray, *, key: jax.Array) -> CHWArray:
spatial_ndim: int = 2


class Pixelate2D(sk.TreeClass):
class Pixelate2D(TreeClass):
"""Pixelate an image by upsizing and downsizing an image
.. image:: ../_static/pixelate2d.png
Expand Down Expand Up @@ -410,8 +405,8 @@ def __call__(self, image: CHWArray) -> CHWArray:
spatial_ndim: int = 2


@sk.autoinit
class Solarize2D(sk.TreeClass):
@autoinit
class Solarize2D(TreeClass):
"""Inverts all values above a given threshold.
.. image:: ../_static/solarize2d.png
Expand Down Expand Up @@ -450,8 +445,8 @@ def __call__(self, image: CHWArray) -> CHWArray:
spatial_ndim: int = 2


@sk.autoinit
class Posterize2D(sk.TreeClass):
@autoinit
class Posterize2D(TreeClass):
"""Reduce the number of bits for each color channel.
.. image:: ../_static/posterize2d.png
Expand Down Expand Up @@ -492,7 +487,7 @@ class Posterize2D(sk.TreeClass):
- https://github.com/python-pillow/Pillow/blob/main/src/PIL/ImageOps.py#L547
"""

bits: int = sk.field(on_setattr=[IsInstance(int), Range(1, 8)])
bits: int = field(on_setattr=[IsInstance(int), Range(1, 8)])

@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
Expand All @@ -501,8 +496,8 @@ def __call__(self, image: CHWArray) -> CHWArray:
spatial_ndim: int = 2


@sk.autoinit
class RandomJigSaw2D(sk.TreeClass):
@autoinit
class RandomJigSaw2D(TreeClass):
"""Mixes up tiles of an image.
.. image:: ../_static/jigsaw2d.png
Expand Down Expand Up @@ -545,7 +540,7 @@ class RandomJigSaw2D(sk.TreeClass):
- https://imgaug.readthedocs.io/en/latest/source/overview/geometric.html#jigsaw
"""

tiles: int = sk.field(on_setattr=[IsInstance(int), Range(1)])
tiles: int = field(on_setattr=[IsInstance(int), Range(1)])

@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
Expand All @@ -562,7 +557,7 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
spatial_ndim: int = 2


class AdjustLog2D(sk.TreeClass):
class AdjustLog2D(TreeClass):
"""Adjust log correction on the input 2D image of range [0, 1].
Args:
Expand Down Expand Up @@ -594,7 +589,7 @@ def __call__(self, array: CHWArray) -> CHWArray:
spatial_ndim: int = 2


class AdjustSigmoid2D(sk.TreeClass):
class AdjustSigmoid2D(TreeClass):
"""Adjust sigmoid correction on the input 2D image of range [0, 1].
Expand Down Expand Up @@ -630,7 +625,7 @@ def __call__(self, image: CHWArray) -> CHWArray:
spatial_ndim: int = 2


class AdjustHue2D(sk.TreeClass):
class AdjustHue2D(TreeClass):
"""Adjust hue of an RGB image.
.. image:: ../_static/adjusthue2d.png
Expand All @@ -651,7 +646,7 @@ def __call__(self, image: CHWArray) -> CHWArray:
spatial_ndim: int = 2


class RandomHue2D(sk.TreeClass):
class RandomHue2D(TreeClass):
"""Randomly adjust hue of an RGB image.
Args:
Expand All @@ -674,7 +669,7 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
spatial_ndim: int = 2


class AdjustSaturation2D(sk.TreeClass):
class AdjustSaturation2D(TreeClass):
"""Adjust saturation of an RGB image.
.. image:: ../_static/adjustsaturation2d.png
Expand All @@ -695,7 +690,7 @@ def __call__(self, image: CHWArray) -> CHWArray:
spatial_ndim: int = 2


class RandomSaturation2D(sk.TreeClass):
class RandomSaturation2D(TreeClass):
"""Randomly adjust saturation of an RGB image.
Args:
Expand All @@ -718,7 +713,7 @@ def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
spatial_ndim: int = 2


class FourierDomainAdapt2D(sk.TreeClass):
class FourierDomainAdapt2D(TreeClass):
"""Domain adaptation via style transfer
.. image:: ../_static/fourierdomainadapt2d.png
Expand Down
15 changes: 8 additions & 7 deletions serket/_src/image/color.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 serket authors
# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -21,8 +21,9 @@
import jax
import jax.numpy as jnp

import serket as sk
from serket._src.utils import CHWArray, validate_spatial_ndim
from serket import TreeClass
from serket._src.utils.typing import CHWArray
from serket._src.utils.validate import validate_spatial_ndim


def rgb_to_grayscale(image: CHWArray, weights: jax.Array | None = None) -> CHWArray:
Expand Down Expand Up @@ -50,7 +51,7 @@ def grayscale_to_rgb(image: CHWArray) -> CHWArray:
return jnp.concatenate([image, image, image], axis=0)


class RGBToGrayscale2D(sk.TreeClass):
class RGBToGrayscale2D(TreeClass):
"""Converts a channel-first RGB image to grayscale.
.. image:: ../_static/rgbtograyscale2d.png
Expand Down Expand Up @@ -135,7 +136,7 @@ def hsv_to_rgb(image: CHWArray) -> CHWArray:
return out


class GrayscaleToRGB2D(sk.TreeClass):
class GrayscaleToRGB2D(TreeClass):
"""Converts a grayscale image to RGB.
Example:
Expand All @@ -155,7 +156,7 @@ def __call__(self, image: CHWArray) -> CHWArray:
spatial_ndim: int = 2


class RGBToHSV2D(sk.TreeClass):
class RGBToHSV2D(TreeClass):
"""Converts an RGB image to HSV.
.. image:: ../_static/rgbtohsv2d.png
Expand All @@ -180,7 +181,7 @@ def __call__(self, image: CHWArray) -> CHWArray:
spatial_ndim: int = 2


class HSVToRGB2D(sk.TreeClass):
class HSVToRGB2D(TreeClass):
"""Converts an HSV image to RGB.
Example:
Expand Down
Loading

0 comments on commit e1aa2e6

Please sign in to comment.