diff --git a/tf_slim/metrics/metric_ops.py b/tf_slim/metrics/metric_ops.py index 3febbfa..cdd844c 100644 --- a/tf_slim/metrics/metric_ops.py +++ b/tf_slim/metrics/metric_ops.py @@ -3242,7 +3242,7 @@ def streaming_covariance(predictions, delta_comoment = ( batch_comoment + (prev_mean_prediction - batch_mean_prediction) * (prev_mean_label - batch_mean_label) * - (prev_count * batch_count / update_count)) + (math_ops.div_no_nan(prev_count * batch_count, update_count))) update_comoment = state_ops.assign_add(comoment, delta_comoment) covariance = array_ops.where( diff --git a/tf_slim/metrics/metric_ops_test.py b/tf_slim/metrics/metric_ops_test.py index 9920fa2..3d63758 100644 --- a/tf_slim/metrics/metric_ops_test.py +++ b/tf_slim/metrics/metric_ops_test.py @@ -5947,12 +5947,13 @@ def testMultiUpdateWithErrorAndWeights(self): with self.cached_session() as sess: np.random.seed(123) n = 100 + stride = 10 predictions = np.random.randn(n) labels = 0.5 * predictions + np.random.randn(n) - weights = np.tile(np.arange(n // 10), n // 10) + weights = np.tile(np.arange(n // stride), n // stride) np.random.shuffle(weights) + weights[0:stride] = 0.0 - stride = 10 predictions_t = array_ops.placeholder(dtypes_lib.float32, [stride]) labels_t = array_ops.placeholder(dtypes_lib.float32, [stride]) weights_t = array_ops.placeholder(dtypes_lib.float32, [stride]) @@ -5974,14 +5975,20 @@ def testMultiUpdateWithErrorAndWeights(self): if not np.isnan(prev_expected_cov): self.assertAlmostEqual(prev_expected_cov, sess.run(cov, feed_dict=feed_dict), 5) - expected_cov = np.cov( - predictions[:stride * (i + 1)], - labels[:stride * (i + 1)], - fweights=weights[:stride * (i + 1)])[0, 1] - self.assertAlmostEqual(expected_cov, - sess.run(update_op, feed_dict=feed_dict), 5) - self.assertAlmostEqual(expected_cov, sess.run(cov, feed_dict=feed_dict), - 5) + if np.sum(weights[:stride * (i + 1)]) != 0.0: + expected_cov = np.cov( + predictions[:stride * (i + 1)], + labels[:stride * (i + 1)], + fweights=weights[:stride * (i + 1)])[0, 1] + else: + expected_cov = NAN + sess.run(update_op, feed_dict=feed_dict) + self.assertEqual( + np.isnan(expected_cov), + np.isnan(sess.run(cov, feed_dict=feed_dict))) + if not np.isnan(expected_cov): + self.assertAlmostEqual(expected_cov, + sess.run(cov, feed_dict=feed_dict), 5) prev_expected_cov = expected_cov