From ac177f5c23bddf15c33cb3befae2782aef446737 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Fri, 12 Apr 2024 18:16:54 -0700 Subject: [PATCH] Entropy of observations metric (#2340) Summary: This commit introduces `entropy_of_observations` as a model fit metric. It quantifies the entropy of the outcomes `y_obs` using a kernel density estimator. This metric can be useful in detecting datasets in which the outcomes are clustered (implying a low entropy), rather than uniformly distributed in the outcome space (high entropy). Reviewed By: saitcakmak Differential Revision: D55930954 --- ax/modelbridge/cross_validation.py | 1 - .../tests/test_model_fit_metrics.py | 31 ++++++++++++- ax/utils/stats/model_fit_stats.py | 45 +++++++++++++++++++ ax/utils/stats/tests/test_model_fit_stats.py | 42 +++++++++++++++-- 4 files changed, 114 insertions(+), 5 deletions(-) diff --git a/ax/modelbridge/cross_validation.py b/ax/modelbridge/cross_validation.py index 27b022594ce..fbdae071b33 100644 --- a/ax/modelbridge/cross_validation.py +++ b/ax/modelbridge/cross_validation.py @@ -535,7 +535,6 @@ def compute_model_fit_metrics_from_modelbridge( if generalization else _predict_on_training_data(model_bridge=model_bridge, experiment=experiment) ) - if fit_metrics_dict is None: fit_metrics_dict = { "coefficient_of_determination": coefficient_of_determination, diff --git a/ax/modelbridge/tests/test_model_fit_metrics.py b/ax/modelbridge/tests/test_model_fit_metrics.py index 1127e75ced4..96af249d946 100644 --- a/ax/modelbridge/tests/test_model_fit_metrics.py +++ b/ax/modelbridge/tests/test_model_fit_metrics.py @@ -9,17 +9,23 @@ import warnings from typing import cast, Dict +import numpy as np + from ax.core.experiment import Experiment from ax.core.objective import Objective from ax.core.optimization_config import OptimizationConfig from ax.metrics.branin import BraninMetric -from ax.modelbridge.cross_validation import compute_model_fit_metrics_from_modelbridge +from ax.modelbridge.cross_validation import ( + _predict_on_cross_validation_data, + compute_model_fit_metrics_from_modelbridge, +) from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.registry import Models from ax.runners.synthetic import SyntheticRunner from ax.service.scheduler import get_fitted_model_bridge, Scheduler, SchedulerOptions from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase +from ax.utils.stats.model_fit_stats import _entropy_via_kde, entropy_of_observations from ax.utils.testing.core_stubs import get_branin_search_space NUM_SOBOL = 5 @@ -83,6 +89,29 @@ def test_model_fit_metrics(self) -> None: std_branin = std["branin"] self.assertIsInstance(std_branin, float) + # checking non-default model-fit-metric + untransform = False + fit_metrics = compute_model_fit_metrics_from_modelbridge( + model_bridge=model_bridge, + experiment=scheduler.experiment, + generalization=True, + untransform=untransform, + fit_metrics_dict={"Entropy": entropy_of_observations}, + ) + entropy = fit_metrics.get("Entropy") + self.assertIsInstance(entropy, dict) + entropy = cast(Dict[str, float], entropy) + self.assertTrue("branin" in entropy) + entropy_branin = entropy["branin"] + self.assertIsInstance(entropy_branin, float) + + y_obs, _, _ = _predict_on_cross_validation_data( + model_bridge=model_bridge, untransform=untransform + ) + y_obs_branin = np.array(y_obs["branin"])[:, np.newaxis] + entropy_truth = _entropy_via_kde(y_obs_branin) + self.assertAlmostEqual(entropy_branin, entropy_truth) + # testing with empty metrics empty_metrics = compute_model_fit_metrics_from_modelbridge( model_bridge=model_bridge, diff --git a/ax/utils/stats/model_fit_stats.py b/ax/utils/stats/model_fit_stats.py index 5b3e1043bd8..34647e448d0 100644 --- a/ax/utils/stats/model_fit_stats.py +++ b/ax/utils/stats/model_fit_stats.py @@ -9,6 +9,7 @@ import numpy as np from scipy.stats import fisher_exact, norm, pearsonr, spearmanr +from sklearn.neighbors import KernelDensity """ ################################ Model Fit Metrics ############################### @@ -127,6 +128,50 @@ def std_of_the_standardized_error( return ((y_obs - y_pred) / se_pred).std() +def entropy_of_observations( + y_obs: np.ndarray, + y_pred: np.ndarray, + se_pred: np.ndarray, + bandwidth: float = 0.1, +) -> float: + """Computes the entropy of the observations y_obs using a kernel density estimator. + This can be used to quantify how "clustered" the outcomes are. NOTE: y_pred and + se_pred are not used, but are required for the API. + + Args: + y_obs: An array of observations for a single metric. + y_pred: An array of the predicted values corresponding to y_obs. + se_pred: An array of the standard errors of the predicted values. + bandwidth: The kernel bandwidth. Defaults to 0.1, which is a reasonable value + for standardized outcomes y_obs. The rank ordering of the results on a set + of y_obs data sets is not generally sensitive to the bandwidth, if it is + held fixed across the data sets. The absolute value of the results however + changes significantly with the bandwidth. + + Returns: + The scalar entropy of the observations. + """ + if y_obs.ndim == 1: + y_obs = y_obs[:, np.newaxis] + return _entropy_via_kde(y_obs, bandwidth=bandwidth) + + +def _entropy_via_kde(y: np.ndarray, bandwidth: float = 0.1) -> float: + """Computes the entropy of the kernel density estimate of the input data. + + Args: + y: An (n x m) array of observations. + bandwidth: The kernel bandwidth. + + Returns: + The scalar entropy of the kernel density estimate. + """ + kde = KernelDensity(kernel="gaussian", bandwidth=bandwidth) + kde.fit(y) + log_p = kde.score_samples(y) # computes the log probability of each data point + return -np.sum(np.exp(log_p) * log_p) # compute entropy, the negated sum of p log p + + def _mean_prediction_ci( y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray ) -> float: diff --git a/ax/utils/stats/tests/test_model_fit_stats.py b/ax/utils/stats/tests/test_model_fit_stats.py index 9d6b258eb46..324af2ef9fc 100644 --- a/ax/utils/stats/tests/test_model_fit_stats.py +++ b/ax/utils/stats/tests/test_model_fit_stats.py @@ -6,20 +6,56 @@ import numpy as np from ax.utils.common.testutils import TestCase -from ax.utils.stats.model_fit_stats import _fisher_exact_test_p +from ax.utils.stats.model_fit_stats import _fisher_exact_test_p, entropy_of_observations from scipy.stats import fisher_exact -class FisherExactTestTest(TestCase): +class TestModelFitStats(TestCase): + def test_entropy_of_observations(self) -> None: + np.random.seed(1234) + n = 16 + yc = np.ones(n) + yc[: n // 2] = -1 + yc += np.random.randn(n) * 0.05 + yr = np.random.randn(n) + + # standardize both observations + yc = yc / yc.std() + yr = yr / yr.std() + + ones = np.ones(n) + + # compute entropy of observations + ec = entropy_of_observations(y_obs=yc, y_pred=ones, se_pred=ones, bandwidth=0.1) + er = entropy_of_observations(y_obs=yr, y_pred=ones, se_pred=ones, bandwidth=0.1) + + # testing that the Gaussian distributed data has a much larger entropy than + # the clustered distribution + self.assertTrue(er - ec > 10.0) + + ec2 = entropy_of_observations( + y_obs=yc, y_pred=ones, se_pred=ones, bandwidth=0.2 + ) + er2 = entropy_of_observations( + y_obs=yr, y_pred=ones, se_pred=ones, bandwidth=0.2 + ) + # entropy increases with larger bandwidth + self.assertGreater(ec2, ec) + self.assertGreater(er2, er) + + # ordering of entropies stays the same, though the difference is smaller + self.assertTrue(er2 - ec2 > 3) + def test_contingency_table_construction(self) -> None: # Create a dummy set of observations and predictions y_obs = np.array([1, 3, 2, 5, 7, 3]) y_pred = np.array([2, 4, 1, 6, 8, 2.5]) + se_pred = np.full(len(y_obs), np.nan) # not used for fisher exact # Compute ground truth contingency table true_table = np.array([[2, 1], [1, 2]]) scipy_result = fisher_exact(true_table, alternative="greater")[1] - ax_result = _fisher_exact_test_p(y_obs, y_pred, se_pred=None) + ax_result = _fisher_exact_test_p(y_obs, y_pred, se_pred=se_pred) self.assertEqual(scipy_result, ax_result)