Skip to content

Commit

Permalink
Fix a NaN issue with streaming_covariance.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 363099343
Change-Id: I7012be19e3afd9b0b1f8f7b8e047065ae09909ce
  • Loading branch information
TF-Slim Team authored and copybara-github committed Mar 16, 2021
1 parent 6387efb commit 77b4412
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
2 changes: 1 addition & 1 deletion tf_slim/metrics/metric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3242,7 +3242,7 @@ def streaming_covariance(predictions,
delta_comoment = (
batch_comoment + (prev_mean_prediction - batch_mean_prediction) *
(prev_mean_label - batch_mean_label) *
(prev_count * batch_count / update_count))
(math_ops.div_no_nan(prev_count * batch_count, update_count)))
update_comoment = state_ops.assign_add(comoment, delta_comoment)

covariance = array_ops.where(
Expand Down
27 changes: 17 additions & 10 deletions tf_slim/metrics/metric_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5947,12 +5947,13 @@ def testMultiUpdateWithErrorAndWeights(self):
with self.cached_session() as sess:
np.random.seed(123)
n = 100
stride = 10
predictions = np.random.randn(n)
labels = 0.5 * predictions + np.random.randn(n)
weights = np.tile(np.arange(n // 10), n // 10)
weights = np.tile(np.arange(n // stride), n // stride)
np.random.shuffle(weights)
weights[0:stride] = 0.0

stride = 10
predictions_t = array_ops.placeholder(dtypes_lib.float32, [stride])
labels_t = array_ops.placeholder(dtypes_lib.float32, [stride])
weights_t = array_ops.placeholder(dtypes_lib.float32, [stride])
Expand All @@ -5974,14 +5975,20 @@ def testMultiUpdateWithErrorAndWeights(self):
if not np.isnan(prev_expected_cov):
self.assertAlmostEqual(prev_expected_cov,
sess.run(cov, feed_dict=feed_dict), 5)
expected_cov = np.cov(
predictions[:stride * (i + 1)],
labels[:stride * (i + 1)],
fweights=weights[:stride * (i + 1)])[0, 1]
self.assertAlmostEqual(expected_cov,
sess.run(update_op, feed_dict=feed_dict), 5)
self.assertAlmostEqual(expected_cov, sess.run(cov, feed_dict=feed_dict),
5)
if np.sum(weights[:stride * (i + 1)]) != 0.0:
expected_cov = np.cov(
predictions[:stride * (i + 1)],
labels[:stride * (i + 1)],
fweights=weights[:stride * (i + 1)])[0, 1]
else:
expected_cov = NAN
sess.run(update_op, feed_dict=feed_dict)
self.assertEqual(
np.isnan(expected_cov),
np.isnan(sess.run(cov, feed_dict=feed_dict)))
if not np.isnan(expected_cov):
self.assertAlmostEqual(expected_cov,
sess.run(cov, feed_dict=feed_dict), 5)
prev_expected_cov = expected_cov


Expand Down

0 comments on commit 77b4412

Please sign in to comment.