forked from Spijkervet/SimCLR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
40 lines (32 loc) · 1.4 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import torch
from modules import SimCLR, LARS
def load_optimizer(args, model):
scheduler = None
if args.optimizer == "Adam":
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) # TODO: LARS
elif args.optimizer == "LARS":
# optimized using LARS with linear learning rate scaling
# (i.e. LearningRate = 0.3 × BatchSize/256) and weight decay of 10−6.
learning_rate = 0.3 * args.batch_size / 256
optimizer = LARS(
model.parameters(),
lr=learning_rate,
weight_decay=args.weight_decay,
exclude_from_weight_decay=["batch_normalization", "bias"],
)
# "decay the learning rate with the cosine decay schedule without restarts"
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, args.epochs, eta_min=0, last_epoch=-1
)
else:
raise NotImplementedError
return optimizer, scheduler
def save_model(args, model, optimizer):
out = os.path.join(args.model_path, "checkpoint_{}.tar".format(args.current_epoch))
# To save a DataParallel model generically, save the model.module.state_dict().
# This way, you have the flexibility to load the model any way you want to any device you want.
if isinstance(model, torch.nn.DataParallel):
torch.save(model.module.state_dict(), out)
else:
torch.save(model.state_dict(), out)