Skip to content

Commit

Permalink
Model uncertainty metrics for logging (#1741)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1741

This commit introduces `_model_std_quality`, which returns the reciprocal of the multiplicative worst over or under-estimation of the predictive uncertainty compared to the observed errors.

Reviewed By: Balandat

Differential Revision: D46969426

fbshipit-source-id: 04665aa46be90d1a256097526963c9b4e03a2c21
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Aug 15, 2023
1 parent 1bc1632 commit 6c8f653
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 40 deletions.
3 changes: 1 addition & 2 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dataclasses import dataclass, field

from logging import Logger
from typing import Any, cast, Dict, List, MutableMapping, Optional, Set, Tuple, Type
from typing import Any, Dict, List, MutableMapping, Optional, Set, Tuple, Type

import numpy as np
from ax.core.arm import Arm
Expand Down Expand Up @@ -990,7 +990,6 @@ def compute_model_fit_metrics(
"mean_of_the_standardized_error": mean_of_the_standardized_error,
"std_of_the_standardized_error": std_of_the_standardized_error,
}
fit_metrics_dict = cast(Dict[str, ModelFitMetricProtocol], fit_metrics_dict)

return compute_model_fit_metrics(
y_obs=y_obs,
Expand Down
41 changes: 10 additions & 31 deletions ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,7 @@

from logging import Logger
from numbers import Number
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
Mapping,
NamedTuple,
Optional,
Set,
Tuple,
)
from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple

import numpy as np
from ax.core.observation import Observation, ObservationData
Expand All @@ -40,7 +28,6 @@
_rank_correlation,
_total_raw_effect,
compute_model_fit_metrics,
ModelFitMetricProtocol,
)

logger: Logger = get_logger(__name__)
Expand Down Expand Up @@ -252,23 +239,15 @@ def compute_diagnostics(result: List[CVResult]) -> CVDiagnostics:
y_pred = _arrayify_dict_values(y_pred)
se_pred = _arrayify_dict_values(se_pred)

# We need to cast here since pyre infers specific types T < ModelFitMetricProtocol
# for the dict values, which is type variant upon initialization, leading
# diagnostic_fns to not be recognized as a Mapping[str, ModelFitMetricProtocol],
# see the last tip in the Pyre docs on [9] Incompatible Variable Type:
# https://staticdocs.internalfb.com/pyre/docs/errors/#9-incompatible-variable-type
diagnostic_fns = cast(
Mapping[str, ModelFitMetricProtocol],
{
MEAN_PREDICTION_CI: _mean_prediction_ci,
MAPE: _mape,
TOTAL_RAW_EFFECT: _total_raw_effect,
CORRELATION_COEFFICIENT: _correlation_coefficient,
RANK_CORRELATION: _rank_correlation,
FISHER_EXACT_TEST_P: _fisher_exact_test_p,
LOG_LIKELIHOOD: _log_likelihood,
},
)
diagnostic_fns = {
MEAN_PREDICTION_CI: _mean_prediction_ci,
MAPE: _mape,
TOTAL_RAW_EFFECT: _total_raw_effect,
CORRELATION_COEFFICIENT: _correlation_coefficient,
RANK_CORRELATION: _rank_correlation,
FISHER_EXACT_TEST_P: _fisher_exact_test_p,
LOG_LIKELIHOOD: _log_likelihood,
}
diagnostics = compute_model_fit_metrics(
y_obs=y_obs, y_pred=y_pred, se_pred=se_pred, fit_metrics_dict=diagnostic_fns
)
Expand Down
6 changes: 6 additions & 0 deletions ax/telemetry/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ class AxClientCompletedRecord:

best_point_quality: float
model_fit_quality: float
model_std_quality: float
model_fit_generalization: float
model_std_generalization: float

@classmethod
def from_ax_client(cls, ax_client: AxClient) -> AxClientCompletedRecord:
Expand All @@ -108,6 +111,9 @@ def from_ax_client(cls, ax_client: AxClient) -> AxClientCompletedRecord:
),
best_point_quality=float("-inf"), # TODO[T147907632]
model_fit_quality=float("-inf"), # TODO[T147907632]
model_std_quality=float("-inf"),
model_fit_generalization=float("-inf"), # TODO via cross_validate_by_trial
model_std_generalization=float("-inf"),
)

def flatten(self) -> Dict[str, Any]:
Expand Down
8 changes: 7 additions & 1 deletion ax/telemetry/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ class OptimizationCompletedRecord:
# SchedulerCompletedRecord fields
best_point_quality: float
model_fit_quality: float
model_std_quality: float
model_fit_generalization: float
model_std_generalization: float

num_metric_fetch_e_encountered: int
num_trials_bad_due_to_err: int
Expand Down Expand Up @@ -403,6 +406,9 @@ def _extract_model_fit_dict(
completed_record: Union[SchedulerCompletedRecord, AxClientCompletedRecord],
) -> Dict[str, float]:
model_fit_names = [
"model_fit_quality", # TODO: add calibration, generalization.
"model_fit_quality",
"model_std_quality",
"model_fit_generalization",
"model_std_generalization",
]
return {n: getattr(completed_record, n) for n in model_fit_names}
36 changes: 35 additions & 1 deletion ax/telemetry/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import Any, Dict, Optional
from warnings import warn

import numpy as np

from ax.service.scheduler import get_fitted_model_bridge, Scheduler
from ax.telemetry.common import _get_max_transformed_dimensionality

Expand Down Expand Up @@ -99,6 +101,9 @@ class SchedulerCompletedRecord:

best_point_quality: float
model_fit_quality: float
model_std_quality: float
model_fit_generalization: float
model_std_generalization: float

