-
Notifications
You must be signed in to change notification settings - Fork 0
/
old_dataloader.py
220 lines (184 loc) · 8.86 KB
/
old_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import tensorflow as tf
import random
class DataLoader(object):
"""A TensorFlow Dataset API based loader for semantic segmentation problems."""
def __init__(self, image_paths, mask_paths, image_size, channels=[3, 3], crop_percent=None, palette=None, seed=None,
testing=None):
"""
Initializes the data loader object
Args:
image_paths: List of paths of train images.
mask_paths: List of paths of train masks (segmentation masks)
image_size: Tuple of (Height, Width), the final height
of the loaded images.
channels: List of ints, first element is number of channels in images,
second is the number of channels in the mask image (needed to
correctly read the images into tensorflow.)
crop_percent: Float in the range 0-1, defining percentage of image
to randomly crop.
palette: A list of RGB pixel values in the mask. If specified, the mask
will be one hot encoded along the channel dimension.
seed: An int, if not specified, chosen randomly. Used as the seed for
the RNG in the data pipeline.
"""
image_paths.sort()
mask_paths.sort()
self.image_paths = image_paths
self.mask_paths = mask_paths
if not testing:
dictionary = dict(zip(image_paths, mask_paths))
dict_list = list(dictionary.items())
random.shuffle(dict_list)
new_dictionary = dict(dict_list)
self.image_paths = list(new_dictionary.keys())[0:1000]
self.mask_paths = list(new_dictionary.values())[0:1000]
elif testing == True:
self.image_paths = image_paths[0:500]
self.mask_paths = mask_paths[0:500]
self.palette = palette
self.image_size = image_size
if crop_percent is not None:
if 0.0 < crop_percent <= 1.0:
self.crop_percent = tf.constant(crop_percent, tf.float32)
elif 0 < crop_percent <= 100:
self.crop_percent = tf.constant(crop_percent / 100., tf.float32)
else:
raise ValueError("Invalid value entered for crop size. Please use an \
integer between 0 and 100, or a float between 0 and 1.0")
else:
self.crop_percent = None
self.channels = channels
if seed is None:
self.seed = random.randint(0, 1000)
else:
self.seed = seed
def _corrupt_brightness(self, image, mask):
"""
Radnomly applies a random brightness change.
"""
cond_brightness = tf.cast(tf.random.uniform(
[], maxval=2, dtype=tf.int32), tf.bool)
image = tf.cond(cond_brightness, lambda: tf.image.random_brightness(
image, 0.1), lambda: tf.identity(image))
return image, mask
def _corrupt_contrast(self, image, mask):
"""
Randomly applies a random contrast change.
"""
cond_contrast = tf.cast(tf.random.uniform(
[], maxval=2, dtype=tf.int32), tf.bool)
image = tf.cond(cond_contrast, lambda: tf.image.random_contrast(
image, 0.1, 0.8), lambda: tf.identity(image))
return image, mask
def _corrupt_saturation(self, image, mask):
"""
Randomly applies a random saturation change.
"""
cond_saturation = tf.cast(tf.random.uniform(
[], maxval=2, dtype=tf.int32), tf.bool)
image = tf.cond(cond_saturation, lambda: tf.image.random_saturation(
image, 0.1, 0.8), lambda: tf.identity(image))
return image, mask
def _crop_random(self, image, mask):
"""
Randomly crops image and mask in accord.
"""
cond_crop_image = tf.cast(tf.random.uniform(
[], maxval=2, dtype=tf.int32, seed=self.seed), tf.bool)
cond_crop_mask = tf.cast(tf.random.uniform(
[], maxval=2, dtype=tf.int32, seed=self.seed), tf.bool)
shape = tf.cast(tf.shape(image), tf.float32)
h = tf.cast(shape[0] * self.crop_percent, tf.int32)
w = tf.cast(shape[1] * self.crop_percent, tf.int32)
image = tf.cond(cond_crop_image, lambda: tf.image.random_crop(
image, [h, w, self.channels[0]], seed=self.seed), lambda: tf.identity(image))
mask = tf.cond(cond_crop_mask, lambda: tf.image.random_crop(
mask, [h, w, self.channels[1]], seed=self.seed), lambda: tf.identity(mask))
return image, mask
def _flip_left_right(self, image, mask):
"""
Randomly flips image and mask left or right in accord.
"""
image = tf.image.random_flip_left_right(image, seed=self.seed)
mask = tf.image.random_flip_left_right(mask, seed=self.seed)
return image, mask
def _resize_data(self, image, mask):
"""
Resizes images to specified size.
"""
image = tf.image.resize(image, [self.image_size[0], self.image_size[1]])
mask = tf.image.resize(mask, [self.image_size[0], self.image_size[1]], method='nearest')
return image, mask
def _parse_data(self, image_paths, mask_paths):
"""
Reads image and mask files depending on
specified extension.
"""
# dictionary = dict(zip(image_paths, mask_paths))
#
# print(list(dictionary.items()))
image_content = tf.io.read_file(image_paths)
mask_content = tf.io.read_file(mask_paths)
images = tf.image.decode_jpeg(image_content, channels=self.channels[0])
masks = tf.image.decode_jpeg(mask_content, channels=self.channels[1])
images = tf.cast(images, tf.float32) / 255.0
masks = tf.cast(masks, tf.float32) / 255.0
return images, masks
def _one_hot_encode(self, image, mask):
"""
Converts mask to a one-hot encoding specified by the semantic map.
"""
one_hot_map = []
for colour in self.palette:
class_map = tf.reduce_all(tf.equal(mask, colour), axis=-1)
one_hot_map.append(class_map)
one_hot_map = tf.stack(one_hot_map, axis=-1)
one_hot_map = tf.cast(one_hot_map, tf.float32)
return image, one_hot_map
def data_batch(self, batch_size, augment, shuffle=False, one_hot_encode=False):
"""
Reads data, normalizes it, shuffles it, then batches it, returns a
the next element in dataset op and the dataset initializer op.
Inputs:
batch_size: Number of images/masks in each batch returned.
augment: Boolean, whether to augment data or not.
shuffle: Boolean, whether to shuffle data in buffer or not.
one_hot_encode: Boolean, whether to one hot encode the mask image or not.
Encoding will done according to the palette specified when
initializing the object.
Returns:
data: A tf dataset object.
"""
# Create dataset out of the 2 files:
data = tf.data.Dataset.from_tensor_slices((self.image_paths, self.mask_paths))
# Parse images and labels
data = data.map(self._parse_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Resize to smaller dims for speed
data = data.map(self._resize_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# If augmentation is to be applied
if augment:
data = data.map(self._corrupt_brightness,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
data = data.map(self._corrupt_contrast,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
data = data.map(self._corrupt_saturation,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
if self.crop_percent is not None:
data = data.map(self._crop_random,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
data = data.map(self._flip_left_right,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
# One hot encode the mask
if one_hot_encode:
if self.palette is None:
raise ValueError('No Palette for one-hot encoding specified in the data loader! \
please specify one when initializing the loader.')
data = data.map(self._one_hot_encode, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if shuffle:
# Prefetch, shuffle then batch
data = data.prefetch(tf.data.experimental.AUTOTUNE).shuffle(random.randint(0, len(self.image_paths))).batch(
batch_size)
else:
# Batch and prefetch
data = data.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
return data