-
Notifications
You must be signed in to change notification settings - Fork 250
/
Copy pathtracker.py
130 lines (101 loc) · 4.33 KB
/
tracker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import cv2
import torch
import numpy as np
from deep_sort.utils.parser import get_config
from deep_sort.deep_sort import DeepSort
cfg = get_config()
cfg.merge_from_file("./deep_sort/configs/deep_sort.yaml")
deepsort = DeepSort(cfg.DEEPSORT.REID_CKPT,
max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET,
use_cuda=True)
def draw_bboxes(image, bboxes, line_thickness):
line_thickness = line_thickness or round(
0.002 * (image.shape[0] + image.shape[1]) * 0.5) + 1
list_pts = []
point_radius = 4
for (x1, y1, x2, y2, cls_id, pos_id) in bboxes:
color = (0, 255, 0)
# 撞线的点
check_point_x = x1
check_point_y = int(y1 + ((y2 - y1) * 0.6))
c1, c2 = (x1, y1), (x2, y2)
cv2.rectangle(image, c1, c2, color, thickness=line_thickness, lineType=cv2.LINE_AA)
font_thickness = max(line_thickness - 1, 1)
t_size = cv2.getTextSize(cls_id, 0, fontScale=line_thickness / 3, thickness=font_thickness)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(image, c1, c2, color, -1, cv2.LINE_AA) # filled
cv2.putText(image, '{} ID-{}'.format(cls_id, pos_id), (c1[0], c1[1] - 2), 0, line_thickness / 3,
[225, 255, 255], thickness=font_thickness, lineType=cv2.LINE_AA)
list_pts.append([check_point_x - point_radius, check_point_y - point_radius])
list_pts.append([check_point_x - point_radius, check_point_y + point_radius])
list_pts.append([check_point_x + point_radius, check_point_y + point_radius])
list_pts.append([check_point_x + point_radius, check_point_y - point_radius])
ndarray_pts = np.array(list_pts, np.int32)
cv2.fillPoly(image, [ndarray_pts], color=(0, 0, 255))
list_pts.clear()
return image
def update(bboxes, image):
bbox_xywh = []
confs = []
bboxes2draw = []
if len(bboxes) > 0:
for x1, y1, x2, y2, lbl, conf in bboxes:
obj = [
int((x1 + x2) * 0.5), int((y1 + y2) * 0.5),
x2 - x1, y2 - y1
]
bbox_xywh.append(obj)
confs.append(conf)
xywhs = torch.Tensor(bbox_xywh)
confss = torch.Tensor(confs)
outputs = deepsort.update(xywhs, confss, image)
for x1, y1, x2, y2, track_id in list(outputs):
# x1, y1, x2, y2, track_id = value
center_x = (x1 + x2) * 0.5
center_y = (y1 + y2) * 0.5
label = search_label(center_x=center_x, center_y=center_y,
bboxes_xyxy=bboxes, max_dist_threshold=20.0)
bboxes2draw.append((x1, y1, x2, y2, label, track_id))
pass
pass
return bboxes2draw
def search_label(center_x, center_y, bboxes_xyxy, max_dist_threshold):
"""
在 yolov5 的 bbox 中搜索中心点最接近的label
:param center_x:
:param center_y:
:param bboxes_xyxy:
:param max_dist_threshold:
:return: 字符串
"""
label = ''
# min_label = ''
min_dist = -1.0
for x1, y1, x2, y2, lbl, conf in bboxes_xyxy:
center_x2 = (x1 + x2) * 0.5
center_y2 = (y1 + y2) * 0.5
# 横纵距离都小于 max_dist
min_x = abs(center_x2 - center_x)
min_y = abs(center_y2 - center_y)
if min_x < max_dist_threshold and min_y < max_dist_threshold:
# 距离阈值,判断是否在允许误差范围内
# 取 x, y 方向上的距离平均值
avg_dist = (min_x + min_y) * 0.5
if min_dist == -1.0:
# 第一次赋值
min_dist = avg_dist
# 赋值label
label = lbl
pass
else:
# 若不是第一次,则距离小的优先
if avg_dist < min_dist:
min_dist = avg_dist
# label
label = lbl
pass
pass
pass
return label