forked from google-research/simclr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_util.py
519 lines (435 loc) · 17.6 KB
/
data_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
# coding=utf-8
# Copyright 2020 The SimCLR Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific simclr governing permissions and
# limitations under the License.
# ==============================================================================
"""Data preprocessing and augmentation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl import flags
import tensorflow.compat.v1 as tf
FLAGS = flags.FLAGS
CROP_PROPORTION = 0.875 # Standard for ImageNet.
def random_apply(func, p, x):
"""Randomly apply function func to x with probability p."""
return tf.cond(
tf.less(tf.random_uniform([], minval=0, maxval=1, dtype=tf.float32),
tf.cast(p, tf.float32)),
lambda: func(x),
lambda: x)
def random_brightness(image, max_delta, impl='simclrv2'):
"""A multiplicative vs additive change of brightness."""
if impl == 'simclrv2':
factor = tf.random_uniform(
[], tf.maximum(1.0 - max_delta, 0), 1.0 + max_delta)
image = image * factor
elif impl == 'simclrv1':
image = tf.image.random_brightness(image, max_delta=max_delta)
else:
raise ValueError('Unknown impl {} for random brightness.'.format(impl))
return image
def to_grayscale(image, keep_channels=True):
image = tf.image.rgb_to_grayscale(image)
if keep_channels:
image = tf.tile(image, [1, 1, 3])
return image
def color_jitter(image, strength, random_order=True, impl='simclrv2'):
"""Distorts the color of the image.
Args:
image: The input image tensor.
strength: the floating number for the strength of the color augmentation.
random_order: A bool, specifying whether to randomize the jittering order.
impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
version of random brightness.
Returns:
The distorted image tensor.
"""
brightness = 0.8 * strength
contrast = 0.8 * strength
saturation = 0.8 * strength
hue = 0.2 * strength
if random_order:
return color_jitter_rand(
image, brightness, contrast, saturation, hue, impl=impl)
else:
return color_jitter_nonrand(
image, brightness, contrast, saturation, hue, impl=impl)
def color_jitter_nonrand(image,
brightness=0,
contrast=0,
saturation=0,
hue=0,
impl='simclrv2'):
"""Distorts the color of the image (jittering order is fixed).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
version of random brightness.
Returns:
The distorted image tensor.
"""
with tf.name_scope('distort_color'):
def apply_transform(i, x, brightness, contrast, saturation, hue):
"""Apply the i-th transformation."""
if brightness != 0 and i == 0:
x = random_brightness(x, max_delta=brightness, impl=impl)
elif contrast != 0 and i == 1:
x = tf.image.random_contrast(
x, lower=1-contrast, upper=1+contrast)
elif saturation != 0 and i == 2:
x = tf.image.random_saturation(
x, lower=1-saturation, upper=1+saturation)
elif hue != 0:
x = tf.image.random_hue(x, max_delta=hue)
return x
for i in range(4):
image = apply_transform(i, image, brightness, contrast, saturation, hue)
image = tf.clip_by_value(image, 0., 1.)
return image
def color_jitter_rand(image,
brightness=0,
contrast=0,
saturation=0,
hue=0,
impl='simclrv2'):
"""Distorts the color of the image (jittering order is random).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
version of random brightness.
Returns:
The distorted image tensor.
"""
with tf.name_scope('distort_color'):
def apply_transform(i, x):
"""Apply the i-th transformation."""
def brightness_foo():
if brightness == 0:
return x
else:
return random_brightness(x, max_delta=brightness, impl=impl)
def contrast_foo():
if contrast == 0:
return x
else:
return tf.image.random_contrast(x, lower=1-contrast, upper=1+contrast)
def saturation_foo():
if saturation == 0:
return x
else:
return tf.image.random_saturation(
x, lower=1-saturation, upper=1+saturation)
def hue_foo():
if hue == 0:
return x
else:
return tf.image.random_hue(x, max_delta=hue)
x = tf.cond(tf.less(i, 2),
lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo),
lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo))
return x
perm = tf.random_shuffle(tf.range(4))
for i in range(4):
image = apply_transform(perm[i], image)
image = tf.clip_by_value(image, 0., 1.)
return image
def _compute_crop_shape(
image_height, image_width, aspect_ratio, crop_proportion):
"""Compute aspect ratio-preserving shape for central crop.
The resulting shape retains `crop_proportion` along one side and a proportion
less than or equal to `crop_proportion` along the other side.
Args:
image_height: Height of image to be cropped.
image_width: Width of image to be cropped.
aspect_ratio: Desired aspect ratio (width / height) of output.
crop_proportion: Proportion of image to retain along the less-cropped side.
Returns:
crop_height: Height of image after cropping.
crop_width: Width of image after cropping.
"""
image_width_float = tf.cast(image_width, tf.float32)
image_height_float = tf.cast(image_height, tf.float32)
def _requested_aspect_ratio_wider_than_image():
crop_height = tf.cast(tf.rint(
crop_proportion / aspect_ratio * image_width_float), tf.int32)
crop_width = tf.cast(tf.rint(
crop_proportion * image_width_float), tf.int32)
return crop_height, crop_width
def _image_wider_than_requested_aspect_ratio():
crop_height = tf.cast(
tf.rint(crop_proportion * image_height_float), tf.int32)
crop_width = tf.cast(tf.rint(
crop_proportion * aspect_ratio *
image_height_float), tf.int32)
return crop_height, crop_width
return tf.cond(
aspect_ratio > image_width_float / image_height_float,
_requested_aspect_ratio_wider_than_image,
_image_wider_than_requested_aspect_ratio)
def center_crop(image, height, width, crop_proportion):
"""Crops to center of image and rescales to desired size.
Args:
image: Image Tensor to crop.
height: Height of image to be cropped.
width: Width of image to be cropped.
crop_proportion: Proportion of image to retain along the less-cropped side.
Returns:
A `height` x `width` x channels Tensor holding a central crop of `image`.
"""
shape = tf.shape(image)
image_height = shape[0]
image_width = shape[1]
crop_height, crop_width = _compute_crop_shape(
image_height, image_width, width / height, crop_proportion)
offset_height = ((image_height - crop_height) + 1) // 2
offset_width = ((image_width - crop_width) + 1) // 2
image = tf.image.crop_to_bounding_box(
image, offset_height, offset_width, crop_height, crop_width)
image = tf.image.resize_bicubic([image], [height, width])[0]
return image
def distorted_bounding_box_crop(image,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(0.75, 1.33),
area_range=(0.05, 1.0),
max_attempts=100,
scope=None):
"""Generates cropped_image using one of the bboxes randomly distorted.
See `tf.image.sample_distorted_bounding_box` for more documentation.
Args:
image: `Tensor` of image data.
bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
where each coordinate is [0, 1) and the coordinates are arranged
as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
image.
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
area of the image must contain at least this fraction of any bounding
box supplied.
aspect_ratio_range: An optional list of `float`s. The cropped area of the
image must have an aspect ratio = width / height within this range.
area_range: An optional list of `float`s. The cropped area of the image
must contain a fraction of the supplied image within in this range.
max_attempts: An optional `int`. Number of attempts at generating a cropped
region of the image of the specified constraints. After `max_attempts`
failures, return the entire image.
scope: Optional `str` for name scope.
Returns:
(cropped image `Tensor`, distorted bbox `Tensor`).
"""
with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
shape = tf.shape(image)
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
shape,
bounding_boxes=bbox,
min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range,
max_attempts=max_attempts,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
# Crop the image to the specified bounding box.
offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
image = tf.image.crop_to_bounding_box(
image, offset_y, offset_x, target_height, target_width)
return image
def crop_and_resize(image, height, width):
"""Make a random crop and resize it to height `height` and width `width`.
Args:
image: Tensor representing the image.
height: Desired image height.
width: Desired image width.
Returns:
A `height` x `width` x channels Tensor holding a random crop of `image`.
"""
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
aspect_ratio = width / height
image = distorted_bounding_box_crop(
image,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(3. / 4 * aspect_ratio, 4. / 3. * aspect_ratio),
area_range=(0.08, 1.0),
max_attempts=100,
scope=None)
return tf.image.resize_bicubic([image], [height, width])[0]
def gaussian_blur(image, kernel_size, sigma, padding='SAME'):
"""Blurs the given image with separable convolution.
Args:
image: Tensor of shape [height, width, channels] and dtype float to blur.
kernel_size: Integer Tensor for the size of the blur kernel. This is should
be an odd number. If it is an even number, the actual kernel size will be
size + 1.
sigma: Sigma value for gaussian operator.
padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
Returns:
A Tensor representing the blurred image.
"""
radius = tf.to_int32(kernel_size / 2)
kernel_size = radius * 2 + 1
x = tf.to_float(tf.range(-radius, radius + 1))
blur_filter = tf.exp(
-tf.pow(x, 2.0) / (2.0 * tf.pow(tf.to_float(sigma), 2.0)))
blur_filter /= tf.reduce_sum(blur_filter)
# One vertical and one horizontal filter.
blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
num_channels = tf.shape(image)[-1]
blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
expand_batch_dim = image.shape.ndims == 3
if expand_batch_dim:
# Tensorflow requires batched input to convolutions, which we can fake with
# an extra dimension.
image = tf.expand_dims(image, axis=0)
blurred = tf.nn.depthwise_conv2d(
image, blur_h, strides=[1, 1, 1, 1], padding=padding)
blurred = tf.nn.depthwise_conv2d(
blurred, blur_v, strides=[1, 1, 1, 1], padding=padding)
if expand_batch_dim:
blurred = tf.squeeze(blurred, axis=0)
return blurred
def random_crop_with_resize(image, height, width, p=1.0):
"""Randomly crop and resize an image.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
p: Probability of applying this transformation.
Returns:
A preprocessed image `Tensor`.
"""
def _transform(image): # pylint: disable=missing-docstring
image = crop_and_resize(image, height, width)
return image
return random_apply(_transform, p=p, x=image)
def random_color_jitter(image, p=1.0, impl='simclrv2'):
def _transform(image):
color_jitter_t = functools.partial(
color_jitter, strength=FLAGS.color_jitter_strength, impl=impl)
image = random_apply(color_jitter_t, p=0.8, x=image)
return random_apply(to_grayscale, p=0.2, x=image)
return random_apply(_transform, p=p, x=image)
def random_blur(image, height, width, p=1.0):
"""Randomly blur an image.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
p: probability of applying this transformation.
Returns:
A preprocessed image `Tensor`.
"""
del width
def _transform(image):
sigma = tf.random.uniform([], 0.1, 2.0, dtype=tf.float32)
return gaussian_blur(
image, kernel_size=height//10, sigma=sigma, padding='SAME')
return random_apply(_transform, p=p, x=image)
def batch_random_blur(images_list, height, width, blur_probability=0.5):
"""Apply efficient batch data transformations.
Args:
images_list: a list of image tensors.
height: the height of image.
width: the width of image.
blur_probability: the probaility to apply the blur operator.
Returns:
Preprocessed feature list.
"""
def generate_selector(p, bsz):
shape = [bsz, 1, 1, 1]
selector = tf.cast(
tf.less(tf.random_uniform(shape, 0, 1, dtype=tf.float32), p),
tf.float32)
return selector
new_images_list = []
for images in images_list:
images_new = random_blur(images, height, width, p=1.)
selector = generate_selector(blur_probability, tf.shape(images)[0])
images = images_new * selector + images * (1 - selector)
images = tf.clip_by_value(images, 0., 1.)
new_images_list.append(images)
return new_images_list
def preprocess_for_train(image,
height,
width,
color_distort=True,
crop=True,
flip=True,
impl='simclrv2'):
"""Preprocesses the given image for training.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
color_distort: Whether to apply the color distortion.
crop: Whether to crop the image.
flip: Whether or not to flip left and right of an image.
impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
version of random brightness.
Returns:
A preprocessed image `Tensor`.
"""
if crop:
image = random_crop_with_resize(image, height, width)
if flip:
image = tf.image.random_flip_left_right(image)
if color_distort:
image = random_color_jitter(image, impl=impl)
image = tf.reshape(image, [height, width, 3])
image = tf.clip_by_value(image, 0., 1.)
return image
def preprocess_for_eval(image, height, width, crop=True):
"""Preprocesses the given image for evaluation.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
crop: Whether or not to (center) crop the test images.
Returns:
A preprocessed image `Tensor`.
"""
if crop:
image = center_crop(image, height, width, crop_proportion=CROP_PROPORTION)
image = tf.reshape(image, [height, width, 3])
image = tf.clip_by_value(image, 0., 1.)
return image
def preprocess_image(image, height, width, is_training=False,
color_distort=True, test_crop=True):
"""Preprocesses the given image.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
is_training: `bool` for whether the preprocessing is for training.
color_distort: whether to apply the color distortion.
test_crop: whether or not to extract a central crop of the images
(as for standard ImageNet evaluation) during the evaluation.
Returns:
A preprocessed image `Tensor` of range [0, 1].
"""
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
if is_training:
return preprocess_for_train(image, height, width, color_distort)
else:
return preprocess_for_eval(image, height, width, test_crop)