Skip to content

Commit

Permalink
[API] Return loaded checkpoint index instead of next epoch
Browse files Browse the repository at this point in the history
or None if no checkpoint was found to load. The user must interpret
this to work out how much longer to train for to pick up from where
they left off.
  • Loading branch information
scottclowe committed Sep 11, 2024
1 parent 0fb6333 commit 153b8a3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
8 changes: 5 additions & 3 deletions example/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ def main(args):
)

# NOTE: If existing, load model, optimizer, and learning rate scheduler state from
# latest checkpoint, set random number generator states, and recover the epoch to
# start training from. Does nothing if there was no checkpoint.
start_epoch, _ = checkpointer.load_latest_checkpoint()
# latest checkpoint, set random number generator states. If there was no checkpoint
# to load, it does nothing and returns `None` for the step count.
last_epoch, _ = checkpointer.load_latest_checkpoint()
# Select the remaining epochs to train
start_epoch = 0 if last_epoch is None else last_epoch + 1

# training
for epoch in range(start_epoch, args.epochs):
Expand Down
15 changes: 10 additions & 5 deletions wandb_preempt/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def save_checkpoint(self, extra_info: Dict) -> None:

def load_latest_checkpoint(
self, weights_only: bool = True, **kwargs
) -> Tuple[int, Dict]:
) -> Tuple[Union[int, None], Dict]:
"""Load the latest checkpoint and set random number generator states.
Updates the model, optimizer, lr scheduler, and gradient scaler states
Expand All @@ -195,14 +195,16 @@ def load_latest_checkpoint(
**kwargs: Additional keyword arguments to pass to the `torch.load` function.
Returns:
epoch: The epoch number at which training should resume.
loaded_step: The index of the checkpoint that was loaded, or `None` if no
checkpoint was found.
extra_info: Extra information that was passed by the user to the `step`
function.
function when the checkpoint was saved, or an empty dictionary if there
is no extra information.
"""
loadpath = self.latest_checkpoint()
if loadpath is None:
self.maybe_print("No checkpoint found. Starting from scratch.")
return 0, {}
return None, {}

self.maybe_print(f"Loading checkpoint {loadpath}.")

Expand All @@ -229,7 +231,10 @@ def load_latest_checkpoint(
else:
set_rng_state(rng_state)

return self.step_count, data["extra_info"]
# N.B. We returns the checkpoint step index of the saved file that was loaded,
# but the checkpointer.step_count is one larger than that because we increment it
# after saving - it tracks the index of the next checkpoint to be saved.
return data["checkpoint_step"], data["extra_info"]

def remove_checkpoints(self, keep_latest: bool = False):
"""Remove checkpoints.
Expand Down

0 comments on commit 153b8a3

Please sign in to comment.