Skip to content

Commit

Permalink
[ADD] Support gradient scaler
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Aug 26, 2024
1 parent 7d1af10 commit 8570b80
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 10 deletions.
2 changes: 1 addition & 1 deletion example/launch.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
#SBATCH --partition=rtx6000
#SBATCH --partition=a40
#SBATCH --nodes=1
#SBATCH --tasks-per-node=1
#SBATCH --gres=gpu:1
Expand Down
22 changes: 15 additions & 7 deletions example/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from argparse import ArgumentParser

import wandb
from torch import cuda, device, manual_seed
from torch import autocast, bfloat16, cuda, device, manual_seed
from torch.cuda.amp import GradScaler
from torch.nn import Conv2d, CrossEntropyLoss, Flatten, Linear, ReLU, Sequential
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
Expand Down Expand Up @@ -59,6 +60,7 @@
print(f"Using SGD with learning rate {args.lr}.")
optimizer = SGD(model.parameters(), lr=args.lr)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=args.max_epochs)
scaler = GradScaler()

# NOTE: Set up a check-pointer which will load and save checkpoints.
# Pass the run ID to obtain unique file names for the checkpoints.
Expand All @@ -67,6 +69,7 @@
model,
optimizer,
lr_scheduler=lr_scheduler,
scaler=scaler,
savedir=SAVEDIR,
verbose=VERBOSE,
)
Expand All @@ -86,24 +89,29 @@
# normal training loop
for step, (inputs, target) in enumerate(train_loader):
optimizer.zero_grad()
loss = loss_func(model(inputs.to(DEV)), target.to(DEV))
loss.backward()

with autocast(device_type="cuda", dtype=bfloat16):
output = model(inputs.to(DEV))
loss = loss_func(output, target.to(DEV))

if step % LOGGING_INTERVAL == 0:
loss = loss.item()
print(f"Epoch {epoch}, Step {step}, Loss {loss:.5e}")
print(f"Epoch {epoch}, Step {step}, Loss {loss.item():.5e}")
# NOTE: Only call `wandb.log` inside `CheckpointAtEnd`.
# Otherwise, runs might contain duplicate logs.
wandb.log(
{
"global_step": epoch * STEPS_PER_EPOCH + step,
"loss": loss,
"loss": loss.item(),
"epoch": epoch + step / STEPS_PER_EPOCH,
"lr": optimizer.param_groups[0]["lr"],
"loss_scale": scaler.get_scale(),
}
)

optimizer.step() # update neural network parameters
scaler.scale(loss).backward()
scaler.step(optimizer) # update neural network parameters
scaler.update() # update the gradient scaler

lr_scheduler.step() # update learning rate

wandb.finish()
Expand Down
17 changes: 15 additions & 2 deletions wandb_preempt/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import wandb
from torch import cuda, device, get_rng_state, load, save, set_rng_state
from torch.cuda.amp import GradScaler
from torch.optim.lr_scheduler import LRScheduler
from wandb import Api

Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
model,
optimizer,
lr_scheduler: Optional[LRScheduler] = None,
scaler: Optional[GradScaler] = None,
savedir: str = "checkpoints",
verbose: bool = False,
) -> None:
Expand All @@ -81,6 +83,8 @@ def __init__(
optimizer: The optimizer that is used for training and checkpointed.
lr_scheduler: The learning rate scheduler that is used for training. If `None`,
no learning rate scheduler is assumed. Default: `None`.
scaler: The gradient scaler that is used when training in mixed precision. If
`None`, no gradient scaler is assumed. Default: `None`.
savedir: Directory to store checkpoints in. Default: `'checkpoints'`.
verbose: Whether to print messages about saving and loading checkpoints.
Default: `False`
Expand All @@ -95,6 +99,7 @@ def __init__(
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.scaler = scaler
self.verbose = verbose
self.marked_preempted = False

Expand Down Expand Up @@ -148,7 +153,8 @@ def checkpoint_path(self, epoch: int) -> str:
def save_checkpoint(self, epoch: int) -> None:
"""Save a checkpoint for a given epoch.
Stores optimizer, model, and random number generator states.
Stores optimizer, model, lr scheduler, gradient scaler, and random number
generator states.
Args:
epoch: The epoch number.
Expand All @@ -172,13 +178,17 @@ def save_checkpoint(self, epoch: int) -> None:
}
if self.lr_scheduler is not None:
data["lr_scheduler"] = self.lr_scheduler.state_dict()
if self.scaler is not None:
data["scaler"] = self.scaler.state_dict()

self.maybe_print(f"Saving checkpoint {savepath}.")
save(data, savepath)

def load_latest_checkpoint(self) -> int:
"""Load the latest checkpoint and set random number generator states.
Updates the model and optimizer states passed at initialization.
Updates the model, optimizer, lr scheduler, and gradient scaler states
passed at initialization.
Returns:
The epoch number at which training should resume.
Expand All @@ -198,6 +208,9 @@ def load_latest_checkpoint(self) -> int:
if self.lr_scheduler is not None:
self.maybe_print("Loading lr scheduler.")
self.lr_scheduler.load_state_dict(data["lr_scheduler"])
if self.scaler is not None:
self.maybe_print("Loading gradient scaler.")
self.scaler.load_state_dict(data["scaler"])

# restore random number generator states for all devices
self.maybe_print("Setting RNG states.")
Expand Down

0 comments on commit 8570b80

Please sign in to comment.