-
Notifications
You must be signed in to change notification settings - Fork 102
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7330ecc
commit b0cb4c4
Showing
30 changed files
with
3,809 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2020 JiaQi Xu | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,58 @@ | ||
# mobilenet-yolov4-lite-pytorch | ||
这是一个mobilenet-yolov4-lite的库,把yolov4主干网络修改成了mobilenet,修改了Panet的卷积组成,使参数量大幅度缩小。 | ||
## YOLOV4:You Only Look Once目标检测模型在pytorch当中的实现 | ||
--- | ||
|
||
### 目录 | ||
1. [所需环境 Environment](#所需环境) | ||
2. [注意事项 Attention](#注意事项) | ||
3. [小技巧的设置 TricksSet](#小技巧的设置) | ||
4. [文件下载 Download](#文件下载)) | ||
5. [训练步骤 How2train](#训练步骤) | ||
6. [参考资料 Reference](#Reference) | ||
|
||
### YOLOV4的改进 | ||
- [x] 主干特征提取网络:DarkNet53 => CSPDarkNet53 | ||
- [x] 特征金字塔:SPP,PAN | ||
- [x] 训练用到的小技巧:Mosaic数据增强、Label Smoothing平滑、CIOU、学习率余弦退火衰减 | ||
- [x] 激活函数:使用Mish激活函数 | ||
- [ ] ……balabla | ||
|
||
### 所需环境 | ||
torch==1.2.0 | ||
|
||
### 注意事项 | ||
代码中的yolo4_weights.pth是基于608x608的图片训练的,但是由于显存原因。我将代码中的图片大小修改成了416x416。有需要的可以修改回来。 代码中的默认anchors是基于608x608的图片的。 | ||
|
||
### 小技巧的设置 | ||
在train.py文件下: | ||
1、mosaic参数可用于控制是否实现Mosaic数据增强。 | ||
2、Cosine_scheduler可用于控制是否使用学习率余弦退火衰减。 | ||
3、label_smoothing可用于控制是否Label Smoothing平滑。 | ||
|
||
### 文件下载 | ||
训练所需的yolo4_weights.pth可在百度网盘中下载。 | ||
链接: https://pan.baidu.com/s/1VNSYi39AaqjHVbdNpW_7sw 提取码: q2iv | ||
yolo4_weights.pth是coco数据集的权重。 | ||
yolo4_voc_weights.pth是voc数据集的权重。 | ||
|
||
### 训练步骤 | ||
1、本文使用VOC格式进行训练。 | ||
2、训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。 | ||
3、训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。 | ||
4、在训练前利用voc2yolo4.py文件生成对应的txt。 | ||
5、再运行根目录下的voc_annotation.py,运行前需要将classes改成你自己的classes。 | ||
```python | ||
classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] | ||
``` | ||
6、就会生成对应的2007_train.txt,每一行对应其图片位置及其真实框的位置。 | ||
7、在训练前需要修改model_data里面的voc_classes.txt文件,需要将classes改成你自己的classes。 | ||
8、运行train.py即可开始训练。 | ||
|
||
### mAP目标检测精度计算更新 | ||
更新了get_gt_txt.py、get_dr_txt.py和get_map.py文件。 | ||
get_map文件克隆自https://github.com/Cartucho/mAP | ||
具体mAP计算过程可参考:https://www.bilibili.com/video/BV1zE411u7Vw | ||
|
||
### Reference | ||
https://github.com/qqwweee/keras-yolo3/ | ||
https://github.com/Cartucho/mAP | ||
https://github.com/Ma-Dan/keras-yolo4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
存放标签文件 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
存放训练索引文件 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
存放图片文件 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import os | ||
import random | ||
|
||
xmlfilepath=r'./VOCdevkit/VOC2007/Annotations' | ||
saveBasePath=r"./VOCdevkit/VOC2007/ImageSets/Main/" | ||
|
||
trainval_percent=1 | ||
train_percent=1 | ||
|
||
temp_xml = os.listdir(xmlfilepath) | ||
total_xml = [] | ||
for xml in temp_xml: | ||
if xml.endswith(".xml"): | ||
total_xml.append(xml) | ||
|
||
num=len(total_xml) | ||
list=range(num) | ||
tv=int(num*trainval_percent) | ||
tr=int(tv*train_percent) | ||
trainval= random.sample(list,tv) | ||
train=random.sample(trainval,tr) | ||
|
||
print("train and val size",tv) | ||
print("traub suze",tr) | ||
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w') | ||
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w') | ||
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w') | ||
fval = open(os.path.join(saveBasePath,'val.txt'), 'w') | ||
|
||
for i in list: | ||
name=total_xml[i][:-4]+'\n' | ||
if i in trainval: | ||
ftrainval.write(name) | ||
if i in train: | ||
ftrain.write(name) | ||
else: | ||
fval.write(name) | ||
else: | ||
ftest.write(name) | ||
|
||
ftrainval.close() | ||
ftrain.close() | ||
fval.close() | ||
ftest .close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import torch | ||
import math | ||
import numpy as np | ||
def box_ciou(b1, b2): | ||
""" | ||
输入为: | ||
---------- | ||
b1: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh | ||
b2: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh | ||
返回为: | ||
------- | ||
ciou: tensor, shape=(batch, feat_w, feat_h, anchor_num, 1) | ||
""" | ||
# 求出预测框左上角右下角 | ||
b1_xy = b1[..., :2] | ||
b1_wh = b1[..., 2:4] | ||
b1_wh_half = b1_wh/2. | ||
b1_mins = b1_xy - b1_wh_half | ||
b1_maxes = b1_xy + b1_wh_half | ||
# 求出真实框左上角右下角 | ||
b2_xy = b2[..., :2] | ||
b2_wh = b2[..., 2:4] | ||
b2_wh_half = b2_wh/2. | ||
b2_mins = b2_xy - b2_wh_half | ||
b2_maxes = b2_xy + b2_wh_half | ||
|
||
# 求真实框和预测框所有的iou | ||
intersect_mins = torch.max(b1_mins, b2_mins) | ||
intersect_maxes = torch.min(b1_maxes, b2_maxes) | ||
intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes)) | ||
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] | ||
b1_area = b1_wh[..., 0] * b1_wh[..., 1] | ||
b2_area = b2_wh[..., 0] * b2_wh[..., 1] | ||
union_area = b1_area + b2_area - intersect_area | ||
iou = intersect_area / (union_area + 1e-7) | ||
|
||
# 计算中心的差距 | ||
center_distance = torch.sum(torch.pow((b1_xy - b2_xy), 2), axis=-1) | ||
# 找到包裹两个框的最小框的左上角和右下角 | ||
enclose_mins = torch.min(b1_mins, b2_mins) | ||
enclose_maxes = torch.max(b1_maxes, b2_maxes) | ||
enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes)) | ||
# 计算对角线距离 | ||
enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1) | ||
ciou = iou - 1.0 * (center_distance) / (enclose_diagonal + 1e-7) | ||
|
||
v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(b1_wh[..., 0]/b1_wh[..., 1]) - torch.atan(b2_wh[..., 0]/b2_wh[..., 1])), 2) | ||
alpha = v / (1.0 - iou + v) | ||
ciou = ciou - alpha * v | ||
return ciou | ||
|
||
box1 = torch.from_numpy(np.array([[25,25,40,40]])).type(torch.FloatTensor) | ||
box2 = torch.from_numpy(np.array([[25,25,30,40]])).type(torch.FloatTensor) | ||
|
||
print(box_ciou(box1,box2)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
#-------------------------------------# | ||
# mAP所需文件计算代码 | ||
# 具体教程请查看Bilibili | ||
# Bubbliiiing | ||
#-------------------------------------# | ||
import colorsys | ||
import os | ||
|
||
import cv2 | ||
import numpy as np | ||
import torch | ||
import torch.backends.cudnn as cudnn | ||
import torch.nn as nn | ||
from PIL import Image, ImageDraw, ImageFont | ||
from torch.autograd import Variable | ||
from tqdm import tqdm | ||
|
||
from nets.yolo4 import YoloBody | ||
from utils.utils import (DecodeBox, bbox_iou, letterbox_image, | ||
non_max_suppression, yolo_correct_boxes) | ||
from yolo import YOLO | ||
|
||
|
||
class mAP_Yolo(YOLO): | ||
#---------------------------------------------------# | ||
# 检测图片 | ||
#---------------------------------------------------# | ||
def detect_image(self,image_id,image): | ||
self.confidence = 0.01 | ||
self.iou = 0.5 | ||
f = open("./input/detection-results/"+image_id+".txt","w") | ||
image_shape = np.array(np.shape(image)[0:2]) | ||
|
||
crop_img = np.array(letterbox_image(image, (self.model_image_size[1],self.model_image_size[0]))) | ||
photo = np.array(crop_img,dtype = np.float32) | ||
photo /= 255.0 | ||
photo = np.transpose(photo, (2, 0, 1)) | ||
photo = photo.astype(np.float32) | ||
images = [] | ||
images.append(photo) | ||
images = np.asarray(images) | ||
|
||
with torch.no_grad(): | ||
images = torch.from_numpy(images) | ||
if self.cuda: | ||
images = images.cuda() | ||
outputs = self.net(images) | ||
|
||
output_list = [] | ||
for i in range(3): | ||
output_list.append(self.yolo_decodes[i](outputs[i])) | ||
output = torch.cat(output_list, 1) | ||
batch_detections = non_max_suppression(output, len(self.class_names), | ||
conf_thres=self.confidence, | ||
nms_thres=self.iou) | ||
|
||
try: | ||
batch_detections = batch_detections[0].cpu().numpy() | ||
except: | ||
return image | ||
|
||
top_index = batch_detections[:,4]*batch_detections[:,5] > self.confidence | ||
top_conf = batch_detections[top_index,4]*batch_detections[top_index,5] | ||
top_label = np.array(batch_detections[top_index,-1],np.int32) | ||
top_bboxes = np.array(batch_detections[top_index,:4]) | ||
top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:,0],-1),np.expand_dims(top_bboxes[:,1],-1),np.expand_dims(top_bboxes[:,2],-1),np.expand_dims(top_bboxes[:,3],-1) | ||
|
||
# 去掉灰条 | ||
boxes = yolo_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.model_image_size[0],self.model_image_size[1]]),image_shape) | ||
|
||
for i, c in enumerate(top_label): | ||
predicted_class = self.class_names[c] | ||
score = str(top_conf[i]) | ||
|
||
top, left, bottom, right = boxes[i] | ||
f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom)))) | ||
|
||
f.close() | ||
return | ||
|
||
yolo = mAP_Yolo() | ||
image_ids = open('VOCdevkit/VOC2007/ImageSets/Main/test.txt').read().strip().split() | ||
|
||
if not os.path.exists("./input"): | ||
os.makedirs("./input") | ||
if not os.path.exists("./input/detection-results"): | ||
os.makedirs("./input/detection-results") | ||
if not os.path.exists("./input/images-optional"): | ||
os.makedirs("./input/images-optional") | ||
|
||
for image_id in tqdm(image_ids): | ||
image_path = "./VOCdevkit/VOC2007/JPEGImages/"+image_id+".jpg" | ||
image = Image.open(image_path) | ||
# 开启后在之后计算mAP可以可视化 | ||
# image.save("./input/images-optional/"+image_id+".jpg") | ||
yolo.detect_image(image_id,image) | ||
|
||
print("Conversion completed!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#----------------------------------------------------# | ||
# 获取测试集的ground-truth | ||
# 具体视频教程可查看 | ||
# https://www.bilibili.com/video/BV1zE411u7Vw | ||
#----------------------------------------------------# | ||
import glob | ||
import os | ||
import sys | ||
import xml.etree.ElementTree as ET | ||
|
||
image_ids = open('VOCdevkit/VOC2007/ImageSets/Main/test.txt').read().strip().split() | ||
|
||
if not os.path.exists("./input"): | ||
os.makedirs("./input") | ||
if not os.path.exists("./input/ground-truth"): | ||
os.makedirs("./input/ground-truth") | ||
|
||
for image_id in image_ids: | ||
with open("./input/ground-truth/"+image_id+".txt", "w") as new_f: | ||
root = ET.parse("VOCdevkit/VOC2007/Annotations/"+image_id+".xml").getroot() | ||
for obj in root.findall('object'): | ||
difficult_flag = False | ||
if obj.find('difficult')!=None: | ||
difficult = obj.find('difficult').text | ||
if int(difficult)==1: | ||
difficult_flag = True | ||
obj_name = obj.find('name').text | ||
bndbox = obj.find('bndbox') | ||
left = bndbox.find('xmin').text | ||
top = bndbox.find('ymin').text | ||
right = bndbox.find('xmax').text | ||
bottom = bndbox.find('ymax').text | ||
if difficult_flag: | ||
new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom)) | ||
else: | ||
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) | ||
|
||
print("Conversion completed!") |
Oops, something went wrong.