Skip to content

Commit b67b210

Browse files
csferngtensorflow-copybara
authored andcommitted
Make the graph estimator test more robust.
PiperOrigin-RevId: 326555753
1 parent 127411b commit b67b210

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

neural_structured_learning/estimator/graph_regularization_test.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,9 @@ def embedding_fn(features, mode):
447447
input_layer = features[FEATURE_NAME]
448448
with tf.compat.v1.variable_scope('hidden_layer', reuse=tf.AUTO_REUSE):
449449
hidden_layer = tf.compat.v1.layers.dense(
450-
input_layer, units=4, activation=tf.nn.relu)
450+
input_layer, units=4, activation=lambda x: tf.abs(x) + 0.1)
451+
# The always-positive activation funciton is to make sure the batch mean
452+
# is non-zero.
451453
batch_norm_layer = tf.compat.v1.layers.batch_normalization(
452454
hidden_layer, training=(mode == tf.estimator.ModeKeys.TRAIN))
453455
return batch_norm_layer
@@ -482,12 +484,12 @@ def input_fn():
482484
nbr_feature = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, FEATURE_NAME)
483485
nbr_weight = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
484486
features = {
485-
FEATURE_NAME: tf.constant([[0.1, 0.9], [0.8, 0.2]]),
486-
nbr_feature: tf.constant([[0.11, 0.89], [0.81, 0.21]]),
487-
nbr_weight: tf.constant([[0.9], [0.8]]),
487+
FEATURE_NAME: tf.constant([[0.1, 0.9], [-0.8, -0.2], [0.3, -0.7]]),
488+
nbr_feature: tf.constant([[0.1, 0.89], [-0.81, -0.2], [0.3, -0.69]]),
489+
nbr_weight: tf.constant([[0.9], [0.8], [0.7]]),
488490
}
489-
labels = tf.constant([[1], [0]])
490-
return tf.data.Dataset.from_tensor_slices((features, labels)).batch(2)
491+
labels = tf.constant([[1], [0], [1]])
492+
return tf.data.Dataset.from_tensor_slices((features, labels)).batch(3)
491493

492494
base_est = tf.estimator.Estimator(model_fn, model_dir=self.model_dir)
493495
graph_reg_config = nsl_configs.make_graph_reg_config(

0 commit comments

Comments
 (0)