diff --git a/ax/utils/stats/model_fit_stats.py b/ax/utils/stats/model_fit_stats.py index 34647e448d0..fd09984a98e 100644 --- a/ax/utils/stats/model_fit_stats.py +++ b/ax/utils/stats/model_fit_stats.py @@ -5,12 +5,21 @@ # pyre-strict +from logging import Logger from typing import Dict, Mapping, Optional, Protocol import numpy as np + +from ax.utils.common.logger import get_logger from scipy.stats import fisher_exact, norm, pearsonr, spearmanr from sklearn.neighbors import KernelDensity + +logger: Logger = get_logger(__name__) + + +DEFAULT_KDE_BANDWIDTH = 0.1 # default bandwidth for kernel density estimators + """ ################################ Model Fit Metrics ############################### """ @@ -132,7 +141,7 @@ def entropy_of_observations( y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray, - bandwidth: float = 0.1, + bandwidth: float = DEFAULT_KDE_BANDWIDTH, ) -> float: """Computes the entropy of the observations y_obs using a kernel density estimator. This can be used to quantify how "clustered" the outcomes are. NOTE: y_pred and @@ -140,8 +149,8 @@ def entropy_of_observations( Args: y_obs: An array of observations for a single metric. - y_pred: An array of the predicted values corresponding to y_obs. - se_pred: An array of the standard errors of the predicted values. + y_pred: Unused. + se_pred: Unused. bandwidth: The kernel bandwidth. Defaults to 0.1, which is a reasonable value for standardized outcomes y_obs. The rank ordering of the results on a set of y_obs data sets is not generally sensitive to the bandwidth, if it is @@ -153,10 +162,20 @@ def entropy_of_observations( """ if y_obs.ndim == 1: y_obs = y_obs[:, np.newaxis] + + # Check if standardization was applied to the observations. + if bandwidth == DEFAULT_KDE_BANDWIDTH: + y_std = np.std(y_obs, axis=0, ddof=1) + if np.any(y_std < 0.5) or np.any(2.0 < y_std): # allowing a fudge factor of 2. + logger.warning( + "Standardization of observations was not applied. " + f"The default bandwidth of {DEFAULT_KDE_BANDWIDTH} is a reasonable " + "choice if observations are standardize, but may not be otherwise." + ) return _entropy_via_kde(y_obs, bandwidth=bandwidth) -def _entropy_via_kde(y: np.ndarray, bandwidth: float = 0.1) -> float: +def _entropy_via_kde(y: np.ndarray, bandwidth: float = DEFAULT_KDE_BANDWIDTH) -> float: """Computes the entropy of the kernel density estimate of the input data. Args: diff --git a/ax/utils/stats/tests/test_model_fit_stats.py b/ax/utils/stats/tests/test_model_fit_stats.py index 324af2ef9fc..6acfaf19aac 100644 --- a/ax/utils/stats/tests/test_model_fit_stats.py +++ b/ax/utils/stats/tests/test_model_fit_stats.py @@ -46,6 +46,23 @@ def test_entropy_of_observations(self) -> None: # ordering of entropies stays the same, though the difference is smaller self.assertTrue(er2 - ec2 > 3) + # test warning if y is not standardized + module_name = "ax.utils.stats.model_fit_stats" + expected_warning = ( + "WARNING:ax.utils.stats.model_fit_stats:Standardization of observations " + "was not applied. The default bandwidth of 0.1 is a reasonable " + "choice if observations are standardize, but may not be otherwise." + ) + with self.assertLogs(module_name, level="WARNING") as logger: + ec = entropy_of_observations(y_obs=10 * yc, y_pred=ones, se_pred=ones) + self.assertEqual(len(logger.output), 1) + self.assertEqual(logger.output[0], expected_warning) + + with self.assertLogs(module_name, level="WARNING") as logger: + ec = entropy_of_observations(y_obs=yc / 10, y_pred=ones, se_pred=ones) + self.assertEqual(len(logger.output), 1) + self.assertEqual(logger.output[0], expected_warning) + def test_contingency_table_construction(self) -> None: # Create a dummy set of observations and predictions y_obs = np.array([1, 3, 2, 5, 7, 3])