diff --git a/plugins/train/model/_base/io.py b/plugins/train/model/_base/io.py index 0cce82d93c..1ce357eced 100644 --- a/plugins/train/model/_base/io.py +++ b/plugins/train/model/_base/io.py @@ -23,6 +23,7 @@ if T.TYPE_CHECKING: from .model import ModelBase + from keras.optimizers import Optimizer logger = logging.getLogger(__name__) # pylint:disable=invalid-name @@ -168,42 +169,46 @@ def load(self) -> kmodels.Model: logger.info("Loaded model from disk: '%s'", self.filename) return model - def save(self, - is_exit: bool = False, - force_save_optimizer: bool = False) -> None: - """ Backup and save the model and state file. + def _remove_optimizer(self) -> Optimizer: + """ Keras 3 `.keras` format ignores the `save_optimizer` kwarg. To hack around this we + remove the optimizer from the model prior to saving and then re-attach it to the model + + Returns + ------- + :class:`keras.optimizers.Optimizer` | None + The optimizer for the model, if it should not be saved. ``None`` if it should be saved + """ + retval = self._plugin.model.optimizer + del self._plugin.model.optimizer + logger.debug("Removed optimizer for saving: %s", retval) + return retval + + def _save_model(self, is_exit: bool, force_save_optimizer: bool) -> None: + """ Save the model either with or without the optimizer weights + + Keras 3 ignores 'save_optimizer` so if it should not be saved, we remove it from + the model for saving, then re-attach it Parameters ---------- - is_exit: bool, optional + is_exit: bool ``True`` if the save request has come from an exit process request otherwise ``False``. - Default: ``False`` - force_save_optimizer: bool, optional + force_save_optimizer: bool ``True`` to force saving the optimizer weights with the model, otherwise ``False``. - Default:``False`` """ - logger.debug("Backing up and saving models") include_optimizer = (force_save_optimizer or self._save_optimizer == "always" or (self._save_optimizer == "exit" and is_exit)) - print("\x1b[2K", end="\r") # Clear last line - logger.info("Saving Model...") - self._plugin.model.save(self.filename, include_optimizer=include_optimizer) + if not include_optimizer: + optimizer = self._remove_optimizer() + + self._plugin.model.save(self.filename) self._plugin.state.save() - save_average = self._get_save_average() - should_backup = self._should_backup(save_average) - if save_average and should_backup: - self._backup.backup_model(self.filename) - self._backup.backup_model(self._plugin.state.filename) - - msg = "[Saved optimizer state for Snapshot]" if force_save_optimizer else "[Saved model]" - if save_average: - msg += f" - Average total loss since last save: {save_average:.5f}" - if should_backup: - msg += " [Model backed up]" - logger.info(msg) + if not include_optimizer: + logger.debug("Re-attaching optimizer: %s", optimizer) + setattr(self._plugin.model, "optimizer", optimizer) def _get_save_average(self) -> float: """ Return the average loss since the last save iteration and reset historical loss @@ -215,7 +220,7 @@ def _get_save_average(self) -> float: """ logger.debug("Getting save averages") if not self._history: - logger.info("No loss in history") + logger.debug("No loss in history") retval = 0.0 else: retval = sum(self._history) / len(self._history) @@ -258,6 +263,56 @@ def _should_backup(self, save_average: float) -> bool: logger.debug("Should backup: %s", backup) return backup + def _maybe_backup(self) -> tuple[float, bool]: + """ Backup the model if total average loss has dropped for the save iteration + + Returns + ------- + float + The total loss average since the last save iteration + bool + ``True`` if the model was backed up + """ + save_average = self._get_save_average() + should_backup = self._should_backup(save_average) + if not save_average or not should_backup: + logger.debug("Not backing up model (save_average: %s, should_backup: %s)", + save_average, should_backup) + return save_average, False + + logger.debug("Backing up model") + self._backup.backup_model(self.filename) + self._backup.backup_model(self._plugin.state.filename) + return save_average, True + + def save(self, + is_exit: bool = False, + force_save_optimizer: bool = False) -> None: + """ Backup and save the model and state file. + + Parameters + ---------- + is_exit: bool, optional + ``True`` if the save request has come from an exit process request otherwise ``False``. + Default: ``False`` + force_save_optimizer: bool, optional + ``True`` to force saving the optimizer weights with the model, otherwise ``False``. + Default:``False`` + """ + logger.debug("Backing up and saving models") + print("\x1b[2K", end="\r") # Clear last line + logger.info("Saving Model...") + + self._save_model(is_exit, force_save_optimizer) + save_average, backed_up = self._maybe_backup() + + msg = "[Saved optimizer state for Snapshot]" if force_save_optimizer else "[Saved model]" + if save_average: + msg += f" - Average total loss since last save: {save_average:.5f}" + if backed_up: + msg += " [Model backed up]" + logger.info(msg) + def snapshot(self) -> None: """ Perform a model snapshot.