-
Notifications
You must be signed in to change notification settings - Fork 0
/
scheduler.py
146 lines (135 loc) · 5.16 KB
/
scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import math
import warnings
from typing import List
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
class WarmupCosineLR(_LRScheduler):
"""
Sets the learning rate of each parameter group to follow a linear warmup schedule
between warmup_start_lr and base_lr followed by a cosine annealing schedule between
base_lr and eta_min.
.. warning::
It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
after each iteration as calling it after each epoch will keep the starting lr at
warmup_start_lr for the first epoch which is 0 in most cases.
.. warning::
passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
:func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
train and validation methods.
Args:
optimizer (Optimizer): Wrapped optimizer.
warmup_epochs (int): Maximum number of iterations for linear warmup
max_epochs (int): Maximum number of iterations
warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
eta_min (float): Minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
Example:
>>> layer = nn.Linear(10, 1)
>>> optimizer = Adam(layer.parameters(), lr=0.02)
>>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40)
>>> #
>>> # the default case
>>> for epoch in range(40):
... # train(...)
... # validate(...)
... scheduler.step()
>>> #
>>> # passing epoch param case
>>> for epoch in range(40):
... scheduler.step(epoch)
... # train(...)
... # validate(...)
"""
def __init__(
self,
optimizer: Optimizer,
warmup_epochs: int,
max_epochs: int,
warmup_start_lr: float = 1e-8,
eta_min: float = 1e-8,
last_epoch: int = -1,
) -> None:
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs
self.warmup_start_lr = warmup_start_lr
self.eta_min = eta_min
super(WarmupCosineLR, self).__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
"""
Compute learning rate using chainable form of the scheduler
"""
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.",
UserWarning,
)
if self.last_epoch == 0:
return [self.warmup_start_lr] * len(self.base_lrs)
elif self.last_epoch < self.warmup_epochs:
return [
group["lr"]
+ (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
elif self.last_epoch == self.warmup_epochs:
return self.base_lrs
elif (self.last_epoch - 1 - self.max_epochs) % (
2 * (self.max_epochs - self.warmup_epochs)
) == 0:
return [
group["lr"]
+ (base_lr - self.eta_min)
* (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs)))
/ 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
return [
(
1
+ math.cos(
math.pi
* (self.last_epoch - self.warmup_epochs)
/ (self.max_epochs - self.warmup_epochs)
)
)
/ (
1
+ math.cos(
math.pi
* (self.last_epoch - self.warmup_epochs - 1)
/ (self.max_epochs - self.warmup_epochs)
)
)
* (group["lr"] - self.eta_min)
+ self.eta_min
for group in self.optimizer.param_groups
]
def _get_closed_form_lr(self) -> List[float]:
"""
Called when epoch is passed as a param to the `step` function of the scheduler.
"""
if self.last_epoch < self.warmup_epochs:
return [
self.warmup_start_lr
+ self.last_epoch
* (base_lr - self.warmup_start_lr)
/ (self.warmup_epochs - 1)
for base_lr in self.base_lrs
]
return [
self.eta_min
+ 0.5
* (base_lr - self.eta_min)
* (
1
+ math.cos(
math.pi
* (self.last_epoch - self.warmup_epochs)
/ (self.max_epochs - self.warmup_epochs)
)
)
for base_lr in self.base_lrs
]