Skip to content

Commit

Permalink
[ADD] Support lr scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Aug 26, 2024
1 parent 00f9f67 commit 6a7cf93
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
17 changes: 13 additions & 4 deletions example/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch import cuda, device, manual_seed
from torch.nn import Conv2d, CrossEntropyLoss, Flatten, Linear, ReLU, Sequential
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
Expand Down Expand Up @@ -57,16 +58,22 @@
loss_func = CrossEntropyLoss().to(DEV)
print(f"Using SGD with learning rate {args.lr}.")
optimizer = SGD(model.parameters(), lr=args.lr)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=args.max_epochs)

# NOTE: Set up a check-pointer which will load and save checkpoints.
# Pass the run ID to obtain unique file names for the checkpoints.
checkpoint_handler = CheckpointHandler(
run.id, model, optimizer, savedir=SAVEDIR, verbose=VERBOSE
run.id,
model,
optimizer,
lr_scheduler=lr_scheduler,
savedir=SAVEDIR,
verbose=VERBOSE,
)

# NOTE: If existing, load model and optimizer state from latest checkpoint, set
# random number generator states, and recover the epoch to start training from.
# Does nothing if there was no checkpoint.
# NOTE: If existing, load model, optimizer, and learning rate scheduler state from
# latest checkpoint, set random number generator states, and recover the epoch to start
# training from. Does nothing if there was no checkpoint.
start_epoch = checkpoint_handler.load_latest_checkpoint()

# training
Expand All @@ -92,10 +99,12 @@
"global_step": epoch * STEPS_PER_EPOCH + step,
"loss": loss,
"epoch": epoch + step / STEPS_PER_EPOCH,
"lr": optimizer.param_groups[0]["lr"],
}
)

optimizer.step() # update neural network parameters
lr_scheduler.step() # update learning rate

wandb.finish()
# NOTE Remove all created checkpoints once we are done training. If you want to
Expand Down
10 changes: 10 additions & 0 deletions wandb_preempt/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import wandb
from torch import cuda, device, get_rng_state, load, save, set_rng_state
from torch.optim.lr_scheduler import LRScheduler
from wandb import Api


Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
run_id: str,
model,
optimizer,
lr_scheduler: Optional[LRScheduler] = None,
savedir: str = "checkpoints",
verbose: bool = False,
) -> None:
Expand All @@ -77,6 +79,8 @@ def __init__(
run_id: A unique identifier for this run.
model: The model that is trained and checkpointed.
optimizer: The optimizer that is used for training and checkpointed.
lr_scheduler: The learning rate scheduler that is used for training. If `None`,
no learning rate scheduler is assumed. Default: `None`.
savedir: Directory to store checkpoints in. Default: `'checkpoints'`.
verbose: Whether to print messages about saving and loading checkpoints.
Default: `False`
Expand All @@ -90,6 +94,7 @@ def __init__(
self.run_id = run_id
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.verbose = verbose
self.marked_preempted = False

Expand Down Expand Up @@ -165,6 +170,8 @@ def save_checkpoint(self, epoch: int) -> None:
"rng_states": rng_states,
"epoch": epoch,
}
if self.lr_scheduler is not None:
data["lr_scheduler"] = self.lr_scheduler.state_dict()
self.maybe_print(f"Saving checkpoint {savepath}.")
save(data, savepath)

Expand All @@ -188,6 +195,9 @@ def load_latest_checkpoint(self) -> int:
self.model.load_state_dict(data["model"])
self.maybe_print("Loading optimizer.")
self.optimizer.load_state_dict(data["optimizer"])
if self.lr_scheduler is not None:
self.maybe_print("Loading lr scheduler.")
self.lr_scheduler.load_state_dict(data["lr_scheduler"])

# restore random number generator states for all devices
self.maybe_print("Setting RNG states.")
Expand Down

0 comments on commit 6a7cf93

Please sign in to comment.