Skip to content

Commit

Permalink
Fix per side loss output
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Mar 6, 2024
1 parent 6d49eaa commit ac47e02
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 34 deletions.
2 changes: 1 addition & 1 deletion plugins/train/model/_base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python3
""" Base class for Models plugins ALL Models should at least inherit from this class. """

from .model import get_all_sub_models, ModelBase
from .model import get_all_sub_models, FSModel, ModelBase
167 changes: 162 additions & 5 deletions plugins/train/model/_base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
import sys
import time
import typing as T
import warnings

from collections import OrderedDict

import numpy as np
import keras
import keras.backend as K
from keras.trainers.data_adapters import data_adapter_utils
from keras import backend, ops
import torch

from lib.serializer import get_serializer
from lib.model.nn_blocks import set_config as set_nnblock_config
Expand All @@ -34,6 +37,160 @@
_CONFIG: dict[str, ConfigValueType] = {}



def fs_compile_loss_call(self, y_true, y_pred, sample_weight=None):
""" Override keras CompileLoss.call to do return the loss for each output, rather than the
summed loss for all outputs
"""
if not self.built:
self.build(y_true, y_pred)

y_true = self._flatten_y(y_true)
y_pred = self._flatten_y(y_pred)

if sample_weight is not None:
sample_weight = self._flatten_y(sample_weight)
# For multi-outputs, repeat sample weights for n outputs.
if len(sample_weight) < len(y_true):
sample_weight = [sample_weight[0] for _ in range(len(y_true))]
else:
sample_weight = [None for _ in y_true]

loss_values = []
for loss, y_t, y_p, loss_weight, sample_weight in zip(
self.flat_losses,
y_true,
y_pred,
self.flat_loss_weights,
sample_weight,
):
if loss:
value = loss_weight * ops.cast(
loss(y_t, y_p, sample_weight), dtype=backend.floatx()
)
loss_values.append(value)
if loss_values:
return loss_values
# original code is next 2 lines
#total_loss = sum(loss_values)
#return total_loss
return None


setattr(keras.trainers.compile_utils.CompileLoss, "call", fs_compile_loss_call)


class FSModel(keras.models.Model):
""" Overriden Keras model with custom training code """
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._models: dict[T.Literal["a", "b"], keras.models.Model] = {}

@property
def metrics(self) -> list[keras.metrics.Metric]:
"""list[:class:`keras.metrics.Metric] The list of metrics for the model
As we have patched :attr:`_loss_tracker` to be a list, we also need to patch metrics
to return correctly
"""
metrics = self._loss_tracker if self.compiled else []
metrics.extend(self._metrics[:])
if self.compiled and self._compile_metrics is not None:
metrics += [self._compile_metrics]
return metrics

def _patched_loss_handler(self, all_loss: list[torch.Tensor]) -> None:
""" Keras only outputs and tracks a single loss scalar. We need to handle loss for each
output.
On first call we remove the :attr:`_loss_tracker` and replace it with a list. A Mean
metrics handler is then added to the list for each model output/loss. These are then
used for tracking going forwards
Backprop is run for each loss output to accumulate gradients for each loss function
along its correct path
Parameters
----------
all_loss: list[:class:`torch.Tensor`]
The list of loss for each output
"""
patch_loss_tracker = not isinstance(self._loss_tracker, list)

if patch_loss_tracker:
logger.debug("Patching loss tracker for all outputs")
del self._loss_tracker
self._loss_tracker = []

for idx, (loss, output_name) in enumerate(zip(all_loss, self.output_names)):

if patch_loss_tracker:
loss_name = f"loss_{output_name.split('_', maxsplit=1)[-1]}" # TODO make sure works with multi-out
tracker = keras.metrics.Mean(name=loss_name)
logger.debug("Adding loss tracker: %s", tracker)
self._loss_tracker.append(tracker)
self._metrics.remove(tracker) # Loss shouldn't be added to metrics

