Skip to content

Commit

Permalink
[API] Add passthrough kwargs for torch.load (#13)
Browse files Browse the repository at this point in the history
* [ENH] Add handling of passthrough kwargs to torch.load

Support passing **kwargs from load_latest_checkpoint to torch.load.

* [API] Change to weights_only=True by default to torch.load

This prevents unpickling dangerous contents from the checkpoint.

* [DOC] Improve weights_only documentation
  • Loading branch information
scottclowe authored Sep 11, 2024
1 parent 959edd3 commit 9b2abfb
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions wandb_preempt/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,23 @@ def save_checkpoint(self, extra_info: Dict) -> None:
save(data, tmp_savepath)
rename(tmp_savepath, savepath)

def load_latest_checkpoint(self) -> Tuple[int, Dict]:
def load_latest_checkpoint(
self, weights_only: bool = True, **kwargs
) -> Tuple[int, Dict]:
"""Load the latest checkpoint and set random number generator states.
Updates the model, optimizer, lr scheduler, and gradient scaler states
passed at initialization.
Args:
weights_only: Whether to only unpickle objects that are safe to unpickle.
If `True`, the only types that will be loaded are tensors, primitive
types, dictionaries and types added via
`torch.serialization.add_safe_globals()`.
See `torch.load` for more information.
Default: `True`.
**kwargs: Additional keyword arguments to pass to the `torch.load` function.
Returns:
The epoch number at which training should resume, and the extra information
that was passed by the user as a dictionary to the :meth:`step` function.
Expand All @@ -193,7 +204,7 @@ def load_latest_checkpoint(self) -> Tuple[int, Dict]:

self.maybe_print(f"Loading checkpoint {loadpath}.")

data = load(loadpath)
data = load(loadpath, weights_only=weights_only, **kwargs)
self.maybe_print("Loading model.")
self.model.load_state_dict(data["model"])
self.maybe_print("Loading optimizer.")
Expand Down

0 comments on commit 9b2abfb

Please sign in to comment.