forked from tinkoff-ai/etna
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement
ConformalPredictionIntervals
(#152)
* added `residuals_matrices` function * updated tests * added `ConformalPredictionIntervals` * updated tests * updated documentation * updated changelog
- Loading branch information
Showing
9 changed files
with
342 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from etna.experimental.prediction_intervals.base import BasePredictionIntervals | ||
from etna.experimental.prediction_intervals.conformal import ConformalPredictionIntervals | ||
from etna.experimental.prediction_intervals.naive_variance import NaiveVariancePredictionIntervals |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from typing import Sequence | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from etna.datasets import TSDataset | ||
from etna.experimental.prediction_intervals import BasePredictionIntervals | ||
from etna.experimental.prediction_intervals.utils import residuals_matrices | ||
from etna.pipeline import BasePipeline | ||
|
||
|
||
class ConformalPredictionIntervals(BasePredictionIntervals): | ||
"""Estimate conformal prediction intervals using absolute values of historical residuals. | ||
1. Compute matrix of absolute residuals :math:`r_{it} = |\hat y_{it} - y_{it}|` using k-fold backtest, where :math:`i` is fold index. | ||
2. Estimate corresponding quantiles levels using the provided coverage (e.g. apply Bonferroni correction). | ||
3. Estimate quantiles for each timestamp using computed absolute residuals and levels. | ||
`Relevant paper <https://proceedings.neurips.cc/paper/2021/file/312f1ba2a72318edaaa995a67835fad5-Paper.pdf>`_. | ||
`Reference implementation <https://www.sktime.net/en/stable/api_reference/auto_generated/sktime.forecasting.conformal.ConformalIntervals.html>`_. | ||
""" | ||
|
||
def __init__( | ||
self, pipeline: BasePipeline, coverage: float = 0.95, bonferroni_correction: bool = False, stride: int = 1 | ||
): | ||
"""Initialize instance of ``ConformalPredictionIntervals`` with given parameters. | ||
Parameters | ||
---------- | ||
pipeline: | ||
Base pipeline or ensemble for prediction intervals estimation. | ||
coverage: | ||
Interval coverage. In literature this value maybe referred as ``1 - alpha``. | ||
bonferroni_correction: | ||
Whether to use Bonferroni correction when estimating quantiles. | ||
stride: | ||
Number of points between folds. | ||
""" | ||
if not (0 <= coverage <= 1): | ||
raise ValueError("Parameter `coverage` must be non-negative number not greater than 1!") | ||
|
||
if stride <= 0: | ||
raise ValueError("Parameter `stride` must be positive!") | ||
|
||
self.coverage = coverage | ||
self.bonferroni_correction = bonferroni_correction | ||
self.stride = stride | ||
|
||
super().__init__(pipeline=pipeline) | ||
|
||
def _forecast_prediction_interval( | ||
self, ts: TSDataset, predictions: TSDataset, quantiles: Sequence[float], n_folds: int | ||
) -> TSDataset: | ||
"""Estimate and store prediction intervals. | ||
Parameters | ||
---------- | ||
ts: | ||
Dataset to forecast. | ||
predictions: | ||
Dataset with point predictions. | ||
quantiles: | ||
Levels of prediction distribution. | ||
n_folds: | ||
Number of folds to use in the backtest for prediction interval estimation. | ||
Returns | ||
------- | ||
: | ||
Dataset with predictions. | ||
""" | ||
residuals = residuals_matrices(pipeline=self, ts=ts, n_folds=n_folds, stride=self.stride) | ||
abs_residuals = np.abs(residuals) | ||
|
||
level = self.coverage | ||
if self.bonferroni_correction: | ||
level = 1 - (1 - self.coverage) / self.horizon | ||
|
||
critical_scores = np.quantile(abs_residuals, q=level, axis=0) | ||
|
||
upper_border = predictions[:, :, "target"] + critical_scores | ||
upper_border.rename({"target": "target_upper"}, inplace=True, axis=1) | ||
|
||
lower_border = predictions[:, :, "target"] - critical_scores | ||
lower_border.rename({"target": "target_lower"}, inplace=True, axis=1) | ||
|
||
intervals_df = pd.concat([lower_border, upper_border], axis=1) | ||
predictions.add_prediction_intervals(prediction_intervals_df=intervals_df) | ||
return predictions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from etna.datasets import TSDataset | ||
from etna.pipeline import BasePipeline | ||
|
||
|
||
def residuals_matrices( | ||
pipeline: BasePipeline, ts: TSDataset, n_folds: int = 5, stride: Optional[int] = None | ||
) -> np.ndarray: | ||
"""Estimate residuals matrices with backtest. | ||
Parameters | ||
---------- | ||
pipeline: | ||
Pipeline for residuals estimation. | ||
ts: | ||
Dataset to estimate residuals. | ||
n_folds: | ||
Number of folds for backtest. | ||
stride: | ||
Number of points between folds. By default, is set to ``horizon``. | ||
Returns | ||
------- | ||
: | ||
Residuals matrices for each segment. Array with shape: ``(n_folds, horizon, n_segments)``. | ||
""" | ||
if n_folds <= 0: | ||
raise ValueError("Parameter `n_folds` must be positive!") | ||
|
||
if stride is not None and stride <= 0: | ||
raise ValueError("Parameter `stride` must be positive!") | ||
|
||
backtest_forecasts = pipeline.get_historical_forecasts(ts=ts, n_folds=n_folds, stride=stride) | ||
|
||
residuals = backtest_forecasts.loc[:, pd.IndexSlice[:, "target"]] - ts[backtest_forecasts.index, :, "target"] | ||
|
||
# shape: (n_folds, horizon, n_segments) | ||
residual_matrices = residuals.values.reshape((-1, pipeline.horizon, len(ts.segments))) | ||
return residual_matrices |
146 changes: 146 additions & 0 deletions
146
tests/test_experimental/test_prediction_intervals/test_conformal.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
from unittest.mock import MagicMock | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
from etna.ensembles import DirectEnsemble | ||
from etna.ensembles import StackingEnsemble | ||
from etna.ensembles import VotingEnsemble | ||
from etna.experimental.prediction_intervals import ConformalPredictionIntervals | ||
from etna.models import NaiveModel | ||
from etna.pipeline import AutoRegressivePipeline | ||
from etna.pipeline import HierarchicalPipeline | ||
from etna.pipeline import Pipeline | ||
from etna.reconciliation import BottomUpReconciliator | ||
from tests.test_experimental.test_prediction_intervals.common import get_arima_pipeline | ||
from tests.test_experimental.test_prediction_intervals.common import get_catboost_pipeline | ||
from tests.test_experimental.test_prediction_intervals.common import get_naive_pipeline | ||
from tests.test_experimental.test_prediction_intervals.common import get_naive_pipeline_with_transforms | ||
from tests.test_experimental.test_prediction_intervals.common import run_base_pipeline_compat_check | ||
|
||
|
||
@pytest.mark.parametrize("stride", (-1, 0)) | ||
def test_invalid_stride_parameter_error(stride): | ||
with pytest.raises(ValueError, match="Parameter `stride` must be positive!"): | ||
ConformalPredictionIntervals(pipeline=Pipeline(model=NaiveModel()), stride=stride) | ||
|
||
|
||
@pytest.mark.parametrize("coverage", (-3, -1)) | ||
def test_invalid_coverage_parameter_error(coverage): | ||
with pytest.raises(ValueError, match="Parameter `coverage` must be non-negative"): | ||
ConformalPredictionIntervals(pipeline=Pipeline(model=NaiveModel()), coverage=coverage) | ||
|
||
|
||
@pytest.mark.parametrize("pipeline_name", ("naive_pipeline", "naive_pipeline_with_transforms")) | ||
def test_pipeline_fit_forecast_without_intervals(example_tsds, pipeline_name, request): | ||
pipeline = request.getfixturevalue(pipeline_name) | ||
|
||
intervals_pipeline = ConformalPredictionIntervals(pipeline=pipeline) | ||
|
||
intervals_pipeline.fit(ts=example_tsds) | ||
|
||
intervals_pipeline_pred = intervals_pipeline.forecast(prediction_interval=False) | ||
pipeline_pred = pipeline.forecast(prediction_interval=False) | ||
|
||
pd.testing.assert_frame_equal(intervals_pipeline_pred.df, pipeline_pred.df) | ||
|
||
|
||
@pytest.mark.parametrize("stride", (2, 5, 10)) | ||
@pytest.mark.parametrize("expected_columns", ({"target", "target_lower", "target_upper"},)) | ||
def test_valid_strides(example_tsds, expected_columns, stride): | ||
intervals_pipeline = ConformalPredictionIntervals(pipeline=Pipeline(model=NaiveModel(), horizon=5), stride=stride) | ||
run_base_pipeline_compat_check( | ||
ts=example_tsds, intervals_pipeline=intervals_pipeline, expected_columns=expected_columns | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"expected_columns", | ||
({"target", "target_lower", "target_upper"},), | ||
) | ||
@pytest.mark.parametrize( | ||
"pipeline", | ||
( | ||
get_naive_pipeline(horizon=1), | ||
get_naive_pipeline_with_transforms(horizon=1), | ||
AutoRegressivePipeline(model=NaiveModel(), horizon=1), | ||
HierarchicalPipeline( | ||
model=NaiveModel(), | ||
horizon=1, | ||
reconciliator=BottomUpReconciliator(target_level="market", source_level="product"), | ||
), | ||
), | ||
) | ||
def test_pipelines_forecast_intervals_exist(product_level_constant_hierarchical_ts, pipeline, expected_columns): | ||
intervals_pipeline = ConformalPredictionIntervals(pipeline=pipeline) | ||
run_base_pipeline_compat_check( | ||
ts=product_level_constant_hierarchical_ts, | ||
intervals_pipeline=intervals_pipeline, | ||
expected_columns=expected_columns, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("pipeline", (get_arima_pipeline(horizon=5),)) | ||
def test_forecast_prediction_intervals_is_used(example_tsds, pipeline): | ||
intervals_pipeline = ConformalPredictionIntervals(pipeline=pipeline) | ||
intervals_pipeline._forecast_prediction_interval = MagicMock() | ||
|
||
intervals_pipeline.fit(ts=example_tsds) | ||
intervals_pipeline.forecast(prediction_interval=True) | ||
intervals_pipeline._forecast_prediction_interval.assert_called() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"pipeline", | ||
( | ||
get_naive_pipeline(horizon=5), | ||
get_naive_pipeline_with_transforms(horizon=5), | ||
AutoRegressivePipeline(model=NaiveModel(), horizon=5), | ||
get_catboost_pipeline(horizon=5), | ||
get_arima_pipeline(horizon=5), | ||
), | ||
) | ||
def test_pipelines_forecast_intervals_valid(example_tsds, pipeline): | ||
intervals_pipeline = ConformalPredictionIntervals(pipeline=pipeline) | ||
intervals_pipeline.fit(ts=example_tsds) | ||
|
||
prediction = intervals_pipeline.forecast(prediction_interval=True) | ||
assert np.all(prediction[:, :, "target_lower"].values <= prediction[:, :, "target"].values) | ||
assert np.all(prediction[:, :, "target"].values <= prediction[:, :, "target_upper"].values) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"expected_columns", | ||
({"target", "target_lower", "target_upper"},), | ||
) | ||
@pytest.mark.parametrize( | ||
"ensemble", | ||
( | ||
DirectEnsemble(pipelines=[get_naive_pipeline(horizon=1), get_naive_pipeline_with_transforms(horizon=2)]), | ||
VotingEnsemble(pipelines=[get_naive_pipeline(horizon=1), get_naive_pipeline_with_transforms(horizon=1)]), | ||
StackingEnsemble(pipelines=[get_naive_pipeline(horizon=1), get_naive_pipeline_with_transforms(horizon=1)]), | ||
), | ||
) | ||
def test_ensembles_forecast_intervals_exist(example_tsds, ensemble, expected_columns): | ||
intervals_pipeline = ConformalPredictionIntervals(pipeline=ensemble) | ||
run_base_pipeline_compat_check( | ||
ts=example_tsds, intervals_pipeline=intervals_pipeline, expected_columns=expected_columns | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"ensemble", | ||
( | ||
DirectEnsemble(pipelines=[get_naive_pipeline(horizon=5), get_naive_pipeline_with_transforms(horizon=6)]), | ||
VotingEnsemble(pipelines=[get_naive_pipeline(horizon=5), get_naive_pipeline_with_transforms(horizon=5)]), | ||
StackingEnsemble(pipelines=[get_naive_pipeline(horizon=5), get_naive_pipeline_with_transforms(horizon=5)]), | ||
), | ||
) | ||
def test_ensembles_forecast_intervals_valid(example_tsds, ensemble): | ||
intervals_pipeline = ConformalPredictionIntervals(pipeline=ensemble) | ||
intervals_pipeline.fit(ts=example_tsds) | ||
|
||
prediction = intervals_pipeline.forecast(prediction_interval=True) | ||
assert np.all(prediction[:, :, "target_lower"].values <= prediction[:, :, "target"].values) | ||
assert np.all(prediction[:, :, "target"].values <= prediction[:, :, "target_upper"].values) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.