-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict_gender.py
168 lines (151 loc) · 6.66 KB
/
predict_gender.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import cv2
import numpy as np
# The gender model architecture
# https://drive.google.com/open?id=1W_moLzMlGiELyPxWiYQJ9KFaXroQ_NFQ
GENDER_MODEL = 'weights/deploy_gender.prototxt'
# The gender model pre-trained weights
# https://drive.google.com/open?id=1AW3WduLk1haTVAxHOkVS_BEzel1WXQHP
GENDER_PROTO = 'weights/gender_net.caffemodel'
# Each Caffe Model impose the shape of the input image also image preprocessing is required like mean
# substraction to eliminate the effect of illumination changes
MODEL_MEAN_VALUES = (78.4263377603, 87.7689143744, 114.895847746)
# Represent the gender classes
GENDER_LIST = ['Male', 'Female']
# https://raw.githubusercontent.com/opencv/opencv/master/samples/dnn/face_detector/deploy.prototxt
FACE_PROTO = "weights/deploy.prototxt.txt"
# https://raw.githubusercontent.com/opencv/opencv_3rdparty/dnn_samples_face_detector_20180205_fp16/res10_300x300_ssd_iter_140000_fp16.caffemodel
FACE_MODEL = "weights/res10_300x300_ssd_iter_140000_fp16.caffemodel"
# cv2.dnn.readNetFromCaffe : Reads a network model stored in Caffe model in memory
# load face Caffe model
face_net = cv2.dnn.readNetFromCaffe(FACE_PROTO, FACE_MODEL)
# Load gender prediction model
gender_net = cv2.dnn.readNetFromCaffe(GENDER_MODEL, GENDER_PROTO)
def get_faces(frame, confidence_threshold=0.5):
'''
This function detects faces in an image
'''
# convert the frame into a blob to be ready for NN input
blob = cv2.dnn.blobFromImage(frame, 1.0, (300, 300), (104, 177.0, 123.0))
# set the image as input to the NN
face_net.setInput(blob)
# perform inference and get predictions
output = np.squeeze(face_net.forward())
# initialize the result list
faces = []
# Loop over the faces detected
for i in range(output.shape[0]):
confidence = output[i, 2]
if confidence > confidence_threshold:
box = output[i, 3:7] * \
np.array([frame.shape[1], frame.shape[0],
frame.shape[1], frame.shape[0]])
# convert to integers
start_x, start_y, end_x, end_y = box.astype(np.int)
# widen the box a little
start_x, start_y, end_x, end_y = start_x - \
10, start_y - 10, end_x + 10, end_y + 10
start_x = 0 if start_x < 0 else start_x
start_y = 0 if start_y < 0 else start_y
end_x = 0 if end_x < 0 else end_x
end_y = 0 if end_y < 0 else end_y
# append to our list
faces.append((start_x, start_y, end_x, end_y))
return faces
def display_img(title, img):
'''
Displays an image on screen and maintains the output until the user presses a key
'''
# Display Image on screen
cv2.imshow(title, img)
# Mantain output until user presses a key
cv2.waitKey(0)
# Destroy windows when user presses a key
cv2.destroyAllWindows()
def get_optimal_font_scale(text, width):
'''
Determine the optimal font scale based on the hosting frame width
'''
for scale in reversed(range(0, 60, 1)):
textSize = cv2.getTextSize(text, fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=scale/10, thickness=1)
new_width = textSize[0][0]
if (new_width <= width):
return scale/10
return 1
# from: https://stackoverflow.com/questions/44650888/resize-an-image-without-distortion-opencv
def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA):
# initialize the dimensions of the image to be resized and
# grab the image size
dim = None
(h, w) = image.shape[:2]
# if both the width and height are None, then return the
# original image
if width is None and height is None:
return image
# check to see if the width is None
if width is None:
# calculate the ratio of the height and construct the
# dimensions
r = height / float(h)
dim = (int(w * r), height)
# otherwise, the height is None
else:
# calculate the ratio of the width and construct the
# dimensions
r = width / float(w)
dim = (width, int(h * r))
# resize the image
return cv2.resize(image, dim, interpolation = inter)
def predict_gender(input_path: str):
'''
Predict the gender of the faces showing in the image
'''
# Read Input Image
img = cv2.imread(input_path)
# resize the image, uncomment if you want to resize the image
# img = cv2.resize(img, (frame_width, frame_height))
# Take a copy of the initial image and resize it
frame = img.copy()
#if frame.shape[1] > frame_width:
# frame = image_resize(frame, width=frame_width)
# predict the faces
faces = get_faces(frame)
# Loop over the faces detected
# for idx, face in enumerate(faces):
for i, (start_x, start_y, end_x, end_y) in enumerate(faces):
face_img = frame[start_y: end_y, start_x: end_x]
# image --> Input image to preprocess before passing it through our dnn for classification.
# scale factor = After performing mean substraction we can optionally scale the image by some factor. (if 1 -> no scaling)
# size = The spatial size that the CNN expects. Options are = (224*224, 227*227 or 299*299)
# mean = mean substraction values to be substracted from every channel of the image.
# swapRB=OpenCV assumes images in BGR whereas the mean is supplied in RGB. To resolve this we set swapRB to True.
blob = cv2.dnn.blobFromImage(image=face_img, scalefactor=1.0, size=(
227, 227), mean=MODEL_MEAN_VALUES, swapRB=False, crop=False)
# Predict Gender
gender_net.setInput(blob)
gender_preds = gender_net.forward()
i = gender_preds[0].argmax()
gender = GENDER_LIST[i]
gender_confidence_score = gender_preds[0][i]
# Draw the box
label = "{}-{:.2f}%".format(gender, gender_confidence_score*100)
print(label)
yPos = start_y - 15
while yPos < 15:
yPos += 15
# get the font scale for this image size
optimal_font_scale = get_optimal_font_scale(label,((end_x-start_x)+25))
box_color = (255, 0, 0) if gender == "Male" else (147, 20, 255)
cv2.rectangle(frame, (start_x, start_y), (end_x, end_y), box_color, 2)
# Label processed image
cv2.putText(frame, label, (start_x, yPos),
cv2.FONT_HERSHEY_SIMPLEX, optimal_font_scale, box_color, 2)
# Display processed image
display_img("Gender Estimator", frame)
# uncomment if you want to save the image
# cv2.imwrite("output.jpg", frame)
# Cleanup
cv2.destroyAllWindows()
if __name__ == '__main__':
# Parsing command line arguments entered by user
import sys
predict_gender(sys.argv[1])