-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathdataset.py
86 lines (68 loc) · 2.15 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import sys
import os
from collections import defaultdict
import numpy as np
import scipy.misc
def preprocess_input(x0):
x = x0 / 255.
x -= 0.5
x *= 2.
return x
def reverse_preprocess_input(x0):
x = x0 / 2.0
x += 0.5
x *= 255.
return x
def dataset(base_dir, n):
d = defaultdict(list)
for root, subdirs, files in os.walk(base_dir):
for filename in files:
file_path = os.path.join(root, filename)
assert file_path.startswith(base_dir)
suffix = file_path[len(base_dir):]
suffix = suffix.lstrip("/")
label = suffix.split("/")[0]
d[label].append(file_path)
tags = sorted(d.keys())
processed_image_count = 0
useful_image_count = 0
X = []
y = []
for class_index, class_name in enumerate(tags):
filenames = d[class_name]
for filename in filenames:
processed_image_count += 1
img = scipy.misc.imread(filename)
height, width, chan = img.shape
assert chan == 3
aspect_ratio = float(max((height, width))) / min((height, width))
if aspect_ratio > 2:
continue
# We pick the largest center square.
centery = height // 2
centerx = width // 2
radius = min((centerx, centery))
img = img[centery-radius:centery+radius, centerx-radius:centerx+radius]
img = scipy.misc.imresize(img, size=(n, n), interp='bilinear')
X.append(img)
y.append(class_index)
useful_image_count += 1
print "processed %d, used %d" % (processed_image_count, useful_image_count)
X = np.array(X).astype(np.float32)
X = X.transpose((0, 3, 1, 2))
X = preprocess_input(X)
y = np.array(y)
perm = np.random.permutation(len(y))
X = X[perm]
y = y[perm]
print "classes:"
for class_index, class_name in enumerate(tags):
print class_name, sum(y==class_index)
print
return X, y, tags
def main():
in_prefix, n = sys.argv[1:]
X, y, tags = dataset(sys.stdin, in_prefix, n)
print X.shape
if __name__ == "__main__":
main()