Skip to content

Commit

Permalink
fix(automl/tuners): save models at the end of the training
Browse files Browse the repository at this point in the history
  • Loading branch information
ruancomelli committed Jul 3, 2022
1 parent e13342f commit f40a847
Showing 1 changed file with 47 additions and 7 deletions.
54 changes: 47 additions & 7 deletions boiling_learning/automl/tuners.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,45 @@ def set_state(self, state: Dict[str, Any]) -> None:
self.goal = state['goal']


class _NoAutomaticSaveBestModel(ak.engine.tuner.AutoTuner):
class _SaveBestModelAtTrainingEnd(ak.engine.tuner.AutoTuner):
"""Only save models at the end of the training.
If early stopping was defined as a callback, replace ``KerasTuner``'s ``SaveBestEpoch`` with
a similar, custom implementation that only saves models at the end of the training.
"""

def _build_and_fit_model(
self, trial: kt.engine.trial.Trial, *fit_args: Any, **fit_kwargs: Any
) -> Dict[str, List[Any]]:
if 'callbacks' in fit_kwargs:
fit_kwargs['callbacks'] = [
callback
for callback in fit_kwargs['callbacks']
if not isinstance(callback, kt.engine.tuner_utils.SaveBestEpoch)
]
callbacks = fit_kwargs['callbacks']

if any(
isinstance(callback, tf.keras.callbacks.EarlyStopping) for callback in callbacks
):
index = next(
(
index
for index, callback in enumerate(callbacks)
if isinstance(callback, kt.engine.tuner_utils.SaveBestEpoch)
),
None,
)
if index is not None:
callbacks.pop(index)
callbacks.insert(
index,
SaveBestEpoch(filepath=self._get_checkpoint_fname(trial.trial_id)),
)
fit_kwargs['callbacks'] = callbacks

return typing.cast(
Dict[str, List[Any]],
super()._build_and_fit_model(trial, *fit_args, **fit_kwargs),
)


class _FixedMaxModelSizeGreedy(_NoAutomaticSaveBestModel):
class _FixedMaxModelSizeGreedy(_SaveBestModelAtTrainingEnd):
def on_trial_end(self, trial: kt.engine.trial.Trial) -> None:
# Send status to Logger
if self.logger:
Expand Down Expand Up @@ -170,3 +191,22 @@ def on_epoch_end(


_HUGE_NUMBER = 100000.0


class SaveBestEpoch(tf.keras.callbacks.Callback):
"""A Keras callback to save the model weights at the end of the training."""

def __init__(self, filepath: str) -> None:
super().__init__()
self.filepath = filepath

def on_train_end(self, logs: Optional[Dict[str, Any]] = None) -> None:
# Create temporary saved model files on non-chief workers.
write_filepath = kt.distribute.utils.write_filepath(
self.filepath, self.model.distribute_strategy
)
self.model.save_weights(write_filepath)
# Remove temporary saved model files on non-chief workers.
kt.distribute.utils.remove_temp_dir_with_filepath(
write_filepath, self.model.distribute_strategy
)

0 comments on commit f40a847

Please sign in to comment.