self._loss_tracker[idx].update_state(loss)

if self.optimizer is not None:
loss = self.optimizer.scale_loss(loss)

# Compute gradients
if self.trainable_weights:
# Call torch.Tensor.backward() on the loss to compute gradients
# for the weights.
loss.backward()


def train_step(self, data: tuple[list[list[np.ndarray]], list[list[np.ndarray]], None]
) -> dict[str, float]:
""" Original keras train_step code, but with loss split out for A and B to perform 2
backprops
Parameters
----------
data: tuple[list[list[class:`numpy.ndarray`]], list[list[class:`numpy.ndarray`]], None]
The training data for a train step
Returns
-------
dict[str, float]
The loss output from the train step
"""
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)

# Compute predictions
if self._call_has_training_arg:
y_pred = self(x, training=True)
else:
y_pred = self(x)

# Call torch.nn.Module.zero_grad() to clear the leftover gradients
# for the weights from the previous train step.
self.zero_grad()

loss = self.compute_loss(x=x, # Updated function to return all loss outputs
y=y,
y_pred=y_pred,
sample_weight=sample_weight)

self._patched_loss_handler(loss)

# Apply gradients
if self.trainable_weights:
trainable_weights = self.trainable_weights[:]
gradients = [v.value.grad for v in trainable_weights]

# Update weights
with torch.no_grad():
self.optimizer.apply(gradients, trainable_weights)
else:
warnings.warn("The model does not have any trainable weights.")

return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)


