From 6c8f653c33b6035341a5c9d8a329dc8da6c78c61 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Tue, 15 Aug 2023 12:51:38 -0700 Subject: [PATCH] Model uncertainty metrics for logging (#1741) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/1741 This commit introduces `_model_std_quality`, which returns the reciprocal of the multiplicative worst over or under-estimation of the predictive uncertainty compared to the observed errors. Reviewed By: Balandat Differential Revision: D46969426 fbshipit-source-id: 04665aa46be90d1a256097526963c9b4e03a2c21 --- ax/modelbridge/base.py | 3 +- ax/modelbridge/cross_validation.py | 41 ++++++------------------- ax/telemetry/ax_client.py | 6 ++++ ax/telemetry/optimization.py | 8 ++++- ax/telemetry/scheduler.py | 36 +++++++++++++++++++++- ax/telemetry/tests/test_ax_client.py | 3 ++ ax/telemetry/tests/test_optimization.py | 1 - ax/telemetry/tests/test_scheduler.py | 16 ++++++++-- ax/utils/stats/model_fit_stats.py | 5 +-- 9 files changed, 79 insertions(+), 40 deletions(-) diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index 373f7c6f679..3917fff9568 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -12,7 +12,7 @@ from dataclasses import dataclass, field from logging import Logger -from typing import Any, cast, Dict, List, MutableMapping, Optional, Set, Tuple, Type +from typing import Any, Dict, List, MutableMapping, Optional, Set, Tuple, Type import numpy as np from ax.core.arm import Arm @@ -990,7 +990,6 @@ def compute_model_fit_metrics( "mean_of_the_standardized_error": mean_of_the_standardized_error, "std_of_the_standardized_error": std_of_the_standardized_error, } - fit_metrics_dict = cast(Dict[str, ModelFitMetricProtocol], fit_metrics_dict) return compute_model_fit_metrics( y_obs=y_obs, diff --git a/ax/modelbridge/cross_validation.py b/ax/modelbridge/cross_validation.py index 155d1375470..4e6bec34a8f 100644 --- a/ax/modelbridge/cross_validation.py +++ b/ax/modelbridge/cross_validation.py @@ -11,19 +11,7 @@ from logging import Logger from numbers import Number -from typing import ( - Any, - Callable, - cast, - Dict, - Iterable, - List, - Mapping, - NamedTuple, - Optional, - Set, - Tuple, -) +from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple import numpy as np from ax.core.observation import Observation, ObservationData @@ -40,7 +28,6 @@ _rank_correlation, _total_raw_effect, compute_model_fit_metrics, - ModelFitMetricProtocol, ) logger: Logger = get_logger(__name__) @@ -252,23 +239,15 @@ def compute_diagnostics(result: List[CVResult]) -> CVDiagnostics: y_pred = _arrayify_dict_values(y_pred) se_pred = _arrayify_dict_values(se_pred) - # We need to cast here since pyre infers specific types T < ModelFitMetricProtocol - # for the dict values, which is type variant upon initialization, leading - # diagnostic_fns to not be recognized as a Mapping[str, ModelFitMetricProtocol], - # see the last tip in the Pyre docs on [9] Incompatible Variable Type: - # https://staticdocs.internalfb.com/pyre/docs/errors/#9-incompatible-variable-type - diagnostic_fns = cast( - Mapping[str, ModelFitMetricProtocol], - { - MEAN_PREDICTION_CI: _mean_prediction_ci, - MAPE: _mape, - TOTAL_RAW_EFFECT: _total_raw_effect, - CORRELATION_COEFFICIENT: _correlation_coefficient, - RANK_CORRELATION: _rank_correlation, - FISHER_EXACT_TEST_P: _fisher_exact_test_p, - LOG_LIKELIHOOD: _log_likelihood, - }, - ) + diagnostic_fns = { + MEAN_PREDICTION_CI: _mean_prediction_ci, + MAPE: _mape, + TOTAL_RAW_EFFECT: _total_raw_effect, + CORRELATION_COEFFICIENT: _correlation_coefficient, + RANK_CORRELATION: _rank_correlation, + FISHER_EXACT_TEST_P: _fisher_exact_test_p, + LOG_LIKELIHOOD: _log_likelihood, + } diagnostics = compute_model_fit_metrics( y_obs=y_obs, y_pred=y_pred, se_pred=se_pred, fit_metrics_dict=diagnostic_fns ) diff --git a/ax/telemetry/ax_client.py b/ax/telemetry/ax_client.py index 6ee63f503ec..211729fa909 100644 --- a/ax/telemetry/ax_client.py +++ b/ax/telemetry/ax_client.py @@ -99,6 +99,9 @@ class AxClientCompletedRecord: best_point_quality: float model_fit_quality: float + model_std_quality: float + model_fit_generalization: float + model_std_generalization: float @classmethod def from_ax_client(cls, ax_client: AxClient) -> AxClientCompletedRecord: @@ -108,6 +111,9 @@ def from_ax_client(cls, ax_client: AxClient) -> AxClientCompletedRecord: ), best_point_quality=float("-inf"), # TODO[T147907632] model_fit_quality=float("-inf"), # TODO[T147907632] + model_std_quality=float("-inf"), + model_fit_generalization=float("-inf"), # TODO via cross_validate_by_trial + model_std_generalization=float("-inf"), ) def flatten(self) -> Dict[str, Any]: diff --git a/ax/telemetry/optimization.py b/ax/telemetry/optimization.py index 9bf500719d8..9fc8e663b57 100644 --- a/ax/telemetry/optimization.py +++ b/ax/telemetry/optimization.py @@ -301,6 +301,9 @@ class OptimizationCompletedRecord: # SchedulerCompletedRecord fields best_point_quality: float model_fit_quality: float + model_std_quality: float + model_fit_generalization: float + model_std_generalization: float num_metric_fetch_e_encountered: int num_trials_bad_due_to_err: int @@ -403,6 +406,9 @@ def _extract_model_fit_dict( completed_record: Union[SchedulerCompletedRecord, AxClientCompletedRecord], ) -> Dict[str, float]: model_fit_names = [ - "model_fit_quality", # TODO: add calibration, generalization. + "model_fit_quality", + "model_std_quality", + "model_fit_generalization", + "model_std_generalization", ] return {n: getattr(completed_record, n) for n in model_fit_names} diff --git a/ax/telemetry/scheduler.py b/ax/telemetry/scheduler.py index f529c8b1827..a51491f456d 100644 --- a/ax/telemetry/scheduler.py +++ b/ax/telemetry/scheduler.py @@ -9,6 +9,8 @@ from typing import Any, Dict, Optional from warnings import warn +import numpy as np + from ax.service.scheduler import get_fitted_model_bridge, Scheduler from ax.telemetry.common import _get_max_transformed_dimensionality @@ -99,6 +101,9 @@ class SchedulerCompletedRecord: best_point_quality: float model_fit_quality: float + model_std_quality: float + model_fit_generalization: float + model_std_generalization: float num_metric_fetch_e_encountered: int num_trials_bad_due_to_err: int @@ -118,11 +123,14 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCompletedRecord: model_fit_quality = min( model_fit_dict["coefficient_of_determination"].values() ) + # similar for uncertainty quantification, but distance from 1 matters + std = list(model_fit_dict["std_of_the_standardized_error"].values()) + model_std_quality = _model_std_quality(np.array(std)) except Exception as e: warn("Encountered exception in computing model fit quality: " + str(e)) model_fit_quality = float("-inf") - # model_std_quality = float("-inf") + model_std_quality = float("-inf") return cls( experiment_completed_record=ExperimentCompletedRecord.from_experiment( @@ -130,6 +138,9 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCompletedRecord: ), best_point_quality=float("-inf"), # TODO[T147907632] model_fit_quality=model_fit_quality, + model_std_quality=model_std_quality, + model_fit_generalization=float("-inf"), # TODO by cross_validate_by_trial + model_std_generalization=float("-inf"), num_metric_fetch_e_encountered=scheduler._num_metric_fetch_e_encountered, num_trials_bad_due_to_err=scheduler._num_trials_bad_due_to_err, ) @@ -146,3 +157,26 @@ def flatten(self) -> Dict[str, Any]: **self_dict, **experiment_completed_record_dict, } + + +def _model_std_quality(std: np.ndarray) -> float: + """Quantifies quality of the model uncertainty. A value of one means the + uncertainty is perfectly predictive of the true standard deviation of the error. + Values larger than one indicate over-estimation and negative values indicate + under-estimation of the true standard deviation of the error. In particular, a value + of 2 (resp. 1 / 2) represents an over-estimation (resp. under-estimation) of the + true standard deviation of the error by a factor of 2. + + Args: + std: The standard deviation of the standardized error. + + Returns: + The factor corresponding to the worst over- or under-estimation factor of the + standard deviation of the error among all experimentally observed metrics. + """ + max_std, min_std = np.max(std), np.min(std) + # comparing worst over-estimation factor with worst under-estimation factor + inv_model_std_quality = max_std if max_std > 1 / min_std else min_std + # reciprocal so that values greater than one indicate over-estimation and + # values smaller than indicate underestimation of the uncertainty. + return 1 / inv_model_std_quality diff --git a/ax/telemetry/tests/test_ax_client.py b/ax/telemetry/tests/test_ax_client.py index 3dd60e36edf..ecebc66760f 100644 --- a/ax/telemetry/tests/test_ax_client.py +++ b/ax/telemetry/tests/test_ax_client.py @@ -120,5 +120,8 @@ def test_ax_client_completed_record_from_ax_client(self) -> None: ), best_point_quality=float("-inf"), model_fit_quality=float("-inf"), + model_std_quality=float("-inf"), + model_fit_generalization=float("-inf"), + model_std_generalization=float("-inf"), ) self.assertEqual(record, expected) diff --git a/ax/telemetry/tests/test_optimization.py b/ax/telemetry/tests/test_optimization.py index 9cb57ce8b81..01c3055b6d2 100644 --- a/ax/telemetry/tests/test_optimization.py +++ b/ax/telemetry/tests/test_optimization.py @@ -149,7 +149,6 @@ def test_optimization_completed_record_from_ax_client(self) -> None: estimated_early_stopping_savings=19, estimated_global_stopping_savings=98, ) - expected_dict = { **AxClientCompletedRecord.from_ax_client(ax_client=ax_client).flatten(), "unique_identifier": "foo", diff --git a/ax/telemetry/tests/test_scheduler.py b/ax/telemetry/tests/test_scheduler.py index 5d66ae8a5c8..f46fedbaac1 100644 --- a/ax/telemetry/tests/test_scheduler.py +++ b/ax/telemetry/tests/test_scheduler.py @@ -92,6 +92,9 @@ def test_scheduler_completed_record_from_scheduler(self) -> None: ), best_point_quality=float("-inf"), model_fit_quality=float("-inf"), # -inf because no model has been fit + model_std_quality=float("-inf"), + model_fit_generalization=float("-inf"), + model_std_generalization=float("-inf"), num_metric_fetch_e_encountered=0, num_trials_bad_due_to_err=0, ) @@ -104,6 +107,9 @@ def test_scheduler_completed_record_from_scheduler(self) -> None: ).__dict__, "best_point_quality": float("-inf"), "model_fit_quality": float("-inf"), + "model_std_quality": float("-inf"), + "model_fit_generalization": float("-inf"), + "model_std_generalization": float("-inf"), "num_metric_fetch_e_encountered": 0, "num_trials_bad_due_to_err": 0, } @@ -163,8 +169,8 @@ def test_scheduler_model_fit_metrics_logging(self) -> None: self.assertTrue("branin" in std) std_branin = std["branin"] self.assertIsInstance(std_branin, float) - # log_std_branin = math.log10(std_branin) - # model_std_quality = -log_std_branin # align positivity with over-estimation + + model_std_quality = 1 / std_branin expected = SchedulerCompletedRecord( experiment_completed_record=ExperimentCompletedRecord.from_experiment( @@ -172,6 +178,9 @@ def test_scheduler_model_fit_metrics_logging(self) -> None: ), best_point_quality=float("-inf"), model_fit_quality=r2_branin, + model_std_quality=model_std_quality, + model_fit_generalization=float("-inf"), + model_std_generalization=float("-inf"), num_metric_fetch_e_encountered=0, num_trials_bad_due_to_err=0, ) @@ -184,6 +193,9 @@ def test_scheduler_model_fit_metrics_logging(self) -> None: ).__dict__, "best_point_quality": float("-inf"), "model_fit_quality": r2_branin, + "model_std_quality": model_std_quality, + "model_fit_generalization": float("-inf"), + "model_std_generalization": float("-inf"), "num_metric_fetch_e_encountered": 0, "num_trials_bad_due_to_err": 0, } diff --git a/ax/utils/stats/model_fit_stats.py b/ax/utils/stats/model_fit_stats.py index 9dc5e0f8ad2..433a17907f8 100644 --- a/ax/utils/stats/model_fit_stats.py +++ b/ax/utils/stats/model_fit_stats.py @@ -16,8 +16,9 @@ class ModelFitMetricProtocol(Protocol): """Structural type for model fit metrics.""" - @staticmethod - def __call__(y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray) -> float: + def __call__( + self, y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray + ) -> float: ...