Skip to content

Commit

Permalink
add inversesquareroot scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
Arian Jamasb committed Feb 1, 2024
1 parent 5572403 commit 24d4d59
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
32 changes: 32 additions & 0 deletions proteinworkshop/config/scheduler/inverse_square_root.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
scheduler:
_target_: proteinworkshop.utils.schedulers.InverseSquareRootScheduler
_partial_: true
warmup_steps: 1
last_epoch: -1

# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.

# It is recommended to call step() for LinearWarmupCosineAnnealingLR
# after each iteration as calling it after each epoch will keep the starting
# lr at warmup_start_lr for the first epoch which is 0 in most cases.
interval: "step"

# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
frequency: 1

# Metric to to monitor for schedulers like `ReduceLROnPlateau`
monitor: "val/loss/total"

# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
strict: True

# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
name: learning_rate
52 changes: 52 additions & 0 deletions proteinworkshop/utils/schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Implement custom learning rate schedulers."""

import warnings

from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer


class InverseSquareRootLR(_LRScheduler):
"""Implement the InverseSquareRootLR learning rate scheduler.
:param optimizer: The optimizer.
:type optimizer: Optimizer
:param warmup_steps: The number of warmup steps.
:type warmup_steps: int
:param last_epoch: The index of the last epoch. If -1, the scheduler will
start at the initial learning rate.
:type last_epoch: int
"""

def __init__(
self, optimizer: Optimizer, warmup_steps: int, last_epoch: int = -1
):
if warmup_steps <= 0:
raise ValueError("warmup_steps must be > 0")
self._warmup_steps = warmup_steps
self._lr_steps = [
param_group["lr"] / warmup_steps
for param_group in optimizer.param_groups
]
self._decay_factors = [
param_group["lr"] * warmup_steps**0.5
for param_group in optimizer.param_groups
]

super().__init__(optimizer, last_epoch)

def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.",
UserWarning,
)

if self.last_epoch < self._warmup_steps:
return [self.last_epoch * lr_step for lr_step in self._lr_steps]
else:
return [
decay_factor * self.last_epoch**-0.5
for decay_factor in self._decay_factors
]

0 comments on commit 24d4d59

Please sign in to comment.