diff --git a/CHANGELOG.md b/CHANGELOG.md index 093a681de..77d00ac6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - Rework validation of `FoldMask` to not fail on tail nans ([#536](https://github.com/etna-team/etna/pull/536)) - Add parameter `missing_mode` into `R2` and `MedAE` metrics ([#537](https://github.com/etna-team/etna/pull/537)) +- Update `analysis.forecast.plots.plot_metric_per_segment` to handle `None` from metrics ([#540](https://github.com/etna-team/etna/pull/540)) +- - - diff --git a/etna/analysis/forecast/plots.py b/etna/analysis/forecast/plots.py index abc046aba..983bf0500 100644 --- a/etna/analysis/forecast/plots.py +++ b/etna/analysis/forecast/plots.py @@ -25,6 +25,8 @@ from statsmodels.graphics.gofplots import qqplot from typing_extensions import Literal +from etna.analysis.forecast.utils import _check_metrics_df_empty_segments +from etna.analysis.forecast.utils import _check_metrics_df_same_folds_for_each_segment from etna.analysis.forecast.utils import _prepare_forecast_results from etna.analysis.forecast.utils import _select_prediction_intervals_names from etna.analysis.forecast.utils import _validate_intersecting_segments @@ -687,6 +689,11 @@ def plot_metric_per_segment( ): """Plot barplot with per-segment metrics. + If for some segment all metric values are missing, it isn't plotted, and the warning is raised. + + If some segments have different set of folds with non-missing metrics, + it can lead to incompatible values between folds. The warning is raised in such case. + Parameters ---------- metrics_df: @@ -715,12 +722,22 @@ def plot_metric_per_segment( if ``metric_name`` isn't present in ``metrics_df`` NotImplementedError: unknown ``per_fold_aggregation_mode`` is given + + Warnings + -------- + UserWarning: + There are segments without non-missing metric values. + UserWarning: + Some segments have different set of folds to be aggregated on due to missing values. """ if barplot_params is None: barplot_params = {} aggregation_mode = PerFoldAggregation(per_fold_aggregation_mode) + _check_metrics_df_empty_segments(metrics_df=metrics_df, metric_name=metric_name) + _check_metrics_df_same_folds_for_each_segment(metrics_df=metrics_df, metric_name=metric_name) + plt.figure(figsize=figsize) if metric_name not in metrics_df.columns: @@ -728,10 +745,13 @@ def plot_metric_per_segment( if "fold_number" in metrics_df.columns: metrics_dict = ( - metrics_df.groupby("segment").agg({metric_name: aggregation_mode.get_function()}).to_dict()[metric_name] + metrics_df.groupby("segment") + .agg({metric_name: aggregation_mode.get_function()}) + .dropna() + .to_dict()[metric_name] ) else: - metrics_dict = metrics_df["segment", metric_name].to_dict()[metric_name] + metrics_dict = metrics_df[["segment", metric_name]].set_index("segment").dropna().to_dict()[metric_name] segments = np.array(list(metrics_dict.keys())) values = np.array(list(metrics_dict.values())) diff --git a/etna/analysis/forecast/utils.py b/etna/analysis/forecast/utils.py index a796eb8c4..e6e6074ed 100644 --- a/etna/analysis/forecast/utils.py +++ b/etna/analysis/forecast/utils.py @@ -1,3 +1,4 @@ +import reprlib import warnings from copy import deepcopy from typing import TYPE_CHECKING @@ -119,3 +120,31 @@ def _validate_intersecting_segments(fold_numbers: pd.Series): for fold_info_1, fold_info_2 in zip(fold_info[:-1], fold_info[1:]): if fold_info_2["fold_start"] <= fold_info_1["fold_end"]: raise ValueError("Folds are intersecting") + + +def _check_metrics_df_empty_segments(metrics_df: pd.DataFrame, metric_name: str) -> None: + """Check if there are segments without any non-missing metrics.""" + df = metrics_df[["segment", metric_name]] + initial_segments = set(df["segment"].unique()) + df = df.dropna(subset=[metric_name]) + filtered_segments = set(df["segment"].unique()) + + if initial_segments != filtered_segments: + missing_segments = initial_segments - filtered_segments + missing_segments_repr = reprlib.repr(missing_segments) + warnings.warn( + f"There are segments with all missing metric values, they won't be plotted: {missing_segments_repr}." + ) + + +def _check_metrics_df_same_folds_for_each_segment(metrics_df: pd.DataFrame, metric_name: str) -> None: + """Check if the same set of folds is present for each segment.""" + if "fold_number" not in metrics_df.columns: + return + + df = metrics_df[["segment", "fold_number", metric_name]] + # we don't take into account segments without any non-missing metrics, they are handled by other check + df = df.dropna(subset=[metric_name]) + num_unique = df.groupby("segment")["fold_number"].apply(frozenset).nunique() + if num_unique > 1: + warnings.warn("Some segments have different set of folds to be aggregated on due to missing values.") diff --git a/tests/test_analysis/test_forecast/test_plots.py b/tests/test_analysis/test_forecast/test_plots.py index 31f2fb026..d7d9d7942 100644 --- a/tests/test_analysis/test_forecast/test_plots.py +++ b/tests/test_analysis/test_forecast/test_plots.py @@ -1,6 +1,7 @@ import pandas as pd import pytest +from etna.analysis import plot_metric_per_segment from etna.analysis import plot_residuals from etna.analysis.forecast.plots import _get_borders_comparator from etna.metrics import MAE @@ -50,3 +51,64 @@ def test_compare_error(segments_df): def test_compare(segments_df, expected): comparator = _get_borders_comparator(segment_borders=segments_df) assert comparator(name_a="a", name_b="b") == expected + + +@pytest.fixture +def metrics_df_with_folds() -> pd.DataFrame: + df = pd.DataFrame( + { + "segment": ["segment_0"] * 3 + ["segment_1"] * 3 + ["segment_2"] * 3, + "MAE": [1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0], + "MSE": [None, 3.0, 4.0, 3.0, 4.0, 5.0, 5.0, 6.0, 7.0], + "MAPE": [None, None, None, 20.0, 30.0, 40.0, 30.0, 40.0, 50.0], + "SMAPE": [None, None, None, None, None, None, 50.0, 60.0, 70.0], + "RMSE": [None, None, None, None, None, None, None, None, None], + "fold_number": [0, 1, 2, 0, 1, 2, 0, 1, 2], + } + ) + return df + + +@pytest.fixture +def metrics_df_no_folds(metrics_df_with_folds) -> pd.DataFrame: + df = metrics_df_with_folds + df = df.groupby("segment").mean(numeric_only=False).reset_index().drop("fold_number", axis=1) + return df + + +@pytest.mark.parametrize( + "df_name, metric_name", + [ + ("metrics_df_with_folds", "MAE"), + ("metrics_df_no_folds", "MSE"), + ("metrics_df_no_folds", "MSE"), + ], +) +def test_plot_metric_per_segment_ok(df_name, metric_name, request): + metrics_df = request.getfixturevalue(df_name) + plot_metric_per_segment(metrics_df=metrics_df, metric_name=metric_name) + + +@pytest.mark.parametrize( + "df_name, metric_name", + [ + ("metrics_df_with_folds", "MAPE"), + ("metrics_df_no_folds", "RMSE"), + ], +) +def test_plot_metric_per_segment_warning_empty_segments(df_name, metric_name, request): + metrics_df = request.getfixturevalue(df_name) + with pytest.warns(UserWarning, match="There are segments with all missing metric values"): + plot_metric_per_segment(metrics_df=metrics_df, metric_name=metric_name) + + +@pytest.mark.parametrize( + "df_name, metric_name", + [ + ("metrics_df_with_folds", "MSE"), + ], +) +def test_plot_metric_per_segment_warning_non_comparable_segments(df_name, metric_name, request): + metrics_df = request.getfixturevalue(df_name) + with pytest.warns(UserWarning, match="Some segments have different set of folds to be aggregated on"): + plot_metric_per_segment(metrics_df=metrics_df, metric_name=metric_name)