-
Notifications
You must be signed in to change notification settings - Fork 4
/
lr_scheduler.py
30 lines (26 loc) · 1.12 KB
/
lr_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from torch.optim.lr_scheduler import _LRScheduler
class PolyScheduler(_LRScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1):
self.base_lr = base_lr
self.warmup_lr_init = 0.0001
self.max_steps: int = max_steps
self.warmup_steps: int = warmup_steps
self.power = 2
super(PolyScheduler, self).__init__(optimizer, -1, False)
self.last_epoch = last_epoch
def get_warmup_lr(self):
alpha = float(self.last_epoch) / float(self.warmup_steps)
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
def get_lr(self):
if self.last_epoch == -1:
return [self.warmup_lr_init for _ in self.optimizer.param_groups]
if self.last_epoch < self.warmup_steps:
return self.get_warmup_lr()
else:
alpha = pow(
1
- float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps),
self.power,
)
return [self.base_lr * alpha for _ in self.optimizer.param_groups]