-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
163 lines (144 loc) · 7.33 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# -*- coding: utf-8 -*-
# @Time : 2019/8/24 12:06
# @Author : zhoujun
import os
import sys
import pathlib
__dir__ = pathlib.Path(os.path.abspath(__file__))
sys.path.append(str(__dir__))
sys.path.append(str(__dir__.parent.parent))
# project = 'DBNet.pytorch' # 工作项目根目录
# sys.path.append(os.getcwd().split(project)[0] + project)
import time
import cv2
import torch
from data_loader import get_transforms
from models import build_model
from post_processing import get_post_processing
def resize_image(img, short_size):
height, width, _ = img.shape
if height < width:
new_height = short_size
new_width = new_height / height * width
else:
new_width = short_size
new_height = new_width / width * height
new_height = int(round(new_height / 32) * 32)
new_width = int(round(new_width / 32) * 32)
resized_img = cv2.resize(img, (new_width, new_height))
return resized_img
class Pytorch_model:
def __init__(self, model_path, post_p_thre=0.7, gpu_id=None):
'''
初始化pytorch模型
:param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件) model_path='/home/share/gaoluoluo/dbnet/output/DBNet_resnet18_FPN_DBHead/checkpoint/model_latest.pth'
:param gpu_id: 在哪一块gpu上运行
'''
self.gpu_id = gpu_id
if self.gpu_id is not None and isinstance(self.gpu_id, int) and torch.cuda.is_available():
self.device = torch.device("cuda:%s" % self.gpu_id)
else:
self.device = torch.device("cpu")
print('device:', self.device)
checkpoint = torch.load(model_path, map_location=self.device)
print("checkpoint:",checkpoint)
config = checkpoint['config']
print(checkpoint['config'])
config['arch']['backbone']['pretrained'] = False
self.model = build_model(config['arch'])
self.post_process = get_post_processing(config['post_processing'])
self.post_process.box_thresh = post_p_thre
self.img_mode = config['dataset']['train']['dataset']['args']['img_mode']
self.model.load_state_dict(checkpoint['state_dict'])
self.model.to(self.device)
self.model.eval()
self.transform = []
for t in config['dataset']['train']['dataset']['args']['transforms']:
if t['type'] in ['ToTensor', 'Normalize']:
self.transform.append(t)
self.transform = get_transforms(self.transform)
def predict(self, img_path: str, is_output_polygon=False, short_size: int = 1024):
'''
对传入的图像进行预测,支持图像地址,opecv 读取图片,偏慢
:param img_path: 图像地址
:param is_numpy:
:return:
'''
assert os.path.exists(img_path), 'file is not exists'
img = cv2.imread(img_path, 1 if self.img_mode != 'GRAY' else 0)
if self.img_mode == 'RGB':
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w = img.shape[:2] # 2550 3507
img = resize_image(img, short_size)
# 将图片由(w,h)变为(1,img_channel,h,w)
tensor = self.transform(img)
tensor = tensor.unsqueeze_(0)
tensor = tensor.to(self.device)
batch = {'shape': [(h, w)]}
with torch.no_grad():
if str(self.device).__contains__('cuda'):
torch.cuda.synchronize(self.device)
start = time.time()
preds = self.model(tensor)
if str(self.device).__contains__('cuda'):
torch.cuda.synchronize(self.device)
box_list, score_list = self.post_process(batch, preds, is_output_polygon=is_output_polygon)
box_list, score_list = box_list[0], score_list[0]
if len(box_list) > 0:
if is_output_polygon:
idx = [x.sum() > 0 for x in box_list]
box_list = [box_list[i] for i, v in enumerate(idx) if v]
score_list = [score_list[i] for i, v in enumerate(idx) if v]
else:
idx = box_list.reshape(box_list.shape[0], -1).sum(axis=1) > 0 # 去掉全为0的框
box_list, score_list = box_list[idx], score_list[idx]
else:
box_list, score_list = [], []
t = time.time() - start
return preds[0, 0, :, :].detach().cpu().numpy(), box_list, score_list, t
def save_depoly(model, input, save_path):
traced_script_model = torch.jit.trace(model, input)
traced_script_model.save(save_path)
def init_args():
import argparse
parser = argparse.ArgumentParser(description='DBNet.pytorch')
parser.add_argument('--model_path', default=r'/home/share/gaoluoluo/dbnet/output/DBNet_resnet18_FPN_DBHead/checkpoint/model_latest.pth', type=str)
parser.add_argument('--input_folder', default='/home/share/gaoluoluo/dbnet/test/input', type=str, help='img path for predict')
parser.add_argument('--output_folder', default='/home/share/gaoluoluo/dbnet/test/output', type=str, help='img path for output')
parser.add_argument('--thre', default=0.1, type=float, help='the thresh of post_processing')
parser.add_argument('--polygon', action='store_true', help='output polygon or box')
parser.add_argument('--show', action='store_true', help='show result')
parser.add_argument('--save_result', action='store_true', help='save box and score to txt file')
args = parser.parse_args()
return args
if __name__ == '__main__':
import pathlib
from tqdm import tqdm
import matplotlib.pyplot as plt
from utils.util import show_img, draw_bbox, save_result, get_file_list
args = init_args()
print(args)
os.environ['CUDA_VISIBLE_DEVICES'] = str('0')
# 初始化网络 0.1
model = Pytorch_model(args.model_path, post_p_thre=args.thre, gpu_id=0)
img_folder = pathlib.Path(args.input_folder)# dbnet/test/input/
for img_path in tqdm(get_file_list(args.input_folder, p_postfix=['.jpg', '.png'])): # img_path /home/share/gaoluoluo/dbnet/test/input/2018实验仪器发票.jpg
# print("img_path:",img_path) /home/share/gaoluoluo/dbnet/test/input/2018实验仪器发票.jpg
preds, boxes_list, score_list, t = model.predict(img_path, is_output_polygon=args.polygon)
img = draw_bbox(cv2.imread(img_path)[:, :, ::-1], boxes_list)
if args.show:
show_img(preds)
show_img(img, title=os.path.basename(img_path))
plt.show()
# 保存结果到路径
os.makedirs(args.output_folder, exist_ok=True)
img_path = pathlib.Path(img_path)
output_path = os.path.join(args.output_folder, img_path.stem + '_result.jpg') # /home/share/gaoluoluo/dbnet/test/output/2018实验仪器发票_result.jpg
pred_path = os.path.join(args.output_folder, img_path.stem + '_pred.jpg')# /home/share/gaoluoluo/dbnet/test/output/2018实验仪器发票_pred.jpg
cv2.imwrite(output_path, img[:, :, ::-1])
cv2.imwrite(pred_path, preds * 255)
print(args.polygon)
# print("output_path:",output_path.replace('_result.jpg','.txt')) /home/share/gaoluoluo/dbnet/test/output/2018实验仪器发票txt
# print("boxes_list:",boxes_list) 4个点的坐标
# print("score_list:",score_list) 率
save_result(output_path.replace('_result.jpg', '.txt'), boxes_list, score_list, args.polygon)