Skip to content

Commit

Permalink
Entropy of observations metric (#2340)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2340

This commit introduces `entropy_of_observations` as a model fit metric. It quantifies the entropy of the outcomes `y_obs` using a kernel density estimator. This metric can be useful in detecting datasets in which the outcomes are clustered (implying a low entropy), rather than uniformly distributed in the outcome space (high entropy).

Reviewed By: saitcakmak

Differential Revision: D55930954

fbshipit-source-id: e2f839e6856879bae7268637966ca160aa8d6b74
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Apr 13, 2024
1 parent cc2a7ad commit cefe7bf
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 5 deletions.
1 change: 0 additions & 1 deletion ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,6 @@ def compute_model_fit_metrics_from_modelbridge(
if generalization
else _predict_on_training_data(model_bridge=model_bridge, experiment=experiment)
)

if fit_metrics_dict is None:
fit_metrics_dict = {
"coefficient_of_determination": coefficient_of_determination,
Expand Down
31 changes: 30 additions & 1 deletion ax/modelbridge/tests/test_model_fit_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,23 @@
import warnings
from typing import cast, Dict

import numpy as np

from ax.core.experiment import Experiment
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.metrics.branin import BraninMetric
from ax.modelbridge.cross_validation import compute_model_fit_metrics_from_modelbridge
from ax.modelbridge.cross_validation import (
_predict_on_cross_validation_data,
compute_model_fit_metrics_from_modelbridge,
)
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.runners.synthetic import SyntheticRunner
from ax.service.scheduler import get_fitted_model_bridge, Scheduler, SchedulerOptions
from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
from ax.utils.stats.model_fit_stats import _entropy_via_kde, entropy_of_observations
from ax.utils.testing.core_stubs import get_branin_search_space

NUM_SOBOL = 5
Expand Down Expand Up @@ -83,6 +89,29 @@ def test_model_fit_metrics(self) -> None:
std_branin = std["branin"]
self.assertIsInstance(std_branin, float)

# checking non-default model-fit-metric
untransform = False
fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=scheduler.experiment,
generalization=True,
untransform=untransform,
fit_metrics_dict={"Entropy": entropy_of_observations},
)
entropy = fit_metrics.get("Entropy")
self.assertIsInstance(entropy, dict)
entropy = cast(Dict[str, float], entropy)
self.assertTrue("branin" in entropy)
entropy_branin = entropy["branin"]
self.assertIsInstance(entropy_branin, float)

y_obs, _, _ = _predict_on_cross_validation_data(
model_bridge=model_bridge, untransform=untransform
)
y_obs_branin = np.array(y_obs["branin"])[:, np.newaxis]
entropy_truth = _entropy_via_kde(y_obs_branin)
self.assertAlmostEqual(entropy_branin, entropy_truth)

# testing with empty metrics
empty_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
Expand Down
45 changes: 45 additions & 0 deletions ax/utils/stats/model_fit_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
from scipy.stats import fisher_exact, norm, pearsonr, spearmanr
from sklearn.neighbors import KernelDensity

"""
################################ Model Fit Metrics ###############################
Expand Down Expand Up @@ -127,6 +128,50 @@ def std_of_the_standardized_error(
return ((y_obs - y_pred) / se_pred).std()


def entropy_of_observations(
y_obs: np.ndarray,
y_pred: np.ndarray,
se_pred: np.ndarray,
bandwidth: float = 0.1,
) -> float:
"""Computes the entropy of the observations y_obs using a kernel density estimator.
This can be used to quantify how "clustered" the outcomes are. NOTE: y_pred and
se_pred are not used, but are required for the API.
Args:
y_obs: An array of observations for a single metric.
y_pred: An array of the predicted values corresponding to y_obs.
se_pred: An array of the standard errors of the predicted values.
bandwidth: The kernel bandwidth. Defaults to 0.1, which is a reasonable value
for standardized outcomes y_obs. The rank ordering of the results on a set
of y_obs data sets is not generally sensitive to the bandwidth, if it is
held fixed across the data sets. The absolute value of the results however
changes significantly with the bandwidth.
Returns:
The scalar entropy of the observations.
"""
if y_obs.ndim == 1:
y_obs = y_obs[:, np.newaxis]
return _entropy_via_kde(y_obs, bandwidth=bandwidth)


def _entropy_via_kde(y: np.ndarray, bandwidth: float = 0.1) -> float:
"""Computes the entropy of the kernel density estimate of the input data.
Args:
y: An (n x m) array of observations.
bandwidth: The kernel bandwidth.
Returns:
The scalar entropy of the kernel density estimate.
"""
kde = KernelDensity(kernel="gaussian", bandwidth=bandwidth)
kde.fit(y)
log_p = kde.score_samples(y) # computes the log probability of each data point
return -np.sum(np.exp(log_p) * log_p) # compute entropy, the negated sum of p log p


def _mean_prediction_ci(
y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
) -> float:
Expand Down
42 changes: 39 additions & 3 deletions ax/utils/stats/tests/test_model_fit_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,56 @@

import numpy as np
from ax.utils.common.testutils import TestCase
from ax.utils.stats.model_fit_stats import _fisher_exact_test_p
from ax.utils.stats.model_fit_stats import _fisher_exact_test_p, entropy_of_observations
from scipy.stats import fisher_exact


class FisherExactTestTest(TestCase):
class TestModelFitStats(TestCase):
def test_entropy_of_observations(self) -> None:
np.random.seed(1234)
n = 16
yc = np.ones(n)
yc[: n // 2] = -1
yc += np.random.randn(n) * 0.05
yr = np.random.randn(n)

# standardize both observations
yc = yc / yc.std()
yr = yr / yr.std()

ones = np.ones(n)

# compute entropy of observations
ec = entropy_of_observations(y_obs=yc, y_pred=ones, se_pred=ones, bandwidth=0.1)
er = entropy_of_observations(y_obs=yr, y_pred=ones, se_pred=ones, bandwidth=0.1)

# testing that the Gaussian distributed data has a much larger entropy than
# the clustered distribution
self.assertTrue(er - ec > 10.0)

ec2 = entropy_of_observations(
y_obs=yc, y_pred=ones, se_pred=ones, bandwidth=0.2
)
er2 = entropy_of_observations(
y_obs=yr, y_pred=ones, se_pred=ones, bandwidth=0.2
)
# entropy increases with larger bandwidth
self.assertGreater(ec2, ec)
self.assertGreater(er2, er)

# ordering of entropies stays the same, though the difference is smaller
self.assertTrue(er2 - ec2 > 3)

def test_contingency_table_construction(self) -> None:
# Create a dummy set of observations and predictions
y_obs = np.array([1, 3, 2, 5, 7, 3])
y_pred = np.array([2, 4, 1, 6, 8, 2.5])
se_pred = np.full(len(y_obs), np.nan) # not used for fisher exact

# Compute ground truth contingency table
true_table = np.array([[2, 1], [1, 2]])

scipy_result = fisher_exact(true_table, alternative="greater")[1]
ax_result = _fisher_exact_test_p(y_obs, y_pred, se_pred=None)
ax_result = _fisher_exact_test_p(y_obs, y_pred, se_pred=se_pred)

self.assertEqual(scipy_result, ax_result)

0 comments on commit cefe7bf

Please sign in to comment.