-
Notifications
You must be signed in to change notification settings - Fork 59
/
infer.py
150 lines (117 loc) · 4.66 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
'''
Script to use ViTSTR to convert scene text image to text.
Usage:
python3 infer.py --image demo_image/demo_1.png --model https://github.com/roatienza/deep-text-recognition-benchmark/releases/download/v0.1.0/vitstr_small_patch16_224_aug_infer.pth
--image: path to image file to convert to text
Inference timing:
Quantized on CPU:
python3 infer.py --model vitstr_small_patch16_quant.pt --time --quantized
Average inference time per image: 2.22e-02 sec
CPU:
python3 infer.py --model vitstr_small_patch16_224_aug_infer.pth --time
Average inference time per image: 3.24e-02 sec
With JIT:
python3 infer.py --model vitstr_small_patch16_jit.pt --time
Average inference time per image: 2.75e-02 sec
GPU:
python3 infer.py --model vitstr_small_patch16_224_aug_infer.pth --time --gpu
Average inference time per image: 3.50e-03 sec
With JIT:
python3 infer.py --model vitstr_small_patch16_jit.pt --time --gpu
Average inference time per image: 2.56e-03 sec
RPi 4 CPU Quantized:
python3 infer.py --model vitstr_small_patch16_quant.pt --time --rpi --quantized
Average inference time per image: 3.59e-01 sec
RPi 4 CPU JIT:
python3 infer.py --model vitstr_small_patch16_jit.pt --time --rpi
Average inference time per image: 4.64e-01 sec
To generate torchscript jit
model.py
def forward(self, input, seqlen: int =25): #text, is_train=True, seqlen=25):
""" Transformation stage """
#if not self.stages['Trans'] == "None":
# input = self.Transformation(input)
#if self.stages['ViTSTR']:
prediction = self.vitstr(input, seqlen=seqlen)
return prediction
modules/vitstr.py
def forward(self, x, seqlen: int =25):
'''
import os
import torch
import string
import validators
import time
from infer_utils import TokenLabelConverter, NormalizePAD, ViTSTRFeatureExtractor
from infer_utils import get_args
def img2text(model, images, converter):
pred_strs = []
with torch.no_grad():
for img in images:
pred = model(img, seqlen=converter.batch_max_length)
_, pred_index = pred.topk(1, dim=-1, largest=True, sorted=True)
pred_index = pred_index.view(-1, converter.batch_max_length)
length_for_pred = torch.IntTensor([converter.batch_max_length - 1] )
pred_str = converter.decode(pred_index[:, 1:], length_for_pred)
pred_EOS = pred_str[0].find('[s]')
pred_str = pred_str[0][:pred_EOS]
pred_strs.append(pred_str)
return pred_strs
def infer(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
converter = TokenLabelConverter(args)
args.num_class = len(converter.character)
extractor = ViTSTRFeatureExtractor()
if args.time:
files = ["demo_1.png", "demo_2.jpg", "demo_3.png", "demo_4.png", "demo_5.png", "demo_6.png", "demo_7.png", "demo_8.jpg", "demo_9.jpg", "demo_10.jpg"]
images = []
extractor
for f in files:
f = os.path.join("demo_image", f)
img = extractor(f)
if args.gpu:
img = img.to(device)
images.append(img)
else:
assert(args.image is not None)
files = [args.image]
img = extractor(args.image)
if args.gpu:
img = img.to(device)
images = [img]
if args.quantized:
if args.rpi:
backend = "qnnpack" #arm
else:
backend = "fbgemm" #x86
torch.backends.quantized.engine = backend
if validators.url(args.model):
checkpoint = args.model.rsplit('/', 1)[-1]
torch.hub.download_url_to_file(args.model, checkpoint)
else:
checkpoint = args.model
if args.quantized:
model = torch.jit.load(checkpoint)
else:
model = torch.load(checkpoint)
if args.gpu:
model.to(device)
model.eval()
if args.time:
n_times = 10
n_total = len(images) * n_times
[img2text(model, images, converter) for _ in range(n_times)]
start_time = time.time()
[img2text(model, images, converter) for _ in range(n_times)]
end_time = time.time()
ave_time = (end_time - start_time) / n_total
print("Average inference time per image: %0.2e sec" % ave_time)
pred_strs = img2text(model, images, converter)
return zip(files, pred_strs)
if __name__ == '__main__':
args = get_args()
args.character = string.printable[:-6]
data = infer(args)
for filename, text in data:
print(filename, "\t: ", text)
#print(infer(args))