-
Notifications
You must be signed in to change notification settings - Fork 0
/
loader.py
313 lines (242 loc) · 10.5 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
from augmenter import get_augmentation_pipeline
from pathlib import Path
from os import sep as file_path_seperator
import numpy as np
import tensorflow as tf
# from imageio import imread
tfds = tf.data.Dataset
MAP_PARALLELISM = tf.data.experimental.AUTOTUNE
def filepath_to_label(fp):
return Path(fp).parent.name
def is_image(fp, extensions=('jpg', 'jpeg')):
return Path(fp).suffix[1:] in extensions
def get_image_filepaths(image_dir, png=False):
"""Get image file paths from directory of subdirectory-labeled images."""
extensions = ('png',) if png else ('jpg', 'jpeg')
image_dir = Path(image_dir)
assert image_dir.exists()
image_filepaths = [str(fp) for fp in image_dir.glob('*/*')
if is_image(fp, extensions)]
assert len(image_filepaths)
return image_filepaths
@tf.function
def load_jpeg(image_path):
img = tf.io.read_file(image_path)
img = tf.io.decode_jpeg(img, channels=3)
return img
@tf.function
def load_png(image_path):
img = tf.io.read_file(image_path)
img = tf.io.decode_png(img, channels=3)
return img
@tf.function
def load_grayscale_jpeg(image_path):
img = tf.io.read_file(image_path)
img = tf.squeeze(tf.io.decode_jpeg(img, channels=1))
return tf.stack([img, img, img], axis=2)
@tf.function
def load_grayscale_png(image_path):
img = tf.io.read_file(image_path)
img = tf.squeeze(tf.io.decode_png(img, channels=1))
return tf.stack([img, img, img], axis=2)
@tf.function
def extract_label(file_path):
return tf.strings.split(file_path, sep=file_path_seperator)
def shape_setter(shape):
"""fix for tf problems when shape confusingly isn't set"""
@tf.function
def shape_setter_func(img):
img.set_shape(shape)
return img
return shape_setter_func
def load(file_paths, augmentation_func=None, size=None, class_names=None,
include_filepaths=False, grayscale=False, png=False, standardize=False):
labels = [filepath_to_label(fp) for fp in file_paths]
if class_names is None:
class_names = list(set(labels))
label_encoder = dict((label, k) for k, label in enumerate(class_names))
encoded_labels = [label_encoder[label] for label in labels]
# just file paths
ds_file_paths = tfds.from_tensor_slices(file_paths)
ds_labels = tfds.from_tensor_slices(encoded_labels)
if grayscale:
load_func = load_grayscale_png if png else load_grayscale_jpeg
else:
load_func = load_png if png else load_jpeg
ds_images = ds_file_paths.map(load_func, num_parallel_calls=MAP_PARALLELISM)
if augmentation_func is not None:
ds_images = ds_images.cache()
ds_images = ds_images.map(
map_func=lambda img: tf.numpy_function(func=augmentation_func,
inp=[img], Tout=[tf.uint8]),
num_parallel_calls=MAP_PARALLELISM
)
# else:
# def generate_augmented_epoch():
# image_generator = map(imread, file_paths)
# return map(augmentation_func, image_generator)
# ds_images = tfds.from_generator(generate_augmented_epoch, tf.float32)
ds_images = ds_images.map(shape_setter([None, None, 3]),
num_parallel_calls=MAP_PARALLELISM)
if size is not None:
@tf.function
def resize(img):
return tf.image.resize(img, list(size))
# from IPython import embed;embed() ### DEBUG
# ds_images = ds_images.map(resize, num_parallel_calls=MAP_PARALLELISM)
ds_images = ds_images.map(lambda img: tf.image.resize(img, list(size)),
num_parallel_calls=MAP_PARALLELISM)
ds_images = ds_images.map(shape_setter(list(size) + [3]),
num_parallel_calls=MAP_PARALLELISM)
# scale pixel values to [0, 1]
# see the common image input conventions
# https://www.tensorflow.org/hub/common_signatures/images#input
ds_images = ds_images.map(
map_func=lambda img: tf.image.convert_image_dtype(img, tf.float32),
num_parallel_calls=MAP_PARALLELISM
)
if standardize:
ds_images = ds_images.map(tf.image.per_image_standardization)
# zip, shuffle, batch, and return
if include_filepaths:
ds = tfds.zip((ds_images, ds_labels, ds_file_paths))
else:
ds = tfds.zip((ds_images, ds_labels))
return ds, class_names
def prepare_data(args):
assert args.test_dir and args.test_part == 0 or 0 < args.test_part < 1
assert args.val_dir and args.val_part == 0 or 0 < args.val_part < 1
assert args.test_part + args.val_part < 1
file_paths = get_image_filepaths(args.image_dir, args.png)
np.random.shuffle(file_paths)
labels = [filepath_to_label(fp) for fp in file_paths]
class_names = list(set(labels))
# split file paths into train and val sets
train_file_paths, val_file_paths, test_file_paths = [], [], []
label_distribution, train_label_distribution = {}, {}
for label in class_names:
label_fps = [fp for fp, l in zip(file_paths, labels) if l == label]
n_val = int(np.ceil(len(label_fps) * args.val_part))
n_test = int(np.ceil(len(label_fps) * args.test_part))
n_train = len(label_fps) - n_val - n_test
train_file_paths += label_fps[:n_train]
if args.val_part > 0:
assert n_val > 0
val_file_paths += label_fps[n_train:n_train + n_val]
if args.test_part > 0:
assert n_test > 0
test_file_paths += label_fps[n_train + n_val:]
# record how many examples there are of each label
label_distribution[label] = len(label_fps)
train_label_distribution[label] = n_train
if args.val_dir is not None:
val_file_paths = get_image_filepaths(args.val_dir, args.png)
if args.test_dir is not None:
test_file_paths = get_image_filepaths(args.test_dir, args.png)
np.random.shuffle(train_file_paths)
np.random.shuffle(val_file_paths)
np.random.shuffle(test_file_paths)
ds_train, train_class_names = load(
file_paths=train_file_paths,
augmentation_func=get_augmentation_pipeline(args),
size=args.image_dimensions,
grayscale=args.grayscale,
png=args.png,
standardize=args.standardize,
)
ds_val, val_class_names = load(
file_paths=val_file_paths,
augmentation_func=None,
size=args.image_dimensions,
grayscale=args.grayscale,
png=args.png,
standardize=args.standardize,
)
ds_test, test_class_names = load(
file_paths=test_file_paths,
augmentation_func=None,
size=args.image_dimensions,
grayscale=args.grayscale,
png=args.png,
standardize=args.standardize,
)
assert set(test_class_names) == set(val_class_names) == \
set(train_class_names) == set(class_names)
ds_train = ds_train.shuffle(
buffer_size=min(10 * args.batch_size, len(train_file_paths)))
ds_train = ds_train.batch(args.batch_size)
ds_val = ds_val.batch(args.batch_size)
ds_test = ds_test.batch(args.batch_size)
def optimize(ds):
options = tf.data.Options()
options.experimental_threading.max_intra_op_parallelism = 1
ds = ds.with_options(options)
return ds
# ds_train = ds_train.optimize()
ds_val = ds_val.cache()
def count_batches(file_paths):
return int(np.ceil(len(file_paths)/args.batch_size))
# prefetch training and val sets, do not prefetch test set
# ds_train = ds_train.prefetch(count_batches(train_file_paths))
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)
ds_val = ds_val.prefetch(count_batches(val_file_paths))
return ds_train, ds_val, ds_test, class_names, train_label_distribution
def load_test(args):
from tempfile import gettempdir
import numpy as np
from augmenter import augment
from os import system as system_call
temp_image_file_path = str(Path(gettempdir(), 'temp_image.jpg'))
file_paths = get_image_filepaths(args.image_dir, args.png)
np.random.shuffle(file_paths)
ds, class_names = load(
file_paths=file_paths,
augmentation_func=augment,
size=(100, 100),
include_filepaths=True,
grayscale=args.grayscale,
png=args.png,
)
ds = ds.shuffle(buffer_size=min(10 * args.batch_size, len(file_paths)))
for augmented_image, label, original_path in ds:
print(f"label: {class_names[label]}\n"
f"file path: {original_path}\n" + \
"pixel range: {} - {}".format(tf.reduce_min(augmented_image),
tf.reduce_max(augmented_image)))
original_image = load_jpeg(original_path)
original_image = tf.cast(original_image, tf.uint8)
augmented_image = tf.cast(augmented_image, tf.uint8)
w = max(original_image.shape[1], augmented_image.shape[1])
h = max(original_image.shape[0], augmented_image.shape[0])
original_image = tf.image.resize_with_pad(original_image, h, w)
augmented_image = tf.image.resize_with_pad(augmented_image, h, w)
# write a side-by-side image comparison to disk
side_by_side = tf.concat([original_image, augmented_image], axis=1)
side_by_side = tf.io.encode_jpeg(tf.cast(side_by_side, tf.uint8))
tf.io.write_file(temp_image_file_path, side_by_side)
system_call(f'open {temp_image_file_path}')
user_says = input("Press enter to see next image (q to quit).")
if user_says.strip() == 'q':
return
def benchmark_input(args):
from tensorflow_datasets.core import benchmark
ds_train, ds_val, ds_test, _, _ = prepare_data(args)
def report_stats(ds, report_title='Benchmark Statistics',
num_iter=None, batch_size=args.batch_size):
stats = benchmark(ds, num_iter=num_iter, batch_size=batch_size)
print(f"\n{report_title}\n{'-'*len(report_title)}")
for k, v in stats.items():
if isinstance(v, dict):
print(f'{k}:')
for kk, vv in v.items():
print(f'\t{kk}: {vv}')
else:
print(f'{k}: {v}')
return stats
report_stats(ds=ds_train, report_title="Train Statistics")
# report_stats(ds=ds_val, report_title="Val Statistics")
# report_stats(ds=ds_test, report_title="Test Statistics")
print()
if __name__ == '__main__':
print("\nTo run load_test or benchmark use main.py and --test_load "
"or --benchmark_input respectively.\n")