Skip to content

Commit

Permalink
Unscale loss value in TF (#20610)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Dec 9, 2024
1 parent b1e4057 commit c789805
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 19 deletions.
8 changes: 5 additions & 3 deletions keras/src/backend/tensorflow/distribute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def test_correctness_with_fit_and_regularizer(self):
)
model = models.Model(inputs, layer(inputs))
model.compile(loss="mse", optimizer="sgd")
model.fit(x, y, batch_size=batch_size, epochs=1)

history = model.fit(x, y, batch_size=batch_size, epochs=1)
expected_loss = history.history["loss"]
expected_weights = keras.ops.convert_to_numpy(layer.kernel)

# Runs with a mirrored strategy.
Expand All @@ -177,8 +177,10 @@ def test_correctness_with_fit_and_regularizer(self):
)
model = models.Model(inputs, layer(inputs))
model.compile(loss="mse", optimizer="sgd")
model.fit(x, y, batch_size=batch_size, epochs=1)
history = model.fit(x, y, batch_size=batch_size, epochs=1)
weights = strategy.run(lambda: layer.kernel.value).values

self.assertAllClose(history.history["loss"], expected_loss)
for w in weights:
self.assertAllClose(
keras.ops.convert_to_numpy(w), expected_weights
Expand Down
22 changes: 7 additions & 15 deletions keras/src/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import tensorflow as tf
from tensorflow.python.eager import context as tf_context

from keras.src import backend as backend_module
from keras.src import callbacks as callbacks_module
from keras.src import metrics as metrics_module
from keras.src import ops as ops_module
from keras.src import optimizers as optimizers_module
from keras.src import tree
from keras.src.losses import loss as loss_module
from keras.src.trainers import trainer as base_trainer
from keras.src.trainers.data_adapters import array_slicing
from keras.src.trainers.data_adapters import data_adapter_utils
Expand Down Expand Up @@ -66,7 +65,8 @@ def train_step(self, data):
training=True,
)
self._loss_tracker.update_state(
loss, sample_weight=tf.shape(tree.flatten(x)[0])[0]
loss_module.unscale_loss_for_distribution(loss),
sample_weight=tf.shape(tree.flatten(x)[0])[0],
)
if self.optimizer is not None:
loss = self.optimizer.scale_loss(loss)
Expand All @@ -93,7 +93,8 @@ def test_step(self, data):
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False
)
self._loss_tracker.update_state(
loss, sample_weight=tf.shape(tree.flatten(x)[0])[0]
loss_module.unscale_loss_for_distribution(loss),
sample_weight=tf.shape(tree.flatten(x)[0])[0],
)
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)

Expand Down Expand Up @@ -710,17 +711,8 @@ def _maybe_symbolic_build(self, iterator=None, data_batch=None):
self._symbolic_build(data_batch=data_batch)

def _aggregate_additional_loss(self, loss):
if not backend_module.is_float_dtype(loss.dtype):
loss = ops_module.cast(loss, dtype=backend_module.floatx())
loss = ops_module.sum(loss)

# Scales the loss by the number of replicas in the strategy.
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
if num_replicas > 1:
loss = ops_module.multiply(
loss, ops_module.cast(1.0 / num_replicas, loss.dtype)
)
return loss
loss = super()._aggregate_additional_loss(loss)
return loss_module.scale_loss_for_distribution(loss)


class TFEpochIterator(EpochIterator):
Expand Down
15 changes: 15 additions & 0 deletions keras/src/losses/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,18 @@ def scale_loss_for_distribution(value):
value, ops.cast(1.0 / num_replicas, value.dtype)
)
return value


def unscale_loss_for_distribution(value):
"""Unscales the given value by the number of replicas in the strategy.
Currently, this function is only effective when using the tensorflow backend
and `tf.distribute`.
"""
if backend.backend() == "tensorflow":
import tensorflow as tf

num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
if num_replicas > 1:
value = ops.multiply(value, ops.cast(num_replicas, value.dtype))
return value
4 changes: 3 additions & 1 deletion keras/src/trainers/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras.src import ops
from keras.src import tree
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.losses import loss as loss_module
from keras.src.utils.naming import get_object_name
from keras.src.utils.tracking import Tracker

Expand Down Expand Up @@ -799,7 +800,8 @@ def resolve_path(path, object):
# Record *unweighted* individual losses.
if metric:
metric.update_state(
value, sample_weight=tree.flatten(y_p)[0].shape[0]
loss_module.unscale_loss_for_distribution(value),
sample_weight=tree.flatten(y_p)[0].shape[0],
)
if loss_weight is not None:
value = ops.multiply(value, loss_weight)
Expand Down

0 comments on commit c789805

Please sign in to comment.