diff --git a/wandb_preempt/checkpointer.py b/wandb_preempt/checkpointer.py index 103e41a..14eae69 100644 --- a/wandb_preempt/checkpointer.py +++ b/wandb_preempt/checkpointer.py @@ -35,7 +35,7 @@ def __init__( self, run_id: str, model: Module, - optimizer: Optimizer, + optimizer: Union[Optimizer, None], lr_scheduler: Optional[LRScheduler] = None, scaler: Optional[GradScaler] = None, savedir: str = "checkpoints", @@ -46,7 +46,10 @@ def __init__( Args: run_id: A unique identifier for this run. model: The model that is trained and checkpointed. - optimizer: The optimizer that is used for training and checkpointed. + optimizer: The optimizer that is used for training and should be + checkpointed. Use `None` to explicitly ignore the optimizer. This can + be useful if your optimizer does not implement `.state_dict` and + `.load_state_dict`. lr_scheduler: The learning rate scheduler that is used for training. If `None`, no learning rate scheduler is assumed. Default: `None`. scaler: The gradient scaler that is used when training in mixed precision. @@ -151,12 +154,13 @@ def save_checkpoint(self, extra_info: Dict) -> None: } data = { "model": self.model.state_dict(), - "optimizer": self.optimizer.state_dict(), "rng_states": rng_states, "checkpoint_step": self.step_count, "resumes": self.num_resumes, "extra_info": extra_info, } + if self.optimizer is not None: + data["optimizer"] = self.optimizer.state_dict() if self.lr_scheduler is not None: data["lr_scheduler"] = self.lr_scheduler.state_dict() if self.scaler is not None: @@ -209,8 +213,9 @@ def load_latest_checkpoint( data = load(loadpath, weights_only=weights_only, **kwargs) self.maybe_print("Loading model.") self.model.load_state_dict(data["model"]) - self.maybe_print("Loading optimizer.") - self.optimizer.load_state_dict(data["optimizer"]) + if self.optimizer is not None: + self.maybe_print("Loading optimizer.") + self.optimizer.load_state_dict(data["optimizer"]) if self.lr_scheduler is not None: self.maybe_print("Loading lr scheduler.") self.lr_scheduler.load_state_dict(data["lr_scheduler"])