Skip to content

Commit

Permalink
Makes BatchNormTest work without disabling resource variables.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 294973730
Change-Id: Ic2b327314698e7fd0eee70c48993a0e145fa8a28
  • Loading branch information
TF-Slim Team authored and copybara-github committed Feb 13, 2020
1 parent a248ee3 commit 2abdae7
Showing 1 changed file with 26 additions and 31 deletions.
57 changes: 26 additions & 31 deletions tf_slim/layers/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
def setUpModule():
tf.disable_eager_execution()


arg_scope = arg_scope_lib.arg_scope


Expand Down Expand Up @@ -1818,26 +1819,6 @@ def testReuseFCWithBatchNorm(self):

class BatchNormTest(test.TestCase):

def setUp(self):
super(BatchNormTest, self).setUp()
# TODO(b/148892830): Investigate if this should be re-enabled.
self.rv_enabled = tf.resource_variables_enabled()
self.cf_enabled = tf.control_flow_v2_enabled()
# Control flow is automatically re-enabled on cond() mode
# (b/149312871) and this breaks disable_resource_variable below,
# remove once either of the bugs is fixed.
tf.disable_control_flow_v2()
# TODO(b/148892830): Investigate if this should be re-enabled.
# Also see TODO(b/149311854): for batchnorm specific shendagians.
tf.disable_resource_variables()

def tearDown(self):
super(BatchNormTest, self).tearDown()
if self.rv_enabled:
tf.enable_resource_variables()
if self.cf_enabled:
tf.enable_control_flow_v2()

def _addBesselsCorrection(self, sample_size, expected_var):
correction_factor = sample_size / (sample_size - 1)
expected_var *= correction_factor
Expand Down Expand Up @@ -1985,9 +1966,10 @@ def testUpdatesCollection(self):
update_layers = ops.get_collection('my_update_ops')
update_moving_mean = update_layers[0]
update_moving_variance = update_layers[1]
self.assertEqual(update_moving_mean.op.name, 'BatchNorm/AssignMovingAvg')
self.assertEqual(update_moving_variance.op.name,
'BatchNorm/AssignMovingAvg_1')
self.assertStartsWith(update_moving_mean.op.name,
'BatchNorm/AssignMovingAvg')
self.assertStartsWith(update_moving_variance.op.name,
'BatchNorm/AssignMovingAvg_1')

def testVariablesCollections(self):
variables_collections = {
Expand Down Expand Up @@ -2139,6 +2121,7 @@ def _testNoneUpdatesCollections(self,

def testNoneUpdatesCollectionsNHWC(self):
self._testNoneUpdatesCollections(False, data_format='NHWC')
print(tf.all_variables())

def testNoneUpdatesCollectionsNCHW(self):
self._testNoneUpdatesCollections(False, data_format='NCHW')
Expand Down Expand Up @@ -2396,7 +2379,8 @@ def _testIsTrainingVariable(self,
batch_size * height * width, expected_var)
images = constant_op.constant(
image_values, shape=image_shape, dtype=dtypes.float32)
is_training = variables_lib.VariableV1(True)
# NB: tf.identity is required, because variables can't be fed.
is_training = tf.identity(tf.Variable(True))
output = _layers.batch_norm(
images,
decay=0.1,
Expand Down Expand Up @@ -2543,7 +2527,10 @@ def _testNoneUpdatesCollectionIsTrainingVariable(self,
batch_size * height * width, expected_var)
images = constant_op.constant(
image_values, shape=image_shape, dtype=dtypes.float32)
is_training = variables_lib.VariableV1(True)
# NB: tf.identity is required because variables cant be fed.
# Ref: https://github.com/tensorflow/tensorflow/issues/19884
is_training = tf.identity(tf.Variable(True))

output = _layers.batch_norm(
images,
decay=0.1,
Expand All @@ -2557,7 +2544,8 @@ def _testNoneUpdatesCollectionIsTrainingVariable(self,
sess.run(variables_lib.global_variables_initializer())
moving_mean = variables.get_variables('BatchNorm/moving_mean')[0]
moving_variance = variables.get_variables('BatchNorm/moving_variance')[0]
mean, variance = sess.run([moving_mean, moving_variance])
mean, variance = sess.run([moving_mean, moving_variance], {
is_training: True})
# After initialization moving_mean == 0 and moving_variance == 1.
self.assertAllClose(mean, [0] * channels)
self.assertAllClose(variance, [1] * channels)
Expand All @@ -2584,7 +2572,7 @@ def _testNoneUpdatesCollectionIsTrainingVariable(self,
moving_variance_corrected = moving_variance / correction_factor
correct_moving_variance = state_ops.assign(moving_variance,
moving_variance_corrected)
sess.run(correct_moving_variance)
sess.run(correct_moving_variance, {is_training: True})
output_false = sess.run([output], {is_training: False})
self.assertTrue(np.allclose(output_true, output_false))

Expand Down Expand Up @@ -2774,6 +2762,13 @@ def testBatchNormBeta(self):
a_16, center=False, data_format='NCHW', zero_debias_moving_mean=True)
sess.run(variables_lib.global_variables_initializer())

def is_float_var(self, v):
if v.dtype == dtypes.float32_ref:
return True
if v.dtype == tf.float32 and v.op.outputs[0].dtype == tf.resource:
return True
return False

def testVariablesAreFloat32(self):
height, width = 3, 3
with self.cached_session():
Expand All @@ -2782,12 +2777,12 @@ def testVariablesAreFloat32(self):
_layers.batch_norm(images, scale=True)
beta = variables.get_variables_by_name('beta')[0]
gamma = variables.get_variables_by_name('gamma')[0]
self.assertEqual(beta.dtype, dtypes.float32_ref)
self.assertEqual(gamma.dtype, dtypes.float32_ref)
self.assertTrue(self.is_float_var(beta))
self.assertTrue(self.is_float_var(gamma))
moving_mean = variables.get_variables_by_name('moving_mean')[0]
moving_variance = variables.get_variables_by_name('moving_variance')[0]
self.assertEqual(moving_mean.dtype, dtypes.float32_ref)
self.assertEqual(moving_variance.dtype, dtypes.float32_ref)
self.assertTrue(self.is_float_var(moving_mean))
self.assertTrue(self.is_float_var(moving_variance))

def _runFusedBatchNorm(self, shape, dtype):
channels = shape[1]
Expand Down

0 comments on commit 2abdae7

Please sign in to comment.