From c5c4f3a18879efc88fabad9525d60bd4630f6e54 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Tue, 9 Apr 2024 12:00:06 -0700 Subject: [PATCH] Entropy of observations metric (#2340) Summary: This commit introduces `entropy_of_observations` as a model fit metric. It quantifies the entropy of the outcomes `y_obs` using a kernel density estimator. This metric can be useful in detecting datasets in which the outcomes are clustered (implying a low entropy), rather than uniformly distributed in the outcome space (high entropy). Differential Revision: D55930954 --- ax/utils/stats/model_fit_stats.py | 45 ++++++++++++++++++++ ax/utils/stats/tests/test_model_fit_stats.py | 39 ++++++++++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/ax/utils/stats/model_fit_stats.py b/ax/utils/stats/model_fit_stats.py index 5b3e1043bd8..6f985ff5851 100644 --- a/ax/utils/stats/model_fit_stats.py +++ b/ax/utils/stats/model_fit_stats.py @@ -9,6 +9,7 @@ import numpy as np from scipy.stats import fisher_exact, norm, pearsonr, spearmanr +from sklearn.neighbors import KernelDensity """ ################################ Model Fit Metrics ############################### @@ -127,6 +128,50 @@ def std_of_the_standardized_error( return ((y_obs - y_pred) / se_pred).std() +def entropy_of_observations( + y_obs: np.ndarray, + y_pred: np.ndarray, + se_pred: np.ndarray, + bandwidth: float = 0.1, +) -> float: + """Computes the entropy of the observations y_obs using a kernel density estimator. + This can be used to quantify how "clustered" the outcomes are. NOTE: y_pred and + se_pred are not used, but are required for the API. + + Args: + y_obs: An array of observations for a single metric. + y_pred: An array of the predicted values corresponding to y_obs. + se_pred: An array of the standard errors of the predicted values. + bandwidth: The kernel bandwidth. Defaults to 0.1, which is a reasonable value + for standardized outcomes y_obs. The rank ordering of the results on a set + of y_obs data sets is not generally sensitive to the bandwidth, if it is + held fixed across the data sets. The absolute value of the results however + changes significantly with the bandwidth. + + Returns: + The scalar entropy of the observations. + """ + if y_obs.ndim == 1: + y_obs = y_obs.reshape(-1, 1) + return _entropy_via_kde(y_obs, bandwidth=bandwidth) + + +def _entropy_via_kde(y: np.ndarray, bandwidth: float = 0.1) -> float: + """Computes the entropy of the kernel density estimate of the input data. + + Args: + y: An (n x m) array of observations. + bandwidth: The kernel bandwidth. + + Returns: + The scalar entropy of the kernel density estimate. + """ + kde = KernelDensity(kernel="gaussian", bandwidth=bandwidth) + kde.fit(y) + log_p = kde.score_samples(y) # computes the log probability of each data point + return -np.sum(np.exp(log_p) * log_p) # compute entropy, the negated sum of p log p + + def _mean_prediction_ci( y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray ) -> float: diff --git a/ax/utils/stats/tests/test_model_fit_stats.py b/ax/utils/stats/tests/test_model_fit_stats.py index 9d6b258eb46..a83eb4d90bc 100644 --- a/ax/utils/stats/tests/test_model_fit_stats.py +++ b/ax/utils/stats/tests/test_model_fit_stats.py @@ -6,11 +6,46 @@ import numpy as np from ax.utils.common.testutils import TestCase -from ax.utils.stats.model_fit_stats import _fisher_exact_test_p +from ax.utils.stats.model_fit_stats import _fisher_exact_test_p, entropy_of_observations from scipy.stats import fisher_exact -class FisherExactTestTest(TestCase): +class TestModelFitStats(TestCase): + def test_entropy_of_observations(self) -> None: + np.random.seed(1234) + n = 16 + yc = np.ones(n) + yc[: n // 2] = -1 + yc += np.random.randn(n) * 0.05 + yr = np.random.randn(n) + + # standardize both observations + yc = yc / yc.std() + yr = yr / yr.std() + + ones = np.ones(n) + + # compute entropy of observations + ec = entropy_of_observations(y_obs=yc, y_pred=ones, se_pred=ones, bandwidth=0.1) + er = entropy_of_observations(y_obs=yr, y_pred=ones, se_pred=ones, bandwidth=0.1) + + # testing that the Gaussian distributed data has a much larger entropy than + # the clustered distribution + self.assertTrue(er - ec > 10.0) + + ec2 = entropy_of_observations( + y_obs=yc, y_pred=ones, se_pred=ones, bandwidth=0.2 + ) + er2 = entropy_of_observations( + y_obs=yr, y_pred=ones, se_pred=ones, bandwidth=0.2 + ) + # entropy increases with larger bandwidth + self.assertGreater(ec2, ec) + self.assertGreater(er2, er) + + # ordering of entropies stays the same, though the difference is smaller + self.assertTrue(er2 - ec2 > 3) + def test_contingency_table_construction(self) -> None: # Create a dummy set of observations and predictions y_obs = np.array([1, 3, 2, 5, 7, 3])