-
Notifications
You must be signed in to change notification settings - Fork 66
/
compute_features.py
executable file
·131 lines (112 loc) · 4.46 KB
/
compute_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
'''
File that computes features for a set of images
ex. python compute_features.py --data_dir=/mnt/images/ --model=vgg19 --model_path=./vgg_19.ckpt
'''
import scipy.misc as misc
import cPickle as pickle
import tensorflow as tf
from tqdm import tqdm
import numpy as np
import argparse
import fnmatch
import sys
import os
sys.path.insert(0, 'nets/')
slim = tf.contrib.slim
'''
Recursively obtains all images in the directory specified
'''
def getPaths(data_dir):
image_paths = []
# add more extensions if need be
ps = ['jpg', 'jpeg', 'JPG', 'JPEG', 'bmp', 'BMP', 'png', 'PNG']
for p in ps:
pattern = '*.'+p
for d, s, fList in os.walk(data_dir):
for filename in fList:
if fnmatch.fnmatch(filename, pattern):
fname_ = os.path.join(d,filename)
image_paths.append(fname_)
return image_paths
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', required=True, type=str, help='Directory images are in. Searches recursively.')
parser.add_argument('--model', required=True, type=str, help='Model to use')
parser.add_argument('--checkpoint_file', required=True, type=str, help='Model file')
a = parser.parse_args()
data_dir = a.data_dir
model = a.model
checkpoint_file = a.checkpoint_file
# I only have these because I thought some take in size of (299,299), but maybe not
if 'inception' in model: height, width, channels = 224, 224, 3
if 'resnet' in model: height, width, channels = 224, 224, 3
if 'vgg' in model: height, width, channels = 224, 224, 3
if model == 'inception_resnet_v2': height, width, channels = 299, 299, 3
x = tf.placeholder(tf.float32, shape=(1, height, width, channels))
# load up model specific stuff
if model == 'inception_v1':
from inception_v1 import *
arg_scope = inception_v1_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = inception_v1(x, is_training=False, num_classes=1001)
features = end_points['AvgPool_0a_7x7']
elif model == 'inception_v2':
from inception_v2 import *
arg_scope = inception_v2_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = inception_v2(x, is_training=False, num_classes=1001)
features = end_points['AvgPool_1a']
elif model == 'inception_v3':
from inception_v3 import *
arg_scope = inception_v3_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = inception_v3(x, is_training=False, num_classes=1001)
features = end_points['AvgPool_1a']
elif model == 'inception_resnet_v2':
from inception_resnet_v2 import *
arg_scope = inception_resnet_v2_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = inception_resnet_v2(x, is_training=False, num_classes=1001)
features = end_points['PreLogitsFlatten']
elif model == 'resnet_v1_50':
from resnet_v1 import *
arg_scope = resnet_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = resnet_v1_50(x, is_training=False, num_classes=1000)
features = end_points['global_pool']
elif model == 'resnet_v1_101':
from resnet_v1 import *
arg_scope = resnet_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = resnet_v1_101(x, is_training=False, num_classes=1000)
features = end_points['global_pool']
elif model == 'vgg_16':
from vgg import *
arg_scope = vgg_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = vgg_16(x, is_training=False)
features = end_points['vgg_16/fc8']
elif model == 'vgg_19':
from vgg import *
arg_scope = vgg_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = vgg_19(x, is_training=False)
features = end_points['vgg_19/fc8']
sess = tf.Session()
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
feat_dict = {}
paths = getPaths(data_dir)
print 'Computing features...'
for path in tqdm(paths):
image = misc.imread(path)
image = misc.imresize(image, (height, width))
image = np.expand_dims(image, 0)
feat = np.squeeze(sess.run(features, feed_dict={x:image}))
feat_dict[path] = feat
try: os.makedirs('features/')
except: pass
exp_pkl = open('features/'+model+'_features.pkl', 'wb')
data = pickle.dumps(feat_dict)
exp_pkl.write(data)
exp_pkl.close()