diff --git a/tensorflow_probability/python/distributions/categorical_test.py b/tensorflow_probability/python/distributions/categorical_test.py index 6d9644cc20..b4ae3b4ec0 100644 --- a/tensorflow_probability/python/distributions/categorical_test.py +++ b/tensorflow_probability/python/distributions/categorical_test.py @@ -338,6 +338,19 @@ def testEntropyGradient(self): res["categorical_entropy"]) self.assertAllClose(res["true_entropy_g"], res["categorical_entropy_g"]) + + + def testEntropyWithZeroProbabilities(self): + probs = [[0, 0.5, 0.5], [0, 1, 0]] + dist = categorical.Categorical(probs=probs) + dist_entropy = dist.entropy() + + with self.cached_session(): + self.assertAllClose(dist_entropy.eval(), + [ + -(0.5*np.log(0.5) + 0.5*np.log(0.5)), + -(np.log(1)) + ]) def testSample(self): with self.cached_session():