Skip to content

Commit 127411b

Browse files
csferngtensorflow-copybara
authored andcommitted
Make the adversarial estimator test more robust.
PiperOrigin-RevId: 326539266
1 parent 537c11e commit 127411b

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

neural_structured_learning/estimator/adversarial_regularization_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,16 @@ def test_adversarial_wrapper_adds_regularization(self, adv_step_size,
128128

129129
@test_util.run_v1_only('Requires tf.train.GradientDescentOptimizer')
130130
def test_adversarial_wrapper_saving_batch_statistics(self):
131-
x0, y0 = np.array([[0.9, 0.1], [0.2, 0.8]]), np.array([1, 0])
131+
x0 = np.array([[0.9, 0.1], [0.2, -0.8], [-0.7, -0.3], [-0.4, 0.6]])
132+
y0 = np.array([1, 0, 1, 0])
132133
input_fn = single_batch_input_fn({FEATURE_NAME: x0}, y0)
133134
fc = tf.feature_column.numeric_column(FEATURE_NAME, shape=[2])
134135
base_est = tf.estimator.DNNClassifier(
135136
hidden_units=[4],
136137
feature_columns=[fc],
137138
model_dir=self.model_dir,
139+
activation_fn=lambda x: tf.abs(x) + 0.1,
140+
dropout=None,
138141
batch_norm=True)
139142
adv_est = nsl_estimator.add_adversarial_regularization(
140143
base_est,
@@ -145,6 +148,8 @@ def test_adversarial_wrapper_saving_batch_statistics(self):
145148
'dnn/hiddenlayer_0/batchnorm_0/moving_mean')
146149
moving_variance = adv_est.get_variable_value(
147150
'dnn/hiddenlayer_0/batchnorm_0/moving_variance')
151+
# The activation function always returns a positive number, so the batch
152+
# mean cannot be zero if updated successfully.
148153
self.assertNotAllClose(moving_mean, np.zeros(moving_mean.shape))
149154
self.assertNotAllClose(moving_variance, np.ones(moving_variance.shape))
150155

0 commit comments

Comments
 (0)