-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsched_lr.py
32 lines (32 loc) · 1.08 KB
/
sched_lr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def adjust_learning_rate(optimizer, epoch, schedule, total_epochs):
"""decrease the learning rate"""
decay_factor = 1
if schedule == 'cifar':
if epoch == int(0.75 * total_epochs):
decay_factor = 0.1
if epoch == int(0.9 * total_epochs):
decay_factor = 0.1
if epoch == total_epochs:
decay_factor = 0.1
elif schedule == 'cifar_long':
if epoch == 100:
decay_factor = 0.1
if epoch == 150:
decay_factor = 0.1
if epoch == total_epochs:
decay_factor = 0.1
elif schedule == 'cifar_swa':
if epoch == 50:
decay_factor = 0.1
if epoch == 150:
decay_factor = 0.1
elif schedule == 'svhn':
if epoch == 50:
decay_factor = 0.1
if epoch == 75:
decay_factor = 0.1
else:
raise ValueError('Unkown LR schedule %s' % schedule)
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * decay_factor
print(f'Update learning rate to: {param_group["lr"]}')