-
Notifications
You must be signed in to change notification settings - Fork 41
/
detect_main.py
114 lines (97 loc) · 4.15 KB
/
detect_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import cv2
import os
import time
import torch
import argparse
from nanodet.util import cfg, load_config, Logger
from nanodet.model.arch import build_model
from nanodet.util import load_model_weight
from nanodet.data.transform import Pipeline
image_ext = ['.jpg', '.jpeg', '.webp', '.bmp', '.png']
video_ext = ['mp4', 'mov', 'avi', 'mkv']
'''目标检测-图片'''
# python detect_main.py image --config ./config/nanodet-m.yml --model model/nanodet_m.pth --path street.png
'''目标检测-视频文件'''
# python detect_main.py video --config ./config/nanodet-m.yml --model model/nanodet_m.pth --path test.mp4
'''目标检测-摄像头'''
# python detect_main.py webcam --config ./config/nanodet-m.yml --model model/nanodet_m.pth --path 0
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('demo', default='image', help='demo type, eg. image, video and webcam')
parser.add_argument('--config', help='model config file path')
parser.add_argument('--model', help='model file path')
parser.add_argument('--path', default='./demo', help='path to images or video')
parser.add_argument('--camid', type=int, default=0, help='webcam demo camera id')
args = parser.parse_args()
return args
class Predictor(object):
def __init__(self, cfg, model_path, logger, device='cuda:0'):
self.cfg = cfg
self.device = device
model = build_model(cfg.model)
ckpt = torch.load(model_path, map_location=lambda storage, loc: storage)
load_model_weight(model, ckpt, logger)
self.model = model.to(device).eval()
self.pipeline = Pipeline(cfg.data.val.pipeline, cfg.data.val.keep_ratio)
def inference(self, img):
img_info = {}
if isinstance(img, str):
img_info['file_name'] = os.path.basename(img)
img = cv2.imread(img)
else:
img_info['file_name'] = None
height, width = img.shape[:2]
img_info['height'] = height
img_info['width'] = width
meta = dict(img_info=img_info,
raw_img=img,
img=img)
meta = self.pipeline(meta, self.cfg.data.val.input_size)
meta['img'] = torch.from_numpy(meta['img'].transpose(2, 0, 1)).unsqueeze(0).to(self.device)
with torch.no_grad():
results = self.model.inference(meta)
return meta, results
def visualize(self, dets, meta, class_names, score_thres, wait=0):
time1 = time.time()
self.model.head.show_result(meta['raw_img'], dets, class_names, score_thres=score_thres, show=True)
print('viz time: {:.3f}s'.format(time.time()-time1))
def get_image_list(path):
image_names = []
for maindir, subdir, file_name_list in os.walk(path):
for filename in file_name_list:
apath = os.path.join(maindir, filename)
ext = os.path.splitext(apath)[1]
if ext in image_ext:
image_names.append(apath)
return image_names
def main():
args = parse_args()
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
load_config(cfg, args.config)
logger = Logger(-1, use_tensorboard=False)
predictor = Predictor(cfg, args.model, logger, device='cuda:0')
logger.log('Press "Esc", "q" or "Q" to exit.')
if args.demo == 'image':
if os.path.isdir(args.path):
files = get_image_list(args.path)
else:
files = [args.path]
files.sort()
for image_name in files:
meta, res = predictor.inference(image_name)
predictor.visualize(res, meta, cfg.class_names, 0.35)
ch = cv2.waitKey(0)
if ch == 27 or ch == ord('q') or ch == ord('Q'):
break
elif args.demo == 'video' or args.demo == 'webcam':
cap = cv2.VideoCapture(args.path if args.demo == 'video' else args.camid)
while True:
ret_val, frame = cap.read()
meta, res = predictor.inference(frame)
predictor.visualize(res, meta, cfg.class_names, 0.35)
ch = cv2.waitKey(1)
if ch == 27 or ch == ord('q') or ch == ord('Q'):
break
if __name__ == '__main__':
main()