-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Arian Jamasb
committed
Feb 1, 2024
1 parent
5572403
commit 24d4d59
Showing
2 changed files
with
84 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
scheduler: | ||
_target_: proteinworkshop.utils.schedulers.InverseSquareRootScheduler | ||
_partial_: true | ||
warmup_steps: 1 | ||
last_epoch: -1 | ||
|
||
# The unit of the scheduler's step size, could also be 'step'. | ||
# 'epoch' updates the scheduler on epoch end whereas 'step' | ||
# updates it after a optimizer update. | ||
|
||
# It is recommended to call step() for LinearWarmupCosineAnnealingLR | ||
# after each iteration as calling it after each epoch will keep the starting | ||
# lr at warmup_start_lr for the first epoch which is 0 in most cases. | ||
interval: "step" | ||
|
||
# How many epochs/steps should pass between calls to | ||
# `scheduler.step()`. 1 corresponds to updating the learning | ||
# rate after every epoch/step. | ||
frequency: 1 | ||
|
||
# Metric to to monitor for schedulers like `ReduceLROnPlateau` | ||
monitor: "val/loss/total" | ||
|
||
# If set to `True`, will enforce that the value specified 'monitor' | ||
# is available when the scheduler is updated, thus stopping | ||
# training if not found. If set to `False`, it will only produce a warning | ||
strict: True | ||
|
||
# If using the `LearningRateMonitor` callback to monitor the | ||
# learning rate progress, this keyword can be used to specify | ||
# a custom logged name | ||
name: learning_rate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
"""Implement custom learning rate schedulers.""" | ||
|
||
import warnings | ||
|
||
from torch.optim.lr_scheduler import _LRScheduler | ||
from torch.optim.optimizer import Optimizer | ||
|
||
|
||
class InverseSquareRootLR(_LRScheduler): | ||
"""Implement the InverseSquareRootLR learning rate scheduler. | ||
:param optimizer: The optimizer. | ||
:type optimizer: Optimizer | ||
:param warmup_steps: The number of warmup steps. | ||
:type warmup_steps: int | ||
:param last_epoch: The index of the last epoch. If -1, the scheduler will | ||
start at the initial learning rate. | ||
:type last_epoch: int | ||
""" | ||
|
||
def __init__( | ||
self, optimizer: Optimizer, warmup_steps: int, last_epoch: int = -1 | ||
): | ||
if warmup_steps <= 0: | ||
raise ValueError("warmup_steps must be > 0") | ||
self._warmup_steps = warmup_steps | ||
self._lr_steps = [ | ||
param_group["lr"] / warmup_steps | ||
for param_group in optimizer.param_groups | ||
] | ||
self._decay_factors = [ | ||
param_group["lr"] * warmup_steps**0.5 | ||
for param_group in optimizer.param_groups | ||
] | ||
|
||
super().__init__(optimizer, last_epoch) | ||
|
||
def get_lr(self): | ||
if not self._get_lr_called_within_step: | ||
warnings.warn( | ||
"To get the last learning rate computed by the scheduler, " | ||
"please use `get_last_lr()`.", | ||
UserWarning, | ||
) | ||
|
||
if self.last_epoch < self._warmup_steps: | ||
return [self.last_epoch * lr_step for lr_step in self._lr_steps] | ||
else: | ||
return [ | ||
decay_factor * self.last_epoch**-0.5 | ||
for decay_factor in self._decay_factors | ||
] |