Skip to content

Commit

Permalink
Merge pull request #14 from f-dangel/optimizer-None
Browse files Browse the repository at this point in the history
[ADD] Allow using `None` for `optimizer` argument
  • Loading branch information
scottclowe authored Sep 11, 2024
2 parents 0fb6333 + 658cc17 commit ad7e403
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions wandb_preempt/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self,
run_id: str,
model: Module,
optimizer: Optimizer,
optimizer: Union[Optimizer, None],
lr_scheduler: Optional[LRScheduler] = None,
scaler: Optional[GradScaler] = None,
savedir: str = "checkpoints",
Expand All @@ -46,7 +46,10 @@ def __init__(
Args:
run_id: A unique identifier for this run.
model: The model that is trained and checkpointed.
optimizer: The optimizer that is used for training and checkpointed.
optimizer: The optimizer that is used for training and should be
checkpointed. Use `None` to explicitly ignore the optimizer. This can
be useful if your optimizer does not implement `.state_dict` and
`.load_state_dict`.
lr_scheduler: The learning rate scheduler that is used for training. If
`None`, no learning rate scheduler is assumed. Default: `None`.
scaler: The gradient scaler that is used when training in mixed precision.
Expand Down Expand Up @@ -151,12 +154,13 @@ def save_checkpoint(self, extra_info: Dict) -> None:
}
data = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"rng_states": rng_states,
"checkpoint_step": self.step_count,
"resumes": self.num_resumes,
"extra_info": extra_info,
}
if self.optimizer is not None:
data["optimizer"] = self.optimizer.state_dict()
if self.lr_scheduler is not None:
data["lr_scheduler"] = self.lr_scheduler.state_dict()
if self.scaler is not None:
Expand Down Expand Up @@ -209,8 +213,9 @@ def load_latest_checkpoint(
data = load(loadpath, weights_only=weights_only, **kwargs)
self.maybe_print("Loading model.")
self.model.load_state_dict(data["model"])
self.maybe_print("Loading optimizer.")
self.optimizer.load_state_dict(data["optimizer"])
if self.optimizer is not None:
self.maybe_print("Loading optimizer.")
self.optimizer.load_state_dict(data["optimizer"])
if self.lr_scheduler is not None:
self.maybe_print("Loading lr scheduler.")
self.lr_scheduler.load_state_dict(data["lr_scheduler"])
Expand Down

0 comments on commit ad7e403

Please sign in to comment.