diff --git a/keras/callbacks/swap_ema_weights.py b/keras/callbacks/swap_ema_weights.py index 134550f3157..eed3797b3b9 100644 --- a/keras/callbacks/swap_ema_weights.py +++ b/keras/callbacks/swap_ema_weights.py @@ -6,19 +6,26 @@ @keras_export("keras.callbacks.SwapEMAWeights") class SwapEMAWeights(Callback): - """Swaps EMA weights before and after the evaluation. + """Swaps model weights and EMA weights before and after evaluation. - `SwapEMAWeights` callback is used in conjunction with the optimizer using - `use_ema=True`. + This callbacks replaces the model's weight values with the values of + the optimizer's EMA weights (the exponential moving average of the past + model weights values, implementing "Polyak averaging") before model + evaluation, and restores the previous weights after evaluation. - Note that: we use swapping to save memory. The behavior is undefined if you - modify the EMA weights or model weights in other callbacks. + The `SwapEMAWeights` callback is to be used in conjunction with + an optimizer that sets `use_ema=True`. + + Note that the weights are swapped in-place in order to save memory. + The behavior is undefined if you modify the EMA weights + or model weights in other callbacks. Example: ```python - # Remember to set `use_ema=True` - model.compile(optimizer=SGD(use_ema=True), loss=..., metrics=...) + # Remember to set `use_ema=True` in the optimizer + optimizer = SGD(use_ema=True) + model.compile(optimizer=optimizer, loss=..., metrics=...) # Metrics will be computed with EMA weights model.fit(X_train, Y_train, callbacks=[SwapEMAWeights()]) @@ -33,10 +40,10 @@ class SwapEMAWeights(Callback): ``` Args: - swap_on_epoch: whether to perform swapping `on_epoch_begin` and - `on_epoch_end`. This is useful if you want to use EMA weights for - other callbacks such as `ModelCheckpoint`. Defaults to `False`. - + swap_on_epoch: whether to perform swapping at `on_epoch_begin()` + and `on_epoch_end()`. This is useful if you want to use + EMA weights for other callbacks such as `ModelCheckpoint`. + Defaults to `False`. """ def __init__(self, swap_on_epoch=False): @@ -111,9 +118,8 @@ def _swap_variables(self): if not hasattr(optimizer, "_model_variables_moving_average"): raise ValueError( "SwapEMAWeights must be used when " - "`_model_variables_moving_average` exists in the optimizer. " - "Please verify if you have set `use_ema=True` in your " - f"optimizer. Received: use_ema={optimizer.use_ema}" + "`use_ema=True` is set on the optimizer. " + f"Received: use_ema={optimizer.use_ema}" ) if backend.backend() == "tensorflow": self._tf_swap_variables(optimizer) @@ -129,9 +135,8 @@ def _finalize_ema_values(self): if not hasattr(optimizer, "_model_variables_moving_average"): raise ValueError( "SwapEMAWeights must be used when " - "`_model_variables_moving_average` exists in the optimizer. " - "Please verify if you have set `use_ema=True` in the " - f"optimizer. Received: use_ema={optimizer.use_ema}" + "`use_ema=True` is set on the optimizer. " + f"Received: use_ema={optimizer.use_ema}" ) if backend.backend() == "tensorflow": self._tf_finalize_ema_values(optimizer) diff --git a/keras/callbacks/swap_ema_weights_test.py b/keras/callbacks/swap_ema_weights_test.py index 34df529caf1..c24895d15a2 100644 --- a/keras/callbacks/swap_ema_weights_test.py +++ b/keras/callbacks/swap_ema_weights_test.py @@ -53,10 +53,7 @@ def test_swap_ema_weights_with_invalid_optimizer(self): model = self._get_compiled_model(use_ema=False) with self.assertRaisesRegex( ValueError, - ( - "SwapEMAWeights must be used when " - "`_model_variables_moving_average` exists in the optimizer. " - ), + ("SwapEMAWeights must be used when " "`use_ema=True` is set"), ): model.fit( self.x_train,