diff --git a/lib/keras_utils.py b/lib/keras_utils.py index b0f92e4fac..75ae6b39c4 100644 --- a/lib/keras_utils.py +++ b/lib/keras_utils.py @@ -5,23 +5,24 @@ import numpy as np -from keras import ops, Variable +from keras import ops +from keras.backend.exports import Variable if T.TYPE_CHECKING: - import torch + from keras import KerasTensor # TODO these can probably be switched to pure pytorch -def frobenius_norm(matrix: torch.Tensor, +def frobenius_norm(matrix: KerasTensor, axis: int = -1, keep_dims: bool = True, - epsilon: float = 1e-15) -> torch.Tensor: + epsilon: float = 1e-15) -> KerasTensor: """ Frobenius normalization for Keras Tensor Parameters ---------- - matrix: :class:`torch.Tensor` + matrix: :class:`keras.KerasTensor` The matrix to normalize axis: int, optional The axis to normalize. Default: `-1` @@ -32,13 +33,13 @@ def frobenius_norm(matrix: torch.Tensor, Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The normalized output """ return ops.sqrt(ops.sum(ops.power(matrix, 2), axis=axis, keepdims=keep_dims) + epsilon) -def replicate_pad(image: torch.Tensor, padding: int) -> torch.Tensor: +def replicate_pad(image: KerasTensor, padding: int) -> KerasTensor: """ Apply replication padding to an input batch of images. Expects 4D tensor in BHWC format. Notes @@ -49,14 +50,14 @@ def replicate_pad(image: torch.Tensor, padding: int) -> torch.Tensor: Parameters ---------- - image: :class:`torch.Tensor` + image: :class:`keras.KerasTensor` Image tensor to pad pad: int The amount of padding to apply to each side of the input image Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The input image with replication padding applied """ top_pad = ops.tile(image[:, :1, ...], (1, padding, 1, 1)) @@ -120,7 +121,7 @@ def __init__(self, from_space: str, to_space: str) -> None: self._xyz_multipliers = Variable([116, 500, 200], dtype="float32", trainable=False) @classmethod - def _get_rgb_xyz_map(cls) -> tuple[torch.Tensor, torch.Tensor]: + def _get_rgb_xyz_map(cls) -> tuple[KerasTensor, KerasTensor]: """ Obtain the mapping and inverse mapping for rgb to xyz color space conversion. Returns @@ -135,38 +136,38 @@ def _get_rgb_xyz_map(cls) -> tuple[torch.Tensor, torch.Tensor]: return (Variable(mapping, dtype="float32", trainable=False), Variable(inverse, dtype="float32", trainable=False)) - def __call__(self, image: torch.Tensor) -> torch.Tensor: + def __call__(self, image: KerasTensor) -> KerasTensor: """ Call the colorspace conversion function. Parameters ---------- - image: :class:`torch.Tensor` + image: :class:`keras.KerasTensor` The image tensor in the colorspace defined by :param:`from_space` Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The image tensor in the colorspace defined by :param:`to_space` """ return self._func(image) - def _rgb_to_lab(self, image: torch.Tensor) -> torch.Tensor: + def _rgb_to_lab(self, image: KerasTensor) -> KerasTensor: """ RGB to LAB conversion. Parameters ---------- - image: :class:`torch.Tensor` + image: :class:`keras.KerasTensor` The image tensor in RGB format Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The image tensor in LAB format """ converted = self._rgb_to_xyz(image) return self._xyz_to_lab(converted) - def _rgb_xyz_rgb(self, image: torch.Tensor, mapping: torch.Tensor) -> torch.Tensor: + def _rgb_xyz_rgb(self, image: KerasTensor, mapping: KerasTensor) -> KerasTensor: """ RGB to XYZ or XYZ to RGB conversion. Notes @@ -180,16 +181,16 @@ def _rgb_xyz_rgb(self, image: torch.Tensor, mapping: torch.Tensor) -> torch.Tens Parameters ---------- - mapping: :class:`torch.Tensor` + mapping: :class:`keras.KerasTensor` The mapping matrix to perform either the XYZ to RGB or RGB to XYZ color space conversion - image: :class:`torch.Tensor` + image: :class:`keras.KerasTensor` The image tensor in RGB format Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The image tensor in XYZ format """ dim = image.shape @@ -198,23 +199,23 @@ def _rgb_xyz_rgb(self, image: torch.Tensor, mapping: torch.Tensor) -> torch.Tens converted = ops.transpose(ops.dot(mapping, image), (0, 2, 1)) return ops.reshape(converted, dim) - def _rgb_to_xyz(self, image: torch.Tensor) -> torch.Tensor: + def _rgb_to_xyz(self, image: KerasTensor) -> KerasTensor: """ RGB to XYZ conversion. Parameters ---------- - image: :class:`torch.Tensor` + image: :class:`keras.KerasTensor` The image tensor in RGB format Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The image tensor in XYZ format """ return self._rgb_xyz_rgb(image, self._rgb_xyz_map[0]) @classmethod - def _srgb_to_rgb(cls, image: torch.Tensor) -> torch.Tensor: + def _srgb_to_rgb(cls, image: KerasTensor) -> KerasTensor: """ SRGB to RGB conversion. Notes @@ -223,12 +224,12 @@ def _srgb_to_rgb(cls, image: torch.Tensor) -> torch.Tensor: Parameters ---------- - image: :class:`torch.Tensor` + image: :class:`keras.KerasTensor` The image tensor in SRGB format Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The image tensor in RGB format """ limit = np.float32(0.04045) @@ -236,34 +237,34 @@ def _srgb_to_rgb(cls, image: torch.Tensor) -> torch.Tensor: ops.power((ops.clip(image, limit, np.inf) + 0.055) / 1.055, 2.4), image / 12.92) - def _srgb_to_ycxcz(self, image: torch.Tensor) -> torch.Tensor: + def _srgb_to_ycxcz(self, image: KerasTensor) -> KerasTensor: """ SRGB to YcXcZ conversion. Parameters ---------- - image: :class:`torch.Tensor` + image: :class:`keras.KerasTensor` The image tensor in SRGB format Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The image tensor in YcXcZ format """ converted = self._srgb_to_rgb(image) converted = self._rgb_to_xyz(converted) return self._xyz_to_ycxcz(converted) - def _xyz_to_lab(self, image: torch.Tensor) -> torch.Tensor: + def _xyz_to_lab(self, image: KerasTensor) -> KerasTensor: """ XYZ to LAB conversion. Parameters ---------- - image: :class:`torch.Tensor` + image: :class:`keras.KerasTensor` The image tensor in XYZ format Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The image tensor in LAB format """ image = image * self._inv_ref_illuminant @@ -280,32 +281,32 @@ def _xyz_to_lab(self, image: torch.Tensor) -> torch.Tensor: self._xyz_multipliers[1:] * (image[..., :2] - image[..., 1:3])], axis=-1) - def _xyz_to_rgb(self, image: torch.Tensor) -> torch.Tensor: + def _xyz_to_rgb(self, image: KerasTensor) -> KerasTensor: """ XYZ to YcXcZ conversion. Parameters ---------- - image: :class:`torch.Tensor` + image: :class:`keras.KerasTensor` The image tensor in XYZ format Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The image tensor in RGB format """ return self._rgb_xyz_rgb(image, self._rgb_xyz_map[1]) - def _xyz_to_ycxcz(self, image: torch.Tensor) -> torch.Tensor: + def _xyz_to_ycxcz(self, image: KerasTensor) -> KerasTensor: """ XYZ to YcXcZ conversion. Parameters ---------- - image: :class:`torch.Tensor` + image: :class:`keras.KerasTensor` The image tensor in XYZ format Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The image tensor in YcXcZ format """ image = image * self._inv_ref_illuminant @@ -313,33 +314,33 @@ def _xyz_to_ycxcz(self, image: torch.Tensor) -> torch.Tensor: self._xyz_multipliers[1:] * (image[..., :2] - image[..., 1:3])], axis=-1) - def _ycxcz_to_rgb(self, image: torch.Tensor) -> torch.Tensor: + def _ycxcz_to_rgb(self, image: KerasTensor) -> KerasTensor: """ YcXcZ to RGB conversion. Parameters ---------- - image: :class:`torch.Tensor` + image: :class:`keras.KerasTensor` The image tensor in YcXcZ format Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The image tensor in RGB format """ converted = self._ycxcz_to_xyz(image) return self._xyz_to_rgb(converted) - def _ycxcz_to_xyz(self, image: torch.Tensor) -> torch.Tensor: + def _ycxcz_to_xyz(self, image: KerasTensor) -> KerasTensor: """ YcXcZ to XYZ conversion. Parameters ---------- - image: :class:`torch.Tensor` + image: :class:`keras.KerasTensor` The image tensor in YcXcZ format Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The image tensor in XYZ format """ ch_y = (image[..., 0:1] + 16.) / self._xyz_multipliers[0] diff --git a/lib/model/autoclip.py b/lib/model/autoclip.py index 2508aab0b6..a58bb79d44 100644 --- a/lib/model/autoclip.py +++ b/lib/model/autoclip.py @@ -1,11 +1,17 @@ """ Auto clipper for clipping gradients. """ +from __future__ import annotations + import logging +import typing as T import numpy as np import torch from lib.logger import parse_class_init +if T.TYPE_CHECKING: + from keras import KerasTensor + logger = logging.getLogger(__name__) @@ -29,26 +35,26 @@ def __init__(self, clip_percentile: int, history_size: int = 10000) -> None: self._clip_percentile = clip_percentile self._history_size = history_size - self._grad_history = [] + self._grad_history: list[float] = [] logger.debug("Initialized %s", self.__class__.__name__) - def __call__(self, gradients: list[torch.Tensor]) -> list[torch.Tensor]: + def __call__(self, gradients: list[KerasTensor]) -> list[KerasTensor]: """ Call the AutoClip function. Parameters ---------- - gradients: list[:class:`torch.Tensor`] + gradients: list[:class:`keras.KerasTensor`] The list of gradient tensors for the optimizer Returns ---------- - list[:class:`torch.Tensor`] + list[:class:`keras.KerasTensor`] The autoclipped gradients """ self._grad_history.append(sum(g.data.norm(2).item() ** 2 for g in gradients if g is not None) ** (1. / 2)) self._grad_history = self._grad_history[-self._history_size:] clip_value = np.percentile(self._grad_history, self._clip_percentile) - torch.nn.utils.clip_grad_norm_(gradients, clip_value) + torch.nn.utils.clip_grad_norm_(gradients, T.cast(float, clip_value)) return gradients diff --git a/lib/model/initializers.py b/lib/model/initializers.py index 45ab6a1317..e111b54db0 100644 --- a/lib/model/initializers.py +++ b/lib/model/initializers.py @@ -7,8 +7,10 @@ import inspect import typing as T -from keras import initializers, ops, Variable -from keras.src.initializers.random_initializers import compute_fans +from keras import initializers, ops +from keras.backend.common.variables import KerasVariable +from keras.backend import floatx +from keras.initializers.random_initializers import compute_fans from keras.saving import get_custom_objects import numpy as np @@ -16,7 +18,7 @@ from lib.logger import parse_class_init if T.TYPE_CHECKING: - from torch import Tensor + from keras import KerasTensor logger = logging.getLogger(__name__) @@ -34,7 +36,7 @@ class ICNR(initializers.Initializer): Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The modified kernel weights Example @@ -59,7 +61,7 @@ def __init__(self, def __call__(self, shape: list[int] | tuple[int, ...], dtype: str = "float32", - **kwargs) -> Tensor: + **kwargs) -> KerasTensor: """ Call function for the ICNR initializer. Parameters @@ -73,12 +75,16 @@ def __call__(self, Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The modified kernel weights """ shape = list(shape) + if self._scale == 1: - return self._initializer(shape) + initializer = self._initializer + if isinstance(initializer, dict): + initializer = next(i for i in self._initializer.values()) + return initializer(shape) new_shape = shape[:3] + [shape[3] // (self._scale ** 2)] size = [s * self._scale for s in new_shape[:2]] @@ -86,7 +92,7 @@ def __call__(self, if isinstance(self._initializer, dict): self._initializer = initializers.deserialize(self._initializer) - var_x: Tensor = self._initializer(new_shape, dtype) + var_x: KerasTensor = self._initializer(new_shape, dtype) var_x = ops.transpose(var_x, [2, 0, 1, 3]) var_x = ops.image.resize(var_x, size, @@ -98,17 +104,17 @@ def __call__(self, logger.debug("ICNR Output shape: %s", var_x.shape) return var_x - def _space_to_depth(self, input_tensor: Tensor) -> Tensor: + def _space_to_depth(self, input_tensor: KerasTensor) -> KerasTensor: """ Space to depth Keras implementation. Parameters ---------- - input_tensor: :class:`torch.Tensor` + input_tensor: :class:`keras.KerasTensor` The tensor to be manipulated Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The manipulated input tensor """ batch, height, width, depth = input_tensor.shape @@ -158,7 +164,7 @@ class ConvolutionAware(initializers.Initializer): Returns ------- - :class:`keras.Variable` + :class:`keras.backend.common.variables.KerasVariable` The modified kernel weights References @@ -174,8 +180,8 @@ def __init__(self, self._eps_std = eps_std self._seed = seed - self._orthogonal = initializers.Orthogonal() - self._he_uniform = initializers.he_uniform() + self._orthogonal = initializers.OrthogonalInitializer() + self._he_uniform = initializers.HeUniform() self._initialized = initialized logger.debug("Initialized %s", self.__class__.__name__) @@ -254,7 +260,7 @@ def _scale_filters(cls, filters: np.ndarray, variance: float) -> np.ndarray: def __call__(self, shape: list[int] | tuple[int, ...], - dtype: str | None = None, **kwargs) -> Variable: + dtype: str | None = None, **kwargs) -> KerasVariable: """ Call function for the ICNR initializer. Parameters @@ -266,11 +272,12 @@ def __call__(self, Returns ------- - :class:`keras.Variable` + :class:`keras.backend.common.variables.KerasVariable` The modified kernel weights """ if self._initialized: # Avoid re-calculating initializer when loading a saved model return self._he_uniform(shape, dtype=dtype) + dtype = floatx if dtype is None else dtype logger.info("Calculating Convolution Aware Initializer for shape: %s", shape) rank = len(shape) if self._seed is not None: @@ -279,6 +286,11 @@ def __call__(self, fan_in, _ = compute_fans(shape) variance = 2 / fan_in + kernel_shape: tuple[int, ...] + transpose_dimensions: tuple[int, ...] + correct_ifft: T.Callable + correct_fft: T.Callable + if rank == 3: row, stack_size, filters_size = shape @@ -306,17 +318,22 @@ def __call__(self, else: self._initialized = True - return Variable(self._orthogonal(shape), dtype=dtype) + return KerasVariable(self._orthogonal(shape), dtype=dtype) kernel_fourier_shape = correct_fft(np.zeros(kernel_shape)).shape - basis = self._create_basis(filters_size, stack_size, np.prod(kernel_fourier_shape), dtype) + basis = self._create_basis(filters_size, + stack_size, + T.cast(int, np.prod(kernel_fourier_shape)), + dtype) basis = basis.reshape((filters_size, stack_size,) + kernel_fourier_shape) randoms = np.random.normal(0, self._eps_std, basis.shape[:-2] + kernel_shape) init = correct_ifft(basis, kernel_shape) + randoms init = self._scale_filters(init, variance) self._initialized = True - retval = Variable(init.transpose(transpose_dimensions), dtype=dtype, name="conv_aware") + retval = KerasVariable(init.transpose(transpose_dimensions), + dtype=dtype, + name="conv_aware") logger.debug("ConvAware output: %s", retval) return retval diff --git a/lib/model/layers.py b/lib/model/layers.py index c761089490..9abbfe3687 100644 --- a/lib/model/layers.py +++ b/lib/model/layers.py @@ -7,29 +7,32 @@ import inspect import typing as T -import keras -from keras.saving import get_custom_objects from keras import ops +from keras.layers.input_spec import InputSpec +from keras.layers.layer import Layer +from keras.saving import get_custom_objects from lib.logger import parse_class_init if T.TYPE_CHECKING: - import torch + from keras import KerasTensor + logger = logging.getLogger(__name__) -class _GlobalPooling2D(keras.layers.Layer): +class _GlobalPooling2D(Layer): # pylint:disable=too-many-ancestors """Abstract class for different global pooling 2D layers. """ def __init__(self, data_format: str | None = None, **kwargs) -> None: logger.debug(parse_class_init(locals())) super().__init__(**kwargs) self.data_format = "channels_last" if data_format is None else data_format - self.input_spec = keras.layers.InputSpec(ndim=4) + self.input_spec = InputSpec(ndim=4) logger.debug("Initialized %s", self.__class__.__name__) - def compute_output_shape(self, input_shape): + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> tuple[int, ...]: """ Compute the output shape based on the input shape. Parameters @@ -41,13 +44,20 @@ def compute_output_shape(self, input_shape): return (input_shape[0], input_shape[3]) return (input_shape[0], input_shape[1]) - def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: """ Override to call the layer. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` The input to the layer + + Returns + ------- + :class:`keras.KerasTensor` + The output from the layer + """ raise NotImplementedError @@ -58,20 +68,21 @@ def get_config(self) -> dict[str, T.Any]: return dict(list(base_config.items()) + list(config.items())) -class GlobalMinPooling2D(_GlobalPooling2D): +class GlobalMinPooling2D(_GlobalPooling2D): # pylint:disable=too-many-ancestors,abstract-method """Global minimum pooling operation for spatial data. """ - def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: """This is where the layer's logic lives. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` Input tensor, or list/tuple of input tensors Returns ------- - tensor + :class:`keras.KerasTensor` A tensor or list/tuple of tensors """ if self.data_format == "channels_last": @@ -81,20 +92,21 @@ def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: return pooled -class GlobalStdDevPooling2D(_GlobalPooling2D): +class GlobalStdDevPooling2D(_GlobalPooling2D): # pylint:disable=too-many-ancestors,abstract-method """Global standard deviation pooling operation for spatial data. """ - def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: """This is where the layer's logic lives. Parameters ---------- - inputs: tensor + inputs: :class:`keras.KerasTensor` Input tensor, or list/tuple of input tensors Returns ------- - tensor + :class:`keras.KerasTensor` A tensor or list/tuple of tensors """ if self.data_format == "channels_last": @@ -104,7 +116,7 @@ def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: return pooled -class KResizeImages(keras.layers.Layer): +class KResizeImages(Layer): # pylint:disable=too-many-ancestors,abstract-method """ A custom upscale function that uses :class:`keras.backend.resize_images` to upsample. Parameters @@ -126,17 +138,18 @@ def __init__(self, self.interpolation = interpolation logger.debug("Initialized %s", self.__class__.__name__) - def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: """ Call the upsample layer Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` Input tensor, or list/tuple of input tensors Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` A tensor or list/tuple of tensors """ size = int(round(inputs.shape[1] * self.size)), int(round(inputs.shape[2] * self.size)) @@ -146,7 +159,8 @@ def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: data_format="channels_last") return retval - def compute_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]: + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> tuple[int, ...]: """Computes the output shape of the layer. This is the input shape with size dimensions multiplied by :attr:`size` @@ -178,7 +192,7 @@ def get_config(self) -> dict[str, T.Any]: return dict(list(base_config.items()) + list(config.items())) -class L2Normalize(keras.layers.Layer): +class L2Normalize(Layer): # pylint:disable=too-many-ancestors,abstract-method """ Normalizes a tensor w.r.t. the L2 norm alongside the specified axis. Parameters @@ -194,20 +208,21 @@ def __init__(self, axis: int, **kwargs) -> None: super().__init__(**kwargs) logger.debug("Initialized %s", self.__class__.__name__) - def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: """This is where the layer's logic lives. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` Input tensor, or list/tuple of input tensors Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` A tensor or list/tuple of tensors """ - return keras.ops.normalize(inputs, self.axis, order=2) + return ops.normalize(inputs, self.axis, order=2) def get_config(self) -> dict[str, T.Any]: """Returns the config of the layer. @@ -229,7 +244,7 @@ class name. These are handled by `Network` (one layer of abstraction above). return config -class PixelShuffler(keras.layers.Layer): +class PixelShuffler(Layer): # pylint:disable=too-many-ancestors,abstract-method """ PixelShuffler layer for Keras. This layer requires a Convolution2D prior to it, having output filters computed according to @@ -278,17 +293,18 @@ def __init__(self, self.size = (size, size) if isinstance(size, int) else tuple(size) logger.debug("Initialized %s", self.__class__.__name__) - def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: """This is where the layer's logic lives. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` Input tensor, or list/tuple of input tensors Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` A tensor or list/tuple of tensors """ input_shape = inputs.shape @@ -321,7 +337,8 @@ def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: out = ops.reshape(out, (batch_size, o_height, o_width, o_channels)) return out - def compute_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]: + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> tuple[int, ...]: """Computes the output shape of the layer. Assumes that the layer will be built to match that input shape provided. @@ -398,7 +415,7 @@ class name. These are handled by `Network` (one layer of abstraction above). return dict(list(base_config.items()) + list(config.items())) -class QuickGELU(keras.layers.Layer): +class QuickGELU(Layer): # pylint:disable=too-many-ancestors,abstract-method """ Applies GELU approximation that is fast but somewhat inaccurate. Parameters @@ -408,29 +425,29 @@ class QuickGELU(keras.layers.Layer): kwargs: dict The standard Keras Layer keyword arguments (if any) """ - def __init__(self, name: str = "QuickGELU", **kwargs) -> None: logger.debug(parse_class_init(locals())) super().__init__(name=name, **kwargs) logger.debug("Initialized %s", self.__class__.__name__) - def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: """ Call the QuickGELU layerr Parameters ---------- - inputs : :class:`torch.Tensor` + inputs : :class:`keras.KerasTensor` The input Tensor Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The output Tensor """ return inputs * ops.sigmoid(1.702 * inputs) -class ReflectionPadding2D(keras.layers.Layer): +class ReflectionPadding2D(Layer): # pylint:disable=too-many-ancestors,abstract-method """Reflection-padding layer for 2D input (e.g. picture). This layer can add rows and columns at the top, bottom, left and right side of an image tensor. @@ -452,26 +469,27 @@ def __init__(self, stride: int = 2, kernel_size: int = 5, **kwargs) -> None: stride = stride[0] self.stride = stride self.kernel_size = kernel_size - self.input_spec: list[torch.Tensor] | None = None + self.input_spec: list[KerasTensor] | None = None super().__init__(**kwargs) logger.debug("Initialized %s", self.__class__.__name__) - def build(self, input_shape: torch.Tensor) -> None: + def build(self, input_shape: KerasTensor) -> None: """Creates the layer weights. Must be implemented on all layers that have weights. Parameters ---------- - input_shape: :class:`torch.Tensor` + input_shape: :class:`keras.KerasTensor` Keras tensor (future input to layer) or ``list``/``tuple`` of Keras tensors to reference for weight shape computations. """ - self.input_spec = [keras.layers.InputSpec(shape=input_shape)] + self.input_spec = [InputSpec(shape=input_shape)] super().build(input_shape) - def compute_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]: + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> tuple[int, ...]: """Computes the output shape of the layer. Assumes that the layer will be built to match that input shape provided. @@ -506,17 +524,18 @@ def compute_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]: input_shape[2] + padding_width, input_shape[3]) - def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: """This is where the layer's logic lives. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` Input tensor, or list/tuple of input tensors Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` A tensor or list/tuple of tensors """ assert self.input_spec is not None @@ -563,7 +582,7 @@ class name. These are handled by `Network` (one layer of abstraction above). return dict(list(base_config.items()) + list(config.items())) -class Swish(keras.layers.Layer): +class Swish(Layer): # pylint:disable=too-many-ancestors,abstract-method """ Swish Activation Layer implementation for Keras. Parameters @@ -583,7 +602,8 @@ def __init__(self, beta: float = 1.0, **kwargs) -> None: self.beta = beta logger.debug("Initialized %s", self.__class__.__name__) - def call(self, inputs, *args, **kwargs): + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: """ Call the Swish Activation function. Parameters @@ -593,7 +613,7 @@ def call(self, inputs, *args, **kwargs): Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` A tensor or list/tuple of tensors """ return ops.nn.swish(inputs * self.beta) diff --git a/lib/model/losses/feature_loss.py b/lib/model/losses/feature_loss.py index 94df8b8bb8..b593c1fed9 100644 --- a/lib/model/losses/feature_loss.py +++ b/lib/model/losses/feature_loss.py @@ -6,7 +6,8 @@ import typing as T import keras -from keras import applications as kapp, ops, Variable +from keras import applications as kapp, ops +from keras.backend.common import KerasVariable from keras.layers import Dropout, Conv2D, Input, Layer, Resizing from keras.models import Model @@ -18,7 +19,7 @@ if T.TYPE_CHECKING: from collections.abc import Callable - import torch + from keras import KerasTensor logger = logging.getLogger(__name__) @@ -90,12 +91,12 @@ def _nets(self) -> dict[str, NetInfo]: outputs=[f"block{i + 1}_conv{2 if i < 2 else 3}" for i in range(5)])} @classmethod - def _normalize_output(cls, inputs: torch.Tensor, epsilon: float = 1e-10) -> torch.Tensor: + def _normalize_output(cls, inputs: KerasTensor, epsilon: float = 1e-10) -> KerasTensor: """ Normalize the output tensors from the trunk network. Parameters ---------- - inputs: :class:`tensorflow.Tensor` + inputs: :class:`keras.KerasTensor` An output tensor from the trunk model epsilon: float, optional Epsilon to apply to the normalization operation. Default: `1e-10` @@ -189,19 +190,19 @@ def _nets(self) -> dict[str, NetInfo]: "vgg16": NetInfo(model_id=20, model_name="vgg16_lpips_v1.h5")} - def _linear_block(self, net_output_layer: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _linear_block(self, net_output_layer: KerasTensor) -> tuple[KerasTensor, KerasTensor]: """ Build a linear block for a trunk network output. Parameters ---------- - net_output_layer: :class:`tensorflow.Tensor` + net_output_layer: :class:`keras.KerasTensor` An output from the selected trunk network Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` The input to the linear block - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` The output from the linear block """ in_shape = net_output_layer.shape[1:] @@ -295,12 +296,12 @@ def __init__(self, # pylint:disable=too-many-arguments self._use_lpips = lpips self._normalize = normalize self._ret_per_layer = ret_per_layer - self._shift = Variable(np.array([-.030, -.088, -.188], - dtype="float32")[None, None, None, :], - trainable=False) - self._scale = Variable(np.array([.458, .448, .450], - dtype="float32")[None, None, None, :], - trainable=False) + self._shift = KerasVariable(np.array([-.030, -.088, -.188], + dtype="float32")[None, None, None, :], + trainable=False) + self._scale = KerasVariable(np.array([.458, .448, .450], + dtype="float32")[None, None, None, :], + trainable=False) # Loss needs to be done as fp32. We could cast at output, but better to update the model switch_mixed_precision = keras.mixed_precision.global_policy().name == "mixed_float16" @@ -319,7 +320,7 @@ def __init__(self, # pylint:disable=too-many-arguments keras.mixed_precision.set_global_policy("mixed_float16") logger.debug("Initialized: %s", self.__class__.__name__) - def _process_diffs(self, inputs: list[torch.Tensor]) -> list[torch.Tensor]: + def _process_diffs(self, inputs: list[KerasTensor]) -> list[KerasTensor]: """ Perform processing on the Trunk Network outputs. If :attr:`use_ldip` is enabled, process the diff values through the linear network, @@ -327,19 +328,19 @@ def _process_diffs(self, inputs: list[torch.Tensor]) -> list[torch.Tensor]: Parameters ---------- - inputs: list + inputs: list[:class:`keras.KerasTensor`] List of the squared difference of the true and predicted outputs from the trunk network Returns ------- - list + list[:class:`keras.KerasTensor`] List of either the linear network outputs (when using lpips) or summed network outputs """ if self._use_lpips: return self._linear_net(inputs) return [ops.sum(x, axis=-1) for x in inputs] - def _process_output(self, inputs: torch.Tensor, output_dims: tuple) -> torch.Tensor: + def _process_output(self, inputs: KerasTensor, output_dims: tuple) -> KerasTensor: """ Process an individual output based on whether :attr:`is_spatial` has been selected. When spatial output is selected, all outputs are sized to the shape of the original True @@ -347,14 +348,14 @@ def _process_output(self, inputs: torch.Tensor, output_dims: tuple) -> torch.Ten Parameters ---------- - inputs: :class:`tensorflow.Tensor` + inputs: :class:`keras.KerasTensor` An individual diff output tensor from the linear network or summed output output_dims: tuple The (height, width) of the original true image Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` Either the original tensor resized to the true image dimensions, or the mean value across the height, width axes. """ @@ -362,19 +363,19 @@ def _process_output(self, inputs: torch.Tensor, output_dims: tuple) -> torch.Ten return Resizing(*output_dims, interpolation="bilinear")(inputs) return ops.mean(inputs, axis=(1, 2), keepdims=True) - def call(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: """ Perform the LPIPS Loss Function. Parameters ---------- - y_true: :class:`tensorflow.Tensor` + y_true: :class:`keras.KerasTensor` The ground truth batch of images - y_pred: :class:`tensorflow.Tensor` + y_pred: :class:`keras.KerasTensor` The predicted batch of images Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` The final loss value """ if self._normalize: diff --git a/lib/model/losses/loss.py b/lib/model/losses/loss.py index fc2092bbb2..ea031f8133 100644 --- a/lib/model/losses/loss.py +++ b/lib/model/losses/loss.py @@ -7,7 +7,8 @@ import numpy as np from keras.losses import Loss -from keras import ops, Variable +from keras import ops +from keras.backend.common import KerasVariable import torch @@ -15,6 +16,7 @@ if T.TYPE_CHECKING: from collections.abc import Callable + from keras import KerasTensor logger = logging.getLogger(__name__) @@ -69,17 +71,17 @@ def __init__(self, self._dims: tuple[int, int] = (0, 0) logger.debug("Initialized: %s", self.__class__.__name__) - def _get_patches(self, inputs: torch.Tensor) -> torch.Tensor: + def _get_patches(self, inputs: KerasTensor) -> KerasTensor: """ Crop the incoming batch of images into patches as defined by :attr:`_patch_factor. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` A batch of images to be converted into patches Returns ------- - :class`torch.Tensor`` + :class:`keras.KerasTensor`` The incoming batch converted into patches """ rows, cols = self._dims @@ -97,17 +99,17 @@ def _get_patches(self, inputs: torch.Tensor) -> torch.Tensor: retval = ops.stack(patch_list, axis=1) return retval - def _tensor_to_frequency_spectrum(self, patch: torch.Tensor) -> torch.Tensor: + def _tensor_to_frequency_spectrum(self, patch: KerasTensor) -> KerasTensor: """ Perform FFT to create the orthonomalized DFT frequencies. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` The incoming batch of patches to convert to the frequency spectrum Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The DFT frequencies split into real and imaginary numbers as float32 """ patch = ops.transpose(patch, (0, 1, 4, 2, 3)) # move channels to first @@ -116,19 +118,19 @@ def _tensor_to_frequency_spectrum(self, patch: torch.Tensor) -> torch.Tensor: freq = ops.transpose(freq, (0, 1, 3, 4, 2, 5)) # channels to last return freq - def _get_weight_matrix(self, freq_true: torch.Tensor, freq_pred: torch.Tensor) -> torch.Tensor: + def _get_weight_matrix(self, freq_true: KerasTensor, freq_pred: KerasTensor) -> KerasTensor: """ Calculate a continuous, dynamic weight matrix based on current Euclidean distance. Parameters ---------- - freq_true: :class:`torch.Tensor` + freq_true: :class:`keras.KerasTensor` The real and imaginary DFT frequencies for the true batch of images - freq_pred: :class:`torch.Tensor` + freq_pred: :class:`keras.KerasTensor` The real and imaginary DFT frequencies for the predicted batch of images Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The weights matrix for prioritizing hard frequencies """ weights = ops.square(freq_pred - freq_true) @@ -150,20 +152,20 @@ def _get_weight_matrix(self, freq_true: torch.Tensor, freq_pred: torch.Tensor) - @classmethod def _calculate_loss(cls, - freq_true: torch.Tensor, - freq_pred: torch.Tensor, - weight_matrix: torch.Tensor) -> torch.Tensor: + freq_true: KerasTensor, + freq_pred: KerasTensor, + weight_matrix: KerasTensor) -> KerasTensor: """ Perform the loss calculation on the DFT spectrum applying the weights matrix. Parameters ---------- - freq_true: :class:`torch.Tensor` + freq_true: :class:`keras.KerasTensor` The real and imaginary DFT frequencies for the true batch of images - freq_pred: :class:`torch.Tensor` + freq_pred: :class:`keras.KerasTensor` The real and imaginary DFT frequencies for the predicted batch of images Returns - :class:`torch.Tensor` + :class:`keras.KerasTensor` The final loss matrix """ @@ -174,19 +176,19 @@ def _calculate_loss(cls, return loss - def call(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: """ Call the Focal Frequency Loss Function. Parameters ---------- - y_true: :class:`torch.Tensor` + y_true: :class:`keras.KerasTensor` The ground truth batch of images - y_pred: :class:`torch.Tensor` + y_pred: :class:`keras.KerasTensor` The predicted batch of images Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The loss for this batch of images """ if not all(self._dims): @@ -238,19 +240,19 @@ def __init__(self, alpha: float = 1.0, beta: float = 1.0/255.0) -> None: self._beta = beta logger.debug("Initialized: %s", self.__class__.__name__) - def call(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: """ Call the Generalized Loss Function Parameters ---------- - y_true: :class:`torch.Tensor` + y_true: :class:`keras.KerasTensor` The ground truth value - y_pred: :class:`torch.Tensor` + y_pred: :class:`keras.KerasTensor` The predicted value Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The loss value from the results of function(y_pred - y_true) """ diff = y_pred - y_true @@ -283,7 +285,7 @@ def __init__(self) -> None: logger.debug("Initialized: %s", self.__class__.__name__) @classmethod - def _diff_x(cls, img: torch.Tensor) -> torch.Tensor: + def _diff_x(cls, img: KerasTensor) -> KerasTensor: """ X Difference """ x_left = img[:, :, 1:2, :] - img[:, :, 0:1, :] x_inner = img[:, :, 2:, :] - img[:, :, :-2, :] @@ -292,7 +294,7 @@ def _diff_x(cls, img: torch.Tensor) -> torch.Tensor: return x_out * 0.5 @classmethod - def _diff_y(cls, img: torch.Tensor) -> torch.Tensor: + def _diff_y(cls, img: KerasTensor) -> KerasTensor: """ Y Difference """ y_top = img[:, 1:2, :, :] - img[:, 0:1, :, :] y_inner = img[:, 2:, :, :] - img[:, :-2, :, :] @@ -301,7 +303,7 @@ def _diff_y(cls, img: torch.Tensor) -> torch.Tensor: return y_out * 0.5 @classmethod - def _diff_xx(cls, img: torch.Tensor) -> torch.Tensor: + def _diff_xx(cls, img: KerasTensor) -> KerasTensor: """ X-X Difference """ x_left = img[:, :, 1:2, :] + img[:, :, 0:1, :] x_inner = img[:, :, 2:, :] + img[:, :, :-2, :] @@ -310,7 +312,7 @@ def _diff_xx(cls, img: torch.Tensor) -> torch.Tensor: return x_out - 2.0 * img @classmethod - def _diff_yy(cls, img: torch.Tensor) -> torch.Tensor: + def _diff_yy(cls, img: KerasTensor) -> KerasTensor: """ Y-Y Difference """ y_top = img[:, 1:2, :, :] + img[:, 0:1, :, :] y_inner = img[:, 2:, :, :] + img[:, :-2, :, :] @@ -319,7 +321,7 @@ def _diff_yy(cls, img: torch.Tensor) -> torch.Tensor: return y_out - 2.0 * img @classmethod - def _diff_xy(cls, img: torch.Tensor) -> torch.Tensor: + def _diff_xy(cls, img: KerasTensor) -> KerasTensor: """ X-Y Difference """ # xout1 # Left @@ -359,19 +361,19 @@ def _diff_xy(cls, img: torch.Tensor) -> torch.Tensor: xy_out2 = ops.concatenate([xy_left, xy_mid, xy_right], axis=2) return (xy_out1 - xy_out2) * 0.25 - def call(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: """ Call the gradient loss function. Parameters ---------- - y_true: :class:`torch.Tensor` + y_true: :class:`keras.KerasTensor` The ground truth value - y_pred: :class:`torch.Tensor` + y_pred: :class:`keras.KerasTensor` The predicted value Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The loss value """ loss = 0.0 @@ -418,14 +420,14 @@ def __init__(self, logger.debug(parse_class_init(locals())) super().__init__(name=self.__class__.__name__) self._max_levels = max_levels - self._weights = Variable([np.power(2., -2 * idx) - for idx in range(max_levels + 1)], - trainable=False) + self._weights = KerasVariable([np.power(2., -2 * idx) + for idx in range(max_levels + 1)], + trainable=False) self._gaussian_kernel = self._get_gaussian_kernel(gaussian_size, gaussian_sigma) logger.debug("Initialized: %s", self.__class__.__name__) @classmethod - def _get_gaussian_kernel(cls, size: int, sigma: float) -> torch.Tensor: + def _get_gaussian_kernel(cls, size: int, sigma: float) -> KerasTensor: """ Obtain the base gaussian kernel for the Laplacian Pyramid. Parameters @@ -437,7 +439,7 @@ def _get_gaussian_kernel(cls, size: int, sigma: float) -> torch.Tensor: Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The base single channel Gaussian kernel """ assert size % 2 == 1, ("kernel size must be uneven") @@ -447,19 +449,19 @@ def _get_gaussian_kernel(cls, size: int, sigma: float) -> torch.Tensor: kernel = np.exp(- x_2[:, None] - x_2[None, :]) kernel /= kernel.sum() kernel = np.reshape(kernel, (size, size, 1, 1)) - return Variable(kernel, trainable=False) + return KerasVariable(kernel, trainable=False) - def _conv_gaussian(self, inputs: torch.Tensor) -> torch.Tensor: + def _conv_gaussian(self, inputs: KerasTensor) -> KerasTensor: """ Perform Gaussian convolution on a batch of images. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` The input batch of images to perform Gaussian convolution on. Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The convolved images """ channels = inputs.shape[-1] @@ -478,12 +480,12 @@ def _conv_gaussian(self, inputs: torch.Tensor) -> torch.Tensor: retval = ops.conv(padded_inputs, gauss, strides=1, padding="valid") return retval - def _get_laplacian_pyramid(self, inputs: torch.Tensor) -> list[torch.Tensor]: + def _get_laplacian_pyramid(self, inputs: KerasTensor) -> list[KerasTensor]: """ Obtain the Laplacian Pyramid. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` The input batch of images to run through the Laplacian Pyramid Returns @@ -501,19 +503,19 @@ def _get_laplacian_pyramid(self, inputs: torch.Tensor) -> list[torch.Tensor]: pyramid.append(current) return pyramid - def call(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: """ Calculate the Laplacian Pyramid Loss. Parameters ---------- - y_true: :class:`torch.Tensor` + y_true: :class:`keras.KerasTensor` The ground truth value - y_pred: :class:`torch.Tensor` + y_pred: :class:`keras.KerasTensor` The predicted value Returns ------- - :class: `torch.Tensor` + :class:`keras.KerasTensor` The loss value """ pyramid_true = self._get_laplacian_pyramid(y_true) @@ -534,19 +536,19 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, name=self.__class__.__name__, **kwargs) logger.debug("Initialized: %s", self.__class__.__name__) - def call(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: """ Call the L-inf norm loss function. Parameters ---------- - y_true: :class:`torch.Tensor` + y_true: :class:`keras.KerasTensor` The ground truth value - y_pred: :class:`torch.Tensor` + y_pred: :class:`keras.KerasTensor` The predicted value Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The loss value """ diff = ops.abs(y_true - y_pred) @@ -606,7 +608,7 @@ def add_loss(self, self._loss_weights.append(weight) self._mask_channels.append(mask_channel) - def call(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: """ Call the sub loss functions for the loss wrapper. Loss is returned as the weighted sum of the chosen losses. @@ -617,14 +619,14 @@ def call(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: Parameters ---------- - y_true: :class:`tensorflow.Tensor` + y_true: :class:`keras.KerasTensor` The ground truth batch of images, with any required masks stacked on the end - y_pred: :class:`tensorflow.Tensor` + y_pred: :class:`keras.KerasTensor` The batch of model predictions Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` The final weighted loss """ loss = 0.0 @@ -640,18 +642,18 @@ def call(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: @classmethod def _apply_mask(cls, - y_true: torch.Tensor, - y_pred: torch.Tensor, + y_true: KerasTensor, + y_pred: KerasTensor, mask_channel: int, - mask_prop: float = 1.0) -> tuple[torch.Tensor, torch.Tensor]: + mask_prop: float = 1.0) -> tuple[KerasTensor, KerasTensor]: """ Apply the mask to the input y_true and y_pred. If a mask is not required then return the unmasked inputs. Parameters ---------- - y_true: tensor or variable + y_true: :class:`keras.KerasTensor` The ground truth value - y_pred: tensor or variable + y_pred: :class:`keras.KerasTensor` The predicted value mask_channel: int The channel within y_true that the required mask resides in @@ -660,9 +662,9 @@ def _apply_mask(cls, Returns ------- - torch.Tensor + :class:`keras.KerasTensor` The ground truth batch of images, with the required mask applied - torch.Tensor + :class:`keras.KerasTensor` The predicted batch of images with the required mask applied """ if mask_channel == -1: diff --git a/lib/model/losses/perceptual_loss.py b/lib/model/losses/perceptual_loss.py index d11e5c13b9..878333459e 100644 --- a/lib/model/losses/perceptual_loss.py +++ b/lib/model/losses/perceptual_loss.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 """ TF Keras implementation of Perceptual Loss Functions for faceswap.py """ +from __future__ import annotations import logging import typing as T @@ -8,11 +9,15 @@ import torch import keras -from keras import ops, Variable +from keras import ops +from keras.backend.common import KerasVariable from lib.keras_utils import ColorSpaceConvert, frobenius_norm, replicate_pad from lib.logger import parse_class_init +if T.TYPE_CHECKING: + from keras import KerasTensor + logger = logging.getLogger(__name__) @@ -61,12 +66,12 @@ def __init__(self, self._c2 = ((k_2 * max_value) ** 2) * compensation logger.debug("Initialized: %s", self.__class__.__name__) - def _get_kernel(self) -> torch.Tensor: + def _get_kernel(self) -> KerasTensor: """ Obtain the base kernel for performing depthwise convolution. Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The gaussian kernel based on selected size and sigma """ coords = np.arange(self._filter_size, dtype="float32") @@ -75,46 +80,46 @@ def _get_kernel(self) -> torch.Tensor: kernel = np.square(coords) kernel *= -0.5 / np.square(self._filter_sigma) kernel = np.reshape(kernel, (1, -1)) + np.reshape(kernel, (-1, 1)) - kernel = Variable(np.reshape(kernel, (1, -1)), trainable=False) + kernel = KerasVariable(np.reshape(kernel, (1, -1)), trainable=False) kernel = ops.softmax(kernel) kernel = ops.reshape(kernel, (self._filter_size, self._filter_size, 1, 1)) return kernel @classmethod - def _depthwise_conv2d(cls, image: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor: + def _depthwise_conv2d(cls, image: KerasTensor, kernel: KerasTensor) -> KerasTensor: """ Perform a standardized depthwise convolution. Parameters ---------- - image: :class:`torch.Tensor` + image: :class:`keras.KerasTensor` Batch of images, channels last, to perform depthwise convolution - kernel: :class:`torch.Tensor` + kernel: :class:`keras.KerasTensor` convolution kernel Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The output from the convolution """ return ops.depthwise_conv(image, kernel, strides=(1, 1), padding="valid") def _get_ssim(self, - y_true: torch.Tensor, - y_pred: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + y_true: KerasTensor, + y_pred: KerasTensor) -> tuple[KerasTensor, KerasTensor]: """ Obtain the structural similarity between a batch of true and predicted images. Parameters ---------- - y_true: :class:`torch.Tensor` + y_true: :class:`keras.KerasTensor` The input batch of ground truth images - y_pred: :class:`torch.Tensor` + y_pred: :class:`keras.KerasTensor` The input batch of predicted images Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The SSIM for the given images - :class:`torch.Tensor` + :class:`keras.KerasTensor` The Contrast for the given images """ channels = y_true.shape[-1] @@ -140,19 +145,19 @@ def _get_ssim(self, return ssim, contrast - def call(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: """ Call the DSSIM or MS-DSSIM Loss Function. Parameters ---------- - y_true: :class:`torch.Tensor` + y_true: :class:`keras.KerasTensor` The input batch of ground truth images - y_pred: :class:`torch.Tensor` + y_pred: :class:`keras.KerasTensor` The input batch of predicted images Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The DSSIM or MS-DSSIM for the given images """ ssim = self._get_ssim(y_true, y_pred)[0] @@ -174,41 +179,41 @@ class GMSDLoss(keras.losses.Loss): def __init__(self, *args, **kwargs) -> None: logger.debug(parse_class_init(locals())) super().__init__(*args, name=self.__class__.__name__, **kwargs) - self._scharr_edges = Variable(np.array([[[[0.00070, 0.00070]], - [[0.00520, 0.00370]], - [[0.03700, 0.00000]], - [[0.00520, -0.0037]], - [[0.00070, -0.0007]]], - [[[0.00370, 0.00520]], - [[0.11870, 0.11870]], - [[0.25890, 0.00000]], - [[0.11870, -0.1187]], - [[0.00370, -0.0052]]], - [[[0.00000, 0.03700]], - [[0.00000, 0.25890]], - [[0.00000, 0.00000]], - [[0.00000, -0.2589]], - [[0.00000, -0.0370]]], - [[[-0.0037, 0.00520]], - [[-0.1187, 0.11870]], - [[-0.2589, 0.00000]], - [[-0.1187, -0.1187]], - [[-0.0037, -0.0052]]], - [[[-0.0007, 0.00070]], - [[-0.0052, 0.00370]], - [[-0.0370, 0.00000]], - [[-0.0052, -0.0037]], - [[-0.0007, -0.0007]]]]), - dtype="float32", - trainable=False) + self._scharr_edges = KerasVariable(np.array([[[[0.00070, 0.00070]], + [[0.00520, 0.00370]], + [[0.03700, 0.00000]], + [[0.00520, -0.0037]], + [[0.00070, -0.0007]]], + [[[0.00370, 0.00520]], + [[0.11870, 0.11870]], + [[0.25890, 0.00000]], + [[0.11870, -0.1187]], + [[0.00370, -0.0052]]], + [[[0.00000, 0.03700]], + [[0.00000, 0.25890]], + [[0.00000, 0.00000]], + [[0.00000, -0.2589]], + [[0.00000, -0.0370]]], + [[[-0.0037, 0.00520]], + [[-0.1187, 0.11870]], + [[-0.2589, 0.00000]], + [[-0.1187, -0.1187]], + [[-0.0037, -0.0052]]], + [[[-0.0007, 0.00070]], + [[-0.0052, 0.00370]], + [[-0.0370, 0.00000]], + [[-0.0052, -0.0037]], + [[-0.0007, -0.0007]]]]), + dtype="float32", + trainable=False) logger.debug("Initialized: %s", self.__class__.__name__) - def _map_scharr_edges(self, image: torch.Tensor, magnitude: bool) -> torch.Tensor: + def _map_scharr_edges(self, image: KerasTensor, magnitude: bool) -> KerasTensor: """ Returns a tensor holding modified Scharr edge maps. Parameters ---------- - image: :class:`torch.Tensor` + image: :class:`keras.KerasTensor` Image tensor with shape [batch_size, h, w, d] and type float32. The image(s) must be 2x2 or larger. magnitude: bool @@ -216,7 +221,7 @@ def _map_scharr_edges(self, image: torch.Tensor, magnitude: bool) -> torch.Tenso Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` Tensor holding edge maps for each channel. Returns a tensor with shape `[batch_size, h, w, d, 2]` where the last two dimensions hold `[[dy[0], dx[0]], [dy[1], dx[1]], ..., [dy[d-1], dx[d-1]]]` calculated using the Scharr filter. @@ -243,19 +248,19 @@ def _map_scharr_edges(self, image: torch.Tensor, magnitude: bool) -> torch.Tenso # magnitude of edges -- unified x & y edges don't work well with Neural Networks return output - def call(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: """ Return the Gradient Magnitude Similarity Deviation Loss. Parameters ---------- - y_true: :class:`torch.Tensor` + y_true: :class:`keras.KerasTensor` The ground truth value - y_pred: :class:`torch.Tensor` + y_pred: :class:`keras.KerasTensor` The predicted value Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The loss value """ true_edge = self._map_scharr_edges(y_true, True) @@ -349,24 +354,28 @@ def __init__(self, self._feature_detector = _FeatureDetection(pixels_per_degree) self._col_conv = {"rgb2lab": ColorSpaceConvert(from_space="rgb", to_space="lab"), "rgb2ycxcz": ColorSpaceConvert("srgb", "ycxcz")} - self._hunt = {"green": Variable([[[[0.0, 1.0, 0.0]]]], dtype="float32", trainable=False), - "blue": Variable([[[[0.0, 0.0, 1.0]]]], dtype="float32", trainable=False)} + self._hunt = {"green": KerasVariable([[[[0.0, 1.0, 0.0]]]], + dtype="float32", + trainable=False), + "blue": KerasVariable([[[[0.0, 0.0, 1.0]]]], + dtype="float32", + trainable=False)} logger.debug("Initialized: %s ", self.__class__.__name__) - def call(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: """ Call the LDR Flip Loss Function Parameters ---------- - y_true: :class:`tensorflow.Tensor` + y_true: :class:`keras.KerasTensor` The ground truth batch of images - y_pred: :class:`tensorflow.Tensor` + y_pred: :class:`keras.KerasTensor` The predicted batch of images Returns ------- - :class::class:`tensorflow.Tensor` + :class::class:`keras.KerasTensor` The calculated Flip loss value """ if self._color_order == "bgr": # Switch models training in bgr order to rgb @@ -385,19 +394,19 @@ def call(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: loss = ops.power(delta_e_color, 1 - delta_e_features) return loss - def _color_pipeline(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + def _color_pipeline(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: """ Perform the color processing part of the FLIP loss function Parameters ---------- - y_true: :class:`tensorflow.Tensor` + y_true: :class:`keras.KerasTensor` The ground truth batch of images in YCxCz color space - y_pred: :class:`tensorflow.Tensor` + y_pred: :class:`keras.KerasTensor` The predicted batch of images in YCxCz color space Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` The exponentiated, maximum HyAB difference between two colors in Hunt-adjusted L*A*B* space """ @@ -416,19 +425,19 @@ def _color_pipeline(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.T self._computed_distance_exponent) return self._redistribute_errors(power_delta, cmax) - def _process_features(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + def _process_features(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: """ Perform the color processing part of the FLIP loss function Parameters ---------- - y_true: :class:`tensorflow.Tensor` + y_true: :class:`keras.KerasTensor` The ground truth batch of images in YCxCz color space - y_pred: :class:`tensorflow.Tensor` + y_pred: :class:`keras.KerasTensor` The predicted batch of images in YCxCz color space Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` The exponentiated features delta """ col_y_true = (y_true[..., 0:1] + 16) / 116. @@ -446,17 +455,17 @@ def _process_features(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch return ops.power(((1 / np.sqrt(2)) * delta), self._feature_exponent) @classmethod - def _hunt_adjustment(cls, image: torch.Tensor) -> torch.Tensor: + def _hunt_adjustment(cls, image: KerasTensor) -> KerasTensor: """ Apply Hunt-adjustment to an image in L*a*b* color space Parameters ---------- - image: :class:`tensorflow.Tensor` + image: :class:`keras.KerasTensor` The batch of images in L*a*b* to adjust Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` The hunt adjusted batch of images in L*a*b color space """ ch_l = image[..., 0:1] @@ -468,14 +477,14 @@ def _hyab(self, y_true, y_pred): Parameters ---------- - y_true: :class:`tensorflow.Tensor` + y_true: :class:`keras.KerasTensor` The ground truth batch of images in standard or Hunt-adjusted L*A*B* color space - y_pred: :class:`tensorflow.Tensor` + y_pred: :class:`keras.KerasTensor` The predicted batch of images in in standard or Hunt-adjusted L*A*B* color space Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` image tensor containing the per-pixel HyAB distances between true and predicted images """ delta = y_true - y_pred @@ -490,15 +499,15 @@ def _redistribute_errors(self, power_delta_e_hyab, cmax): Parameters ---------- - power_delta_e_hyab: :class:`tensorflow.Tensor` + power_delta_e_hyab: :class:`keras.KerasTensor` The exponentiated HyAb distance - cmax: :class:`tensorflow.Tensor` + cmax: :class:`keras.KerasTensor` The exponentiated, maximum HyAB difference between two colors in Hunt-adjusted L*A*B* space Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` The redistributed per-pixel HyAB distances (in range [0,1]) """ pccmax = self._pc * cmax @@ -528,7 +537,7 @@ def __init__(self, pixels_per_degree: float) -> None: self._ycxcz2rgb = ColorSpaceConvert(from_space="ycxcz", to_space="rgb") logger.debug("Initialized: %s", self.__class__.__name__) - def _generate_spatial_filters(self) -> tuple[torch.Tensor, int]: + def _generate_spatial_filters(self) -> tuple[KerasTensor, int]: """ Generates spatial contrast sensitivity filters with width depending on the number of pixels per degree of visual angle of the observer for channels "A", "RG" and "BY" @@ -552,7 +561,7 @@ def _generate_spatial_filters(self) -> tuple[torch.Tensor, int]: weights = np.array([self._generate_weights(mapping[channel], domain) for channel in ("A", "RG", "BY")]) - vweights = Variable(np.moveaxis(weights, 0, -1), dtype="float32", trainable=False) + vweights = KerasVariable(np.moveaxis(weights, 0, -1), dtype="float32", trainable=False) return vweights, radius @@ -582,17 +591,17 @@ def _generate_weights(cls, channel: dict[str, float], domain: np.ndarray) -> np. grad = np.reshape(grad, (*grad.shape, 1)) return grad - def __call__(self, image: torch.Tensor) -> torch.Tensor: + def __call__(self, image: KerasTensor) -> KerasTensor: """ Call the spacial filtering. Parameters ---------- - image: Tensor + image: :class:`keras.KerasTensor` Image tensor to filter in YCxCz color space Returns ------- - Tensor + :class:`keras.KerasTensor` The input image transformed to linear RGB after filtering with spatial contrast sensitivity functions """ @@ -625,24 +634,24 @@ def __init__(self, pixels_per_degree: float) -> None: gradient = np.exp(-(grid[0] ** 2 + grid[1] ** 2) / (2 * (self._std ** 2))) self._grads = { - "edge": Variable(np.multiply(-grid[0], gradient), trainable=False), - "point": Variable(np.multiply(grid[0] ** 2 / (self._std ** 2) - 1, gradient), - trainable=False)} + "edge": KerasVariable(np.multiply(-grid[0], gradient), trainable=False), + "point": KerasVariable(np.multiply(grid[0] ** 2 / (self._std ** 2) - 1, gradient), + trainable=False)} logger.debug("Initialized: %s", self.__class__.__name__) - def __call__(self, image: torch.Tensor, feature_type: str) -> torch.Tensor: + def __call__(self, image: KerasTensor, feature_type: str) -> KerasTensor: """ Run the feature detection Parameters ---------- - image: Tensor + image: :class:`keras.KerasTensor` Batch of images in YCxCz color space with normalized Y values feature_type: str Type of features to detect (`"edge"` or `"point"`) Returns ------- - Tensor + :class:`keras.KerasTensor` Detected features in the 0-1 range """ feature_type = feature_type.lower() @@ -706,29 +715,29 @@ def __init__(self, logger.debug(parse_class_init(locals())) super().__init__(name=self.__class__.__name__) self.filter_size = filter_size - self._filter_sigma = Variable(filter_sigma, dtype="float32", trainable=False) + self._filter_sigma = KerasVariable(filter_sigma, dtype="float32", trainable=False) self._k_1 = k_1 self._k_2 = k_2 self._max_value = max_value self._power_factors = power_factors self._divisor = [1, 2, 2, 1] - self._divisor_tensor = Variable(self._divisor[1:], dtype="int32", trainable=False) + self._divisor_tensor = KerasVariable(self._divisor[1:], dtype="int32", trainable=False) logger.debug("Initialized: %s", self.__class__.__name__) @classmethod - def _reducer(cls, image: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor: + def _reducer(cls, image: KerasTensor, kernel: KerasTensor) -> KerasTensor: """ Computes local averages from a set of images Parameters ---------- - image: :class:`torch.Tensor` + image: :class:`keras.KerasTensor` The images to be processed - kernel: :class:`torch.Tensor` + kernel: :class:`keras.KerasTensor` The kernel to apply Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The reduced image """ shape = image.shape @@ -737,25 +746,25 @@ def _reducer(cls, image: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor: return ops.reshape(var_y, (*shape[:-3], *var_y.shape[1:])) def _ssim_helper(self, - image1: torch.Tensor, - image2: torch.Tensor, - kernel: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + image1: KerasTensor, + image2: KerasTensor, + kernel: KerasTensor) -> tuple[KerasTensor, KerasTensor]: """ Helper function for computing SSIM Parameters ---------- - image1: :class:`torch.Tensor` + image1: :class:`keras.KerasTensor` The first set of images - image2: :class:`torch.Tensor` + image2: :class:`keras.KerasTensor` The second set of images - kernel: :class:`torch.Tensor` + kernel: :class:`keras.KerasTensor` The gaussian kernel Returns ------- - :class:`torch.Tensor`: + :class:`keras.KerasTensor`: The channel-wise SSIM - :class:`torch.Tensor`: + :class:`keras.KerasTensor`: The channel-wise contrast-structure """ c_1 = (self._k_1 * self._max_value) ** 2 @@ -773,7 +782,7 @@ def _ssim_helper(self, return luminance, cs_ - def _fspecial_gauss(self, size: int) -> torch.Tensor: + def _fspecial_gauss(self, size: int) -> KerasTensor: """Function to mimic the 'fspecial' gaussian MATLAB function. Parameters @@ -783,7 +792,7 @@ def _fspecial_gauss(self, size: int) -> torch.Tensor: Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The gaussian kernel """ coords = ops.cast(range(size), self._filter_sigma.dtype) @@ -798,9 +807,9 @@ def _fspecial_gauss(self, size: int) -> torch.Tensor: return ops.reshape(gauss, [size, size, 1, 1]) def _ssim_per_channel(self, - image1: torch.Tensor, - image2: torch.Tensor, - filter_size: int) -> tuple[torch.Tensor]: + image1: KerasTensor, + image2: KerasTensor, + filter_size: int) -> tuple[KerasTensor, KerasTensor]: """Computes SSIM index between image1 and image2 per color channel. This function matches the standard SSIM implementation from: @@ -810,18 +819,18 @@ def _ssim_per_channel(self, Parameters ---------- - image1: :class;`torch.Tensor` + image1: :class:`keras.KerasTensor` The first image batch - image2: :class;`torch.Tensor` + image2: :class:`keras.KerasTensor` The second image batch. filter_size: int size of gaussian filter). Returns ------- - :class:`torch.Tensor`: + :class:`keras.KerasTensor`: The channel-wise SSIM - :class:`torch.Tensor`: + :class:`keras.KerasTensor`: The channel-wise contrast-structure """ shape = image1.shape @@ -837,19 +846,19 @@ def _ssim_per_channel(self, return ssim_val, cs_ @classmethod - def _do_pad(cls, images: torch.Tensor, remainder: torch.Tensor) -> list[torch.Tensor]: + def _do_pad(cls, images: KerasTensor, remainder: KerasTensor) -> list[KerasTensor]: """ Pad images Parameters ---------- - images: :class:`torch.Tensor` + images: :class:`keras.KerasTensor` Images to pad - remainder: :class:`torch.Tensor` + remainder: :class:`keras.KerasTensor` Remainding images to pad Returns ------- - list[:class:`torch.Tensor`] + list[:class:`keras.KerasTensor`] Padded images """ padding = ops.expand_dims(remainder, axis=-1) @@ -857,18 +866,18 @@ def _do_pad(cls, images: torch.Tensor, remainder: torch.Tensor) -> list[torch.Te return [ops.pad(x, padding, mode="symmetric") for x in images] def _mssism(self, - y_true: torch.Tensor, - y_pred: torch.Tensor, - filter_size: int) -> torch.Tensor: + y_true: KerasTensor, + y_pred: KerasTensor, + filter_size: int) -> KerasTensor: """ Perform the MSSISM calculation. Ported from Tensorflow implementation `tf.image.ssim_multiscale` Parameters ---------- - y_true: :class:`torch.Tensor` + y_true: :class:`keras.KerasTensor` The ground truth value - y_pred: :class:`torch.Tensor` + y_pred: :class:`keras.KerasTensor` The predicted value filter_size: int The filter size to use @@ -887,9 +896,11 @@ def _mssism(self, remainder = tails[0] % self._divisor_tensor need_padding = ops.any(ops.not_equal(remainder, 0)) - padded = ops.cond(need_padding, - lambda: self._do_pad(flat_images, remainder), - lambda: flat_images) + padded = ops.cond( + need_padding, + lambda: self._do_pad(flat_images, # pylint:disable=cell-var-from-loop + remainder), # pylint:disable=cell-var-from-loop + lambda: flat_images) # pylint:disable=cell-var-from-loop downscaled = [ops.average_pool(x, self._divisor[1:3], @@ -902,7 +913,7 @@ def _mssism(self, for x, h, t in zip(downscaled, heads, tails)] # Overwrite previous ssim value since we only need the last one. - ssim_per_channel, cs_ = self._ssim_per_channel(*images, filter_size) + ssim_per_channel, cs_ = self._ssim_per_channel(images[0], images[1], filter_size) mcs.append(ops.relu(cs_)) mcs.pop() # Remove the cs score for the last scale. @@ -912,19 +923,19 @@ def _mssism(self, return ops.mean(ms_ssim, [-1]) # Avg over color channels. - def call(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + def call(self, y_true: KerasTensor, y_pred: KerasTensor) -> KerasTensor: """ Call the MS-SSIM Loss Function. Parameters ---------- - y_true: :class:`torch.Tensor` + y_true: :class:`keras.KerasTensor` The ground truth value - y_pred: :class:`torch.Tensor` + y_pred: :class:`keras.KerasTensor` The predicted value Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The MS-SSIM Loss value """ im_size = y_true.shape[1] diff --git a/lib/model/networks/clip.py b/lib/model/networks/clip.py index 1095be9a1c..38ee48d246 100644 --- a/lib/model/networks/clip.py +++ b/lib/model/networks/clip.py @@ -11,14 +11,18 @@ from dataclasses import dataclass -from keras import layers, Model, ops, Variable +from keras import layers, Model, ops +from keras.backend.common import KerasVariable from keras.saving import get_custom_objects import numpy as np -import torch from lib.model.layers import QuickGELU from lib.utils import GetModel +if T.TYPE_CHECKING: + from keras import KerasTensor + + logger = logging.getLogger(__name__) TypeModels = T.Literal["RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B-16", @@ -98,7 +102,7 @@ class Transformer(): The number of layers in the Transformer. heads: int The number of attention heads. - attn_mask: torch.Tensor, optional + attn_mask: :class:`keras.KerasTensor`, optional The attention mask, by default None. name: str, optional The name of the Transformer model, by default "transformer". @@ -115,7 +119,7 @@ def __init__(self, width: int, num_layers: int, heads: int, - attn_mask: torch.Tensor = None, + attn_mask: KerasTensor = None, name: str = "transformer") -> None: logger.debug("Initializing: %s (width: %s, num_layers: %s, heads: %s, attn_mask: %s, " "name: %s)", @@ -150,12 +154,12 @@ def _get_name(cls, name: str) -> str: return name @classmethod - def _mlp(cls, inputs: torch.Tensor, key_dim: int, name: str) -> torch.Tensor: + def _mlp(cls, inputs: KerasTensor, key_dim: int, name: str) -> KerasTensor: """" Multilayer Perecptron for Block Ateention Parameters ---------- - inputs: :class:`tensorflow.Tensor` + inputs: :class:`keras.KerasTensor` The input to the MLP key_dim: int key dimension per head for MultiHeadAttention @@ -164,7 +168,7 @@ def _mlp(cls, inputs: torch.Tensor, key_dim: int, name: str) -> torch.Tensor: Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` The output from the MLP """ name = f"{name}.mlp" @@ -174,29 +178,29 @@ def _mlp(cls, inputs: torch.Tensor, key_dim: int, name: str) -> torch.Tensor: return var_x def residual_attention_block(self, - inputs: torch.Tensor, + inputs: KerasTensor, key_dim: int, num_heads: int, - attn_mask: torch.Tensor, - name: str = "ResidualAttentionBlock") -> torch.Tensor: + attn_mask: KerasTensor, + name: str = "ResidualAttentionBlock") -> KerasTensor: """ Call the residual attention block Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` The input Tensor key_dim: int key dimension per head for MultiHeadAttention num_heads: int Number of heads for MultiHeadAttention - attn_mask: :class:`tensorflow.Tensor`, optional + attn_mask: :class:`keras.KerasTensor`, optional Default: ``None`` name: str, optional The name for the layer. Default: "ResidualAttentionBlock" Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The return Tensor """ name = self._get_name(name) @@ -212,17 +216,17 @@ def residual_attention_block(self, var_x = layers.Add()([var_y, self._mlp(var_x, key_dim, name)]) return var_x - def __call__(self, inputs: torch.Tensor) -> torch.Tensor: + def __call__(self, inputs: KerasTensor) -> KerasTensor: """ Call the Transformer layers Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` The input Tensor Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The return Tensor """ logger.debug("Calling %s with input: %s", self.__class__.__name__, inputs.shape) @@ -236,7 +240,7 @@ def __call__(self, inputs: torch.Tensor) -> torch.Tensor: return var_x -class EmbeddingLayer(layers.Layer): +class EmbeddingLayer(layers.Layer): # pylint:disable=too-many-ancestors,abstract-method """ Parent class for trainable embedding variables Parameters @@ -260,7 +264,7 @@ def __init__(self, super().__init__(name=name, dtype=dtype, *args, **kwargs) self._input_shape = input_shape self._scale = scale - self._var: torch.Tensor + self._var: KerasTensor def build(self, input_shape: tuple[int, ...]) -> None: """ Add the weights @@ -270,9 +274,9 @@ def build(self, input_shape: tuple[int, ...]) -> None: input_shape: tuple[int, ... The input shape of the incoming tensor """ - self._var = Variable(self._scale * np.random.normal(size=(self._input_shape)), - trainable=True, - dtype=self.dtype) + self._var = KerasVariable(self._scale * np.random.normal(size=self._input_shape), + trainable=True, + dtype=self.dtype) super().build(input_shape) def get_config(self) -> dict[str, T.Any]: @@ -289,55 +293,58 @@ def get_config(self) -> dict[str, T.Any]: return retval -class ClassEmbedding(EmbeddingLayer): +class ClassEmbedding(EmbeddingLayer): # pylint:disable=too-many-ancestors,abstract-method """ Trainable Class Embedding layer """ - def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: """ Get the Class Embedding layer Parameters ---------- - inputs: :class:`tensorflow.Tensor` + inputs: :class:`keras.KerasTensor` Input tensor to the embedding layer Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` The class embedding layer shaped for the input tensor """ return ops.tile(self._var[None, None], [inputs.shape[0], 1, 1]) -class PositionalEmbedding(EmbeddingLayer): +class PositionalEmbedding(EmbeddingLayer): # pylint:disable=too-many-ancestors,abstract-method """ Trainable Positional Embedding layer """ - def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: """ Get the Positional Embedding layer Parameters ---------- - inputs: :class:`tensorflow.Tensor` + inputs: :class:`keras.KerasTensor` Input tensor to the embedding layer Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` The positional embedding layer shaped for the input tensor """ return ops.tile(self._var[None], [inputs.shape[0], 1, 1]) -class Projection(EmbeddingLayer): +class Projection(EmbeddingLayer): # pylint:disable=too-many-ancestors,abstract-method """ Trainable Projection Embedding Layer """ - def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: """ Get the Projection layer Parameters ---------- - inputs: :class:`tensorflow.Tensor` + inputs: :class:`keras.KerasTensor` Input tensor to the embedding layer Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` The Projection layer expanded to the batch dimension and transposed for matmul """ return ops.tile(ops.transpose(self._var)[None], [inputs.shape[0], 1, 1]) @@ -398,11 +405,11 @@ def __call__(self) -> Model: The Visual Transformer model. """ inputs = layers.Input([self._input_resolution, self._input_resolution, 3]) - var_x: torch.Tensor = layers.Conv2D(self._width, # shape = [*, grid, grid, width] - self._patch_size, - strides=self._patch_size, - use_bias=False, - name=f"{self._name}.conv1")(inputs) + var_x: KerasTensor = layers.Conv2D(self._width, # shape = [*, grid, grid, width] + self._patch_size, + strides=self._patch_size, + use_bias=False, + name=f"{self._name}.conv1")(inputs) var_x = layers.Reshape((-1, self._width))(var_x) # shape = [*, grid ** 2, width] @@ -465,17 +472,17 @@ def __init__(self, self._name = name logger.debug("Initialized: %s", self.__class__.__name__) - def _downsample(self, inputs: torch.Tensor) -> torch.Tensor: + def _downsample(self, inputs: KerasTensor) -> KerasTensor: """ Perform downsample if required Parameters ---------- - inputs: :class:`tensorflow.Tensor` + inputs: :class:`keras.KerasTensor` The input the downsample Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` The original tensor, if downsizing not required, otherwise the downsized tensor """ if self._stride <= 1 and self._inplanes == self._planes * self.expansion: @@ -491,7 +498,7 @@ def _downsample(self, inputs: torch.Tensor) -> torch.Tensor: out = layers.BatchNormalization(name=f"{name}.1", epsilon=1e-5)(out) return out - def __call__(self, inputs: torch.Tensor) -> torch.Tensor: + def __call__(self, inputs: KerasTensor) -> KerasTensor: """ Performs the forward pass for a Bottleneck block. All conv layers have stride 1. an avgpool is performed after the second convolution when @@ -499,12 +506,12 @@ def __call__(self, inputs: torch.Tensor) -> torch.Tensor: Parameters ---------- - inputs: :class:`tensorflow.Tensor` + inputs: :class:`keras.KerasTensor` The input tensor to the Bottleneck block. Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` The result of the forward pass through the Bottleneck block. """ out = layers.Conv2D(self._planes, 1, use_bias=False, name=f"{self._name}.conv1")(inputs) @@ -566,19 +573,19 @@ def __init__(self, self._name = name logger.debug("Initialized: %s", self.__class__.__name__) - def __call__(self, inputs: torch.Tensor) -> torch.Tensor: + def __call__(self, inputs: KerasTensor) -> KerasTensor: """Performs the attention pooling operation on the input tensor. Parameters ---------- - inputs: :class:`tensorflow.Tensor`: + inputs: :class:`keras.KerasTensor`: The input tensor of shape [batch_size, height, width, embed_dim]. Returns ------- - :class:`tensorflow.Tensor`:: The result of the attention pooling operation + :class:`keras.KerasTensor`:: The result of the attention pooling operation """ - var_x: torch.Tensor + var_x: KerasTensor var_x = layers.Reshape((-1, inputs.shape[-1]))(inputs) # NHWC -> N(HW)C var_x = layers.Concatenate(axis=1)([ops.mean(var_x, axis=1, # N(HW)C -> N(HW+1)C keepdims=True), var_x]) @@ -636,19 +643,19 @@ def __init__(self, self._output_dim = output_dim self._name = name - def _stem(self, inputs: torch.Tensor) -> torch.Tensor: + def _stem(self, inputs: KerasTensor) -> KerasTensor: """ Applies the stem operation to the input tensor, which consists of 3 convolutional layers with BatchNormalization and ReLU activation, followed by an average pooling layer. Parameters ---------- - inputs: :class:`tensorflow.Tensor` + inputs: :class:`keras.KerasTensor` The input tensor Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` The output tensor after applying the stem operation. """ var_x = inputs @@ -667,17 +674,17 @@ def _stem(self, inputs: torch.Tensor) -> torch.Tensor: return var_x def _bottleneck(self, - inputs: torch.Tensor, + inputs: KerasTensor, planes: int, blocks: int, stride: int = 1, - name: str = "layer") -> torch.Tensor: + name: str = "layer") -> KerasTensor: """ A private method that creates a sequential layer of Bottleneck blocks for the ModifiedResNet model. Parameters ---------- - inputs: :class:`tensorflow.Tensor` + inputs: :class:`keras.KerasTensor` The input tensor planes: int The number of output channels for the layer. @@ -690,10 +697,10 @@ def _bottleneck(self, Returns ------- - :class:`tensorflow.Tensor` + :class:`keras.KerasTensor` Sequential block of bottlenecks """ - retval: torch.Tensor + retval: KerasTensor retval = Bottleneck(planes, planes, stride, name=f"{name}.0")(inputs) for i in range(1, blocks): retval = Bottleneck(planes * Bottleneck.expansion, diff --git a/lib/model/networks/simple_nets.py b/lib/model/networks/simple_nets.py index f8f8b40cf8..da699f8b6b 100644 --- a/lib/model/networks/simple_nets.py +++ b/lib/model/networks/simple_nets.py @@ -10,7 +10,7 @@ from lib.logger import parse_class_init if T.TYPE_CHECKING: - import torch + from keras import KerasTensor logger = logging.getLogger(__name__) @@ -60,19 +60,19 @@ def __init__(self, input_shape: tuple[int, int, int] | None = None) -> None: @classmethod def _conv_block(cls, - inputs: torch.Tensor, + inputs: KerasTensor, padding: int, filters: int, kernel_size: int, strides: int, block_idx: int, - max_pool: bool) -> torch.Tensor: + max_pool: bool) -> KerasTensor: """ The Convolutional block for AlexNet Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` The input tensor to the block padding: int The amount of zero paddin to apply prior to convolution @@ -89,13 +89,13 @@ def _conv_block(cls, Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The output of the Convolutional block """ name = f"features.{block_idx}" var_x = inputs if max_pool: - var_x = layers.MaxPool2D(pool_size=3, strides=2, name=f"{name}.pool")(var_x) + var_x = layers.MaxPooling2D(pool_size=3, strides=2, name=f"{name}.pool")(var_x) var_x = layers.ZeroPadding2D(padding=padding, name=f"{name}.pad")(var_x) var_x = layers.Conv2D(filters, kernel_size=kernel_size, @@ -152,15 +152,15 @@ class SqueezeNet(_net): @classmethod def _fire(cls, - inputs: torch.Tensor, + inputs: KerasTensor, squeeze_planes: int, expand_planes: int, - block_idx: int) -> torch.Tensor: + block_idx: int) -> KerasTensor: """ The fire block for SqueezeNet. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` The input to the fire block squeeze_planes: int The number of filters for the squeeze convolution @@ -171,7 +171,7 @@ def _fire(cls, Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The output of the SqueezeNet fire block """ name = f"features.{block_idx}" @@ -202,7 +202,7 @@ def __call__(self) -> Model: expand = 64 for idx in range(4): if idx < 3: - var_x = layers.MaxPool2D(pool_size=3, strides=2)(var_x) + var_x = layers.MaxPooling2D(pool_size=3, strides=2)(var_x) block_idx += 1 var_x = self._fire(var_x, squeeze, expand, block_idx) block_idx += 1 diff --git a/lib/model/nn_blocks.py b/lib/model/nn_blocks.py index 565801866e..99482e283e 100644 --- a/lib/model/nn_blocks.py +++ b/lib/model/nn_blocks.py @@ -17,7 +17,6 @@ if T.TYPE_CHECKING: import keras - from tensorflow import Tensor logger = logging.getLogger(__name__) @@ -99,7 +98,7 @@ def _get_default_initializer( return retval -class Conv2D(KConv2D): # pylint:disable=too-many-ancestors +class Conv2D(KConv2D): # pylint:disable=too-many-ancestors,abstract-method """ A standard Keras Convolution 2D layer with parameters updated to be more appropriate for Faceswap architecture. @@ -132,7 +131,7 @@ def __init__(self, *args, padding: str = "same", is_upscale: bool = False, **kwa logger.debug("Initialized %s", self.__class__.__name__) -class DepthwiseConv2D(KDepthwiseConv2d): # noqa,pylint:disable=too-many-ancestors +class DepthwiseConv2D(KDepthwiseConv2d): # noqa,pylint:disable=too-many-ancestors,abstract-method """ A standard Keras Depthwise Convolution 2D layer with parameters updated to be more appropriate for Faceswap architecture. @@ -204,17 +203,17 @@ def __init__(self, self._activation = Activation(activation, dtype="float32", name=name) logger.debug("Initialized %s", self.__class__.__name__) - def __call__(self, inputs: Tensor) -> Tensor: + def __call__(self, inputs: keras.KerasTensor) -> keras.KerasTensor: """ Call the Faceswap Convolutional Output Layer. Parameters ---------- - inputs: Tensor + inputs: :class:`keras.KerasTensor` The input to the layer Returns ------- - Tensor + :class:`keras.KerasTensor` The output tensor from the Convolution 2D Layer """ var_x = self._conv(inputs) @@ -338,17 +337,17 @@ def _get_layers(self) -> list[keras.layers.Layer]: logger.debug("%s layers: %s", self.__class__.__name__, retval) return retval - def __call__(self, inputs: Tensor) -> Tensor: + def __call__(self, inputs: keras.KerasTensor) -> keras.KerasTensor: """ Call the Faceswap Convolutional Layer. Parameters ---------- - inputs: Tensor + inputs: :class:`keras.KerasTensor` The input to the layer Returns ------- - Tensor + :class:`keras.KerasTensor` The output tensor from the Convolution 2D Layer """ var_x = inputs @@ -397,17 +396,17 @@ def __init__(self, self._activation = Activation("relu", name=f"{name}_relu") logger.debug("Initialized %s", self.__class__.__name__) - def __call__(self, inputs: Tensor) -> Tensor: + def __call__(self, inputs: keras.KerasTensor) -> keras.KerasTensor: """ Call the Faceswap Separable Convolutional 2D Block. Parameters ---------- - inputs: Tensor + inputs: :class:`keras.KerasTensor` The input to the layer Returns ------- - Tensor + :class:`keras.KerasTensor` The output tensor from the Upscale Layer """ var_x = self._conv(inputs) @@ -469,17 +468,17 @@ def __init__(self, self._shuffle = PixelShuffler(name=f"{name}_pixelshuffler", size=scale_factor) logger.debug("Initialized %s", self.__class__.__name__) - def __call__(self, inputs: Tensor) -> Tensor: + def __call__(self, inputs: keras.KerasTensor) -> keras.KerasTensor: """ Call the Faceswap Convolutional Layer. Parameters ---------- - inputs: Tensor + inputs: :class:`keras.KerasTensor` The input to the layer Returns ------- - Tensor + :class:`keras.KerasTensor` The output tensor from the Upscale Layer """ var_x = self._conv(inputs) @@ -563,17 +562,17 @@ def __init__(self, logger.debug("Initialized %s", self.__class__.__name__) - def __call__(self, inputs: Tensor) -> Tensor: + def __call__(self, inputs: keras.KerasTensor) -> keras.KerasTensor: """ Call the Faceswap Upscale 2x Layer. Parameters ---------- - inputs: Tensor + inputs: :class:`keras.KerasTensor` The input to the layer Returns ------- - Tensor + :class:`keras.KerasTensor` The output tensor from the Upscale Layer """ var_x = inputs @@ -659,17 +658,17 @@ def __init__(self, self._acivation = PReLU(name=f"{name}_prelu") logger.debug("Initialized %s", self.__class__.__name__) - def __call__(self, inputs: Tensor) -> Tensor: + def __call__(self, inputs: keras.KerasTensor) -> keras.KerasTensor: """ Call the Faceswap Resize Images Layer. Parameters ---------- - inputs: Tensor + inputs: :class:`keras.KerasTensor` The input to the layer Returns ------- - Tensor + :class:`keras.KerasTensor` The output tensor from the Upscale Layer """ var_x = inputs @@ -741,17 +740,17 @@ def __init__(self, for idx in range(2)] logger.debug("Initialized %s", self.__class__.__name__) - def __call__(self, inputs: Tensor) -> Tensor: + def __call__(self, inputs: keras.KerasTensor) -> keras.KerasTensor: """ Call the UpscaleDNY block Parameters ---------- - inputs: :class;`torch.Tensor` + inputs: :class:`keras.KerasTensor` The input to the block Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The output from the block """ var_x = self._upsample(inputs) @@ -844,17 +843,17 @@ def _get_layers(self) -> list[keras.layers.Layer]: logger.debug("%s layers: %s", self.__class__.__name__, retval) return retval - def __call__(self, inputs: Tensor) -> Tensor: + def __call__(self, inputs: keras.KerasTensor) -> keras.KerasTensor: """ Call the Faceswap Residual Block. Parameters ---------- - inputs: Tensor + inputs: :class:`keras.KerasTensor` The input to the layer Returns ------- - Tensor + :class:`keras.KerasTensor` The output tensor from the Upscale Layer """ var_x = inputs diff --git a/lib/model/normalization.py b/lib/model/normalization.py index cf68ec206d..34a341722e 100644 --- a/lib/model/normalization.py +++ b/lib/model/normalization.py @@ -8,17 +8,18 @@ import typing as T from keras import constraints, initializers, layers, ops, regularizers +from keras.layers.input_spec import InputSpec from keras.saving import get_custom_objects from lib.logger import parse_class_init if T.TYPE_CHECKING: - import torch + from keras import KerasTensor logger = logging.getLogger(__name__) -class AdaInstanceNormalization(layers.Layer): +class AdaInstanceNormalization(layers.Layer): # pylint:disable=too-many-ancestors,abstract-method """ Adaptive Instance Normalization Layer for Keras. Parameters @@ -62,7 +63,7 @@ def __init__(self, self.scale = scale logger.debug("Initialized %s", self.__class__.__name__) - def build(self, input_shape: tuple[int, ...]) -> None: + def build(self, input_shape: tuple[tuple[int, ...], ...]) -> None: """Creates the layer weights. Parameters @@ -80,17 +81,18 @@ def build(self, input_shape: tuple[int, ...]) -> None: super().build(input_shape) - def call(self, inputs: torch.Tensor) -> torch.Tensor: + def call(self, inputs: KerasTensor # pylint:disable=arguments-differ + ) -> KerasTensor: """This is where the layer's logic lives. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` Input tensor, or list/tuple of input tensors Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` A tensor or list/tuple of tensors """ input_shape = inputs[0].shape @@ -129,7 +131,8 @@ def get_config(self) -> dict[str, T.Any]: base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) - def compute_output_shape(self, input_shape: tuple[int, ...]) -> int: + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> int: """ Calculate the output shape from this layer. Parameters @@ -145,7 +148,7 @@ def compute_output_shape(self, input_shape: tuple[int, ...]) -> int: return input_shape[0] -class GroupNormalization(layers.Layer): +class GroupNormalization(layers.Layer): # pylint:disable=too-many-ancestors,abstract-method """ Group Normalization Parameters @@ -213,7 +216,7 @@ def build(self, input_shape: tuple[int, ...]) -> None: Keras tensor (future input to layer) or ``list``/``tuple`` of Keras tensors to reference for weight shape computations. """ - input_spec = [layers.InputSpec(shape=input_shape)] + input_spec = [InputSpec(shape=input_shape)] self.input_spec = input_spec # pylint:disable=attribute-defined-outside-init shape = [1 for _ in input_shape] if self.data_format == 'channels_last': @@ -234,17 +237,17 @@ def build(self, input_shape: tuple[int, ...]) -> None: name='beta') self.built = True # pylint:disable=attribute-defined-outside-init - def _process_4_channel(self, inputs: torch.Tensor) -> torch.Tensor: + def _process_4_channel(self, inputs: KerasTensor) -> KerasTensor: """ Logic for processing 4 channel inputs Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` The input to the layer Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` A tensor or list/tuple of tensors """ input_shape = inputs.shape @@ -292,17 +295,18 @@ def _process_4_channel(self, inputs: torch.Tensor) -> torch.Tensor: var_x = ops.reshape(var_x, (batch_size, channels, height, width)) return self.gamma * var_x + self.beta - def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: """This is where the layer's logic lives. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` Input tensor, or list/tuple of input tensors Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` A tensor or list/tuple of tensors """ input_shape = inputs.shape @@ -347,7 +351,7 @@ def get_config(self) -> dict[str, T.Any]: return dict(list(base_config.items()) + list(config.items())) -class InstanceNormalization(layers.Layer): +class InstanceNormalization(layers.Layer): # pylint:disable=too-many-ancestors,abstract-method """Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016). Normalize the activations of the previous layer at each step, i.e. applies a transformation @@ -436,8 +440,7 @@ def build(self, input_shape: tuple[int, ...]) -> None: if (self.axis is not None) and (ndim == 2): raise ValueError("Cannot specify axis for rank 1 tensor") - self.input_spec = layers.InputSpec( # pylint:disable=attribute-defined-outside-init - ndim=ndim) + self.input_spec = InputSpec(ndim=ndim) # pylint:disable=attribute-defined-outside-init if self.axis is None: shape = (1,) @@ -462,17 +465,18 @@ def build(self, input_shape: tuple[int, ...]) -> None: self.beta = None self.built = True # pylint:disable=attribute-defined-outside-init - def call(self, inputs: torch.Tensor) -> torch.Tensor: + def call(self, inputs: KerasTensor # pylint:disable=arguments-differ + ) -> KerasTensor: """This is where the layer's logic lives. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` Input tensor, or list/tuple of input tensors Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` A tensor or list/tuple of tensors """ input_shape = inputs.shape @@ -530,7 +534,7 @@ class name. These are handled by `Network` (one layer of abstraction above). return dict(list(base_config.items()) + list(config.items())) -class RMSNormalization(layers.Layer): +class RMSNormalization(layers.Layer): # pylint:disable=too-many-ancestors,abstract-method """ Root Mean Square Layer Normalization (Biao Zhang, Rico Sennrich, 2019) RMSNorm is a simplification of the original layer normalization (LayerNorm). LayerNorm is a @@ -568,10 +572,10 @@ def __init__(self, axis: int = -1, epsilon: float = 1e-8, partial: float = 0.0, - bias: bool = False, **kwargs) -> False: + bias: bool = False, + **kwargs) -> None: logger.debug(parse_class_init(locals())) self.scale = None - self.offset = 0 super().__init__(**kwargs) # Checks @@ -622,17 +626,18 @@ def build(self, input_shape: tuple[int, ...]) -> None: self.built = True # pylint:disable=attribute-defined-outside-init - def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def call(self, inputs: KerasTensor, *args, **kwargs # pylint:disable=arguments-differ + ) -> KerasTensor: """ Call Root Mean Square Layer Normalization Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` Input tensor, or list/tuple of input tensors Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` A tensor or list/tuple of tensors """ # Compute the axes along which to reduce the mean / variance @@ -653,7 +658,8 @@ def call(self, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor: output = self.scale * inputs * recip_square_root + self.offset return output - def compute_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]: + def compute_output_shape(self, input_shape: tuple[int, ...] # pylint:disable=arguments-differ + ) -> tuple[int, ...]: """ The output shape of the layer is the same as the input shape. Parameters diff --git a/lib/model/optimizers.py b/lib/model/optimizers.py index 107bbb2394..57574f3536 100644 --- a/lib/model/optimizers.py +++ b/lib/model/optimizers.py @@ -226,9 +226,9 @@ def _maybe_rectify(self, Parameters ---------- - momentum: :class:`keras.Tensor` + momentum: :class:`keras.KerasTensor` The momentum update - velocity: :class:`keras.Tensor` + velocity: :class:`keras.KerasTensor` The velocity update local_step: :class:`keras.KerasTensor` The current training step diff --git a/plugins/train/model/_base/settings.py b/plugins/train/model/_base/settings.py index f2bb9ec19c..7df492ab10 100644 --- a/plugins/train/model/_base/settings.py +++ b/plugins/train/model/_base/settings.py @@ -18,7 +18,6 @@ from contextlib import nullcontext -import torch import keras from keras import losses as k_losses from keras.config import set_dtype_policy @@ -33,6 +32,7 @@ from collections.abc import Callable from contextlib import AbstractContextManager as ContextManager from argparse import Namespace + from keras import KerasTensor from .model import State logger = logging.getLogger(__name__) @@ -53,8 +53,8 @@ class LossClass: kwargs: dict Any keyword arguments to supply to the loss function at initialization. """ - function: Callable[[torch.Tensor, torch.Tensor], - torch.Tensor] | T.Any = k_losses.MeanSquaredError + function: Callable[[KerasTensor, KerasTensor], + KerasTensor] | T.Any = k_losses.MeanSquaredError init: bool = True kwargs: dict[str, T.Any] = field(default_factory=dict) @@ -138,7 +138,7 @@ def configure(self, model: keras.models.Model) -> None: self._set_loss_functions(model.output_names) self._names.insert(0, "total") - def _set_loss_names(self, outputs: list[torch.Tensor]) -> None: + def _set_loss_names(self, outputs: list[KerasTensor]) -> None: """ Name the losses based on model output. This is used for correct naming in the state file, for display purposes only. @@ -147,7 +147,7 @@ def _set_loss_names(self, outputs: list[torch.Tensor]) -> None: Parameters ---------- - outputs: list + outputs: list[:class:`keras.KerasTensor`] A list of output tensors from the model plugin """ # TODO Use output names if/when these are fixed upstream @@ -163,7 +163,7 @@ def _set_loss_names(self, outputs: list[torch.Tensor]) -> None: self._names.append(f"{name}_{side}{suffix}") logger.debug(self._names) - def _get_function(self, name: str) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: + def _get_function(self, name: str) -> Callable[[KerasTensor, KerasTensor], KerasTensor]: """ Obtain the requested Loss function Parameters diff --git a/plugins/train/model/phaze_a.py b/plugins/train/model/phaze_a.py index 40abc83791..3c6b4fdb76 100644 --- a/plugins/train/model/phaze_a.py +++ b/plugins/train/model/phaze_a.py @@ -22,7 +22,7 @@ from ._base import ModelBase, get_all_sub_models if T.TYPE_CHECKING: - from torch import Tensor + from keras import KerasTensor logger = logging.getLogger(__name__) @@ -328,12 +328,12 @@ def _validate_encoder_architecture(self) -> None: f"minimum version required is {keras_min} whilst you have version " f"{keras_ver} installed.") - def build_model(self, inputs: list[Tensor]) -> keras.models.Model: + def build_model(self, inputs: list[KerasTensor]) -> keras.models.Model: """ Create the model's structure. Parameters ---------- - inputs: list + inputs: list[:class:`keras.KerasTensor`] A list of input tensors for the model. This will be a list of 2 tensors of shape :attr:`input_shape`, the first for side "a", the second for side "b". @@ -353,12 +353,12 @@ def build_model(self, inputs: list[Tensor]) -> keras.models.Model: autoencoder = keras.models.Model(inputs, outputs, name=self.model_name) return autoencoder - def _build_encoders(self, inputs: list[Tensor]) -> dict[str, keras.models.Model]: + def _build_encoders(self, inputs: list[KerasTensor]) -> dict[str, keras.models.Model]: """ Build the encoders for Phaze-A Parameters ---------- - inputs: list + inputs: list[:class:`keras.KerasTensor`] A list of input tensors for the model. This will be a list of 2 tensors of shape :attr:`input_shape`, the first for side "a", the second for side "b". @@ -492,12 +492,13 @@ def _build_decoders(self, return retval -def _bottleneck(inputs: Tensor, bottleneck: str, size: int, normalization: str) -> Tensor: +def _bottleneck(inputs: KerasTensor, bottleneck: str, size: int, normalization: str + ) -> KerasTensor: """ The bottleneck fully connected layer. Can be called from Encoder or FullyConnected layers. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` The input to the bottleneck layer bottleneck: str or ``None`` The type of layer to use for the bottleneck. ``None`` to not use a bottleneck @@ -508,7 +509,7 @@ def _bottleneck(inputs: Tensor, bottleneck: str, size: int, normalization: str) Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The output from the bottleneck """ norms = {"layer": kl.LayerNormalization, @@ -777,17 +778,17 @@ def __init__(self, config: dict) -> None: self._kernel_size = 3 if self._is_alt else 5 self._strides = 1 if self._is_alt else 2 - def __call__(self, inputs: Tensor) -> Tensor: + def __call__(self, inputs: KerasTensor) -> KerasTensor: """ Call the original Faceswap Encoder Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` The input tensor to the Faceswap Encoder Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The output tensor from the Faceswap Encoder """ var_x = inputs @@ -820,7 +821,7 @@ def __call__(self, inputs: Tensor) -> Tensor: strides=self._strides, relu_alpha=self._relu_alpha, name=f"{name}_convblk_{i}_1")(var_x) - var_x = kl.MaxPool2D(2, name=f"{name}_pool_{i}")(var_x) + var_x = kl.MaxPooling2D(2, name=f"{name}_pool_{i}")(var_x) return var_x @@ -901,17 +902,17 @@ def _scale_filters(self, original_filters: int) -> int: logger.debug("original_filters: %s, scaled_filters: %s", original_filters, retval) return retval - def _do_upsampling(self, inputs: Tensor) -> Tensor: + def _do_upsampling(self, inputs: KerasTensor) -> KerasTensor: """ Perform the upsampling at the end of the fully connected layers. Parameters ---------- - inputs: Tensor + inputs: :class:`keras.KerasTensor` The input to the upsample layers Returns ------- - Tensor + :class:`keras.KerasTensor` The output from the upsample layers """ upsample_filts = self._scale_filters(self._config["fc_upsample_filters"]) @@ -1011,7 +1012,7 @@ def __init__(self, self._layer_indicies = layer_indicies logger.debug("Initialized: %s", self.__class__.__name__,) - def _reshape_for_output(self, inputs: Tensor) -> Tensor: + def _reshape_for_output(self, inputs: KerasTensor) -> KerasTensor: """ Reshape the input for arbitrary output sizes. The number of filters in the input will have been scaled to the model output size allowing @@ -1019,12 +1020,12 @@ def _reshape_for_output(self, inputs: Tensor) -> Tensor: Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` The tensor that is to be reshaped Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The tensor shaped correctly to upscale to output size """ var_x = inputs @@ -1039,17 +1040,17 @@ def _reshape_for_output(self, inputs: Tensor) -> Tensor: return var_x def _upscale_block(self, - inputs: Tensor, + inputs: KerasTensor, filters: int, skip_residual: bool = False, - is_mask: bool = False) -> Tensor: + is_mask: bool = False) -> KerasTensor: """ Upscale block for Phaze-A Decoder. Uses requested upscale method, adds requested regularization and activation function. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` The input tensor for the upscale block filters: int The number of filters to use for the upscale @@ -1061,7 +1062,7 @@ def _upscale_block(self, Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The output tensor from the upscale block """ upscaler = _get_upscale_layer(self._config["dec_upscale_method"].lower(), @@ -1084,17 +1085,17 @@ def _upscale_block(self, var_x = kl.LeakyReLU(alpha=0.1)(var_x) return var_x - def _normalization(self, inputs: Tensor) -> Tensor: + def _normalization(self, inputs: KerasTensor) -> KerasTensor: """ Add a normalization layer if requested. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` The input tensor to apply normalization to. Returns -------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The tensor with any normalization applied """ if not self._config["dec_norm"]: @@ -1106,17 +1107,17 @@ def _normalization(self, inputs: Tensor) -> Tensor: "rms": RMSNormalization} return norms[self._config["dec_norm"]]()(inputs) - def _dny_entry(self, inputs: Tensor) -> Tensor: + def _dny_entry(self, inputs: KerasTensor) -> KerasTensor: """ Entry convolutions for using the upscale_dny method. Parameters ---------- - inputs: Tensor + inputs: :class:`keras.KerasTensor` The inputs to the dny entry block Returns ------- - Tensor + :class:`keras.KerasTensor` The output from the dny entry block """ var_x = Conv2DBlock(self._config["dec_max_filters"], @@ -1131,11 +1132,11 @@ def _dny_entry(self, inputs: Tensor) -> Tensor: relu_alpha=0.2)(var_x) return var_x - def __call__(self, inputs: Tensor | list[Tensor]) -> Tensor | list[Tensor]: + def __call__(self, inputs: KerasTensor | list[KerasTensor]) -> KerasTensor | list[KerasTensor]: """ Upscale Network. Parameters - inputs: Tensor or list of tensors + inputs: :class:`keras.KerasTensor` | list[:class:`keras.KerasTensor`] Input tensor(s) to upscale block. This will be a single tensor if learn mask is not selected or if this is the first call to the upscale blocks. If learn mask is selected and this is not the first call to upscale blocks, then this will be a list of the face @@ -1143,13 +1144,15 @@ def __call__(self, inputs: Tensor | list[Tensor]) -> Tensor | list[Tensor]: Returns ------- - Tensor or list of tensors + :class:`keras.KerasTensor` | list[:class:`keras.KerasTensor`] The output of encoder blocks. Either a single tensor (if learn mask is not enabled) or list of tensors (if learn mask is enabled) """ start_idx, end_idx = (0, None) if self._layer_indicies is None else self._layer_indicies end_idx = None if end_idx == -1 else end_idx + var_x: KerasTensor + var_y: KerasTensor if self._config["learn_mask"] and start_idx == 0: # Mask needs to be created var_x = inputs @@ -1222,17 +1225,17 @@ def __init__(self, @classmethod def _g_block(cls, - inputs: Tensor, - style: Tensor, + inputs: KerasTensor, + style: KerasTensor, filters: int, - recursions: int = 2) -> Tensor: + recursions: int = 2) -> KerasTensor: """ G_block adapted from ADAIN StyleGAN. Parameters ---------- - inputs: :class:`torch.Tensor` + inputs: :class:`keras.KerasTensor` The input tensor to the G-Block model - style: :class:`torch.Tensor` + style: :class:`keras.KerasTensor` The input combined 'style' tensor to the G-Block model filters: int The number of filters to use for the G-Block Convolutional layers @@ -1241,7 +1244,7 @@ def _g_block(cls, Returns ------- - :class:`torch.Tensor` + :class:`keras.KerasTensor` The output tensor from the G-Block model """ var_x = inputs diff --git a/plugins/train/trainer/_base.py b/plugins/train/trainer/_base.py index c630fa5a36..c87fbe08dc 100644 --- a/plugins/train/trainer/_base.py +++ b/plugins/train/trainer/_base.py @@ -26,6 +26,7 @@ if T.TYPE_CHECKING: from collections.abc import Callable + from keras import KerasTensor from plugins.train.model._base import ModelBase from lib.config import ConfigValueType @@ -167,15 +168,16 @@ def _handle_lr_finder(self) -> bool: f"{learning_rate:.1e}") return False - def _set_tensorboard(self) -> TorchTensorBoard: + def _set_tensorboard(self) -> TorchTensorBoard | None: """ Set up Tensorboard callback for logging loss. Bypassed if command line option "no-logs" has been selected. Returns ------- - :class:`keras.callbacks.TensorBoard` - Tensorboard object for the the current training session. + :class:`keras.callbacks.TensorBoard` | None + Tensorboard object for the the current training session. ``None`` if Tensorboard + logging is not selected """ if self._model.state.current_session["no_logs"]: logger.verbose("TensorBoard logging disabled") # type: ignore @@ -377,6 +379,7 @@ def save(self, is_exit: bool = False) -> None: ``True`` if save has been called on model exit. Default: ``False`` """ self._model.io.save(is_exit=is_exit) + assert self._tensorboard is not None self._tensorboard.on_save() if is_exit: self._clear_tensorboard() @@ -477,27 +480,27 @@ def _resize_sample(cls, logger.debug("Resizing sample: (side: '%s', sample.shape: %s, target_size: %s, scale: %s)", side, sample.shape, target_size, scale) interpn = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA - retval = np.array([cv2.resize(img, (target_size, target_size), interpn) + retval = np.array([cv2.resize(img, (target_size, target_size), interpolation=interpn) for img in sample]) logger.debug("Resized sample: (side: '%s' shape: %s)", side, retval.shape) return retval - def _filter_multiscale_output(self, standard: list[torch.Tensor], swapped: list[torch.Tensor] - ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + def _filter_multiscale_output(self, standard: list[KerasTensor], swapped: list[KerasTensor] + ) -> tuple[list[KerasTensor], list[KerasTensor]]: """ Only return the largest predictions if the model has multi-scaled output Parameters ---------- - standard: list[:class:`torch.Tensor`] + standard: list[:class:`keras.KerasTensor`] The standard output from the model - swapped: list[:class:`torch.Tensor`] + swapped: list[:class:`keras.KerasTensor`] The swapped output from the model Returns ------- - standard: list[:class:`torch.Tensor`] + standard: list[:class:`keras.KerasTensor`] The standard output from the model, filtered to just the largest output - swapped: list[:class:`torch.Tensor`] + swapped: list[:class:`keras.KerasTensor`] The swapped output from the model, filtered to just the largest output """ sizes = set(p.shape[1] for p in standard) @@ -512,16 +515,16 @@ def _filter_multiscale_output(self, standard: list[torch.Tensor], swapped: list[ [s.shape for s in standard], [s.shape for s in swapped]) return standard, swapped - def _collate_output(self, standard: list[torch.Tensor], swapped: list[torch.Tensor] + def _collate_output(self, standard: list[KerasTensor], swapped: list[KerasTensor] ) -> tuple[list[np.ndarray], list[np.ndarray]]: """ Merge the mask onto the preview image's 4th channel if learn mask is selected. Return as numpy array Parameters ---------- - standard: list[:class:`torch.Tensor`] + standard: list[:class:`keras.KerasTensor`] The standard output from the model - swapped: list[:class:`torch.Tensor`] + swapped: list[:class:`keras.KerasTensor`] The swapped output from the model Returns @@ -601,8 +604,10 @@ def _compile_preview(self, predictions: dict[T.Literal["a_a", "a_b", "b_b", "b_a for side, samples in self.images.items(): other_side = "a" if side == "b" else "b" - preds = [predictions[f"{side}_{side}"], - predictions[f"{other_side}_{side}"]] + preds = [predictions[T.cast(T.Literal["a_a", "a_b", "b_b", "b_a"], + f"{side}_{side}")], + predictions[T.cast(T.Literal["a_a", "a_b", "b_b", "b_a"], + f"{other_side}_{side}")]] display = self._to_full_frame(side, samples, preds) headers[side] = self._get_headers(side, display[0].shape[1]) figures[side] = np.stack([display[0], display[1], display[2], ], axis=1) diff --git a/requirements/_requirements_base.txt b/requirements/_requirements_base.txt index a78ebb1b88..bcba3ae559 100644 --- a/requirements/_requirements_base.txt +++ b/requirements/_requirements_base.txt @@ -1,9 +1,9 @@ tqdm>=4.65 psutil>=5.9.0 -numexpr>=2.8.7 -numpy>=1.26.0 +numexpr>=2.9.0 +numpy>=1.26.4 opencv-python>=4.9.0.0 -pillow>=9.4.0,<10.0.0 +pillow>10.0.0 scikit-learn>=1.3.0 fastcluster>=1.2.6 matplotlib>=3.8.0 @@ -11,3 +11,9 @@ imageio>=2.33.1 imageio-ffmpeg>=0.4.9 ffmpy>=0.3.0 pywin32>=305 ; sys_platform == "win32" +pytorch>=2.2.1 +torchaudio>=2.2.1 +torchvision>=0.17.1 +tensorboard>=2.12.1 +keras>=3.2.1 + diff --git a/requirements/requirements_nvidia.txt b/requirements/requirements_nvidia.txt index de9c9c7bb3..cce7de596a 100644 --- a/requirements/requirements_nvidia.txt +++ b/requirements/requirements_nvidia.txt @@ -2,4 +2,4 @@ # Exclude badly numbered Python2 version of nvidia-ml-py nvidia-ml-py>=12.535,<300 pynvx==1.0.0 ; sys_platform == "darwin" -pytorch-cuda==12.1 +pytorch-cuda>=12.1