Skip to content

Commit

Permalink
lib.model.loss to keras3
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Mar 13, 2024
1 parent 0748b11 commit 8fc5863
Showing 1 changed file with 54 additions and 60 deletions.
114 changes: 54 additions & 60 deletions lib/model/losses/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import typing as T

import numpy as np
import keras.backend as K
from keras.losses import Loss
from keras import ops
from keras import ops, Variable

import torch

if T.TYPE_CHECKING:
Expand Down Expand Up @@ -89,7 +89,7 @@ def _get_patches(self, inputs: torch.Tensor) -> torch.Tensor:
col_to = (j + 1) * patch_cols
patch_list.append(inputs[:, row_from: row_to, col_from: col_to, :])

retval = K.stack(patch_list, axis=1)
retval = ops.stack(patch_list, axis=1)
return retval

def _tensor_to_frequency_spectrum(self, patch: torch.Tensor) -> torch.Tensor:
Expand All @@ -105,20 +105,10 @@ def _tensor_to_frequency_spectrum(self, patch: torch.Tensor) -> torch.Tensor:
:class:`torch.Tensor`
The DFT frequencies split into real and imaginary numbers as float32
"""
# TODO fix this for when self._patch_factor != 1.
rows, cols = self._dims
patch = K.permute_dimensions(patch, (0, 1, 4, 2, 3)) # move channels to first

patch = patch / np.sqrt(rows * cols) # Orthonormalization

patch = K.cast(patch, "complex64")
freq = tf.signal.fft2d(patch)[..., None]

freq = K.concatenate([tf.math.real(freq), tf.math.imag(freq)], axis=-1)
freq = K.cast(freq, "float32")

freq = K.permute_dimensions(freq, (0, 1, 3, 4, 2, 5)) # channels to last

patch = ops.transpose(patch, (0, 1, 4, 2, 3)) # move channels to first
freq = torch.fft.fft2(patch, norm="ortho")
freq = ops.stack([freq.real, freq.imag], axis=-1)
freq = ops.transpose(freq, (0, 1, 3, 4, 2, 5)) # channels to last
return freq

def _get_weight_matrix(self, freq_true: torch.Tensor, freq_pred: torch.Tensor) -> torch.Tensor:
Expand All @@ -136,20 +126,20 @@ def _get_weight_matrix(self, freq_true: torch.Tensor, freq_pred: torch.Tensor) -
:class:`torch.Tensor`
The weights matrix for prioritizing hard frequencies
"""
weights = K.square(freq_pred - freq_true)
weights = K.sqrt(weights[..., 0] + weights[..., 1])
weights = K.pow(weights, self._alpha)
weights = ops.square(freq_pred - freq_true)
weights = ops.sqrt(weights[..., 0] + weights[..., 1])
weights = ops.power(weights, self._alpha)

if self._log_matrix: # adjust the spectrum weight matrix by logarithm
weights = K.log(weights + 1.0)
weights = ops.log(weights + 1.0)

if self._batch_matrix: # calculate the spectrum weight matrix using batch-based statistics
weights = weights / K.max(weights)
weights = weights / ops.max(weights)
else:
weights = weights / K.max(K.max(weights, axis=-2), axis=-2)[..., None, None, :]
weights = weights / ops.max(ops.max(weights, axis=-2), axis=-2)[..., None, None, :]

weights = K.switch(tf.math.is_nan(weights), K.zeros_like(weights), weights)
weights = K.clip(weights, min_value=0.0, max_value=1.0)
weights = ops.where(torch.isnan(weights), ops.zeros_like(weights), weights)
weights = ops.clip(weights, x_min=0.0, x_max=1.0)

return weights

