Skip to content

Commit

Permalink
✨ Add a CLI arg to enable resuming training
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Aug 26, 2023
1 parent 13a6439 commit 9ed8c4c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tests/models/test_resnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_main(self):


class TestMIMOResnet:
"""Testing the ResNet mimo class."""
"""Testing the ResNet MIMO class."""

def test_main(self):
model = mimo_resnet34(1, 10, 2, style="cifar")
Expand Down
7 changes: 6 additions & 1 deletion torch_uncertainty/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def init_args(
action="store_true",
help="Use channels last memory format",
)
parser.add_argument(
"--enable_resume",
action="store_true",
help="Allow resuming the training (save optimizer's states)",
)

parser = pl.Trainer.add_argparse_args(parser)
if network is not None:
Expand Down Expand Up @@ -105,7 +110,7 @@ def cli_main(
monitor=monitor,
mode=mode,
save_last=True,
save_weights_only=True,
save_weights_only=not args.enable_resume,
)

# Select the best model, monitor the lr and stop if NaN
Expand Down

0 comments on commit 9ed8c4c

Please sign in to comment.