Skip to content

Commit

Permalink
Regression in LR schedulers with metric tracking
Browse files Browse the repository at this point in the history
Fixes #635
  • Loading branch information
mittagessen committed Aug 16, 2024
1 parent b96d0cd commit 6fa8f79
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion kraken/lib/pretrain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,10 @@ def lr_scheduler_step(self, scheduler, metric):
scheduler.step()
# step every other scheduler epoch-wise
elif self.trainer.is_last_batch:
scheduler.step()
if metric is None:
scheduler.step()
else:
scheduler.step(metric)

def setup(self, stage: Optional[str] = None):
# finalize models in case of appending/loading
Expand Down
5 changes: 4 additions & 1 deletion kraken/lib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,10 @@ def lr_scheduler_step(self, scheduler, metric):
scheduler.step()
# step every other scheduler epoch-wise
elif self.trainer.is_last_batch:
scheduler.step()
if metric is None:
scheduler.step()
else:
scheduler.step(metric)


def _configure_optimizer_and_lr_scheduler(hparams, params, len_train_set=None, loss_tracking_mode='max'):
Expand Down

0 comments on commit 6fa8f79

Please sign in to comment.