-
Notifications
You must be signed in to change notification settings - Fork 59
/
test.py
359 lines (298 loc) · 15.2 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
'''
Test on single GPU:
CUDA_VISIBLE_DEVICES=0 python3 test.py --eval_data data_lmdb_release/evaluation --benchmark_all_eval --Transformation None --FeatureExtraction None --SequenceModeling None --Prediction None --Transformer --sensitive --data_filtering_off --imgH 224 --imgW 224 --TransformerModel=vitstr_small_patch16_224 --saved_model https://github.com/roatienza/deep-text-recognition-benchmark/releases/download/v0.1.0/vitstr_small_patch16_224_aug.pth
To convert to quantized model, add the ff to the script above:
--infer_model=vitstr_small_patch16_quant.pt --quantized
To convert to a standalone jit model, add the ff to the script above:
--infer_model=vitstr_small_patch16_jit.pt
'''
import os
import time
import string
import argparse
import re
import validators
import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
import torch.nn.functional as F
import numpy as np
from nltk.metrics.distance import edit_distance
from utils import CTCLabelConverter, AttnLabelConverter, Averager, TokenLabelConverter
from dataset import hierarchical_dataset, AlignCollate
from model import Model, JitModel
from utils import get_args
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def benchmark_all_eval(model, criterion, converter, opt): #, calculate_infer_time=False):
""" evaluation with 10 benchmark evaluation datasets """
if opt.fast_acc:
# # To easily compute the total accuracy of our paper.
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_867', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80']
else:
# The evaluation datasets, dataset order is same with Table 1 in our paper.
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857',
'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
if opt.calculate_infer_time:
evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image.
else:
evaluation_batch_size = opt.batch_size
list_accuracy = []
total_forward_time = 0
total_evaluation_data_number = 0
total_correct_number = 0
log = open(f'./result/{opt.exp_name}/log_all_evaluation.txt', 'a')
dashed_line = '-' * 80
print(dashed_line)
log.write(dashed_line + '\n')
for eval_data in eval_data_list:
eval_data_path = os.path.join(opt.eval_data, eval_data)
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, opt=opt)
eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt)
evaluation_loader = torch.utils.data.DataLoader(
eval_data, batch_size=evaluation_batch_size,
shuffle=False,
num_workers=int(opt.workers),
collate_fn=AlignCollate_evaluation, pin_memory=True)
_, accuracy_by_best_model, norm_ED_by_best_model, _, _, _, infer_time, length_of_data = validation(
model, criterion, evaluation_loader, converter, opt)
list_accuracy.append(f'{accuracy_by_best_model:0.3f}')
total_forward_time += infer_time
total_evaluation_data_number += len(eval_data)
total_correct_number += accuracy_by_best_model * length_of_data
log.write(eval_data_log)
print(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}')
log.write(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}\n')
print(dashed_line)
log.write(dashed_line + '\n')
averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000
total_accuracy = total_correct_number / total_evaluation_data_number
params_num = sum([np.prod(p.size()) for p in model.parameters()])
evaluation_log = 'accuracy: '
for name, accuracy in zip(eval_data_list, list_accuracy):
evaluation_log += f'{name}: {accuracy}\t'
evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t'
evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.3f}'
if opt.flops:
evaluation_log += get_flops(model, opt, converter)
print(evaluation_log)
log.write(evaluation_log + '\n')
log.close()
return None
def validation(model, criterion, evaluation_loader, converter, opt):
""" validation or evaluation """
n_correct = 0
norm_ED = 0
length_of_data = 0
infer_time = 0
valid_loss_avg = Averager()
for i, (image_tensors, labels) in enumerate(evaluation_loader):
batch_size = image_tensors.size(0)
length_of_data = length_of_data + batch_size
image = image_tensors.to(device)
# For max length prediction
if opt.Transformer:
target = converter.encode(labels)
else:
length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)
start_time = time.time()
if 'CTC' in opt.Prediction:
preds = model(image, text_for_pred)
forward_time = time.time() - start_time
# Calculate evaluation loss for CTC deocder.
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
# permute 'preds' to use CTCloss format
if opt.baiduCTC:
cost = criterion(preds.permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) / batch_size
else:
cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
# Select max probabilty (greedy decoding) then decode index to character
if opt.baiduCTC:
_, preds_index = preds.max(2)
preds_index = preds_index.view(-1)
else:
_, preds_index = preds.max(2)
preds_str = converter.decode(preds_index.data, preds_size.data)
elif opt.Transformer:
preds = model(image, text=target, seqlen=converter.batch_max_length)
_, preds_index = preds.topk(1, dim=-1, largest=True, sorted=True)
preds_index = preds_index.view(-1, converter.batch_max_length)
forward_time = time.time() - start_time
cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))
length_for_pred = torch.IntTensor([converter.batch_max_length - 1] * batch_size).to(device)
preds_str = converter.decode(preds_index[:, 1:], length_for_pred)
else:
preds = model(image, text_for_pred, is_train=False)
forward_time = time.time() - start_time
preds = preds[:, :text_for_loss.shape[1] - 1, :]
target = text_for_loss[:, 1:] # without [GO] Symbol
cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))
# select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds.max(2)
preds_str = converter.decode(preds_index, length_for_pred)
labels = converter.decode(text_for_loss[:, 1:], length_for_loss)
infer_time += forward_time
valid_loss_avg.add(cost)
# calculate accuracy & confidence score
preds_prob = F.softmax(preds, dim=2)
preds_max_prob, _ = preds_prob.max(dim=2)
confidence_score_list = []
for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
if opt.Transformer:
pred_EOS = pred.find('[s]')
pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
pred_max_prob = pred_max_prob[:pred_EOS]
elif 'Attn' in opt.Prediction:
gt = gt[:gt.find('[s]')]
pred_EOS = pred.find('[s]')
pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
pred_max_prob = pred_max_prob[:pred_EOS]
# To evaluate 'case sensitive model' with alphanumeric and case insensitve setting.
if opt.sensitive and opt.data_filtering_off:
pred = pred.lower()
gt = gt.lower()
alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt)
if pred == gt:
n_correct += 1
'''
(old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
"For each word we calculate the normalized edit distance to the length of the ground truth transcription."
if len(gt) == 0:
norm_ED += 1
else:
norm_ED += edit_distance(pred, gt) / len(gt)
'''
# ICDAR2019 Normalized Edit Distance
if len(gt) == 0 or len(pred) == 0:
norm_ED += 0
elif len(gt) > len(pred):
norm_ED += 1 - edit_distance(pred, gt) / len(gt)
else:
norm_ED += 1 - edit_distance(pred, gt) / len(pred)
# calculate confidence score (= multiply of pred_max_prob)
try:
confidence_score = pred_max_prob.cumprod(dim=0)[-1]
except:
confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s])
confidence_score_list.append(confidence_score)
# print(pred, gt, pred==gt, confidence_score)
accuracy = n_correct / float(length_of_data) * 100
norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance
return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data
def get_state_dict(state_dict):
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove module.
new_state_dict[name] = v
return new_state_dict
# https://pytorch.org/tutorials/beginner/saving_loading_models.html
def get_infer_model(model, opt):
new_state_dict = get_state_dict(model.state_dict())
model = JitModel(opt)
model.load_state_dict(new_state_dict)
model.eval()
if opt.quantized:
# static quantization : Work in progress
if opt.static:
backend = "qnnpack"
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
model_quantized = torch.quantization.prepare(model, inplace=False)
model_quantized = torch.quantization.convert(model_quantized, inplace=False)
# support for dynamic quantization
else:
from torch.quantization import quantize_dynamic
model_quantized = quantize_dynamic(model=model,
qconfig_spec={torch.nn.Linear}, dtype=torch.qint8, inplace=False
)
# quantized model save/load https://pytorch.org/docs/stable/quantization.html
model = torch.jit.script(model_quantized)
model_scripted = torch.jit.script(model)
model_scripted.save(opt.infer_model)
return
def test(opt):
""" model configuration """
if opt.Transformer:
converter = TokenLabelConverter(opt)
elif 'CTC' in opt.Prediction:
converter = CTCLabelConverter(opt.character)
else:
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)
if opt.rgb:
opt.input_channel = 3
model = Model(opt)
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
opt.SequenceModeling, opt.Prediction)
model = torch.nn.DataParallel(model).to(device)
# load model
print('loading pretrained model from %s' % opt.saved_model)
if validators.url(opt.saved_model):
model.load_state_dict(torch.hub.load_state_dict_from_url(opt.saved_model, progress=True, map_location=device))
else:
model.load_state_dict(torch.load(opt.saved_model, map_location=device))
opt.exp_name = '_'.join(opt.saved_model.split('/')[1:])
# print(model)
if opt.infer_model is not None:
get_infer_model(model, opt)
return
""" keep evaluation model and result logs """
os.makedirs(f'./result/{opt.exp_name}', exist_ok=True)
os.system(f'cp {opt.saved_model} ./result/{opt.exp_name}/')
""" setup loss """
if 'CTC' in opt.Prediction:
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
else:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
""" evaluation """
model.eval()
opt.eval = True
with torch.no_grad():
if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets
benchmark_all_eval(model, criterion, converter, opt)
else:
log = open(f'./result/{opt.exp_name}/log_evaluation.txt', 'a')
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, opt=opt)
eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt)
evaluation_loader = torch.utils.data.DataLoader(
eval_data, batch_size=opt.batch_size,
shuffle=False,
num_workers=int(opt.workers),
collate_fn=AlignCollate_evaluation, pin_memory=True)
_, accuracy_by_best_model, _, _, _, _, _, _ = validation(
model, criterion, evaluation_loader, converter, opt)
log.write(eval_data_log)
print(f'{accuracy_by_best_model:0.3f}')
log.write(f'{accuracy_by_best_model:0.3f}\n')
log.close()
# https://github.com/clovaai/deep-text-recognition-benchmark/issues/125
def get_flops(model, opt, converter):
from thop import profile
input = torch.randn(1, 1, opt.imgH, opt.imgW).to(device)
model = model.to(device)
if opt.Transformer:
seqlen = converter.batch_max_length
text_for_pred = torch.LongTensor(1, seqlen).fill_(0).to(device)
#preds = model(image, text=target, seqlen=converter.batch_max_length)
MACs, params = profile(model, inputs=(input, text_for_pred, True, seqlen))
else:
text_for_pred = torch.LongTensor(1, opt.batch_max_length + 1).fill_(0).to(device)
#model_ = Model(opt).to(device)
MACs, params = profile(model, inputs=(input, text_for_pred, ))
flops = 2 * MACs # approximate FLOPS
return f'Approximate FLOPS: {flops:0.3f}'
if __name__ == '__main__':
opt = get_args(is_train=False)
""" vocab / character number configuration """
if opt.sensitive:
opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
cudnn.benchmark = True
cudnn.deterministic = True
opt.num_gpu = torch.cuda.device_count()
test(opt)