Skip to content

Commit

Permalink
Entropy of observations metric (#2340)
Browse files Browse the repository at this point in the history
Summary:

This commit introduces `entropy_of_observations` as a model fit metric. It quantifies the entropy of the outcomes `y_obs` using a kernel density estimator. This metric can be useful in detecting datasets in which the outcomes are clustered (implying a low entropy), rather than uniformly distributed in the outcome space (high entropy).

Differential Revision: D55930954
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Apr 9, 2024
1 parent a161618 commit c5c4f3a
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 2 deletions.
45 changes: 45 additions & 0 deletions ax/utils/stats/model_fit_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
from scipy.stats import fisher_exact, norm, pearsonr, spearmanr
from sklearn.neighbors import KernelDensity

"""
################################ Model Fit Metrics ###############################
Expand Down Expand Up @@ -127,6 +128,50 @@ def std_of_the_standardized_error(
return ((y_obs - y_pred) / se_pred).std()


def entropy_of_observations(
y_obs: np.ndarray,
y_pred: np.ndarray,
se_pred: np.ndarray,
bandwidth: float = 0.1,
) -> float:
"""Computes the entropy of the observations y_obs using a kernel density estimator.
This can be used to quantify how "clustered" the outcomes are. NOTE: y_pred and
se_pred are not used, but are required for the API.
Args:
y_obs: An array of observations for a single metric.
y_pred: An array of the predicted values corresponding to y_obs.
se_pred: An array of the standard errors of the predicted values.
bandwidth: The kernel bandwidth. Defaults to 0.1, which is a reasonable value
for standardized outcomes y_obs. The rank ordering of the results on a set
of y_obs data sets is not generally sensitive to the bandwidth, if it is
held fixed across the data sets. The absolute value of the results however
changes significantly with the bandwidth.
Returns:
The scalar entropy of the observations.
"""
if y_obs.ndim == 1:
y_obs = y_obs.reshape(-1, 1)
return _entropy_via_kde(y_obs, bandwidth=bandwidth)


def _entropy_via_kde(y: np.ndarray, bandwidth: float = 0.1) -> float:
"""Computes the entropy of the kernel density estimate of the input data.
Args:
y: An (n x m) array of observations.
bandwidth: The kernel bandwidth.
Returns:
The scalar entropy of the kernel density estimate.
"""
kde = KernelDensity(kernel="gaussian", bandwidth=bandwidth)
kde.fit(y)
log_p = kde.score_samples(y) # computes the log probability of each data point
return -np.sum(np.exp(log_p) * log_p) # compute entropy, the negated sum of p log p


def _mean_prediction_ci(
y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
) -> float:
Expand Down
39 changes: 37 additions & 2 deletions ax/utils/stats/tests/test_model_fit_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,46 @@

import numpy as np
from ax.utils.common.testutils import TestCase
from ax.utils.stats.model_fit_stats import _fisher_exact_test_p
from ax.utils.stats.model_fit_stats import _fisher_exact_test_p, entropy_of_observations
from scipy.stats import fisher_exact


class FisherExactTestTest(TestCase):
class TestModelFitStats(TestCase):
def test_entropy_of_observations(self) -> None:
np.random.seed(1234)
n = 16
yc = np.ones(n)
yc[: n // 2] = -1
yc += np.random.randn(n) * 0.05
yr = np.random.randn(n)

# standardize both observations
yc = yc / yc.std()
yr = yr / yr.std()

ones = np.ones(n)

# compute entropy of observations
ec = entropy_of_observations(y_obs=yc, y_pred=ones, se_pred=ones, bandwidth=0.1)
er = entropy_of_observations(y_obs=yr, y_pred=ones, se_pred=ones, bandwidth=0.1)

# testing that the Gaussian distributed data has a much larger entropy than
# the clustered distribution
self.assertTrue(er - ec > 10.0)

ec2 = entropy_of_observations(
y_obs=yc, y_pred=ones, se_pred=ones, bandwidth=0.2
)
er2 = entropy_of_observations(
y_obs=yr, y_pred=ones, se_pred=ones, bandwidth=0.2
)
# entropy increases with larger bandwidth
self.assertGreater(ec2, ec)
self.assertGreater(er2, er)

# ordering of entropies stays the same, though the difference is smaller
self.assertTrue(er2 - ec2 > 3)

def test_contingency_table_construction(self) -> None:
# Create a dummy set of observations and predictions
y_obs = np.array([1, 3, 2, 5, 7, 3])
Expand Down

0 comments on commit c5c4f3a

Please sign in to comment.