class ModelBase():
""" Base class that all model plugins should inherit from.
Expand Down Expand Up @@ -260,7 +417,7 @@ def build(self) -> None:
inputs = self._get_inputs()
if not self._settings.use_mixed_precision and not is_summary:
# Store layer names which can be switched to mixed precision

# TODO Re-enable mixed precision switching
self._model = self.build_model(inputs)
#model, mp_layers = self._settings.get_mixed_precision_layers(self.build_model,
Expand Down Expand Up @@ -362,13 +519,13 @@ def build_model(self, inputs: list[keras.layers.Input]) -> keras.models.Model:

def _summary_to_log(self, summary: str, line_break: bool) -> None:
""" Function to output Keras model summary to log file at verbose log level
Parameters
----------
summary, str
The model summary output from keras
line_breal: bool
line_break: bool
Unused, but required by Keras for print_fn as of keras 3.0.5
"""
for line in summary.splitlines():
Expand Down
4 changes: 2 additions & 2 deletions plugins/train/model/original.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from keras.models import Model as KModel

from lib.model.nn_blocks import Conv2DOutput, Conv2DBlock, UpscaleBlock
from ._base import ModelBase
from ._base import ModelBase, FSModel


class Model(ModelBase):
Expand Down Expand Up @@ -89,7 +89,7 @@ def build_model(self, inputs):

outputs = self.decoder("a")(encoder_a) + self.decoder("b")(encoder_b)

autoencoder = KModel(inputs, outputs, name=self.model_name)
autoencoder = FSModel(inputs, outputs, name=self.model_name)
return autoencoder

def encoder(self):
Expand Down
50 changes: 24 additions & 26 deletions plugins/train/trainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ def train_one_step(self,
model_inputs, model_targets = self._feeder.get_batch()

try:
loss: list[float] = self._model.model.train_on_batch(model_inputs, y=model_targets)
loss: dict[str, float] = self._model.model.train_on_batch(model_inputs,
y=model_targets,
return_dict=True)
except OutOfMemoryError as err:
msg = ("You do not have enough GPU memory available to train the selected model at "
"the selected settings. You can try a number of things:"
Expand All @@ -259,39 +261,36 @@ def train_one_step(self,
"(in config) if it has one.")
raise FaceswapError(msg) from err
self._log_tensorboard(loss)
#loss = self._collate_and_store_loss(loss[1:]) # TODO
#self._print_loss(loss) # TODO
print(loss)
loss = self._collate_and_store_loss(loss)
self._print_loss(loss)
if do_snapshot:
self._model.io.snapshot()
self._update_viewers(viewer, timelapse_kwargs)

def _log_tensorboard(self, loss: list[float]) -> None:
def _log_tensorboard(self, loss: dict[str, float]) -> None:
""" Log current loss to Tensorboard log files
Parameters
----------
loss: list
The list of loss ``floats`` output from the model
loss: dict[str, float]
The dictionary od loss per output
"""
if not self._tensorboard:
return
logger.trace("Updating TensorBoard log") # type: ignore
logs = {log[0]: log[1]
for log in zip(self._model.state.loss_names, loss)}

# Bug in TF 2.8/2.9/2.10 where batch recording got deleted.
# ref: https://github.com/keras-team/keras/issues/16173
with tf.summary.record_if(True), self._tensorboard._train_writer.as_default(): # noqa:E501 pylint:disable=protected-access,not-context-manager
for name, value in logs.items():
for name, value in loss.items():
tf.summary.scalar(
"batch_" + name,
value,
step=self._tensorboard._train_step) # pylint:disable=protected-access
# TODO revert this code if fixed in tensorflow
# self._tensorboard.on_train_batch_end(self._model.iterations, logs=logs)
# self._tensorboard.on_train_batch_end(self._model.iterations, logs=loss)

def _collate_and_store_loss(self, loss: list[float]) -> list[float]:
def _collate_and_store_loss(self, loss: dict[str, float]) -> dict[str, float]:
""" Collate the loss into totals for each side.
The losses are summed into a total for each side. Loss totals are added to
Expand All @@ -301,43 +300,42 @@ def _collate_and_store_loss(self, loss: list[float]) -> list[float]:
Parameters
----------
loss: list
The list of loss ``floats`` for each side this iteration (excluding total combined
loss)
loss: dict[str, float]
The loss keys per output with the loss ``floats`` for each key)
Returns
-------
list
List of 2 ``floats`` which is the total loss for each side (eg sum of face + mask loss)
dict[str, float]
The combined loss for each side of the model
Raises
------
FaceswapError
If a NaN is detected, a :class:`FaceswapError` will be raised
"""
# NaN protection
if self._config["nan_protection"] and not all(np.isfinite(val) for val in loss):
if self._config["nan_protection"] and not all(np.isfinite(val) for val in loss.values()):
logger.critical("NaN Detected. Loss: %s", loss)
raise FaceswapError("A NaN was detected and you have NaN protection enabled. Training "
"has been terminated.")

split = len(loss) // 2
combined_loss = [sum(loss[:split]), sum(loss[split:])]
self._model.add_history(combined_loss)
logger.trace("original loss: %s, combined_loss: %s", loss, combined_loss) # type: ignore
a_loss = [v for k, v in loss.items() if k.endswith("_a")]
b_loss = [v for k, v in loss.items() if k.endswith("_b")]
combined_loss = {"loss_a": sum(a_loss), "loss_b": sum(b_loss)}
self._model.add_history(list(loss.values()))
logger.debug("original loss: %s, combined_loss: %s", loss, combined_loss) # type: ignore
return combined_loss

def _print_loss(self, loss: list[float]) -> None:
def _print_loss(self, loss: dict[str, float]) -> None:
""" Outputs the loss for the current iteration to the console.
Parameters
----------
loss: list
loss: dict[str, float]
The loss for each side. List should contain 2 ``floats`` side "a" in position 0 and
side "b" in position `.
"""
output = ", ".join([f"Loss {side}: {side_loss:.5f}"
for side, side_loss in zip(("A", "B"), loss)])
output = ", ".join([f"{' '.join(k.split('_')).title()}: {v:.5f}" for k, v in loss.items()])
timestamp = time.strftime("%H:%M:%S")
output = f"[{timestamp}] [#{self._model.iterations:05d}] {output}"
try:
Expand Down

0 comments on commit ac47e02

Please sign in to comment.