Skip to content

Commit

Permalink
lib.model.losses + lib.keras_utils to Keras 3
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Mar 13, 2024
1 parent 6b38957 commit 38feb60
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 208 deletions.
159 changes: 82 additions & 77 deletions lib/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,23 @@

import numpy as np

import keras.backend as K
from keras import ops, Variable

if T.TYPE_CHECKING:
from tensorflow import Tensor
import torch

# TODO these can probably be switched to pure pytorch

def frobenius_norm(matrix: Tensor,

def frobenius_norm(matrix: torch.Tensor,
axis: int = -1,
keep_dims: bool = True,
epsilon: float = 1e-15) -> Tensor:
epsilon: float = 1e-15) -> torch.Tensor:
""" Frobenius normalization for Keras Tensor
Parameters
----------
matrix: Tensor
matrix: :class:`torch.Tensor`
The matrix to normalize
axis: int, optional
The axis to normalize. Default: `-1`
Expand All @@ -30,13 +32,13 @@ def frobenius_norm(matrix: Tensor,
Returns
-------
Tensor
:class:`torch.Tensor`
The normalized output
"""
return K.sqrt(K.sum(K.pow(matrix, 2), axis=axis, keepdims=keep_dims) + epsilon)
return ops.sqrt(ops.sum(ops.power(matrix, 2), axis=axis, keepdims=keep_dims) + epsilon)


def replicate_pad(image: Tensor, padding: int) -> Tensor:
def replicate_pad(image: torch.Tensor, padding: int) -> torch.Tensor:
""" Apply replication padding to an input batch of images. Expects 4D tensor in BHWC format.
Notes
Expand All @@ -47,22 +49,22 @@ def replicate_pad(image: Tensor, padding: int) -> Tensor:
Parameters
----------
image: Tensor
image: :class:`torch.Tensor`
Image tensor to pad
pad: int
The amount of padding to apply to each side of the input image
Returns
-------
Tensor
:class:`torch.Tensor`
The input image with replication padding applied
"""
top_pad = K.tile(image[:, :1, ...], (1, padding, 1, 1))
bottom_pad = K.tile(image[:, -1:, ...], (1, padding, 1, 1))
pad_top_bottom = K.concatenate([top_pad, image, bottom_pad], axis=1)
left_pad = K.tile(pad_top_bottom[..., :1, :], (1, 1, padding, 1))
right_pad = K.tile(pad_top_bottom[..., -1:, :], (1, 1, padding, 1))
padded = K.concatenate([left_pad, pad_top_bottom, right_pad], axis=2)
top_pad = ops.tile(image[:, :1, ...], (1, padding, 1, 1))
bottom_pad = ops.tile(image[:, -1:, ...], (1, padding, 1, 1))
pad_top_bottom = ops.concatenate([top_pad, image, bottom_pad], axis=1)
left_pad = ops.tile(pad_top_bottom[..., :1, :], (1, 1, padding, 1))
right_pad = ops.tile(pad_top_bottom[..., -1:, :], (1, 1, padding, 1))
padded = ops.concatenate([left_pad, pad_top_bottom, right_pad], axis=2)
return padded


Expand Down Expand Up @@ -101,23 +103,24 @@ def __init__(self, from_space: str, to_space: str) -> None:
"srgb_ycxcz": self._srgb_to_ycxcz,
"xyz_ycxcz": self._xyz_to_ycxcz,
"xyz_lab": self._xyz_to_lab,
"xyz_to_rgb": self._xyz_to_rgb,
"xyz_rgb": self._xyz_to_rgb,
"ycxcz_rgb": self._ycxcz_to_rgb,
"ycxcz_xyz": self._ycxcz_to_xyz}
func_name = f"{from_space.lower()}_{to_space.lower()}"
if func_name not in functions:
raise ValueError(f"The color transform {from_space} to {to_space} is not defined.")

self._func = functions[func_name]
self._ref_illuminant = K.constant(np.array([[[0.950428545, 1.000000000, 1.088900371]]]),
dtype="float32")
self._ref_illuminant = Variable(np.array([[[0.950428545, 1.000000000, 1.088900371]]]),
dtype="float32",
trainable=False)
self._inv_ref_illuminant = 1. / self._ref_illuminant

self._rgb_xyz_map = self._get_rgb_xyz_map()
self._xyz_multipliers = K.constant([116, 500, 200], dtype="float32")
self._xyz_multipliers = Variable([116, 500, 200], dtype="float32", trainable=False)

