From 0feb5e0f8867332d3c55711e4390dd4f5b03fa18 Mon Sep 17 00:00:00 2001 From: Stephen Fleming Date: Fri, 19 Jun 2020 10:19:50 -0400 Subject: [PATCH] Remove cooldown from LR scheduler, as it is deprecated --- cellbender/remove_background/train.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/cellbender/remove_background/train.py b/cellbender/remove_background/train.py index 9f45b12..bec9b7c 100644 --- a/cellbender/remove_background/train.py +++ b/cellbender/remove_background/train.py @@ -102,7 +102,7 @@ def run_training(model: RemoveBackgroundPyroModel, train_loader: DataLoader, test_loader: DataLoader, epochs: int, - epoch_start_cooldown: int, + # epoch_start_cooldown: int, test_freq: int = 10) -> Tuple[List[float], List[float]]: """Run an entire course of training, evaluating on a tests set periodically. @@ -113,7 +113,7 @@ def run_training(model: RemoveBackgroundPyroModel, train_loader: Dataloader for training set. test_loader: Dataloader for tests set. epochs: Number of epochs to run training. - epoch_start_cooldown: Epoch at which cooldown starts. + # epoch_start_cooldown: Epoch at which cooldown starts. test_freq: Test set loss is calculated every test_freq epochs of training. @@ -292,11 +292,16 @@ def run_inference(dataset_obj: SingleCellRNACountsDataset, # Set up a learning rate scheduler. minibatches_per_epoch = int(np.ceil(len(train_loader) / train_loader.batch_size).item()) - epoch_start_cooldown = max(50, args.epochs - 10) # last 10 epochs (beyond 50) cool off + # epoch_start_cooldown = max(50, args.epochs - 10) # last 10 epochs (beyond 50) cool off + # scheduler_args = {'optimizer': optimizer, + # 'max_lr': args.learning_rate * 10, + # 'steps_per_epoch': minibatches_per_epoch, + # 'epochs': epoch_start_cooldown, + # 'optim_args': optimizer_args} scheduler_args = {'optimizer': optimizer, 'max_lr': args.learning_rate * 10, 'steps_per_epoch': minibatches_per_epoch, - 'epochs': epoch_start_cooldown, + 'epochs': args.epochs, 'optim_args': optimizer_args} scheduler = pyro.optim.OneCycleLR(scheduler_args) @@ -326,7 +331,7 @@ def run_inference(dataset_obj: SingleCellRNACountsDataset, # model.guide(train_loader.__next__()) # TODO: just examine initialization # with torch.autograd.set_detect_anomaly(True): # TODO: debug only!! doubles runtime! run_training(model, svi, train_loader, test_loader, - epochs=args.epochs, test_freq=5, - epoch_start_cooldown=epoch_start_cooldown) + epochs=args.epochs, test_freq=5)#, + # epoch_start_cooldown=epoch_start_cooldown) return model