diff --git a/keras_core/callbacks/__init__.py b/keras_core/callbacks/__init__.py index dfee7abee..a13cbc927 100644 --- a/keras_core/callbacks/__init__.py +++ b/keras_core/callbacks/__init__.py @@ -1,3 +1,4 @@ +from keras_core.callbacks.backup_and_restore_callback import BackupAndRestore from keras_core.callbacks.callback import Callback from keras_core.callbacks.callback_list import CallbackList from keras_core.callbacks.csv_logger import CSVLogger diff --git a/keras_core/callbacks/backup_and_restore_callback.py b/keras_core/callbacks/backup_and_restore_callback.py new file mode 100644 index 000000000..fcba7d6c6 --- /dev/null +++ b/keras_core/callbacks/backup_and_restore_callback.py @@ -0,0 +1,216 @@ +import os +import warnings + +from keras_core.api_export import keras_core_export +from keras_core.callbacks.callback import Callback +from keras_core.utils import file_utils + + +@keras_core_export("keras_core.callbacks.BackupAndRestore") +class BackupAndRestore(Callback): + """Callback to back up and restore the training state. + + `BackupAndRestore` callback is intended to recover training from an + interruption that has happened in the middle of a `Model.fit` execution, by + backing up the training states in a temporary checkpoint file, at the end of + each epoch. Each backup overwrites the previously written checkpoint file, + so at any given time there is at most one such checkpoint file for + backup/restoring purpose. + + If training restarts before completion, the training state (which includes + the `Model` weights and epoch number) is restored to the most recently saved + state at the beginning of a new `Model.fit` run. At the completion of a + `Model.fit` run, the temporary checkpoint file is deleted. + + Note that the user is responsible to bring jobs back after the interruption. + This callback is important for the backup and restore mechanism for fault + tolerance purpose, and the model to be restored from a previous checkpoint + is expected to be the same as the one used to back up. If user changes + arguments passed to compile or fit, the checkpoint saved for fault tolerance + can become invalid. + + Example: + + >>> class InterruptingCallback(keras.callbacks.Callback): + ... def on_epoch_begin(self, epoch, logs=None): + ... if epoch == 4: + ... raise RuntimeError('Interrupting!') + >>> callback = keras.callbacks.BackupAndRestore(backup_dir="/tmp/backup") + >>> model = keras.models.Sequential([keras.layers.Dense(10)]) + >>> model.compile(keras.optimizers.SGD(), loss='mse') + >>> try: + ... model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10, + ... batch_size=1, callbacks=[callback, InterruptingCallback()], + ... verbose=0) + ... except: + ... pass + >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), + ... epochs=10, batch_size=1, callbacks=[callback], + ... verbose=0) + >>> # Only 6 more epochs are run, since first training got interrupted at + >>> # zero-indexed epoch 4, second training will continue from 4 to 9. + >>> len(history.history['loss']) + >>> 6 + + Args: + file_path: String, path to store the checkpoint. + e.g. `backup_dir = os.path.join(working_dir, "backup")`. + This is the directory in which the system stores temporary files to + recover the model from jobs terminated unexpectedly. The directory + cannot be reused elsewhere to store other files, e.g. by the + `BackupAndRestore` callback of another training run, + or by another callback + (e.g. `ModelCheckpoint`) of the same training. + save_freq: `"epoch"`, integer, or `False`. When set to `"epoch"` + the callback saves the checkpoint at the end of each epoch. + When set to an integer, the callback saves the checkpoint every + `save_freq` batches. Set `save_freq` to `False` if only using + preemption checkpointing (with `save_before_preemption=True`). + delete_checkpoint: Boolean, default to True. This `BackupAndRestore` + callback works by saving a checkpoint to back up the training state. + If `delete_checkpoint=True`, the checkpoint will be deleted after + training is finished. Use `False` if you'd like to keep the checkpoint + for future usage. + save_before_preemption: A boolean value instructing whether to turn on + the automatic checkpoint saving for preemption/maintenance events. + """ + + def __init__( + self, + file_path, + save_freq="epoch", + delete_checkpoint=True, + save_before_preemption=False, + ): + super().__init__() + self._current_epoch = 0 + self.save_freq = save_freq + self.delete_checkpoint = delete_checkpoint + self.save_before_preemption = save_before_preemption + self._batches_seen_since_last_saving = 0 + self._last_batch_seen = 0 + + if not file_path: + raise ValueError("Empty `backup_dir` argument passed") + self.file_path = file_path + + if not save_freq and not save_before_preemption: + raise ValueError( + "Either `save_freq` or `save_before_preemption` " "must be set." + ) + + if self.save_freq != "epoch" and not isinstance(self.save_freq, int): + raise ValueError( + f"Unrecognized save_freq: {self.save_freq}. " + "Expected save_freq are 'epoch' or integer values" + ) + if self.save_before_preemption: + warnings.warn("`save_before_preemption` not yet implemented") + + def on_train_begin(self, logs=None): + """ + Get training state from temporary file and restore it + """ + if self._check_checkpoints_exists(self.file_path): + self._model.load_weights(filepath=self.file_path) + + def on_train_end(self, logs=None): + if self.delete_checkpoint and self._check_checkpoints_exists( + self.file_path + ): + self._cleanup_checkpoint() + + def on_epoch_begin(self, epoch, logs=None): + self._current_epoch = epoch + + def on_epoch_end(self, epoch, logs=None): + if self.save_freq == "epoch": + self._save_model(epoch=epoch, batch=None, logs=logs) + + def on_train_batch_end(self, batch, logs=None): + if self._should_save_on_batch(batch): + self._save_model(epoch=self._current_epoch, batch=batch, logs=logs) + + def _save_model(self, epoch, batch, logs): + """Saves the model. + + Args: + epoch: the epoch this iteration is in. + batch: the batch this iteration is in. `None` if the `save_freq` + is set to `"epoch"`. + logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`. + """ + logs = logs or {} + + filepath = self._get_file_path(epoch, batch, logs) + # Create host directory if it doesn't exist. + dirname = os.path.dirname(filepath) + if dirname and not file_utils.exists(dirname): + file_utils.makedirs(dirname) + + try: + self._model.save_weights(filepath=filepath, overwrite=True) + except IsADirectoryError: # h5py 3.x + raise IOError( + "Please specify a non-directory filepath for " + "ModelCheckpoint. Filepath used is an existing " + f"directory: {filepath}" + ) + except IOError as e: # h5py 2.x + # `e.errno` appears to be `None` so checking the content of + # `e.args[0]`. + if "is a directory" in str(e.args[0]).lower(): + raise IOError( + "Please specify a non-directory filepath for " + "ModelCheckpoint. Filepath used is an existing " + f"directory: f{filepath}" + ) + # Re-throw the error for any other causes. + raise e + + def _get_file_path(self, epoch, batch, logs): + """Returns the file path for checkpoint.""" + + try: + # `filepath` may contain placeholders such as + # `{epoch:02d}`,`{batch:02d}` and `{mape:.2f}`. A mismatch between + # logged metrics and the path's placeholders can cause formatting to + # fail. + if batch is None or "batch" in logs: + file_path = self.file_path.format(epoch=epoch + 1, **logs) + else: + file_path = self.file_path.format( + epoch=epoch + 1, batch=batch + 1, **logs + ) + except KeyError as e: + raise KeyError( + f'Failed to format this callback filepath: "{self.file_path}". ' + f"Reason: {e}" + ) + return file_path + + def _should_save_on_batch(self, batch): + """Handles batch-level saving logic, supports steps_per_execution.""" + if self.save_freq == "epoch": + return False + if batch <= self._last_batch_seen: # New epoch. + add_batches = batch + 1 # batches are zero-indexed. + else: + add_batches = batch - self._last_batch_seen + self._batches_seen_since_last_saving += add_batches + self._last_batch_seen = batch + + if self._batches_seen_since_last_saving >= self.save_freq: + self._batches_seen_since_last_saving = 0 + return True + return False + + def _cleanup_checkpoint(self): + """ + Delete other checkpoint files (if present) in the directory + """ + if self._check_checkpoints_exists(filepath=self.file_path): + file_utils.rmtree(self.file_path) + + def _check_checkpoints_exists(self, filepath): + return file_utils.exists(filepath) diff --git a/keras_core/callbacks/backup_and_restore_callback_test.py b/keras_core/callbacks/backup_and_restore_callback_test.py new file mode 100644 index 000000000..e75e8a091 --- /dev/null +++ b/keras_core/callbacks/backup_and_restore_callback_test.py @@ -0,0 +1,224 @@ +import os + +import numpy as np +import pytest + +from keras_core import callbacks +from keras_core import layers +from keras_core import testing +from keras_core.models import Sequential +from keras_core.utils import file_utils + + +class InterruptingCallback(callbacks.Callback): + """A callback to intentionally introduce interruption to + training.""" + + def __init__(self, steps_int, epoch_int): + self.batch_count = 0 + self.epoch_count = 0 + self.steps_int = steps_int + self.epoch_int = epoch_int + + def on_epoch_end(self, epoch, log=None): + self.epoch_count += 1 + if self.epoch_int is not None and self.epoch_count == self.epoch_int: + raise RuntimeError("EpochInterruption") + + def on_batch_end(self, batch, logs=None): + self.batch_count += 1 + if self.steps_int is not None and self.batch_count == self.steps_int: + raise RuntimeError("StepsInterruption") + + +class BackupAndRestoreCallbackTest(testing.TestCase): + # Checking for invalid backup_dir + def test_empty_backup_dir(self): + with self.assertRaisesRegex( + ValueError, expected_regex="Empty " "`backup_dir`" + ): + callbacks.BackupAndRestore(file_path=None) + + # Checking save_freq and save_before_preemption both unset + def test_save_set_error(self): + with self.assertRaisesRegex( + ValueError, + expected_regex="`save_freq` or " + "`save_before_preemption` " + "" + "must be set", + ): + callbacks.BackupAndRestore( + file_path="backup_dir", + save_freq=None, + save_before_preemption=False, + ) + + # Check invalid save_freq, both string and non integer + def test_save_freq_unknown_error(self): + with self.assertRaisesRegex( + ValueError, expected_regex="Unrecognized save_freq" + ): + callbacks.BackupAndRestore( + file_path="backup_dir", save_freq="batch" + ) + + with self.assertRaisesRegex( + ValueError, expected_regex="Unrecognized save_freq" + ): + callbacks.BackupAndRestore(file_path="backup_dir", save_freq=0.15) + + # Checking if after interruption, correct model params and + # weights are loaded in step-wise backup + @pytest.mark.requires_trainable_backend + def test_best_case_step(self): + def make_model(): + np.random.seed(1337) + model = Sequential( + [ + layers.Dense(2, activation="relu"), + layers.Dense(1), + ] + ) + model.compile( + loss="mse", + optimizer="sgd", + metrics=["mse"], + ) + return model + + temp_dir = self.get_temp_dir() + filepath = os.path.join(temp_dir, "subdir", "checkpoint.weights.h5") + file_utils.rmtree(filepath) + self.assertFalse(os.path.exists(filepath)) + + model = make_model() + cbk = callbacks.BackupAndRestore(file_path=filepath, save_freq=1) + + x_train = np.random.random((10, 3)) + y_train = np.random.random((10, 1)) + + try: + model.fit( + x_train, + y_train, + batch_size=4, + callbacks=[ + cbk, + InterruptingCallback(steps_int=2, epoch_int=None), + ], + epochs=2, + verbose=0, + ) + except RuntimeError: + self.assertTrue(os.path.exists(filepath)) + self.assertEqual(cbk._current_epoch, 0) + self.assertEqual(cbk._last_batch_seen, 1) + + hist = model.fit( + x_train, y_train, batch_size=4, callbacks=[cbk], epochs=5 + ) + + self.assertEqual(cbk._current_epoch, 4) + self.assertEqual(hist.epoch[-1], 4) + + # Checking if after interruption, correct model params and + # weights are loaded in epoch-wise backup + @pytest.mark.requires_trainable_backend + def test_best_case_epoch(self): + def make_model(): + np.random.seed(1337) + model = Sequential( + [ + layers.Dense(2, activation="relu"), + layers.Dense(1), + ] + ) + model.compile( + loss="mse", + optimizer="sgd", + metrics=["mse"], + ) + return model + + temp_dir = self.get_temp_dir() + filepath = os.path.join(temp_dir, "subdir", "checkpoint.weights.h5") + file_utils.rmtree(filepath) + self.assertFalse(os.path.exists(filepath)) + + model = make_model() + cbk = callbacks.BackupAndRestore(file_path=filepath, save_freq="epoch") + + x_train = np.random.random((10, 3)) + y_train = np.random.random((10, 1)) + + try: + model.fit( + x_train, + y_train, + batch_size=4, + callbacks=[ + cbk, + InterruptingCallback(steps_int=None, epoch_int=2), + ], + epochs=6, + verbose=0, + ) + except RuntimeError: + self.assertEqual(cbk._current_epoch, 1) + self.assertTrue(os.path.exists(filepath)) + + hist = model.fit( + x_train, y_train, batch_size=4, callbacks=[cbk], epochs=5 + ) + self.assertEqual(cbk._current_epoch, 4) + self.assertEqual(hist.epoch[-1], 4) + + # Checking if after interruption, when model is deleted + @pytest.mark.requires_trainable_backend + def test_model_deleted_case_epoch(self): + def make_model(): + np.random.seed(1337) + model = Sequential( + [ + layers.Dense(2, activation="relu"), + layers.Dense(1), + ] + ) + model.compile( + loss="mse", + optimizer="sgd", + metrics=["mse"], + ) + return model + + temp_dir = self.get_temp_dir() + filepath = os.path.join(temp_dir, "subdir", "checkpoint.weights.h5") + file_utils.rmtree(filepath) + self.assertFalse(os.path.exists(filepath)) + + model = make_model() + cbk = callbacks.BackupAndRestore(file_path=filepath, save_freq="epoch") + + x_train = np.random.random((10, 3)) + y_train = np.random.random((10, 1)) + + try: + model.fit( + x_train, + y_train, + batch_size=4, + callbacks=[ + cbk, + InterruptingCallback(steps_int=None, epoch_int=2), + ], + epochs=6, + verbose=0, + ) + except RuntimeError: + self.assertEqual(cbk._current_epoch, 1) + self.assertTrue(os.path.exists(filepath)) + file_utils.rmtree(filepath) + + model.fit(x_train, y_train, batch_size=4, callbacks=[cbk], epochs=5) + self.assertEqual(cbk._current_epoch, 4)