Skip to content

Commit

Permalink
Remove cooldown from LR scheduler, as it is deprecated
Browse files Browse the repository at this point in the history
  • Loading branch information
sjfleming committed Jun 19, 2020
1 parent 41d87e3 commit 0feb5e0
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions cellbender/remove_background/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def run_training(model: RemoveBackgroundPyroModel,
train_loader: DataLoader,
test_loader: DataLoader,
epochs: int,
epoch_start_cooldown: int,
# epoch_start_cooldown: int,
test_freq: int = 10) -> Tuple[List[float],
List[float]]:
"""Run an entire course of training, evaluating on a tests set periodically.
Expand All @@ -113,7 +113,7 @@ def run_training(model: RemoveBackgroundPyroModel,
train_loader: Dataloader for training set.
test_loader: Dataloader for tests set.
epochs: Number of epochs to run training.
epoch_start_cooldown: Epoch at which cooldown starts.
# epoch_start_cooldown: Epoch at which cooldown starts.
test_freq: Test set loss is calculated every test_freq epochs of
training.
Expand Down Expand Up @@ -292,11 +292,16 @@ def run_inference(dataset_obj: SingleCellRNACountsDataset,

# Set up a learning rate scheduler.
minibatches_per_epoch = int(np.ceil(len(train_loader) / train_loader.batch_size).item())
epoch_start_cooldown = max(50, args.epochs - 10) # last 10 epochs (beyond 50) cool off
# epoch_start_cooldown = max(50, args.epochs - 10) # last 10 epochs (beyond 50) cool off
# scheduler_args = {'optimizer': optimizer,
# 'max_lr': args.learning_rate * 10,
# 'steps_per_epoch': minibatches_per_epoch,
# 'epochs': epoch_start_cooldown,
# 'optim_args': optimizer_args}
scheduler_args = {'optimizer': optimizer,
'max_lr': args.learning_rate * 10,
'steps_per_epoch': minibatches_per_epoch,
'epochs': epoch_start_cooldown,
'epochs': args.epochs,
'optim_args': optimizer_args}
scheduler = pyro.optim.OneCycleLR(scheduler_args)

Expand Down Expand Up @@ -326,7 +331,7 @@ def run_inference(dataset_obj: SingleCellRNACountsDataset,
# model.guide(train_loader.__next__()) # TODO: just examine initialization
# with torch.autograd.set_detect_anomaly(True): # TODO: debug only!! doubles runtime!
run_training(model, svi, train_loader, test_loader,
epochs=args.epochs, test_freq=5,
epoch_start_cooldown=epoch_start_cooldown)
epochs=args.epochs, test_freq=5)#,
# epoch_start_cooldown=epoch_start_cooldown)

return model

0 comments on commit 0feb5e0

Please sign in to comment.