-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_detector.py
63 lines (53 loc) · 2.06 KB
/
train_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# USAGE
# python train_detector.py --class stop_sign_images --annotations stop_sign_annotations \
# --output output/stop_sign_detector.svm
# import the necessary packages
from __future__ import print_function
from imutils import paths
from scipy.io import loadmat
from skimage import io
import argparse
import dlib
import sys
# handle Python 3 compatibility
if sys.version_info > (3,):
long = int
# construct the argument parse and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-c", "--class", required=True,
help="Path to the CALTECH-101 class images")
ap.add_argument("-a", "--annotations", required=True,
help="Path to the CALTECH-101 class annotations")
ap.add_argument("-o", "--output", required=True,
help="Path to the output detector")
args = vars(ap.parse_args())
# grab the default training options for our HOG + Linear SVM detector initialize the
# list of images and bounding boxes used to train the classifier
print("[INFO] gathering images and bounding boxes...")
options = dlib.simple_object_detector_training_options()
images = []
boxes = []
# loop over the image paths
for imagePath in paths.list_images(args["class"]):
# extract the image ID from the image path and load the annotations file
imageID = imagePath[imagePath.rfind("/") + 1:].split("_")[1]
imageID = imageID.replace(".jpg", "")
p = "{}/annotation_{}.mat".format(args["annotations"], imageID)
annotations = loadmat(p)["box_coord"]
# loop over the annotations and add each annotation to the list of bounding
# boxes
bb = [dlib.rectangle(left=long(x), top=long(y), right=long(w), bottom=long(h))
for (y, h, x, w) in annotations]
boxes.append(bb)
# add the image to the list of images
images.append(io.imread(imagePath))
# train the object detector
print("[INFO] training detector...")
detector = dlib.train_simple_object_detector(images, boxes, options)
# dump the classifier to file
print("[INFO] dumping classifier to file...")
detector.save(args["output"])
# visualize the results of the detector
win = dlib.image_window()
win.set_image(detector)
dlib.hit_enter_to_continue()