Skip to content

Commit

Permalink
Adds a test case for the imagenet/rank1_bnn.py script.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 449523540
  • Loading branch information
dusenberrymw authored and copybara-github committed May 18, 2022
1 parent 4c48e7a commit 5082c15
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions robustness_metrics/metrics/diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def bregman_kl_variance(x):
Returns:
tf.Tensor of shape [batch_size].
"""
num_models = x.shape[0]
batch_size = x.shape[1]
num_models = tf.shape(x)[0]
batch_size = tf.shape(x)[1]

variance = tf.zeros(batch_size)
central_prediction = tf.nn.softmax(tf.reduce_mean(tf.math.log(x), axis=0))
for i in range(num_models):
variance += kl_divergence(central_prediction, x[i])
return variance / num_models
return variance / tf.cast(num_models, dtype=tf.float32)


@metrics_base.registry.register('average_pairwise_diversity')
Expand Down

0 comments on commit 5082c15

Please sign in to comment.