Skip to content

Commit

Permalink
Docstring fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Dec 24, 2023
1 parent 41bad26 commit d1df2ae
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 21 deletions.
39 changes: 22 additions & 17 deletions keras/callbacks/swap_ema_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,26 @@

@keras_export("keras.callbacks.SwapEMAWeights")
class SwapEMAWeights(Callback):
"""Swaps EMA weights before and after the evaluation.
"""Swaps model weights and EMA weights before and after evaluation.
`SwapEMAWeights` callback is used in conjunction with the optimizer using
`use_ema=True`.
This callbacks replaces the model's weight values with the values of
the optimizer's EMA weights (the exponential moving average of the past
model weights values, implementing "Polyak averaging") before model
evaluation, and restores the previous weights after evaluation.
Note that: we use swapping to save memory. The behavior is undefined if you
modify the EMA weights or model weights in other callbacks.
The `SwapEMAWeights` callback is to be used in conjunction with
an optimizer that sets `use_ema=True`.
Note that the weights are swapped in-place in order to save memory.
The behavior is undefined if you modify the EMA weights
or model weights in other callbacks.
Example:
```python
# Remember to set `use_ema=True`
model.compile(optimizer=SGD(use_ema=True), loss=..., metrics=...)
# Remember to set `use_ema=True` in the optimizer
optimizer = SGD(use_ema=True)
model.compile(optimizer=optimizer, loss=..., metrics=...)
# Metrics will be computed with EMA weights
model.fit(X_train, Y_train, callbacks=[SwapEMAWeights()])
Expand All @@ -33,10 +40,10 @@ class SwapEMAWeights(Callback):
```
Args:
swap_on_epoch: whether to perform swapping `on_epoch_begin` and
`on_epoch_end`. This is useful if you want to use EMA weights for
other callbacks such as `ModelCheckpoint`. Defaults to `False`.
swap_on_epoch: whether to perform swapping at `on_epoch_begin()`
and `on_epoch_end()`. This is useful if you want to use
EMA weights for other callbacks such as `ModelCheckpoint`.
Defaults to `False`.
"""

def __init__(self, swap_on_epoch=False):
Expand Down Expand Up @@ -111,9 +118,8 @@ def _swap_variables(self):
if not hasattr(optimizer, "_model_variables_moving_average"):
raise ValueError(
"SwapEMAWeights must be used when "
"`_model_variables_moving_average` exists in the optimizer. "
"Please verify if you have set `use_ema=True` in your "
f"optimizer. Received: use_ema={optimizer.use_ema}"
"`use_ema=True` is set on the optimizer. "
f"Received: use_ema={optimizer.use_ema}"
)
if backend.backend() == "tensorflow":
self._tf_swap_variables(optimizer)
Expand All @@ -129,9 +135,8 @@ def _finalize_ema_values(self):
if not hasattr(optimizer, "_model_variables_moving_average"):
raise ValueError(
"SwapEMAWeights must be used when "
"`_model_variables_moving_average` exists in the optimizer. "
"Please verify if you have set `use_ema=True` in the "
f"optimizer. Received: use_ema={optimizer.use_ema}"
"`use_ema=True` is set on the optimizer. "
f"Received: use_ema={optimizer.use_ema}"
)
if backend.backend() == "tensorflow":
self._tf_finalize_ema_values(optimizer)
Expand Down
5 changes: 1 addition & 4 deletions keras/callbacks/swap_ema_weights_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ def test_swap_ema_weights_with_invalid_optimizer(self):
model = self._get_compiled_model(use_ema=False)
with self.assertRaisesRegex(
ValueError,
(
"SwapEMAWeights must be used when "
"`_model_variables_moving_average` exists in the optimizer. "
),
("SwapEMAWeights must be used when " "`use_ema=True` is set"),
):
model.fit(
self.x_train,
Expand Down

0 comments on commit d1df2ae

Please sign in to comment.