From 8fc5863aef0eada9ad4fb8ea03e4cee97c957e43 Mon Sep 17 00:00:00 2001 From: torzdf <36920800+torzdf@users.noreply.github.com> Date: Wed, 13 Mar 2024 06:34:42 +0000 Subject: [PATCH] lib.model.loss to keras3 --- lib/model/losses/loss.py | 114 +++++++++++++++++++-------------------- 1 file changed, 54 insertions(+), 60 deletions(-) diff --git a/lib/model/losses/loss.py b/lib/model/losses/loss.py index 586693f51ee..bbeb7afa731 100644 --- a/lib/model/losses/loss.py +++ b/lib/model/losses/loss.py @@ -6,9 +6,9 @@ import typing as T import numpy as np -import keras.backend as K from keras.losses import Loss -from keras import ops +from keras import ops, Variable + import torch if T.TYPE_CHECKING: @@ -89,7 +89,7 @@ def _get_patches(self, inputs: torch.Tensor) -> torch.Tensor: col_to = (j + 1) * patch_cols patch_list.append(inputs[:, row_from: row_to, col_from: col_to, :]) - retval = K.stack(patch_list, axis=1) + retval = ops.stack(patch_list, axis=1) return retval def _tensor_to_frequency_spectrum(self, patch: torch.Tensor) -> torch.Tensor: @@ -105,20 +105,10 @@ def _tensor_to_frequency_spectrum(self, patch: torch.Tensor) -> torch.Tensor: :class:`torch.Tensor` The DFT frequencies split into real and imaginary numbers as float32 """ - # TODO fix this for when self._patch_factor != 1. - rows, cols = self._dims - patch = K.permute_dimensions(patch, (0, 1, 4, 2, 3)) # move channels to first - - patch = patch / np.sqrt(rows * cols) # Orthonormalization - - patch = K.cast(patch, "complex64") - freq = tf.signal.fft2d(patch)[..., None] - - freq = K.concatenate([tf.math.real(freq), tf.math.imag(freq)], axis=-1) - freq = K.cast(freq, "float32") - - freq = K.permute_dimensions(freq, (0, 1, 3, 4, 2, 5)) # channels to last - + patch = ops.transpose(patch, (0, 1, 4, 2, 3)) # move channels to first + freq = torch.fft.fft2(patch, norm="ortho") + freq = ops.stack([freq.real, freq.imag], axis=-1) + freq = ops.transpose(freq, (0, 1, 3, 4, 2, 5)) # channels to last return freq def _get_weight_matrix(self, freq_true: torch.Tensor, freq_pred: torch.Tensor) -> torch.Tensor: @@ -136,20 +126,20 @@ def _get_weight_matrix(self, freq_true: torch.Tensor, freq_pred: torch.Tensor) - :class:`torch.Tensor` The weights matrix for prioritizing hard frequencies """ - weights = K.square(freq_pred - freq_true) - weights = K.sqrt(weights[..., 0] + weights[..., 1]) - weights = K.pow(weights, self._alpha) + weights = ops.square(freq_pred - freq_true) + weights = ops.sqrt(weights[..., 0] + weights[..., 1]) + weights = ops.power(weights, self._alpha) if self._log_matrix: # adjust the spectrum weight matrix by logarithm - weights = K.log(weights + 1.0) + weights = ops.log(weights + 1.0) if self._batch_matrix: # calculate the spectrum weight matrix using batch-based statistics - weights = weights / K.max(weights) + weights = weights / ops.max(weights) else: - weights = weights / K.max(K.max(weights, axis=-2), axis=-2)[..., None, None, :] + weights = weights / ops.max(ops.max(weights, axis=-2), axis=-2)[..., None, None, :] - weights = K.switch(tf.math.is_nan(weights), K.zeros_like(weights), weights) - weights = K.clip(weights, min_value=0.0, max_value=1.0) + weights = ops.where(torch.isnan(weights), ops.zeros_like(weights), weights) + weights = ops.clip(weights, x_min=0.0, x_max=1.0) return weights @@ -172,7 +162,7 @@ def _calculate_loss(cls, The final loss matrix """ - tmp = K.square(freq_pred - freq_true) # freq distance using squared Euclidean distance + tmp = ops.square(freq_pred - freq_true) # freq distance using squared Euclidean distance freq_distance = tmp[..., 0] + tmp[..., 1] loss = weight_matrix * freq_distance # dynamic spectrum weighting (Hadamard product) @@ -195,7 +185,7 @@ def __call__(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: The loss for this batch of images """ if not all(self._dims): - rows, cols = K.int_shape(y_true)[1:3] + rows, cols = y_true.shape[1:3] assert cols % self._patch_factor == 0 and rows % self._patch_factor == 0, ( "Patch factor must be a divisor of the image height and width") self._dims = (rows, cols) @@ -207,8 +197,8 @@ def __call__(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: freq_pred = self._tensor_to_frequency_spectrum(patches_pred) if self._ave_spectrum: # whether to use minibatch average spectrum - freq_true = K.mean(freq_true, axis=0, keepdims=True) - freq_pred = K.mean(freq_pred, axis=0, keepdims=True) + freq_true = ops.mean(freq_true, axis=0, keepdims=True) + freq_pred = ops.mean(freq_pred, axis=0, keepdims=True) weight_matrix = self._get_weight_matrix(freq_true, freq_pred) return self._calculate_loss(freq_true, freq_pred, weight_matrix) @@ -256,10 +246,10 @@ def __call__(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: The loss value from the results of function(y_pred - y_true) """ diff = y_pred - y_true - second = (K.pow(K.pow(diff/self._beta, 2.) / K.abs(2. - self._alpha) + 1., - (self._alpha / 2.)) - 1.) - loss = (K.abs(2. - self._alpha)/self._alpha) * second - loss = K.mean(loss, axis=-1) * self._beta + second = (ops.power(ops.power(diff/self._beta, 2.) / ops.abs(2. - self._alpha) + 1., + (self._alpha / 2.)) - 1.) + loss = (ops.abs(2. - self._alpha)/self._alpha) * second + loss = ops.mean(loss, axis=-1) * self._beta return loss @@ -317,7 +307,7 @@ def _diff_x(cls, img: torch.Tensor) -> torch.Tensor: x_left = img[:, :, 1:2, :] - img[:, :, 0:1, :] x_inner = img[:, :, 2:, :] - img[:, :, :-2, :] x_right = img[:, :, -1:, :] - img[:, :, -2:-1, :] - x_out = K.concatenate([x_left, x_inner, x_right], axis=2) + x_out = ops.concatenate([x_left, x_inner, x_right], axis=2) return x_out * 0.5 @classmethod @@ -326,7 +316,7 @@ def _diff_y(cls, img: torch.Tensor) -> torch.Tensor: y_top = img[:, 1:2, :, :] - img[:, 0:1, :, :] y_inner = img[:, 2:, :, :] - img[:, :-2, :, :] y_bot = img[:, -1:, :, :] - img[:, -2:-1, :, :] - y_out = K.concatenate([y_top, y_inner, y_bot], axis=1) + y_out = ops.concatenate([y_top, y_inner, y_bot], axis=1) return y_out * 0.5 @classmethod @@ -335,7 +325,7 @@ def _diff_xx(cls, img: torch.Tensor) -> torch.Tensor: x_left = img[:, :, 1:2, :] + img[:, :, 0:1, :] x_inner = img[:, :, 2:, :] + img[:, :, :-2, :] x_right = img[:, :, -1:, :] + img[:, :, -2:-1, :] - x_out = K.concatenate([x_left, x_inner, x_right], axis=2) + x_out = ops.concatenate([x_left, x_inner, x_right], axis=2) return x_out - 2.0 * img @classmethod @@ -344,7 +334,7 @@ def _diff_yy(cls, img: torch.Tensor) -> torch.Tensor: y_top = img[:, 1:2, :, :] + img[:, 0:1, :, :] y_inner = img[:, 2:, :, :] + img[:, :-2, :, :] y_bot = img[:, -1:, :, :] + img[:, -2:-1, :, :] - y_out = K.concatenate([y_top, y_inner, y_bot], axis=1) + y_out = ops.concatenate([y_top, y_inner, y_bot], axis=1) return y_out - 2.0 * img @classmethod @@ -355,37 +345,37 @@ def _diff_xy(cls, img: torch.Tensor) -> torch.Tensor: top = img[:, 1:2, 1:2, :] + img[:, 0:1, 0:1, :] inner = img[:, 2:, 1:2, :] + img[:, :-2, 0:1, :] bottom = img[:, -1:, 1:2, :] + img[:, -2:-1, 0:1, :] - xy_left = K.concatenate([top, inner, bottom], axis=1) + xy_left = ops.concatenate([top, inner, bottom], axis=1) # Mid top = img[:, 1:2, 2:, :] + img[:, 0:1, :-2, :] mid = img[:, 2:, 2:, :] + img[:, :-2, :-2, :] bottom = img[:, -1:, 2:, :] + img[:, -2:-1, :-2, :] - xy_mid = K.concatenate([top, mid, bottom], axis=1) + xy_mid = ops.concatenate([top, mid, bottom], axis=1) # Right top = img[:, 1:2, -1:, :] + img[:, 0:1, -2:-1, :] inner = img[:, 2:, -1:, :] + img[:, :-2, -2:-1, :] bottom = img[:, -1:, -1:, :] + img[:, -2:-1, -2:-1, :] - xy_right = K.concatenate([top, inner, bottom], axis=1) + xy_right = ops.concatenate([top, inner, bottom], axis=1) # Xout2 # Left top = img[:, 0:1, 1:2, :] + img[:, 1:2, 0:1, :] inner = img[:, :-2, 1:2, :] + img[:, 2:, 0:1, :] bottom = img[:, -2:-1, 1:2, :] + img[:, -1:, 0:1, :] - xy_left = K.concatenate([top, inner, bottom], axis=1) + xy_left = ops.concatenate([top, inner, bottom], axis=1) # Mid top = img[:, 0:1, 2:, :] + img[:, 1:2, :-2, :] mid = img[:, :-2, 2:, :] + img[:, 2:, :-2, :] bottom = img[:, -2:-1, 2:, :] + img[:, -1:, :-2, :] - xy_mid = K.concatenate([top, mid, bottom], axis=1) + xy_mid = ops.concatenate([top, mid, bottom], axis=1) # Right top = img[:, 0:1, -1:, :] + img[:, 1:2, -2:-1, :] inner = img[:, :-2, -1:, :] + img[:, 2:, -2:-1, :] bottom = img[:, -2:-1, -1:, :] + img[:, -1:, -2:-1, :] - xy_right = K.concatenate([top, inner, bottom], axis=1) + xy_right = ops.concatenate([top, inner, bottom], axis=1) - xy_out1 = K.concatenate([xy_left, xy_mid, xy_right], axis=2) - xy_out2 = K.concatenate([xy_left, xy_mid, xy_right], axis=2) + xy_out1 = ops.concatenate([xy_left, xy_mid, xy_right], axis=2) + xy_out2 = ops.concatenate([xy_left, xy_mid, xy_right], axis=2) return (xy_out1 - xy_out2) * 0.25 @@ -415,7 +405,9 @@ def __init__(self, gaussian_size: int = 5, gaussian_sigma: float = 1.0) -> None: self._max_levels = max_levels - self._weights = K.constant([np.power(2., -2 * idx) for idx in range(max_levels + 1)]) + self._weights = Variable([np.power(2., -2 * idx) + for idx in range(max_levels + 1)], + trainable=False) self._gaussian_kernel = self._get_gaussian_kernel(gaussian_size, gaussian_sigma) @classmethod @@ -441,7 +433,7 @@ def _get_gaussian_kernel(cls, size: int, sigma: float) -> torch.Tensor: kernel = np.exp(- x_2[:, None] - x_2[None, :]) kernel /= kernel.sum() kernel = np.reshape(kernel, (size, size, 1, 1)) - return K.constant(kernel) + return Variable(kernel, trainable=False) def _conv_gaussian(self, inputs: torch.Tensor) -> torch.Tensor: """ Perform Gaussian convolution on a batch of images. @@ -456,19 +448,20 @@ def _conv_gaussian(self, inputs: torch.Tensor) -> torch.Tensor: :class:`torch.Tensor` The convolved images """ - channels = K.int_shape(inputs)[-1] - gauss = K.tile(self._gaussian_kernel, (1, 1, 1, channels)) + channels = inputs.shape[-1] + gauss = ops.tile(self._gaussian_kernel, (1, 1, 1, channels)) # TF doesn't implement replication padding like pytorch. This is an inefficient way to # implement it for a square guassian kernel + # TODO Make this pure pytorch code size = self._gaussian_kernel.shape[1] // 2 padded_inputs = inputs for _ in range(size): - padded_inputs = tf.pad(padded_inputs, # noqa,pylint:disable=no-value-for-parameter,unexpected-keyword-arg - ([0, 0], [1, 1], [1, 1], [0, 0]), - mode="SYMMETRIC") + padded_inputs = ops.pad(padded_inputs, + ([0, 0], [1, 1], [1, 1], [0, 0]), + mode="symmetric") - retval = K.conv2d(padded_inputs, gauss, strides=1, padding="valid") + retval = ops.conv(padded_inputs, gauss, strides=1, padding="valid") return retval def _get_laplacian_pyramid(self, inputs: torch.Tensor) -> list[torch.Tensor]: @@ -490,7 +483,7 @@ def _get_laplacian_pyramid(self, inputs: torch.Tensor) -> list[torch.Tensor]: gauss = self._conv_gaussian(current) diff = current - gauss pyramid.append(diff) - current = K.pool2d(gauss, (2, 2), strides=(2, 2), padding="valid", pool_mode="avg") + current = ops.average_pool(gauss, (2, 2), strides=(2, 2), padding="valid") pyramid.append(current) return pyramid @@ -512,9 +505,10 @@ def __call__(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: pyramid_true = self._get_laplacian_pyramid(y_true) pyramid_pred = self._get_laplacian_pyramid(y_pred) - losses = K.stack([K.sum(K.abs(ppred - ptrue)) / K.cast(K.prod(K.shape(ptrue)), "float32") - for ptrue, ppred in zip(pyramid_true, pyramid_pred)]) - loss = K.sum(losses * self._weights) + losses = ops.stack([ops.sum(ops.abs(ppred - ptrue)) / ops.cast(ops.prod(ops.shape(ptrue)), + "float32") + for ptrue, ppred in zip(pyramid_true, pyramid_pred)]) + loss = ops.sum(losses * self._weights) return loss @@ -536,9 +530,9 @@ def __call__(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: :class:`torch.Tensor` The loss value """ - diff = K.abs(y_true - y_pred) - max_loss = K.max(diff, axis=(1, 2), keepdims=True) - loss = K.mean(max_loss, axis=-1) + diff = ops.abs(y_true - y_pred) + max_loss = ops.max(diff, axis=(1, 2), keepdims=True) + loss = ops.mean(max_loss, axis=-1) return loss