From 15a8452aaafd1a32d63d4eace50272e8c403f412 Mon Sep 17 00:00:00 2001 From: Scott Lowe Date: Mon, 9 Sep 2024 15:37:18 -0400 Subject: [PATCH] [REF] Convert train script into a function (#2) * [REF] Change example train script to have fn and __main__ block * [FIX] Make train.py executable, add shebang * [ENH] Make hard-coded params adjustable in example script * [REF] Remove cli function * [REF] Reduce number of lines for argparsing * [API] Change max_epochs -> epochs to make it more concise * [STY] Make isort happy, even though black indifferent --- example/train.py | 212 +++++++++++++++++++++++++---------------------- 1 file changed, 115 insertions(+), 97 deletions(-) mode change 100644 => 100755 example/train.py diff --git a/example/train.py b/example/train.py old mode 100644 new mode 100755 index fbbb3fd..3522925 --- a/example/train.py +++ b/example/train.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python """Train a simple CNN on MNIST using checkpoints, integrated with Weights & Biases. The changes required to integrate checkpointing with wandb are tagged with 'NOTE'. @@ -17,100 +18,117 @@ from wandb_preempt.checkpointer import Checkpointer, get_resume_value -parser = ArgumentParser("Train a simple CNN on MNIST using SGD.") -parser.add_argument("--lr", type=float, default=0.01, help="SGD's learning rate.") -parser.add_argument( - "--max_epochs", - type=int, - default=10, - help="Number of epochs to train for.", -) -args = parser.parse_args() - -LOGGING_INTERVAL = 50 # print and log loss at this frequency -BATCH_SIZE = 256 -VERBOSE = True - -manual_seed(0) # make deterministic -DEV = device("cuda" if cuda.is_available() else "cpu") - -# NOTE: Define the directory where checkpoints are stored -SAVEDIR = "checkpoints" - -# NOTE: Figure out the `resume` value and pass it to wandb -run = wandb.init(resume=get_resume_value(verbose=VERBOSE)) - -# Set up the data, neural net, loss function, and optimizer -train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor()) -train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True) -model = Sequential( - Conv2d(1, 3, kernel_size=5, stride=2), - ReLU(), - Flatten(), - Linear(432, 50), - ReLU(), - Linear(50, 10), -).to(DEV) -loss_func = CrossEntropyLoss().to(DEV) -print(f"Using SGD with learning rate {args.lr}.") -optimizer = SGD(model.parameters(), lr=args.lr) -lr_scheduler = CosineAnnealingLR(optimizer, T_max=args.max_epochs) -scaler = GradScaler() - -# NOTE: Set up a check-pointer which will load and save checkpoints. -# Pass the run ID to obtain unique file names for the checkpoints. -checkpointer = Checkpointer( - run.id, - model, - optimizer, - lr_scheduler=lr_scheduler, - scaler=scaler, - savedir=SAVEDIR, - verbose=VERBOSE, -) - -# NOTE: If existing, load model, optimizer, and learning rate scheduler state from -# latest checkpoint, set random number generator states, and recover the epoch to start -# training from. Does nothing if there was no checkpoint. -start_epoch = checkpointer.load_latest_checkpoint() - -# training -for epoch in range(start_epoch, args.max_epochs): - model.train() - for step, (inputs, target) in enumerate(train_loader): - optimizer.zero_grad() - - with autocast(device_type="cuda", dtype=bfloat16): - output = model(inputs.to(DEV)) - loss = loss_func(output, target.to(DEV)) - - if step % LOGGING_INTERVAL == 0: - print(f"Epoch {epoch}, Step {step}, Loss {loss.item():.5e}") - wandb.log( - { - "loss": loss.item(), - "lr": optimizer.param_groups[0]["lr"], - "loss_scale": scaler.get_scale(), - "resumes": checkpointer.num_resumes, - } - ) - - scaler.scale(loss).backward() - scaler.step(optimizer) # update neural network parameters - scaler.update() # update the gradient scaler - - lr_scheduler.step() # update learning rate - - # NOTE Put validation code here - # eval(model, ...) - - # NOTE Call checkpointer.step() at the end of the epoch to save a checkpoint. - # If SLURM sent us a signal that our time for this job is running out, it will now - # also take care of pre-empting the wandb job and requeuing the SLURM job, killing - # the current python training script to resume with the requeued job. - checkpointer.step() - -wandb.finish() -# NOTE Remove all created checkpoints once we are done training. If you want to -# keep the trained model, remove this line. -checkpointer.remove_checkpoints() +LOGGING_INTERVAL = 50 # Num batches between logging to stdout and wandb +VERBOSE = True # Enable verbose output + + +def get_parser(): + r"""Create argument parser.""" + parser = ArgumentParser("Train a simple CNN on MNIST using SGD.") + parser.add_argument( + "--lr", type=float, default=0.01, help="Learning rate. Default: %(default)s" + ) + parser.add_argument( + "--epochs", type=int, default=10, help="Number of epochs. Default: %(default)s" + ) + parser.add_argument( + "--batch_size", type=int, default=256, help="Batch size. Default: %(default)s" + ) + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Checkpoint save dir." + ) + return parser + + +def main(args): + r"""Train model.""" + manual_seed(0) # make deterministic + DEV = device("cuda" if cuda.is_available() else "cpu") + + # NOTE: Figure out the `resume` value and pass it to wandb + run = wandb.init(resume=get_resume_value(verbose=VERBOSE)) + + # Set up the data, neural net, loss function, and optimizer + train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor()) + train_loader = DataLoader( + dataset=train_dataset, batch_size=args.batch_size, shuffle=True + ) + model = Sequential( + Conv2d(1, 3, kernel_size=5, stride=2), + ReLU(), + Flatten(), + Linear(432, 50), + ReLU(), + Linear(50, 10), + ).to(DEV) + loss_func = CrossEntropyLoss().to(DEV) + print(f"Using SGD with learning rate {args.lr}.") + optimizer = SGD(model.parameters(), lr=args.lr) + lr_scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs) + scaler = GradScaler() + + # NOTE: Set up a check-pointer which will load and save checkpoints. + # Pass the run ID to obtain unique file names for the checkpoints. + checkpointer = Checkpointer( + run.id, + model, + optimizer, + lr_scheduler=lr_scheduler, + scaler=scaler, + savedir=args.checkpoint_dir, + verbose=VERBOSE, + ) + + # NOTE: If existing, load model, optimizer, and learning rate scheduler state from + # latest checkpoint, set random number generator states, and recover the epoch to + # start training from. Does nothing if there was no checkpoint. + start_epoch = checkpointer.load_latest_checkpoint() + + # training + for epoch in range(start_epoch, args.epochs): + model.train() + for step, (inputs, target) in enumerate(train_loader): + optimizer.zero_grad() + + with autocast(device_type="cuda", dtype=bfloat16): + output = model(inputs.to(DEV)) + loss = loss_func(output, target.to(DEV)) + + if step % LOGGING_INTERVAL == 0: + print(f"Epoch {epoch}, Step {step}, Loss {loss.item():.5e}") + wandb.log( + { + "loss": loss.item(), + "lr": optimizer.param_groups[0]["lr"], + "loss_scale": scaler.get_scale(), + "resumes": checkpointer.num_resumes, + } + ) + + scaler.scale(loss).backward() + scaler.step(optimizer) # update neural network parameters + scaler.update() # update the gradient scaler + + lr_scheduler.step() # update learning rate + + # NOTE Put validation code here + # eval(model, ...) + + # NOTE Call checkpointer.step() at the end of the epoch to save a + # checkpoint. If SLURM sent us a signal that our time for this job is + # running out, it will now also take care of pre-empting the wandb job + # and requeuing the SLURM job, killing the current python training script + # to resume with the requeued job. + checkpointer.step() + + wandb.finish() + # NOTE Remove all created checkpoints once we are done training. If you want to + # keep the trained model, remove this line. + checkpointer.remove_checkpoints() + + +if __name__ == "__main__": + # Run as a script + parser = get_parser() + args = parser.parse_args() + main(args)