Expand All @@ -172,7 +162,7 @@ def _calculate_loss(cls,
The final loss matrix
"""

tmp = K.square(freq_pred - freq_true) # freq distance using squared Euclidean distance
tmp = ops.square(freq_pred - freq_true) # freq distance using squared Euclidean distance

freq_distance = tmp[..., 0] + tmp[..., 1]
loss = weight_matrix * freq_distance # dynamic spectrum weighting (Hadamard product)
Expand All @@ -195,7 +185,7 @@ def __call__(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
The loss for this batch of images
"""
if not all(self._dims):
rows, cols = K.int_shape(y_true)[1:3]
rows, cols = y_true.shape[1:3]
assert cols % self._patch_factor == 0 and rows % self._patch_factor == 0, (
"Patch factor must be a divisor of the image height and width")
self._dims = (rows, cols)
Expand All @@ -207,8 +197,8 @@ def __call__(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
freq_pred = self._tensor_to_frequency_spectrum(patches_pred)

if self._ave_spectrum: # whether to use minibatch average spectrum
freq_true = K.mean(freq_true, axis=0, keepdims=True)
freq_pred = K.mean(freq_pred, axis=0, keepdims=True)
freq_true = ops.mean(freq_true, axis=0, keepdims=True)
freq_pred = ops.mean(freq_pred, axis=0, keepdims=True)

weight_matrix = self._get_weight_matrix(freq_true, freq_pred)
return self._calculate_loss(freq_true, freq_pred, weight_matrix)
Expand Down Expand Up @@ -256,10 +246,10 @@ def __call__(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
The loss value from the results of function(y_pred - y_true)
"""
diff = y_pred - y_true
second = (K.pow(K.pow(diff/self._beta, 2.) / K.abs(2. - self._alpha) + 1.,
(self._alpha / 2.)) - 1.)
loss = (K.abs(2. - self._alpha)/self._alpha) * second
loss = K.mean(loss, axis=-1) * self._beta
second = (ops.power(ops.power(diff/self._beta, 2.) / ops.abs(2. - self._alpha) + 1.,
(self._alpha / 2.)) - 1.)
loss = (ops.abs(2. - self._alpha)/self._alpha) * second
loss = ops.mean(loss, axis=-1) * self._beta
return loss


Expand Down Expand Up @@ -317,7 +307,7 @@ def _diff_x(cls, img: torch.Tensor) -> torch.Tensor:
x_left = img[:, :, 1:2, :] - img[:, :, 0:1, :]
x_inner = img[:, :, 2:, :] - img[:, :, :-2, :]
x_right = img[:, :, -1:, :] - img[:, :, -2:-1, :]
x_out = K.concatenate([x_left, x_inner, x_right], axis=2)
x_out = ops.concatenate([x_left, x_inner, x_right], axis=2)
return x_out * 0.5

@classmethod
Expand All @@ -326,7 +316,7 @@ def _diff_y(cls, img: torch.Tensor) -> torch.Tensor:
y_top = img[:, 1:2, :, :] - img[:, 0:1, :, :]
y_inner = img[:, 2:, :, :] - img[:, :-2, :, :]
y_bot = img[:, -1:, :, :] - img[:, -2:-1, :, :]
y_out = K.concatenate([y_top, y_inner, y_bot], axis=1)
y_out = ops.concatenate([y_top, y_inner, y_bot], axis=1)
return y_out * 0.5

@classmethod
Expand All @@ -335,7 +325,7 @@ def _diff_xx(cls, img: torch.Tensor) -> torch.Tensor:
x_left = img[:, :, 1:2, :] + img[:, :, 0:1, :]
x_inner = img[:, :, 2:, :] + img[:, :, :-2, :]
x_right = img[:, :, -1:, :] + img[:, :, -2:-1, :]
x_out = K.concatenate([x_left, x_inner, x_right], axis=2)
x_out = ops.concatenate([x_left, x_inner, x_right], axis=2)
return x_out - 2.0 * img

@classmethod
Expand All @@ -344,7 +334,7 @@ def _diff_yy(cls, img: torch.Tensor) -> torch.Tensor:
y_top = img[:, 1:2, :, :] + img[:, 0:1, :, :]
y_inner = img[:, 2:, :, :] + img[:, :-2, :, :]
y_bot = img[:, -1:, :, :] + img[:, -2:-1, :, :]
y_out = K.concatenate([y_top, y_inner, y_bot], axis=1)
y_out = ops.concatenate([y_top, y_inner, y_bot], axis=1)
return y_out - 2.0 * img

@classmethod
Expand All @@ -355,37 +345,37 @@ def _diff_xy(cls, img: torch.Tensor) -> torch.Tensor:
top = img[:, 1:2, 1:2, :] + img[:, 0:1, 0:1, :]
inner = img[:, 2:, 1:2, :] + img[:, :-2, 0:1, :]
bottom = img[:, -1:, 1:2, :] + img[:, -2:-1, 0:1, :]
xy_left = K.concatenate([top, inner, bottom], axis=1)
xy_left = ops.concatenate([top, inner, bottom], axis=1)
# Mid
top = img[:, 1:2, 2:, :] + img[:, 0:1, :-2, :]
mid = img[:, 2:, 2:, :] + img[:, :-2, :-2, :]
bottom = img[:, -1:, 2:, :] + img[:, -2:-1, :-2, :]
xy_mid = K.concatenate([top, mid, bottom], axis=1)
xy_mid = ops.concatenate([top, mid, bottom], axis=1)
# Right
top = img[:, 1:2, -1:, :] + img[:, 0:1, -2:-1, :]
inner = img[:, 2:, -1:, :] + img[:, :-2, -2:-1, :]
bottom = img[:, -1:, -1:, :] + img[:, -2:-1, -2:-1, :]
xy_right = K.concatenate([top, inner, bottom], axis=1)
xy_right = ops.concatenate([top, inner, bottom], axis=1)

# Xout2
# Left
top = img[:, 0:1, 1:2, :] + img[:, 1:2, 0:1, :]
inner = img[:, :-2, 1:2, :] + img[:, 2:, 0:1, :]
bottom = img[:, -2:-1, 1:2, :] + img[:, -1:, 0:1, :]
xy_left = K.concatenate([top, inner, bottom], axis=1)
xy_left = ops.concatenate([top, inner, bottom], axis=1)
# Mid
top = img[:, 0:1, 2:, :] + img[:, 1:2, :-2, :]
mid = img[:, :-2, 2:, :] + img[:, 2:, :-2, :]
bottom = img[:, -2:-1, 2:, :] + img[:, -1:, :-2, :]
xy_mid = K.concatenate([top, mid, bottom], axis=1)
xy_mid = ops.concatenate([top, mid, bottom], axis=1)
# Right
top = img[:, 0:1, -1:, :] + img[:, 1:2, -2:-1, :]
inner = img[:, :-2, -1:, :] + img[:, 2:, -2:-1, :]
bottom = img[:, -2:-1, -1:, :] + img[:, -1:, -2:-1, :]
xy_right = K.concatenate([top, inner, bottom], axis=1)
xy_right = ops.concatenate([top, inner, bottom], axis=1)

xy_out1 = K.concatenate([xy_left, xy_mid, xy_right], axis=2)
xy_out2 = K.concatenate([xy_left, xy_mid, xy_right], axis=2)
xy_out1 = ops.concatenate([xy_left, xy_mid, xy_right], axis=2)
xy_out2 = ops.concatenate([xy_left, xy_mid, xy_right], axis=2)
return (xy_out1 - xy_out2) * 0.25


Expand Down Expand Up @@ -415,7 +405,9 @@ def __init__(self,
gaussian_size: int = 5,
gaussian_sigma: float = 1.0) -> None:
self._max_levels = max_levels
self._weights = K.constant([np.power(2., -2 * idx) for idx in range(max_levels + 1)])
self._weights = Variable([np.power(2., -2 * idx)
for idx in range(max_levels + 1)],
trainable=False)
self._gaussian_kernel = self._get_gaussian_kernel(gaussian_size, gaussian_sigma)

@classmethod
Expand All @@ -441,7 +433,7 @@ def _get_gaussian_kernel(cls, size: int, sigma: float) -> torch.Tensor:
kernel = np.exp(- x_2[:, None] - x_2[None, :])
kernel /= kernel.sum()
kernel = np.reshape(kernel, (size, size, 1, 1))
return K.constant(kernel)
return Variable(kernel, trainable=False)

def _conv_gaussian(self, inputs: torch.Tensor) -> torch.Tensor:
""" Perform Gaussian convolution on a batch of images.
Expand All @@ -456,19 +448,20 @@ def _conv_gaussian(self, inputs: torch.Tensor) -> torch.Tensor:
:class:`torch.Tensor`
The convolved images
"""
channels = K.int_shape(inputs)[-1]
gauss = K.tile(self._gaussian_kernel, (1, 1, 1, channels))
channels = inputs.shape[-1]
gauss = ops.tile(self._gaussian_kernel, (1, 1, 1, channels))

# TF doesn't implement replication padding like pytorch. This is an inefficient way to
# implement it for a square guassian kernel
# TODO Make this pure pytorch code
size = self._gaussian_kernel.shape[1] // 2
padded_inputs = inputs
for _ in range(size):
padded_inputs = tf.pad(padded_inputs, # noqa,pylint:disable=no-value-for-parameter,unexpected-keyword-arg
([0, 0], [1, 1], [1, 1], [0, 0]),
mode="SYMMETRIC")
padded_inputs = ops.pad(padded_inputs,
([0, 0], [1, 1], [1, 1], [0, 0]),
mode="symmetric")

retval = K.conv2d(padded_inputs, gauss, strides=1, padding="valid")
retval = ops.conv(padded_inputs, gauss, strides=1, padding="valid")
return retval

def _get_laplacian_pyramid(self, inputs: torch.Tensor) -> list[torch.Tensor]:
Expand All @@ -490,7 +483,7 @@ def _get_laplacian_pyramid(self, inputs: torch.Tensor) -> list[torch.Tensor]:
gauss = self._conv_gaussian(current)
diff = current - gauss
pyramid.append(diff)
current = K.pool2d(gauss, (2, 2), strides=(2, 2), padding="valid", pool_mode="avg")
current = ops.average_pool(gauss, (2, 2), strides=(2, 2), padding="valid")
pyramid.append(current)
return pyramid

Expand All @@ -512,9 +505,10 @@ def __call__(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
pyramid_true = self._get_laplacian_pyramid(y_true)
pyramid_pred = self._get_laplacian_pyramid(y_pred)

losses = K.stack([K.sum(K.abs(ppred - ptrue)) / K.cast(K.prod(K.shape(ptrue)), "float32")
for ptrue, ppred in zip(pyramid_true, pyramid_pred)])
loss = K.sum(losses * self._weights)
losses = ops.stack([ops.sum(ops.abs(ppred - ptrue)) / ops.cast(ops.prod(ops.shape(ptrue)),
"float32")
for ptrue, ppred in zip(pyramid_true, pyramid_pred)])
loss = ops.sum(losses * self._weights)

return loss

Expand All @@ -536,9 +530,9 @@ def __call__(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
:class:`torch.Tensor`
The loss value
"""
diff = K.abs(y_true - y_pred)
max_loss = K.max(diff, axis=(1, 2), keepdims=True)
loss = K.mean(max_loss, axis=-1)
diff = ops.abs(y_true - y_pred)
max_loss = ops.max(diff, axis=(1, 2), keepdims=True)
loss = ops.mean(max_loss, axis=-1)
return loss


Expand Down

0 comments on commit 8fc5863

Please sign in to comment.