Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardization check for entropy of observations #2366

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions ax/utils/stats/model_fit_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,21 @@

# pyre-strict

from logging import Logger
from typing import Dict, Mapping, Optional, Protocol

import numpy as np

from ax.utils.common.logger import get_logger
from scipy.stats import fisher_exact, norm, pearsonr, spearmanr
from sklearn.neighbors import KernelDensity


logger: Logger = get_logger(__name__)


DEFAULT_KDE_BANDWIDTH = 0.1 # default bandwidth for kernel density estimators

"""
################################ Model Fit Metrics ###############################
"""
Expand Down Expand Up @@ -132,16 +141,16 @@ def entropy_of_observations(
y_obs: np.ndarray,
y_pred: np.ndarray,
se_pred: np.ndarray,
bandwidth: float = 0.1,
bandwidth: float = DEFAULT_KDE_BANDWIDTH,
) -> float:
"""Computes the entropy of the observations y_obs using a kernel density estimator.
This can be used to quantify how "clustered" the outcomes are. NOTE: y_pred and
se_pred are not used, but are required for the API.

Args:
y_obs: An array of observations for a single metric.
y_pred: An array of the predicted values corresponding to y_obs.
se_pred: An array of the standard errors of the predicted values.
y_pred: Unused.
se_pred: Unused.
bandwidth: The kernel bandwidth. Defaults to 0.1, which is a reasonable value
for standardized outcomes y_obs. The rank ordering of the results on a set
of y_obs data sets is not generally sensitive to the bandwidth, if it is
Expand All @@ -153,10 +162,20 @@ def entropy_of_observations(
"""
if y_obs.ndim == 1:
y_obs = y_obs[:, np.newaxis]

# Check if standardization was applied to the observations.
if bandwidth == DEFAULT_KDE_BANDWIDTH:
y_std = np.std(y_obs, axis=0, ddof=1)
if np.any(y_std < 0.5) or np.any(2.0 < y_std): # allowing a fudge factor of 2.
logger.warning(
"Standardization of observations was not applied. "
f"The default bandwidth of {DEFAULT_KDE_BANDWIDTH} is a reasonable "
"choice if observations are standardize, but may not be otherwise."
)
return _entropy_via_kde(y_obs, bandwidth=bandwidth)


def _entropy_via_kde(y: np.ndarray, bandwidth: float = 0.1) -> float:
def _entropy_via_kde(y: np.ndarray, bandwidth: float = DEFAULT_KDE_BANDWIDTH) -> float:
"""Computes the entropy of the kernel density estimate of the input data.

Args:
Expand Down
17 changes: 17 additions & 0 deletions ax/utils/stats/tests/test_model_fit_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,23 @@ def test_entropy_of_observations(self) -> None:
# ordering of entropies stays the same, though the difference is smaller
self.assertTrue(er2 - ec2 > 3)

# test warning if y is not standardized
module_name = "ax.utils.stats.model_fit_stats"
expected_warning = (
"WARNING:ax.utils.stats.model_fit_stats:Standardization of observations "
"was not applied. The default bandwidth of 0.1 is a reasonable "
"choice if observations are standardize, but may not be otherwise."
)
with self.assertLogs(module_name, level="WARNING") as logger:
ec = entropy_of_observations(y_obs=10 * yc, y_pred=ones, se_pred=ones)
self.assertEqual(len(logger.output), 1)
self.assertEqual(logger.output[0], expected_warning)

with self.assertLogs(module_name, level="WARNING") as logger:
ec = entropy_of_observations(y_obs=yc / 10, y_pred=ones, se_pred=ones)
self.assertEqual(len(logger.output), 1)
self.assertEqual(logger.output[0], expected_warning)

def test_contingency_table_construction(self) -> None:
# Create a dummy set of observations and predictions
y_obs = np.array([1, 3, 2, 5, 7, 3])
Expand Down