diff --git a/ax/utils/stats/model_fit_stats.py b/ax/utils/stats/model_fit_stats.py index 5b3e1043bd8..6f985ff5851 100644 --- a/ax/utils/stats/model_fit_stats.py +++ b/ax/utils/stats/model_fit_stats.py @@ -9,6 +9,7 @@ import numpy as np from scipy.stats import fisher_exact, norm, pearsonr, spearmanr +from sklearn.neighbors import KernelDensity """ ################################ Model Fit Metrics ############################### @@ -127,6 +128,50 @@ def std_of_the_standardized_error( return ((y_obs - y_pred) / se_pred).std() +def entropy_of_observations( + y_obs: np.ndarray, + y_pred: np.ndarray, + se_pred: np.ndarray, + bandwidth: float = 0.1, +) -> float: + """Computes the entropy of the observations y_obs using a kernel density estimator. + This can be used to quantify how "clustered" the outcomes are. NOTE: y_pred and + se_pred are not used, but are required for the API. + + Args: + y_obs: An array of observations for a single metric. + y_pred: An array of the predicted values corresponding to y_obs. + se_pred: An array of the standard errors of the predicted values. + bandwidth: The kernel bandwidth. Defaults to 0.1, which is a reasonable value + for standardized outcomes y_obs. The rank ordering of the results on a set + of y_obs data sets is not generally sensitive to the bandwidth, if it is + held fixed across the data sets. The absolute value of the results however + changes significantly with the bandwidth. + + Returns: + The scalar entropy of the observations. + """ + if y_obs.ndim == 1: + y_obs = y_obs.reshape(-1, 1) + return _entropy_via_kde(y_obs, bandwidth=bandwidth) + + +def _entropy_via_kde(y: np.ndarray, bandwidth: float = 0.1) -> float: + """Computes the entropy of the kernel density estimate of the input data. + + Args: + y: An (n x m) array of observations. + bandwidth: The kernel bandwidth. + + Returns: + The scalar entropy of the kernel density estimate. + """ + kde = KernelDensity(kernel="gaussian", bandwidth=bandwidth) + kde.fit(y) + log_p = kde.score_samples(y) # computes the log probability of each data point + return -np.sum(np.exp(log_p) * log_p) # compute entropy, the negated sum of p log p + + def _mean_prediction_ci( y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray ) -> float: diff --git a/ax/utils/stats/tests/test_model_fit_stats.py b/ax/utils/stats/tests/test_model_fit_stats.py index 9d6b258eb46..324af2ef9fc 100644 --- a/ax/utils/stats/tests/test_model_fit_stats.py +++ b/ax/utils/stats/tests/test_model_fit_stats.py @@ -6,20 +6,56 @@ import numpy as np from ax.utils.common.testutils import TestCase -from ax.utils.stats.model_fit_stats import _fisher_exact_test_p +from ax.utils.stats.model_fit_stats import _fisher_exact_test_p, entropy_of_observations from scipy.stats import fisher_exact -class FisherExactTestTest(TestCase): +class TestModelFitStats(TestCase): + def test_entropy_of_observations(self) -> None: + np.random.seed(1234) + n = 16 + yc = np.ones(n) + yc[: n // 2] = -1 + yc += np.random.randn(n) * 0.05 + yr = np.random.randn(n) + + # standardize both observations + yc = yc / yc.std() + yr = yr / yr.std() + + ones = np.ones(n) + + # compute entropy of observations + ec = entropy_of_observations(y_obs=yc, y_pred=ones, se_pred=ones, bandwidth=0.1) + er = entropy_of_observations(y_obs=yr, y_pred=ones, se_pred=ones, bandwidth=0.1) + + # testing that the Gaussian distributed data has a much larger entropy than + # the clustered distribution + self.assertTrue(er - ec > 10.0) + + ec2 = entropy_of_observations( + y_obs=yc, y_pred=ones, se_pred=ones, bandwidth=0.2 + ) + er2 = entropy_of_observations( + y_obs=yr, y_pred=ones, se_pred=ones, bandwidth=0.2 + ) + # entropy increases with larger bandwidth + self.assertGreater(ec2, ec) + self.assertGreater(er2, er) + + # ordering of entropies stays the same, though the difference is smaller + self.assertTrue(er2 - ec2 > 3) + def test_contingency_table_construction(self) -> None: # Create a dummy set of observations and predictions y_obs = np.array([1, 3, 2, 5, 7, 3]) y_pred = np.array([2, 4, 1, 6, 8, 2.5]) + se_pred = np.full(len(y_obs), np.nan) # not used for fisher exact # Compute ground truth contingency table true_table = np.array([[2, 1], [1, 2]]) scipy_result = fisher_exact(true_table, alternative="greater")[1] - ax_result = _fisher_exact_test_p(y_obs, y_pred, se_pred=None) + ax_result = _fisher_exact_test_p(y_obs, y_pred, se_pred=se_pred) self.assertEqual(scipy_result, ax_result)