diff --git a/wandb_preempt/checkpointer.py b/wandb_preempt/checkpointer.py index 48e3934..f22d656 100644 --- a/wandb_preempt/checkpointer.py +++ b/wandb_preempt/checkpointer.py @@ -176,12 +176,23 @@ def save_checkpoint(self, extra_info: Dict) -> None: save(data, tmp_savepath) rename(tmp_savepath, savepath) - def load_latest_checkpoint(self) -> Tuple[int, Dict]: + def load_latest_checkpoint( + self, weights_only: bool = True, **kwargs + ) -> Tuple[int, Dict]: """Load the latest checkpoint and set random number generator states. Updates the model, optimizer, lr scheduler, and gradient scaler states passed at initialization. + Args: + weights_only: Whether to only unpickle objects that are safe to unpickle. + If `True`, the only types that will be loaded are tensors, primitive + types, dictionaries and types added via + `torch.serialization.add_safe_globals()`. + See `torch.load` for more information. + Default: `True`. + **kwargs: Additional keyword arguments to pass to the `torch.load` function. + Returns: The epoch number at which training should resume, and the extra information that was passed by the user as a dictionary to the :meth:`step` function. @@ -193,7 +204,7 @@ def load_latest_checkpoint(self) -> Tuple[int, Dict]: self.maybe_print(f"Loading checkpoint {loadpath}.") - data = load(loadpath) + data = load(loadpath, weights_only=weights_only, **kwargs) self.maybe_print("Loading model.") self.model.load_state_dict(data["model"]) self.maybe_print("Loading optimizer.")