-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_stage2.py
122 lines (91 loc) · 4.58 KB
/
main_stage2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from __future__ import print_function
import os, time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data as data
import torch.optim as optim
from torch.optim import lr_scheduler
from datasets.Cifar100LT import get_cifar100
from models.resnet import *
from train.train import train_base
from train.validate import valid_base
from utils.config import *
from utils.common import hms_string
from utils.logger import logger
import copy
args = parse_args()
reproducibility(args.seed)
args.logger = logger(args)
best_acc = 0 # best test accuracy
many_best, med_best, few_best = 0, 0, 0
best_model = None
def train_stage2(args, model, trainloader, testloader, N_SAMPLES_PER_CLASS):
global best_acc, many_best, med_best, few_best, best_model
train_criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smooth)
test_criterion = nn.CrossEntropyLoss() # For test, validation
optimizer = optim.SGD(model.fc.parameters(), lr=args.finetune_lr, momentum=0.9, weight_decay=args.finetune_wd)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, args.finetune_epoch, eta_min=0.0)
best_model = None
test_accs = []
start_time = time.time()
for epoch in range(args.finetune_epoch):
train_loss, train_acc = train_base(trainloader, model, optimizer, train_criterion)
test_loss, test_acc, test_cls = valid_base(testloader, model, test_criterion, N_SAMPLES_PER_CLASS,
num_class=args.num_class, mode='test Valid')
lr = scheduler.get_last_lr()[0]
scheduler.step()
if best_acc <= test_acc:
best_acc = test_acc
many_best = test_cls[0]
med_best = test_cls[1]
few_best = test_cls[2]
best_model = copy.deepcopy(model)
test_accs.append(test_acc)
args.logger(f'Epoch: [{epoch + 1} | {args.finetune_epoch}]', level=1)
args.logger(f'[Train]\tLoss:\t{train_loss:.4f}\tAcc:\t{train_acc:}', level=2)
args.logger(f'[Test ]\tLoss:\t{test_loss:.4f}\tAcc:\t{test_acc:.4f}', level=2)
args.logger(f'[Stats]\tMany:\t{test_cls[0]:.4f}\tMedium:\t{test_cls[1]:.4f}\tFew:\t{test_cls[2]:.4f}', level=2)
args.logger(
f'[Best ]\tAcc:\t{np.max(test_accs):.4f}\tMany:\t{100 * many_best:.4f}\tMedium:\t{100 * med_best:.4f}\tFew:\t{100 * few_best:.4f}',
level=2)
args.logger(f'[Param]\tLR:\t{lr:.8f}', level=2)
end_time = time.time()
file_name = os.path.join(args.out, 'best_model_stage2.pth')
torch.save(best_model, file_name)
# Print the final results
args.logger(f'Finish Training Stage 1...', level=1)
args.logger(f'Final performance...', level=1)
args.logger(f'best bAcc (test):\t{np.max(test_accs)}', level=2)
args.logger(f'best statistics:\tMany:\t{many_best}\tMed:\t{med_best}\tFew:\t{few_best}', level=2)
args.logger(f'Training Time: {hms_string(end_time - start_time)}', level=1)
def load_model(args, model, testloader, N_SAMPLES_PER_CLASS):
if args.pretrained_pth is not None:
pth = args.pretrained_pth
else:
pth = f'pretrained/cifar100/IR={args.imb_ratio}/best_model_stage1.pt'
state_dict = torch.load(pth)
model.load_state_dict(state_dict)
# model = torch.load(pth)
test_criterion = nn.CrossEntropyLoss() # For test, validation
test_loss, test_acc, test_cls = valid_base(testloader, model, test_criterion, N_SAMPLES_PER_CLASS,
num_class=args.num_class, mode='test Valid')
args.logger(f'Loaded performance...', level=1)
args.logger(f'[Test ]\t Acc:\t{test_acc:.4f}', level=2)
args.logger(f'[Stats]\tMany:\t{test_cls[0]:.4f}\tMedium:\t{test_cls[1]:.4f}\tFew:\t{test_cls[2]:.4f}', level=2)
return model
def main():
print(f'==> Preparing imbalanced CIFAR-100')
trainset, testset = get_cifar100(os.path.join(args.data_dir, 'cifar100/'), args)
N_SAMPLES_PER_CLASS = trainset.img_num_list
trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
drop_last=False, pin_memory=True, sampler=None)
testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
pin_memory=True)
# Model
print("==> creating {}".format(args.network))
model = resnet34(num_classes=100, pool_size=4).cuda()
model = load_model(args, model, testloader, N_SAMPLES_PER_CLASS)
train_stage2(args, model, trainloader, testloader, N_SAMPLES_PER_CLASS)
if __name__ == '__main__':
main()