diff --git a/cyclops/models/wrappers/pt_model.py b/cyclops/models/wrappers/pt_model.py index 50b3a4302..1c54b948c 100644 --- a/cyclops/models/wrappers/pt_model.py +++ b/cyclops/models/wrappers/pt_model.py @@ -1309,7 +1309,7 @@ def save_model(self, filepath: str, overwrite: bool = True, **kwargs): if include_lr_scheduler: state_dict["lr_scheduler"] = self.lr_scheduler_.state_dict() # type: ignore[attr-defined] - epoch = kwargs.get("epoch", None) + epoch = kwargs.get("epoch") if epoch is not None: filename, extension = os.path.basename(filepath).split(".") filepath = join(