From c7898055e61b3c8020a5b98e10c5f23627392cd8 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Tue, 10 Dec 2024 04:03:27 +0800 Subject: [PATCH] Unscale loss value in TF (#20610) --- .../src/backend/tensorflow/distribute_test.py | 8 ++++--- keras/src/backend/tensorflow/trainer.py | 22 ++++++------------- keras/src/losses/loss.py | 15 +++++++++++++ keras/src/trainers/compile_utils.py | 4 +++- 4 files changed, 30 insertions(+), 19 deletions(-) diff --git a/keras/src/backend/tensorflow/distribute_test.py b/keras/src/backend/tensorflow/distribute_test.py index f46c524427b..3c29777c582 100644 --- a/keras/src/backend/tensorflow/distribute_test.py +++ b/keras/src/backend/tensorflow/distribute_test.py @@ -162,8 +162,8 @@ def test_correctness_with_fit_and_regularizer(self): ) model = models.Model(inputs, layer(inputs)) model.compile(loss="mse", optimizer="sgd") - model.fit(x, y, batch_size=batch_size, epochs=1) - + history = model.fit(x, y, batch_size=batch_size, epochs=1) + expected_loss = history.history["loss"] expected_weights = keras.ops.convert_to_numpy(layer.kernel) # Runs with a mirrored strategy. @@ -177,8 +177,10 @@ def test_correctness_with_fit_and_regularizer(self): ) model = models.Model(inputs, layer(inputs)) model.compile(loss="mse", optimizer="sgd") - model.fit(x, y, batch_size=batch_size, epochs=1) + history = model.fit(x, y, batch_size=batch_size, epochs=1) weights = strategy.run(lambda: layer.kernel.value).values + + self.assertAllClose(history.history["loss"], expected_loss) for w in weights: self.assertAllClose( keras.ops.convert_to_numpy(w), expected_weights diff --git a/keras/src/backend/tensorflow/trainer.py b/keras/src/backend/tensorflow/trainer.py index 556100b14c9..e4f999dcd78 100644 --- a/keras/src/backend/tensorflow/trainer.py +++ b/keras/src/backend/tensorflow/trainer.py @@ -5,12 +5,11 @@ import tensorflow as tf from tensorflow.python.eager import context as tf_context -from keras.src import backend as backend_module from keras.src import callbacks as callbacks_module from keras.src import metrics as metrics_module -from keras.src import ops as ops_module from keras.src import optimizers as optimizers_module from keras.src import tree +from keras.src.losses import loss as loss_module from keras.src.trainers import trainer as base_trainer from keras.src.trainers.data_adapters import array_slicing from keras.src.trainers.data_adapters import data_adapter_utils @@ -66,7 +65,8 @@ def train_step(self, data): training=True, ) self._loss_tracker.update_state( - loss, sample_weight=tf.shape(tree.flatten(x)[0])[0] + loss_module.unscale_loss_for_distribution(loss), + sample_weight=tf.shape(tree.flatten(x)[0])[0], ) if self.optimizer is not None: loss = self.optimizer.scale_loss(loss) @@ -93,7 +93,8 @@ def test_step(self, data): x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False ) self._loss_tracker.update_state( - loss, sample_weight=tf.shape(tree.flatten(x)[0])[0] + loss_module.unscale_loss_for_distribution(loss), + sample_weight=tf.shape(tree.flatten(x)[0])[0], ) return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) @@ -710,17 +711,8 @@ def _maybe_symbolic_build(self, iterator=None, data_batch=None): self._symbolic_build(data_batch=data_batch) def _aggregate_additional_loss(self, loss): - if not backend_module.is_float_dtype(loss.dtype): - loss = ops_module.cast(loss, dtype=backend_module.floatx()) - loss = ops_module.sum(loss) - - # Scales the loss by the number of replicas in the strategy. - num_replicas = tf.distribute.get_strategy().num_replicas_in_sync - if num_replicas > 1: - loss = ops_module.multiply( - loss, ops_module.cast(1.0 / num_replicas, loss.dtype) - ) - return loss + loss = super()._aggregate_additional_loss(loss) + return loss_module.scale_loss_for_distribution(loss) class TFEpochIterator(EpochIterator): diff --git a/keras/src/losses/loss.py b/keras/src/losses/loss.py index a47e542c378..6af73902d0f 100644 --- a/keras/src/losses/loss.py +++ b/keras/src/losses/loss.py @@ -239,3 +239,18 @@ def scale_loss_for_distribution(value): value, ops.cast(1.0 / num_replicas, value.dtype) ) return value + + +def unscale_loss_for_distribution(value): + """Unscales the given value by the number of replicas in the strategy. + + Currently, this function is only effective when using the tensorflow backend + and `tf.distribute`. + """ + if backend.backend() == "tensorflow": + import tensorflow as tf + + num_replicas = tf.distribute.get_strategy().num_replicas_in_sync + if num_replicas > 1: + value = ops.multiply(value, ops.cast(num_replicas, value.dtype)) + return value diff --git a/keras/src/trainers/compile_utils.py b/keras/src/trainers/compile_utils.py index fc1b46874bb..1ca6e54f21f 100644 --- a/keras/src/trainers/compile_utils.py +++ b/keras/src/trainers/compile_utils.py @@ -5,6 +5,7 @@ from keras.src import ops from keras.src import tree from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.losses import loss as loss_module from keras.src.utils.naming import get_object_name from keras.src.utils.tracking import Tracker @@ -799,7 +800,8 @@ def resolve_path(path, object): # Record *unweighted* individual losses. if metric: metric.update_state( - value, sample_weight=tree.flatten(y_p)[0].shape[0] + loss_module.unscale_loss_for_distribution(value), + sample_weight=tree.flatten(y_p)[0].shape[0], ) if loss_weight is not None: value = ops.multiply(value, loss_weight)