@classmethod
def _get_rgb_xyz_map(cls) -> tuple[Tensor, Tensor]:
def _get_rgb_xyz_map(cls) -> tuple[torch.Tensor, torch.Tensor]:
""" Obtain the mapping and inverse mapping for rgb to xyz color space conversion.
Returns
Expand All @@ -129,40 +132,41 @@ def _get_rgb_xyz_map(cls) -> tuple[Tensor, Tensor]:
[2613072 / 12288897, 8788810 / 12288897, 887015 / 12288897],
[1425312 / 73733382, 8788810 / 73733382, 70074185 / 73733382]])
inverse = np.linalg.inv(mapping)
return (K.constant(mapping, dtype="float32"), K.constant(inverse, dtype="float32"))
return (Variable(mapping, dtype="float32", trainable=False),
Variable(inverse, dtype="float32", trainable=False))

def __call__(self, image: Tensor) -> Tensor:
def __call__(self, image: torch.Tensor) -> torch.Tensor:
""" Call the colorspace conversion function.
Parameters
----------
image: Tensor
image: :class:`torch.Tensor`
The image tensor in the colorspace defined by :param:`from_space`
Returns
-------
Tensor
:class:`torch.Tensor`
The image tensor in the colorspace defined by :param:`to_space`
"""
return self._func(image)

def _rgb_to_lab(self, image: Tensor) -> Tensor:
def _rgb_to_lab(self, image: torch.Tensor) -> torch.Tensor:
""" RGB to LAB conversion.
Parameters
----------
image: Tensor
image: :class:`torch.Tensor`
The image tensor in RGB format
Returns
-------
Tensor
:class:`torch.Tensor`
The image tensor in LAB format
"""
converted = self._rgb_to_xyz(image)
return self._xyz_to_lab(converted)

def _rgb_xyz_rgb(self, image: Tensor, mapping: Tensor) -> Tensor:
def _rgb_xyz_rgb(self, image: torch.Tensor, mapping: torch.Tensor) -> torch.Tensor:
""" RGB to XYZ or XYZ to RGB conversion.
Notes
Expand All @@ -176,41 +180,41 @@ def _rgb_xyz_rgb(self, image: Tensor, mapping: Tensor) -> Tensor:
Parameters
----------
mapping: Tensor
mapping: :class:`torch.Tensor`
The mapping matrix to perform either the XYZ to RGB or RGB to XYZ color space
conversion
image: Tensor
image: :class:`torch.Tensor`
The image tensor in RGB format
Returns
-------
Tensor
:class:`torch.Tensor`
The image tensor in XYZ format
"""
dim = K.int_shape(image)
image = K.permute_dimensions(image, (0, 3, 1, 2))
image = K.reshape(image, (dim[0], dim[3], dim[1] * dim[2]))
converted = K.permute_dimensions(K.dot(mapping, image), (1, 2, 0))
return K.reshape(converted, dim)
dim = image.shape
image = ops.transpose(image, (0, 3, 1, 2))
image = ops.reshape(image, (dim[0], dim[3], dim[1] * dim[2]))
converted = ops.transpose(ops.dot(mapping, image), (0, 2, 1))
return ops.reshape(converted, dim)

def _rgb_to_xyz(self, image: Tensor) -> Tensor:
def _rgb_to_xyz(self, image: torch.Tensor) -> torch.Tensor:
""" RGB to XYZ conversion.
Parameters
----------
image: Tensor
image: :class:`torch.Tensor`
The image tensor in RGB format
Returns
-------
Tensor
:class:`torch.Tensor`
The image tensor in XYZ format
"""
return self._rgb_xyz_rgb(image, self._rgb_xyz_map[0])

@classmethod
def _srgb_to_rgb(cls, image: Tensor) -> Tensor:
def _srgb_to_rgb(cls, image: torch.Tensor) -> torch.Tensor:
""" SRGB to RGB conversion.
Notes
Expand All @@ -219,126 +223,127 @@ def _srgb_to_rgb(cls, image: Tensor) -> Tensor:
Parameters
----------
image: Tensor
image: :class:`torch.Tensor`
The image tensor in SRGB format
Returns
-------
Tensor
:class:`torch.Tensor`
The image tensor in RGB format
"""
limit = 0.04045
return K.switch(image > limit,
K.pow((K.clip(image, limit, None) + 0.055) / 1.055, 2.4),
image / 12.92)
limit = np.float32(0.04045)
return ops.where(image > limit,
ops.power((ops.clip(image, limit, np.inf) + 0.055) / 1.055, 2.4),
image / 12.92)

def _srgb_to_ycxcz(self, image: Tensor) -> Tensor:
def _srgb_to_ycxcz(self, image: torch.Tensor) -> torch.Tensor:
""" SRGB to YcXcZ conversion.
Parameters
----------
image: Tensor
image: :class:`torch.Tensor`
The image tensor in SRGB format
Returns
-------
Tensor
:class:`torch.Tensor`
The image tensor in YcXcZ format
"""
converted = self._srgb_to_rgb(image)
converted = self._rgb_to_xyz(converted)
return self._xyz_to_ycxcz(converted)

def _xyz_to_lab(self, image: Tensor) -> Tensor:
def _xyz_to_lab(self, image: torch.Tensor) -> torch.Tensor:
""" XYZ to LAB conversion.
Parameters
----------
image: Tensor
image: :class:`torch.Tensor`
The image tensor in XYZ format
Returns
-------
Tensor
:class:`torch.Tensor`
The image tensor in LAB format
"""
image = image * self._inv_ref_illuminant
delta = 6 / 29
delta_cube = delta ** 3
factor = 1 / (3 * (delta ** 2))

clamped_term = K.pow(K.clip(image, delta_cube, None), 1.0 / 3.0)
clamped_term = ops.power(ops.clip(image, delta_cube, np.inf), 1.0 / 3.0)
div = factor * image + (4 / 29)

image = K.switch(image > delta_cube, clamped_term, div)
return K.concatenate([self._xyz_multipliers[0] * image[..., 1:2] - 16.,
self._xyz_multipliers[1:] * (image[..., :2] - image[..., 1:3])],
axis=-1)
image = ops.where(image > delta_cube, clamped_term, div)

return ops.concatenate([self._xyz_multipliers[0] * image[..., 1:2] - 16.,
self._xyz_multipliers[1:] * (image[..., :2] - image[..., 1:3])],
axis=-1)

def _xyz_to_rgb(self, image: Tensor) -> Tensor:
def _xyz_to_rgb(self, image: torch.Tensor) -> torch.Tensor:
""" XYZ to YcXcZ conversion.
Parameters
----------
image: Tensor
image: :class:`torch.Tensor`
The image tensor in XYZ format
Returns
-------
Tensor
:class:`torch.Tensor`
The image tensor in RGB format
"""
return self._rgb_xyz_rgb(image, self._rgb_xyz_map[1])

def _xyz_to_ycxcz(self, image: Tensor) -> Tensor:
def _xyz_to_ycxcz(self, image: torch.Tensor) -> torch.Tensor:
""" XYZ to YcXcZ conversion.
Parameters
----------
image: Tensor
image: :class:`torch.Tensor`
The image tensor in XYZ format
Returns
-------
Tensor
:class:`torch.Tensor`
The image tensor in YcXcZ format
"""
image = image * self._inv_ref_illuminant
return K.concatenate([self._xyz_multipliers[0] * image[..., 1:2] - 16.,
self._xyz_multipliers[1:] * (image[..., :2] - image[..., 1:3])],
axis=-1)
return ops.concatenate([self._xyz_multipliers[0] * image[..., 1:2] - 16.,
self._xyz_multipliers[1:] * (image[..., :2] - image[..., 1:3])],
axis=-1)

def _ycxcz_to_rgb(self, image: Tensor) -> Tensor:
def _ycxcz_to_rgb(self, image: torch.Tensor) -> torch.Tensor:
""" YcXcZ to RGB conversion.
Parameters
----------
image: Tensor
image: :class:`torch.Tensor`
The image tensor in YcXcZ format
Returns
-------
Tensor
:class:`torch.Tensor`
The image tensor in RGB format
"""
converted = self._ycxcz_to_xyz(image)
return self._xyz_to_rgb(converted)

def _ycxcz_to_xyz(self, image: Tensor) -> Tensor:
def _ycxcz_to_xyz(self, image: torch.Tensor) -> torch.Tensor:
""" YcXcZ to XYZ conversion.
Parameters
----------
image: Tensor
image: :class:`torch.Tensor`
The image tensor in YcXcZ format
Returns
-------
Tensor
:class:`torch.Tensor`
The image tensor in XYZ format
"""
ch_y = (image[..., 0:1] + 16.) / self._xyz_multipliers[0]
return K.concatenate([ch_y + (image[..., 1:2] / self._xyz_multipliers[1]),
ch_y,
ch_y - (image[..., 2:3] / self._xyz_multipliers[2])],
axis=-1) * self._ref_illuminant
return ops.concatenate([ch_y + (image[..., 1:2] / self._xyz_multipliers[1]),
ch_y,
ch_y - (image[..., 2:3] / self._xyz_multipliers[2])],
axis=-1) * self._ref_illuminant
Loading

0 comments on commit 38feb60

Please sign in to comment.