Skip to content

Commit

Permalink
Fix optimizer saving
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Mar 26, 2024
1 parent ed0fe85 commit 3fbfef2
Showing 1 changed file with 80 additions and 25 deletions.
105 changes: 80 additions & 25 deletions plugins/train/model/_base/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

if T.TYPE_CHECKING:
from .model import ModelBase
from keras.optimizers import Optimizer

logger = logging.getLogger(__name__) # pylint:disable=invalid-name

Expand Down Expand Up @@ -168,42 +169,46 @@ def load(self) -> kmodels.Model:
logger.info("Loaded model from disk: '%s'", self.filename)
return model

def save(self,
is_exit: bool = False,
force_save_optimizer: bool = False) -> None:
""" Backup and save the model and state file.
def _remove_optimizer(self) -> Optimizer:
""" Keras 3 `.keras` format ignores the `save_optimizer` kwarg. To hack around this we
remove the optimizer from the model prior to saving and then re-attach it to the model
Returns
-------
:class:`keras.optimizers.Optimizer` | None
The optimizer for the model, if it should not be saved. ``None`` if it should be saved
"""
retval = self._plugin.model.optimizer
del self._plugin.model.optimizer
logger.debug("Removed optimizer for saving: %s", retval)
return retval

def _save_model(self, is_exit: bool, force_save_optimizer: bool) -> None:
""" Save the model either with or without the optimizer weights
Keras 3 ignores 'save_optimizer` so if it should not be saved, we remove it from
the model for saving, then re-attach it
Parameters
----------
is_exit: bool, optional
is_exit: bool
``True`` if the save request has come from an exit process request otherwise ``False``.
Default: ``False``
force_save_optimizer: bool, optional
force_save_optimizer: bool
``True`` to force saving the optimizer weights with the model, otherwise ``False``.
Default:``False``
"""
logger.debug("Backing up and saving models")
include_optimizer = (force_save_optimizer or
self._save_optimizer == "always" or
(self._save_optimizer == "exit" and is_exit))

print("\x1b[2K", end="\r") # Clear last line
logger.info("Saving Model...")
self._plugin.model.save(self.filename, include_optimizer=include_optimizer)
if not include_optimizer:
optimizer = self._remove_optimizer()

self._plugin.model.save(self.filename)
self._plugin.state.save()

save_average = self._get_save_average()
should_backup = self._should_backup(save_average)
if save_average and should_backup:
self._backup.backup_model(self.filename)
self._backup.backup_model(self._plugin.state.filename)

msg = "[Saved optimizer state for Snapshot]" if force_save_optimizer else "[Saved model]"
if save_average:
msg += f" - Average total loss since last save: {save_average:.5f}"
if should_backup:
msg += " [Model backed up]"
logger.info(msg)
if not include_optimizer:
logger.debug("Re-attaching optimizer: %s", optimizer)
setattr(self._plugin.model, "optimizer", optimizer)

def _get_save_average(self) -> float:
""" Return the average loss since the last save iteration and reset historical loss
Expand All @@ -215,7 +220,7 @@ def _get_save_average(self) -> float:
"""
logger.debug("Getting save averages")
if not self._history:
logger.info("No loss in history")
logger.debug("No loss in history")
retval = 0.0
else:
retval = sum(self._history) / len(self._history)
Expand Down Expand Up @@ -258,6 +263,56 @@ def _should_backup(self, save_average: float) -> bool:
logger.debug("Should backup: %s", backup)
return backup

def _maybe_backup(self) -> tuple[float, bool]:
""" Backup the model if total average loss has dropped for the save iteration
Returns
-------
float
The total loss average since the last save iteration
bool
``True`` if the model was backed up
"""
save_average = self._get_save_average()
should_backup = self._should_backup(save_average)
if not save_average or not should_backup:
logger.debug("Not backing up model (save_average: %s, should_backup: %s)",
save_average, should_backup)
return save_average, False

logger.debug("Backing up model")
self._backup.backup_model(self.filename)
self._backup.backup_model(self._plugin.state.filename)
return save_average, True

def save(self,
is_exit: bool = False,
force_save_optimizer: bool = False) -> None:
""" Backup and save the model and state file.
Parameters
----------
is_exit: bool, optional
``True`` if the save request has come from an exit process request otherwise ``False``.
Default: ``False``
force_save_optimizer: bool, optional
``True`` to force saving the optimizer weights with the model, otherwise ``False``.
Default:``False``
"""
logger.debug("Backing up and saving models")
print("\x1b[2K", end="\r") # Clear last line
logger.info("Saving Model...")

self._save_model(is_exit, force_save_optimizer)
save_average, backed_up = self._maybe_backup()

msg = "[Saved optimizer state for Snapshot]" if force_save_optimizer else "[Saved model]"
if save_average:
msg += f" - Average total loss since last save: {save_average:.5f}"
if backed_up:
msg += " [Model backed up]"
logger.info(msg)

def snapshot(self) -> None:
""" Perform a model snapshot.
Expand Down

0 comments on commit 3fbfef2

Please sign in to comment.