-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
64 lines (47 loc) · 2.46 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
import torch.optim as optim
from config import config, preprocess
from models import SimSiam_CIFAR
from train import _train_SimSiam, _train_classifier
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
args = config()
def train_SimSiam():
trainloader,bankloader,queryloader,_ = preprocess(args)
model = SimSiam_CIFAR()
optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=args.num_epochs)
_train_SimSiam(device=device,trainloader=trainloader,bankloader=bankloader,queryloader=queryloader,
model=model,optimizer=optimizer,scheduler=scheduler,num_epochs=args.num_epochs,
base_dir=args.save_dir,best_acc=args.save_acc)
def train_classifier(weight_path, classifier_type='lin_small'):
"""
:param classifier_type:
- 'lin_small' >> Default, simplest linear classifier
- 'lin_large' >> large linear classifier without nonlinearity
- 'nonlin_large' >> large linear classifier with nonlinearity; solely for performance
"""
_,_,queryloader,classifier_trainloader = preprocess(args)
model = SimSiam_CIFAR()
chkpt = torch.load(weight_path)
model.load_state_dict(chkpt['model_state_dict'])
if classifier_type == 'lin_small':
classifier = nn.Linear(2048,10)
elif classifier_type == 'lin_large':
classifier = nn.Sequential(nn.Linear(2048,2048),
nn.Linear(2048,10))
elif classifier_type == 'nonlin_large':
classifier = nn.Sequential(nn.Linear(2048,2048),
nn.BatchNorm1d(2048),
nn.ReLU(),
nn.Linear(2048,10))
optimizer = optim.SGD(classifier.parameters(),lr=0.1,momentum=0.9,weight_decay=5e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='max',factor=0.1,patience=10)
criterion = nn.CrossEntropyLoss()
_train_classifier(device=device,criterion=criterion,classifier_trainloader=classifier_trainloader,
queryloader=queryloader,model=model,classifier=classifier,optimizer=optimizer,scheduler=scheduler,
num_epochs=200,base_dir=args.save_dir,best_acc=args.save_acc)
# SimSiam module training
train_SimSiam()
# classifier training
train_classifier('YOUR WEIGHT DIRECTORY', classifier_type='lin_small')