-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
204 lines (168 loc) · 6.89 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from __future__ import absolute_import
import os
import sys
from tqdm import tqdm
import torch
# config
from utils.config import opt
#dataset
from torch.utils.data import DataLoader
from data.dataset import Dataset
# model
from model import FasterRCNNVGG16, FPNFasterRCNNVGG16
from torchnet.meter import AverageValueMeter
from model.frcnn_bottleneck import Losses
# utils
from utils import array_tool as at
from utils.eval_tool import voc_ap
# def setup_logger():
# if not os.path.exists(opt.save_dir):
# os.makedirs(opt.save_dir)
# f = open(f'{opt.save_dir}/log.txt', 'w')
# sys.stdout = f
# return f
def update_meters(meters, losses):
loss_d = {k: at.scalar(v) for k, v in losses._asdict().items()}
for key, meter in meters.items():
meter.add(loss_d[key])
def reset_meters(meters):
for _, meter in meters.items():
meter.reset()
def get_meter_data(meters):
return {k: v.value()[0] for k, v in meters.items()}
def save_model(model, model_name, dataset_name,epoch):
PATH = f'./checkpoints/{model_name}/{dataset_name}/checkpoint{epoch}.pth'
dir = os.path.dirname(PATH)
if not os.path.exists(dir):
os.makedirs(dir)
torch.save(model.state_dict(), PATH)
return PATH
def build_optimizer(net):
"""
return optimizer, It could be overwriten if you want to specify
special optimizer
"""
lr = opt.lr
params = []
for key, value in dict(net.named_parameters()).items():
if value.requires_grad:
if 'bias' in key:
params += [{'params': [value], 'lr': lr * 2, 'weight_decay': 0}]
else:
params += [{'params': [value], 'lr': lr, 'weight_decay': opt.weight_decay}]
return torch.optim.SGD(params, momentum=0.9)
def train(**kwargs):
# set up cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# parse model parameters from config
opt.f_parse_args(kwargs)
if opt.database == 'voc':
print('load voc data')
elif opt.database == 'kitti':
print('load kitti data')
# load training dataset
train_data = Dataset(opt,mode='train')
train_dataloader = DataLoader(train_data,
batch_size=1,
shuffle=True,
num_workers=opt.train_num_workers)
# # load testing dataset
test_data = Dataset(opt, mode='test')
test_dataloader = DataLoader(test_data,
batch_size=1,
shuffle=False,
num_workers=opt.test_num_workers)
# model construction
if opt.database == 'voc':
if opt.apply_fpn:
if opt.deformable:
print('load Deformable FPN Faster RCNN Model')
else:
print('load FPN Faster RCNN Model')
net = FPNFasterRCNNVGG16(n_fg_class=20).to(device)
else:
if opt.deformable:
print('load Deformable Faster RCNN Model')
else:
print('load Faster RCNN Model')
net = FasterRCNNVGG16(n_fg_class=20).to(device)
elif opt.database == 'kitti':
if opt.apply_fpn:
if opt.deformable:
print('load Deformable FPN Faster RCNN Model')
else:
print('load FPN Faster RCNN Model')
net = FPNFasterRCNNVGG16(n_fg_class=3).to(device)
else:
if opt.deformable:
print('load Deformable Faster RCNN Model')
else:
print('load Faster RCNN Model')
net = FasterRCNNVGG16(n_fg_class=3).to(device) # 3 classes: Car, Pedestrian, Cyclist
print('Load SDG optimizer')
# optimizer construction
optimizer = build_optimizer(net)
# fitting
meters = {k: AverageValueMeter() for k in Losses._fields}
best_mAP = 0
best_path = None
lr = opt.lr
print('Start training...')
for epoch in range(1, opt.epoch + 1):
# switch to train mode
net.train()
print(f'epoch #{epoch}')
# reset meters
reset_meters(meters)
# train batch
for img, bboxes, labels, scale in tqdm(train_dataloader):
# prepare data
scale = at.scalar(scale)
img, bboxes, labels = img.to(device).float(), bboxes.to(device), labels.to(device)
# forward + backward
optimizer.zero_grad()
losses = net.forward(img,bboxes, labels, scale)
losses.total_loss.backward()
optimizer.step()
update_meters(meters, losses)
# print loss
loss_metadata = get_meter_data(meters)
rpn_loc_loss = loss_metadata['rpn_loc_loss']
rpn_cls_loss = loss_metadata['rpn_cls_loss']
roi_loc_loss = loss_metadata['roi_loc_loss']
roi_cls_loss = loss_metadata['roi_cls_loss']
total_loss = loss_metadata['total_loss']
print('lr=={} | rpn_loc_loss=={:.4f} | rpn_cls_loss=={:.4f} | roi_loc_loss=={:.4f} | roi_cls_loss=={:.4f} | total_loss=={:.4f}'.format(lr,
rpn_loc_loss,
rpn_cls_loss,
roi_loc_loss,
roi_cls_loss,
total_loss))
# evaluate
net.eval()
mAP = voc_ap(net, test_dataloader)
# save model (if best model)
if mAP > best_mAP:
best_mAP = mAP
best_path = save_model(net, opt.model,opt.database, epoch)
# learning rate decay
if epoch == opt.epoch_decay:
# load best model
net.load_state_dict(torch.load(best_path))
# learning rate decay
for param in optimizer.param_groups:
param['lr'] *= opt.lr_decay
lr = lr * opt.lr_decay
# save final model
if opt.deformable:
model_name = 'deformable_frcnn_vgg16' if not opt.apply_fpn else 'deformable_fpn_frcnn_vgg16'
else:
model_name = 'frcnn_vgg16' if not opt.apply_fpn else 'fpn_frcnn_vgg16'
PATH = f'{opt.save_dir}/{opt.database}/{model_name}.pth'
target_dir = os.path.dirname(PATH)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
torch.save(net.state_dict(), PATH)
#log.close()
if __name__ == '__main__':
train()