Skip to content

Commit

Permalink
training: Linting, fixes and docstrings
Browse files Browse the repository at this point in the history
setup: Update requirements files
  • Loading branch information
torzdf committed Apr 17, 2024
1 parent cc5a320 commit 5bcd748
Show file tree
Hide file tree
Showing 17 changed files with 600 additions and 516 deletions.
89 changes: 45 additions & 44 deletions lib/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,24 @@

import numpy as np

from keras import ops, Variable
from keras import ops
from keras.backend.exports import Variable

if T.TYPE_CHECKING:
import torch
from keras import KerasTensor

# TODO these can probably be switched to pure pytorch


def frobenius_norm(matrix: torch.Tensor,
def frobenius_norm(matrix: KerasTensor,
axis: int = -1,
keep_dims: bool = True,
epsilon: float = 1e-15) -> torch.Tensor:
epsilon: float = 1e-15) -> KerasTensor:
""" Frobenius normalization for Keras Tensor
Parameters
----------
matrix: :class:`torch.Tensor`
matrix: :class:`keras.KerasTensor`
The matrix to normalize
axis: int, optional
The axis to normalize. Default: `-1`
Expand All @@ -32,13 +33,13 @@ def frobenius_norm(matrix: torch.Tensor,
Returns
-------
:class:`torch.Tensor`
:class:`keras.KerasTensor`
The normalized output
"""
return ops.sqrt(ops.sum(ops.power(matrix, 2), axis=axis, keepdims=keep_dims) + epsilon)


def replicate_pad(image: torch.Tensor, padding: int) -> torch.Tensor:
def replicate_pad(image: KerasTensor, padding: int) -> KerasTensor:
""" Apply replication padding to an input batch of images. Expects 4D tensor in BHWC format.
Notes
Expand All @@ -49,14 +50,14 @@ def replicate_pad(image: torch.Tensor, padding: int) -> torch.Tensor:
Parameters
----------
image: :class:`torch.Tensor`
image: :class:`keras.KerasTensor`
Image tensor to pad
pad: int
The amount of padding to apply to each side of the input image
Returns
-------
:class:`torch.Tensor`
:class:`keras.KerasTensor`
The input image with replication padding applied
"""
top_pad = ops.tile(image[:, :1, ...], (1, padding, 1, 1))
Expand Down Expand Up @@ -120,7 +121,7 @@ def __init__(self, from_space: str, to_space: str) -> None:
self._xyz_multipliers = Variable([116, 500, 200], dtype="float32", trainable=False)

