-
Notifications
You must be signed in to change notification settings - Fork 2
/
tf_utils.py
166 lines (136 loc) · 6.7 KB
/
tf_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.contrib.framework.python.ops import add_arg_scope
def string_tuple_to_image_pair(image_file, label_file, label_mapping):
image_encoded = tf.read_file(image_file)
image_decoded = tf.to_float(tf.image.decode_png(image_encoded, channels=3))
labels_encoded = tf.read_file(label_file)
labels_decoded = tf.cast(
tf.image.decode_png(labels_encoded, channels=1), tf.int32)
labels_decoded = tf.gather(label_mapping, labels_decoded)
return image_decoded, labels_decoded
def flip_augment(image, labels):
flip = tf.random_uniform((1,))[0]
image, labels = tf.cond(
flip < 0.5,
lambda: (image, labels),
lambda: (tf.reverse(image, [1]), tf.reverse(labels, [1])))
return image, labels
def gamma_augment(image, labels, gamma_range=0.1):
# See Full-Resolution Residual Networks for Semantic Segmentation in Street Scenes
# CVPR'17 for a discussion on this.
scaled_image = image / 255.0
factor = tf.random_uniform(
shape=[], minval=-gamma_range, maxval=gamma_range, dtype=tf.float32)
gamma = (tf.log(0.5 + 1.0 / tf.sqrt(2.0) * factor) /
tf.log(0.5 - 1.0 / tf.sqrt(2.0) * factor))
image = tf.pow(scaled_image, gamma) * 255.0
return image, labels
# TODO(pandoro): change the underlying workings of the crop augment functions
# they also need to support crops that are bigger than the image. Right now
# this will likely just completely break.
def crop_augment(image, labels, pixel_to_remove_h, pixel_to_remove_w):
# Compute the corners.
begin_h = tf.random_uniform(
shape=[], minval=0, maxval=pixel_to_remove_h, dtype=tf.int32)
begin_w = tf.random_uniform(
shape=[], minval=0, maxval=pixel_to_remove_w, dtype=tf.int32)
end_h = tf.shape(image)[0] - pixel_to_remove_h + begin_h
end_w = tf.shape(image)[1] - pixel_to_remove_w + begin_w
# Compute the new width if statically defined.
if image.shape.is_fully_defined():
h = image.shape[0] - pixel_to_remove_h
w = image.shape[1] - pixel_to_remove_w
else:
h = None
w = None
# Actually cut out the crop and fix the shapes.
image = image[begin_h:end_h, begin_w:end_w]
labels = labels[begin_h:end_h, begin_w:end_w]
# We can't set a static width/height here.
if h is not None:
image.set_shape([h, w, 3])
labels.set_shape([h, w, 1])
return image, labels
def fixed_crop_augment(image, labels, crop_size_h, crop_size_w):
# Simply compute the border that can be removed given the image and the
# fixed crop and then reuse the crop_augment function.
remove_h = tf.shape(image)[0] - crop_size_h
remove_w = tf.shape(image)[1] - crop_size_w
image, labels = crop_augment(image, labels, remove_h, remove_w)
image.set_shape([crop_size_h, crop_size_w, 3])
labels.set_shape([crop_size_h, crop_size_w, 1])
return image, labels
@add_arg_scope
def group_normalization(input, group_count=None, channel_count=None,
epsilon=1e-5, scope=None):
# Check that the provided parameters make sense.
if group_count is not None and channel_count is not None:
raise ValueError('You cannot specify both the group and channel count '
'for group normalization.')
if group_count is None and channel_count is None:
raise ValueError('You have to specify either the group or the channel '
'count for group normalization.')
# Check that the number of channels can be divided as specified.
# Here we need the static shape to do actual computations with.
C = input.shape[-1].value
if group_count is not None:
if C % group_count:
raise ValueError(
'An input channel count of {} cannot be divided into {} groups.'
''.format(C, group_count))
else:
groups = group_count
channels = C // group_count
else:
if C % channel_count:
raise ValueError(
'An input channel count of {} cannot be divided into groups of '
'{} channels.'.format(C, channel_count))
else:
groups = C // channel_count
channels = channel_count
with tf.variable_scope(scope, 'group_normalization'):
#return tf.contrib.layers.group_norm(input, groups=groups)
# This implements Group Normalization as introduced in:
# "Group Normalization", Yuxin Wu, Kaiming He
# https://arxiv.org/abs/1803.08494.
# For reshaping we need the dynamic shapes. This is an important detail
# done wrong in the original code snippet and the two implementations I
# found in GitHub. When using the static shape the dimensions need to be
# fully specified which doesn't make any sense for dynamic image and/or
# batch sizes.
# This also seems to be the case for the TF implementation found in
# tf contrib.
# However, this implementation also seems to seriously benefit from
# having a fixed input and batch size at which point it is approximately
# as fast as the TF contrib implementation. However, in turn this can
# be used without fixed input sizes during deployment.
# In any case they are both seriously slow compared to vanilla batch
# normalization, easily increasing training time by a factor of two :(
N = tf.shape(input)[0]
H = tf.shape(input)[1]
W = tf.shape(input)[2]
grouped = tf.reshape(input, [N, H, W, channels, groups])
# Compute the group statistics.
mean, var = tf.nn.moments(grouped, [1, 2, 3], keep_dims=True)
# Reshape them so that they can first me multiplied together with gamma
# and beta, before applying them to the input. This involves recreating
# np.repeat which TF misses for some reason. In the last reshape op we
# directly include squeezing out the surplus dimension created from
# grouping.
mean = tf.expand_dims(mean, -1)
mean = tf.tile(mean, [1, 1, 1, 1, 1, channels])
mean = tf.reshape(mean, [N, 1, 1, channels * groups])
var = tf.expand_dims(var, -1)
var = tf.tile(var, [1, 1, 1, 1, 1, channels])
var = tf.reshape(var, [N, 1, 1, channels * groups])
# Setup the scale and offset parameters
gamma = tf.get_variable(
'gamma', [1, 1, 1, C], initializer=tf.constant_initializer(1.0))
beta = tf.get_variable(
'beta', [1, 1, 1, C], initializer=tf.constant_initializer(0.0))
inv_std = tf.rsqrt(var + epsilon)
gamma = gamma * inv_std
beta = beta - mean * gamma
return input * gamma + beta