forked from yss9701/Ultra96-Yolov4-tiny-and-Yolo-Fastest
-
Notifications
You must be signed in to change notification settings - Fork 0
/
yolo.py
186 lines (163 loc) · 7.47 KB
/
yolo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import numpy as np
import copy
import colorsys
import tensorflow as tf
from timeit import default_timer as timer
from tensorflow.compat.v1.keras import backend as K
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Lambda
from tensorflow.keras.models import Model
from PIL import Image, ImageFont, ImageDraw
from nets.yolo4_tiny import yolo_body,yolo_eval
from utils.utils import letterbox_image
class YOLO(object):
_defaults = {
"model_path" : 'logs_2/last1.h5',
"anchors_path" : 'model_data/yolo_anchors.txt',
"classes_path" : 'model_data/new_class.txt',
"score" : 0.5,
"iou" : 0.3,
"eager" : False,
"max_boxes" : 100,
# 显存比较小可以使用416x416
# 显存比较大可以使用608x608
"model_image_size" : (320, 320)
}
@classmethod
def get_defaults(cls, n):
if n in cls._defaults:
return cls._defaults[n]
else:
return "Unrecognized attribute name '" + n + "'"
#---------------------------------------------------#
# 初始化yolo
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
self.class_names = self._get_class()
self.anchors = self._get_anchors()
if not self.eager:
tf.compat.v1.disable_eager_execution()
self.sess = K.get_session()
self.generate()
#---------------------------------------------------#
# 获得所有的分类
#---------------------------------------------------#
def _get_class(self):
classes_path = os.path.expanduser(self.classes_path)
with open(classes_path) as f:
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
return class_names
#---------------------------------------------------#
# 获得所有的先验框
#---------------------------------------------------#
def _get_anchors(self):
anchors_path = os.path.expanduser(self.anchors_path)
with open(anchors_path) as f:
anchors = f.readline()
anchors = [float(x) for x in anchors.split(',')]
return np.array(anchors).reshape(-1, 2)
#---------------------------------------------------#
# 获得所有的分类
#---------------------------------------------------#
def generate(self):
model_path = os.path.expanduser(self.model_path)
assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
# 计算anchor数量
num_anchors = len(self.anchors)
num_classes = len(self.class_names)
# 载入模型,如果原来的模型里已经包括了模型结构则直接载入。
# 否则先构建模型再载入
self.yolo_model = yolo_body(Input(shape=(None,None,3)), num_anchors//2, num_classes)
self.yolo_model.load_weights(self.model_path)
print('{} model, anchors, and classes loaded.'.format(model_path))
# 画框设置不同的颜色
hsv_tuples = [(x / len(self.class_names), 1., 1.)
for x in range(len(self.class_names))]
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
self.colors = list(
map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
self.colors))
# 打乱颜色
np.random.seed(10101)
np.random.shuffle(self.colors)
np.random.seed(None)
if self.eager:
self.input_image_shape = Input([2,],batch_size=1)
inputs = [*self.yolo_model.output, self.input_image_shape]
outputs = Lambda(yolo_eval, output_shape=(1,), name='yolo_eval',
arguments={'anchors': self.anchors, 'num_classes': len(self.class_names), 'image_shape': self.model_image_size,
'score_threshold': self.score, 'eager': True, 'max_boxes': self.max_boxes})(inputs)
self.yolo_model = Model([self.yolo_model.input, self.input_image_shape], outputs)
else:
self.input_image_shape = K.placeholder(shape=(2, ))
self.boxes, self.scores, self.classes = yolo_eval(self.yolo_model.output, self.anchors,
num_classes, self.input_image_shape, max_boxes=self.max_boxes,
score_threshold=self.score, iou_threshold=self.iou)
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def detect_image(self, image):
start = timer()
# 调整图片使其符合输入要求
new_image_size = (self.model_image_size[1],self.model_image_size[0])
boxed_image = letterbox_image(image, new_image_size)
image_data = np.array(boxed_image, dtype='float32')
image_data /= 255.
image_data = np.expand_dims(image_data, 0) # Add batch dimension.
if self.eager:
# 预测结果
input_image_shape = np.expand_dims(np.array([image.size[1], image.size[0]], dtype='float32'), 0)
out_boxes, out_scores, out_classes = self.yolo_model.predict([image_data, input_image_shape])
else:
# 预测结果
out_boxes, out_scores, out_classes = self.sess.run(
[self.boxes, self.scores, self.classes],
feed_dict={
self.yolo_model.input: image_data,
self.input_image_shape: [image.size[1], image.size[0]],
K.learning_phase(): 0
})
print('Found {} boxes for {}'.format(len(out_boxes), 'img'))
# 设置字体
font = ImageFont.truetype(font='font/simhei.ttf',
size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
thickness = (image.size[0] + image.size[1]) // 300
for i, c in list(enumerate(out_classes)):
predicted_class = self.class_names[c]
box = out_boxes[i]
score = out_scores[i]
top, left, bottom, right = box
top = top - 5
left = left - 5
bottom = bottom + 5
right = right + 5
top = max(0, np.floor(top + 0.5).astype('int32'))
left = max(0, np.floor(left + 0.5).astype('int32'))
bottom = min(image.size[1], np.floor(bottom + 0.5).astype('int32'))
right = min(image.size[0], np.floor(right + 0.5).astype('int32'))
# 画框框
label = '{} {:.2f}'.format(predicted_class, score)
draw = ImageDraw.Draw(image)
label_size = draw.textsize(label, font)
label = label.encode('utf-8')
print(label)
if top - label_size[1] >= 0:
text_origin = np.array([left, top - label_size[1]])
else:
text_origin = np.array([left, top + 1])
for i in range(thickness):
draw.rectangle(
[left + i, top + i, right - i, bottom - i],
outline=self.colors[c])
draw.rectangle(
[tuple(text_origin), tuple(text_origin + label_size)],
fill=self.colors[c])
draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
del draw
end = timer()
print(end - start)
return image