@classmethod
def _get_rgb_xyz_map(cls) -> tuple[torch.Tensor, torch.Tensor]:
def _get_rgb_xyz_map(cls) -> tuple[KerasTensor, KerasTensor]:
""" Obtain the mapping and inverse mapping for rgb to xyz color space conversion.
Returns
Expand All @@ -135,38 +136,38 @@ def _get_rgb_xyz_map(cls) -> tuple[torch.Tensor, torch.Tensor]:
return (Variable(mapping, dtype="float32", trainable=False),
Variable(inverse, dtype="float32", trainable=False))

def __call__(self, image: torch.Tensor) -> torch.Tensor:
def __call__(self, image: KerasTensor) -> KerasTensor:
""" Call the colorspace conversion function.
Parameters
----------
image: :class:`torch.Tensor`
image: :class:`keras.KerasTensor`
The image tensor in the colorspace defined by :param:`from_space`
Returns
-------
:class:`torch.Tensor`
:class:`keras.KerasTensor`
The image tensor in the colorspace defined by :param:`to_space`
"""
return self._func(image)

def _rgb_to_lab(self, image: torch.Tensor) -> torch.Tensor:
def _rgb_to_lab(self, image: KerasTensor) -> KerasTensor:
""" RGB to LAB conversion.
Parameters
----------
image: :class:`torch.Tensor`
image: :class:`keras.KerasTensor`
The image tensor in RGB format
Returns
-------
:class:`torch.Tensor`
:class:`keras.KerasTensor`
The image tensor in LAB format
"""
converted = self._rgb_to_xyz(image)
return self._xyz_to_lab(converted)

def _rgb_xyz_rgb(self, image: torch.Tensor, mapping: torch.Tensor) -> torch.Tensor:
def _rgb_xyz_rgb(self, image: KerasTensor, mapping: KerasTensor) -> KerasTensor:
""" RGB to XYZ or XYZ to RGB conversion.
Notes
Expand All @@ -180,16 +181,16 @@ def _rgb_xyz_rgb(self, image: torch.Tensor, mapping: torch.Tensor) -> torch.Tens
Parameters
----------
mapping: :class:`torch.Tensor`
mapping: :class:`keras.KerasTensor`
The mapping matrix to perform either the XYZ to RGB or RGB to XYZ color space
conversion
image: :class:`torch.Tensor`
image: :class:`keras.KerasTensor`
The image tensor in RGB format
Returns
-------
:class:`torch.Tensor`
:class:`keras.KerasTensor`
The image tensor in XYZ format
"""
dim = image.shape
Expand All @@ -198,23 +199,23 @@ def _rgb_xyz_rgb(self, image: torch.Tensor, mapping: torch.Tensor) -> torch.Tens
converted = ops.transpose(ops.dot(mapping, image), (0, 2, 1))
return ops.reshape(converted, dim)

def _rgb_to_xyz(self, image: torch.Tensor) -> torch.Tensor:
def _rgb_to_xyz(self, image: KerasTensor) -> KerasTensor:
""" RGB to XYZ conversion.
Parameters
----------
image: :class:`torch.Tensor`
image: :class:`keras.KerasTensor`
The image tensor in RGB format
Returns
-------
:class:`torch.Tensor`
:class:`keras.KerasTensor`
The image tensor in XYZ format
"""
return self._rgb_xyz_rgb(image, self._rgb_xyz_map[0])

@classmethod
def _srgb_to_rgb(cls, image: torch.Tensor) -> torch.Tensor:
def _srgb_to_rgb(cls, image: KerasTensor) -> KerasTensor:
""" SRGB to RGB conversion.
Notes
Expand All @@ -223,47 +224,47 @@ def _srgb_to_rgb(cls, image: torch.Tensor) -> torch.Tensor:
Parameters
----------
image: :class:`torch.Tensor`
image: :class:`keras.KerasTensor`
The image tensor in SRGB format
Returns
-------
:class:`torch.Tensor`
:class:`keras.KerasTensor`
The image tensor in RGB format
"""
limit = np.float32(0.04045)
return ops.where(image > limit,
ops.power((ops.clip(image, limit, np.inf) + 0.055) / 1.055, 2.4),
image / 12.92)

def _srgb_to_ycxcz(self, image: torch.Tensor) -> torch.Tensor:
def _srgb_to_ycxcz(self, image: KerasTensor) -> KerasTensor:
""" SRGB to YcXcZ conversion.
Parameters
----------
image: :class:`torch.Tensor`
image: :class:`keras.KerasTensor`
The image tensor in SRGB format
Returns
-------
:class:`torch.Tensor`
:class:`keras.KerasTensor`
The image tensor in YcXcZ format
"""
converted = self._srgb_to_rgb(image)
converted = self._rgb_to_xyz(converted)
return self._xyz_to_ycxcz(converted)

def _xyz_to_lab(self, image: torch.Tensor) -> torch.Tensor:
def _xyz_to_lab(self, image: KerasTensor) -> KerasTensor:
""" XYZ to LAB conversion.
Parameters
----------
image: :class:`torch.Tensor`
image: :class:`keras.KerasTensor`
The image tensor in XYZ format
Returns
-------
:class:`torch.Tensor`
:class:`keras.KerasTensor`
The image tensor in LAB format
"""
image = image * self._inv_ref_illuminant
Expand All @@ -280,66 +281,66 @@ def _xyz_to_lab(self, image: torch.Tensor) -> torch.Tensor:
self._xyz_multipliers[1:] * (image[..., :2] - image[..., 1:3])],
axis=-1)

def _xyz_to_rgb(self, image: torch.Tensor) -> torch.Tensor:
def _xyz_to_rgb(self, image: KerasTensor) -> KerasTensor:
""" XYZ to YcXcZ conversion.
Parameters
----------
image: :class:`torch.Tensor`
image: :class:`keras.KerasTensor`
The image tensor in XYZ format
Returns
-------
:class:`torch.Tensor`
:class:`keras.KerasTensor`
The image tensor in RGB format
"""
return self._rgb_xyz_rgb(image, self._rgb_xyz_map[1])

def _xyz_to_ycxcz(self, image: torch.Tensor) -> torch.Tensor:
def _xyz_to_ycxcz(self, image: KerasTensor) -> KerasTensor:
""" XYZ to YcXcZ conversion.
Parameters
----------
image: :class:`torch.Tensor`
image: :class:`keras.KerasTensor`
The image tensor in XYZ format
Returns
-------
:class:`torch.Tensor`
:class:`keras.KerasTensor`
The image tensor in YcXcZ format
"""
image = image * self._inv_ref_illuminant
return ops.concatenate([self._xyz_multipliers[0] * image[..., 1:2] - 16.,
self._xyz_multipliers[1:] * (image[..., :2] - image[..., 1:3])],
axis=-1)

def _ycxcz_to_rgb(self, image: torch.Tensor) -> torch.Tensor:
def _ycxcz_to_rgb(self, image: KerasTensor) -> KerasTensor:
""" YcXcZ to RGB conversion.
Parameters
----------
image: :class:`torch.Tensor`
image: :class:`keras.KerasTensor`
The image tensor in YcXcZ format
Returns
-------
:class:`torch.Tensor`
:class:`keras.KerasTensor`
The image tensor in RGB format
"""
converted = self._ycxcz_to_xyz(image)
return self._xyz_to_rgb(converted)

def _ycxcz_to_xyz(self, image: torch.Tensor) -> torch.Tensor:
def _ycxcz_to_xyz(self, image: KerasTensor) -> KerasTensor:
""" YcXcZ to XYZ conversion.
Parameters
----------
image: :class:`torch.Tensor`
image: :class:`keras.KerasTensor`
The image tensor in YcXcZ format
Returns
-------
:class:`torch.Tensor`
:class:`keras.KerasTensor`
The image tensor in XYZ format
"""
ch_y = (image[..., 0:1] + 16.) / self._xyz_multipliers[0]
Expand Down
16 changes: 11 additions & 5 deletions lib/model/autoclip.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
""" Auto clipper for clipping gradients. """
from __future__ import annotations

import logging
import typing as T

import numpy as np
import torch

from lib.logger import parse_class_init

if T.TYPE_CHECKING:
from keras import KerasTensor

logger = logging.getLogger(__name__)


Expand All @@ -29,26 +35,26 @@ def __init__(self, clip_percentile: int, history_size: int = 10000) -> None:

self._clip_percentile = clip_percentile
self._history_size = history_size
self._grad_history = []
self._grad_history: list[float] = []

logger.debug("Initialized %s", self.__class__.__name__)

def __call__(self, gradients: list[torch.Tensor]) -> list[torch.Tensor]:
def __call__(self, gradients: list[KerasTensor]) -> list[KerasTensor]:
""" Call the AutoClip function.
Parameters
----------
gradients: list[:class:`torch.Tensor`]
gradients: list[:class:`keras.KerasTensor`]
The list of gradient tensors for the optimizer
Returns
----------
list[:class:`torch.Tensor`]
list[:class:`keras.KerasTensor`]
The autoclipped gradients
"""
self._grad_history.append(sum(g.data.norm(2).item() ** 2
for g in gradients if g is not None) ** (1. / 2))
self._grad_history = self._grad_history[-self._history_size:]
clip_value = np.percentile(self._grad_history, self._clip_percentile)
torch.nn.utils.clip_grad_norm_(gradients, clip_value)
torch.nn.utils.clip_grad_norm_(gradients, T.cast(float, clip_value))
return gradients
Loading

0 comments on commit 5bcd748

Please sign in to comment.