Skip to content

Commit

Permalink
Update analysis.forecast.plots.plot_metric_per_segment to handle `N…
Browse files Browse the repository at this point in the history
…one` from metrics (#540)
  • Loading branch information
d-a-bunin authored Dec 23, 2024
1 parent b98c853 commit ac792f8
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
- Rework validation of `FoldMask` to not fail on tail nans ([#536](https://github.com/etna-team/etna/pull/536))
- Add parameter `missing_mode` into `R2` and `MedAE` metrics ([#537](https://github.com/etna-team/etna/pull/537))
- Update `analysis.forecast.plots.plot_metric_per_segment` to handle `None` from metrics ([#540](https://github.com/etna-team/etna/pull/540))
-
-
-

Expand Down
24 changes: 22 additions & 2 deletions etna/analysis/forecast/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from statsmodels.graphics.gofplots import qqplot
from typing_extensions import Literal

from etna.analysis.forecast.utils import _check_metrics_df_empty_segments
from etna.analysis.forecast.utils import _check_metrics_df_same_folds_for_each_segment
from etna.analysis.forecast.utils import _prepare_forecast_results
from etna.analysis.forecast.utils import _select_prediction_intervals_names
from etna.analysis.forecast.utils import _validate_intersecting_segments
Expand Down Expand Up @@ -687,6 +689,11 @@ def plot_metric_per_segment(
):
"""Plot barplot with per-segment metrics.
If for some segment all metric values are missing, it isn't plotted, and the warning is raised.
If some segments have different set of folds with non-missing metrics,
it can lead to incompatible values between folds. The warning is raised in such case.
Parameters
----------
metrics_df:
Expand Down Expand Up @@ -715,23 +722,36 @@ def plot_metric_per_segment(
if ``metric_name`` isn't present in ``metrics_df``
NotImplementedError:
unknown ``per_fold_aggregation_mode`` is given
Warnings
--------
UserWarning:
There are segments without non-missing metric values.
UserWarning:
Some segments have different set of folds to be aggregated on due to missing values.
"""
if barplot_params is None:
barplot_params = {}

aggregation_mode = PerFoldAggregation(per_fold_aggregation_mode)

_check_metrics_df_empty_segments(metrics_df=metrics_df, metric_name=metric_name)
_check_metrics_df_same_folds_for_each_segment(metrics_df=metrics_df, metric_name=metric_name)

plt.figure(figsize=figsize)

if metric_name not in metrics_df.columns:
raise ValueError("Given metric_name isn't present in metrics_df")

if "fold_number" in metrics_df.columns:
metrics_dict = (
metrics_df.groupby("segment").agg({metric_name: aggregation_mode.get_function()}).to_dict()[metric_name]
metrics_df.groupby("segment")
.agg({metric_name: aggregation_mode.get_function()})
.dropna()
.to_dict()[metric_name]
)
else:
metrics_dict = metrics_df["segment", metric_name].to_dict()[metric_name]
metrics_dict = metrics_df[["segment", metric_name]].set_index("segment").dropna().to_dict()[metric_name]

segments = np.array(list(metrics_dict.keys()))
values = np.array(list(metrics_dict.values()))
Expand Down
29 changes: 29 additions & 0 deletions etna/analysis/forecast/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import reprlib
import warnings
from copy import deepcopy
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -119,3 +120,31 @@ def _validate_intersecting_segments(fold_numbers: pd.Series):
for fold_info_1, fold_info_2 in zip(fold_info[:-1], fold_info[1:]):
if fold_info_2["fold_start"] <= fold_info_1["fold_end"]:
raise ValueError("Folds are intersecting")


def _check_metrics_df_empty_segments(metrics_df: pd.DataFrame, metric_name: str) -> None:
"""Check if there are segments without any non-missing metrics."""
df = metrics_df[["segment", metric_name]]
initial_segments = set(df["segment"].unique())
df = df.dropna(subset=[metric_name])
filtered_segments = set(df["segment"].unique())

if initial_segments != filtered_segments:
missing_segments = initial_segments - filtered_segments
missing_segments_repr = reprlib.repr(missing_segments)
warnings.warn(
f"There are segments with all missing metric values, they won't be plotted: {missing_segments_repr}."
)


def _check_metrics_df_same_folds_for_each_segment(metrics_df: pd.DataFrame, metric_name: str) -> None:
"""Check if the same set of folds is present for each segment."""
if "fold_number" not in metrics_df.columns:
return

df = metrics_df[["segment", "fold_number", metric_name]]
# we don't take into account segments without any non-missing metrics, they are handled by other check
df = df.dropna(subset=[metric_name])
num_unique = df.groupby("segment")["fold_number"].apply(frozenset).nunique()
if num_unique > 1:
warnings.warn("Some segments have different set of folds to be aggregated on due to missing values.")
62 changes: 62 additions & 0 deletions tests/test_analysis/test_forecast/test_plots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
import pytest

from etna.analysis import plot_metric_per_segment
from etna.analysis import plot_residuals
from etna.analysis.forecast.plots import _get_borders_comparator
from etna.metrics import MAE
Expand Down Expand Up @@ -50,3 +51,64 @@ def test_compare_error(segments_df):
def test_compare(segments_df, expected):
comparator = _get_borders_comparator(segment_borders=segments_df)
assert comparator(name_a="a", name_b="b") == expected


@pytest.fixture
def metrics_df_with_folds() -> pd.DataFrame:
df = pd.DataFrame(
{
"segment": ["segment_0"] * 3 + ["segment_1"] * 3 + ["segment_2"] * 3,
"MAE": [1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0],
"MSE": [None, 3.0, 4.0, 3.0, 4.0, 5.0, 5.0, 6.0, 7.0],
"MAPE": [None, None, None, 20.0, 30.0, 40.0, 30.0, 40.0, 50.0],
"SMAPE": [None, None, None, None, None, None, 50.0, 60.0, 70.0],
"RMSE": [None, None, None, None, None, None, None, None, None],
"fold_number": [0, 1, 2, 0, 1, 2, 0, 1, 2],
}
)
return df


@pytest.fixture
def metrics_df_no_folds(metrics_df_with_folds) -> pd.DataFrame:
df = metrics_df_with_folds
df = df.groupby("segment").mean(numeric_only=False).reset_index().drop("fold_number", axis=1)
return df


@pytest.mark.parametrize(
"df_name, metric_name",
[
("metrics_df_with_folds", "MAE"),
("metrics_df_no_folds", "MSE"),
("metrics_df_no_folds", "MSE"),
],
)
def test_plot_metric_per_segment_ok(df_name, metric_name, request):
metrics_df = request.getfixturevalue(df_name)
plot_metric_per_segment(metrics_df=metrics_df, metric_name=metric_name)


@pytest.mark.parametrize(
"df_name, metric_name",
[
("metrics_df_with_folds", "MAPE"),
("metrics_df_no_folds", "RMSE"),
],
)
def test_plot_metric_per_segment_warning_empty_segments(df_name, metric_name, request):
metrics_df = request.getfixturevalue(df_name)
with pytest.warns(UserWarning, match="There are segments with all missing metric values"):
plot_metric_per_segment(metrics_df=metrics_df, metric_name=metric_name)


@pytest.mark.parametrize(
"df_name, metric_name",
[
("metrics_df_with_folds", "MSE"),
],
)
def test_plot_metric_per_segment_warning_non_comparable_segments(df_name, metric_name, request):
metrics_df = request.getfixturevalue(df_name)
with pytest.warns(UserWarning, match="Some segments have different set of folds to be aggregated on"):
plot_metric_per_segment(metrics_df=metrics_df, metric_name=metric_name)

0 comments on commit ac792f8

Please sign in to comment.