num_metric_fetch_e_encountered: int
num_trials_bad_due_to_err: int
Expand All @@ -118,18 +123,24 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCompletedRecord:
model_fit_quality = min(
model_fit_dict["coefficient_of_determination"].values()
)
# similar for uncertainty quantification, but distance from 1 matters
std = list(model_fit_dict["std_of_the_standardized_error"].values())
model_std_quality = _model_std_quality(np.array(std))

except Exception as e:
warn("Encountered exception in computing model fit quality: " + str(e))
model_fit_quality = float("-inf")
# model_std_quality = float("-inf")
model_std_quality = float("-inf")

return cls(
experiment_completed_record=ExperimentCompletedRecord.from_experiment(
experiment=scheduler.experiment
),
best_point_quality=float("-inf"), # TODO[T147907632]
model_fit_quality=model_fit_quality,
model_std_quality=model_std_quality,
model_fit_generalization=float("-inf"), # TODO by cross_validate_by_trial
model_std_generalization=float("-inf"),
num_metric_fetch_e_encountered=scheduler._num_metric_fetch_e_encountered,
num_trials_bad_due_to_err=scheduler._num_trials_bad_due_to_err,
)
Expand All @@ -146,3 +157,26 @@ def flatten(self) -> Dict[str, Any]:
**self_dict,
**experiment_completed_record_dict,
}


def _model_std_quality(std: np.ndarray) -> float:
"""Quantifies quality of the model uncertainty. A value of one means the
uncertainty is perfectly predictive of the true standard deviation of the error.
Values larger than one indicate over-estimation and negative values indicate
under-estimation of the true standard deviation of the error. In particular, a value
of 2 (resp. 1 / 2) represents an over-estimation (resp. under-estimation) of the
true standard deviation of the error by a factor of 2.
Args:
std: The standard deviation of the standardized error.
Returns:
The factor corresponding to the worst over- or under-estimation factor of the
standard deviation of the error among all experimentally observed metrics.
"""
max_std, min_std = np.max(std), np.min(std)
# comparing worst over-estimation factor with worst under-estimation factor
inv_model_std_quality = max_std if max_std > 1 / min_std else min_std
# reciprocal so that values greater than one indicate over-estimation and
# values smaller than indicate underestimation of the uncertainty.
return 1 / inv_model_std_quality
3 changes: 3 additions & 0 deletions ax/telemetry/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,5 +120,8 @@ def test_ax_client_completed_record_from_ax_client(self) -> None:
),
best_point_quality=float("-inf"),
model_fit_quality=float("-inf"),
model_std_quality=float("-inf"),
model_fit_generalization=float("-inf"),
model_std_generalization=float("-inf"),
)
self.assertEqual(record, expected)
1 change: 0 additions & 1 deletion ax/telemetry/tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def test_optimization_completed_record_from_ax_client(self) -> None:
estimated_early_stopping_savings=19,
estimated_global_stopping_savings=98,
)

expected_dict = {
**AxClientCompletedRecord.from_ax_client(ax_client=ax_client).flatten(),
"unique_identifier": "foo",
Expand Down
16 changes: 14 additions & 2 deletions ax/telemetry/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def test_scheduler_completed_record_from_scheduler(self) -> None:
),
best_point_quality=float("-inf"),
model_fit_quality=float("-inf"), # -inf because no model has been fit
model_std_quality=float("-inf"),
model_fit_generalization=float("-inf"),
model_std_generalization=float("-inf"),
num_metric_fetch_e_encountered=0,
num_trials_bad_due_to_err=0,
)
Expand All @@ -104,6 +107,9 @@ def test_scheduler_completed_record_from_scheduler(self) -> None:
).__dict__,
"best_point_quality": float("-inf"),
"model_fit_quality": float("-inf"),
"model_std_quality": float("-inf"),
"model_fit_generalization": float("-inf"),
"model_std_generalization": float("-inf"),
"num_metric_fetch_e_encountered": 0,
"num_trials_bad_due_to_err": 0,
}
Expand Down Expand Up @@ -163,15 +169,18 @@ def test_scheduler_model_fit_metrics_logging(self) -> None:
self.assertTrue("branin" in std)
std_branin = std["branin"]
self.assertIsInstance(std_branin, float)
# log_std_branin = math.log10(std_branin)
# model_std_quality = -log_std_branin # align positivity with over-estimation

model_std_quality = 1 / std_branin

expected = SchedulerCompletedRecord(
experiment_completed_record=ExperimentCompletedRecord.from_experiment(
experiment=scheduler.experiment
),
best_point_quality=float("-inf"),
model_fit_quality=r2_branin,
model_std_quality=model_std_quality,
model_fit_generalization=float("-inf"),
model_std_generalization=float("-inf"),
num_metric_fetch_e_encountered=0,
num_trials_bad_due_to_err=0,
)
Expand All @@ -184,6 +193,9 @@ def test_scheduler_model_fit_metrics_logging(self) -> None:
).__dict__,
"best_point_quality": float("-inf"),
"model_fit_quality": r2_branin,
"model_std_quality": model_std_quality,
"model_fit_generalization": float("-inf"),
"model_std_generalization": float("-inf"),
"num_metric_fetch_e_encountered": 0,
"num_trials_bad_due_to_err": 0,
}
Expand Down
5 changes: 3 additions & 2 deletions ax/utils/stats/model_fit_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
class ModelFitMetricProtocol(Protocol):
"""Structural type for model fit metrics."""

@staticmethod
def __call__(y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray) -> float:
def __call__(
self, y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
) -> float:
...


Expand Down

0 comments on commit 6c8f653

Please sign in to comment.