-
Notifications
You must be signed in to change notification settings - Fork 5
/
opt.py
59 lines (52 loc) · 1.94 KB
/
opt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# coding=utf-8
import torch
def get_params(alg, args, alg_name, inner=False, alias=True):
if args.schuse:
if args.schusech == 'cos':
initlr = args.lr
else:
initlr = 1.0
else:
if inner:
initlr = args.inner_lr
else:
initlr = args.lr
if inner:
params = [
{'params': alg[0].parameters(), 'lr': args.lr_decay1 *
initlr},
{'params': alg[1].parameters(), 'lr': args.lr_decay2 *
initlr}
]
elif alias:
params = [
{'params': alg.featurizer.parameters(), 'lr': args.lr_decay1 * initlr},
{'params': alg.classifier.parameters(), 'lr': args.lr_decay2 * initlr}
]
else:
params = [
{'params': alg[0].parameters(), 'lr': args.lr_decay1 * initlr},
{'params': alg[1].parameters(), 'lr': args.lr_decay2 * initlr}
]
if ('DANN' in alg_name) or ('CDANN' in alg_name):
params.append({'params': alg.discriminator.parameters(),
'lr': args.lr_decay2 * initlr})
if ('CDANN' in alg_name):
params.append({'params': alg.class_embeddings.parameters(),
'lr': args.lr_decay2 * initlr})
return params
def get_optimizer(alg, args, inner=False, alias=True):
params = get_params(alg, args, args.DGalgorithm, inner, alias)
optimizer = torch.optim.SGD(
params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
return optimizer
def get_scheduler(optimizer, args):
if not args.schuse:
return None
if args.schusech == 'cos':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, args.max_epoch * args.steps_per_epoch)
else:
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
return scheduler