Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mean_squared_error functional metric #515

Merged
merged 9 commits into from
Dec 6, 2024
13 changes: 9 additions & 4 deletions etna/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class Metric(AbstractMetric, BaseMixin):
def __init__(
self,
metric_fn: MetricFunction,
mode: str = MetricAggregationMode.per_segment,
mode: str = MetricAggregationMode.per_segment.value,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should consider setting here a plain string instead of an enum value

Copy link
Collaborator Author

@d-a-bunin d-a-bunin Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You suggest using just "per-segment"? I suppose we should do this in all other places too?

Why do you think we should consider that?

metric_fn_signature: str = "array_to_scalar",
**kwargs,
):
Expand All @@ -146,6 +146,8 @@ def __init__(

* if "per-segment" -- does not aggregate metrics

See :py:class:`~etna.metrics.base.MetricAggregationMode`.

metric_fn_signature:
type of signature of ``metric_fn`` (see :py:class:`~etna.metrics.base.MetricFunctionSignature`)
kwargs:
Expand Down Expand Up @@ -385,7 +387,7 @@ class MetricWithMissingHandling(Metric):
def __init__(
self,
metric_fn: MetricFunction,
mode: str = MetricAggregationMode.per_segment,
mode: str = MetricAggregationMode.per_segment.value,
metric_fn_signature: str = "array_to_scalar",
missing_mode: str = "error",
**kwargs,
Expand All @@ -404,6 +406,8 @@ def __init__(

* if "per-segment" -- does not aggregate metrics

See :py:class:`~etna.metrics.base.MetricAggregationMode`.

metric_fn_signature:
type of signature of ``metric_fn`` (see :py:class:`~etna.metrics.base.MetricFunctionSignature`)
missing_mode:
Expand All @@ -421,7 +425,8 @@ def __init__(
If non-existent ``missing_mode`` is used.
"""
super().__init__(metric_fn=metric_fn, mode=mode, metric_fn_signature=metric_fn_signature, **kwargs)
self.missing_mode = MetricMissingMode(missing_mode)
self.missing_mode = missing_mode
self._missing_mode_enum = MetricMissingMode(missing_mode)

def _validate_nans(self, y_true: TSDataset, y_pred: TSDataset):
"""Check that ``y_true`` and ``y_pred`` doesn't have NaNs depending on ``missing_mode``.
Expand All @@ -442,7 +447,7 @@ def _validate_nans(self, y_true: TSDataset, y_pred: TSDataset):
df_pred = y_pred.df.loc[:, pd.IndexSlice[:, "target"]]

df_true_isna_sum = df_true.isna().sum()
if self.missing_mode is MetricMissingMode.error and (df_true_isna_sum > 0).any():
if self._missing_mode_enum is MetricMissingMode.error and (df_true_isna_sum > 0).any():
error_segments = set(df_true_isna_sum[df_true_isna_sum > 0].index.droplevel("feature").tolist())
raise ValueError(f"There are NaNs in y_true! Segments with NaNs: {reprlib.repr(error_segments)}.")

Expand Down
47 changes: 47 additions & 0 deletions etna/metrics/functional_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from enum import Enum
from functools import partial
from typing import Optional
Expand Down Expand Up @@ -41,6 +42,52 @@
assert_never(multioutput_enum)


def mse_with_missing_handling(y_true: ArrayLike, y_pred: ArrayLike, multioutput: str = "joint") -> ArrayLike:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should discuss naming and the fact that sklearn's mse is in our etna/metrics/__init__.py.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can name it mean_squared_error. Mentioning missing handling in the name seems a bit excessive. Better to leave this clarification to the documentation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have both mse from sklearn and our mean_squared_error it could be confusing.

We could remove sklearn's version, we haven't mentioned it in our public documentation (but have it in etna/metrics/__init__.py).

I'm also a bit of worried that by replacing sklearn's mse with our mse we changed list of available kwargs.

"""Mean squared error with missing values handling.

`Wikipedia entry on the Mean squared error
<https://en.wikipedia.org/wiki/Mean_squared_error>`_
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is a good idea to reference wikipedia. Better to replace with a more reputable source (e.g. Hyndman Forecasting) or to remove completely.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I just repeated that we have done with mape and smape. I could replace link in those places too.


The nans are ignored during computation.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to also note what will be returned if all nans in the segment


Parameters
----------
y_true:
array-like of shape (n_samples,) or (n_samples, n_outputs)

Ground truth (correct) target values.

y_pred:
array-like of shape (n_samples,) or (n_samples, n_outputs)

Estimated target values.

multioutput:
Defines aggregating of multiple output values
(see :py:class:`~etna.metrics.functional_metrics.FunctionalMetricMultioutput`).

Returns
-------
:
A non-negative floating point value (the best value is 0.0), or an array of floating point values,
one for each individual target.
"""
y_true_array, y_pred_array = np.asarray(y_true), np.asarray(y_pred)

if len(y_true_array.shape) != len(y_pred_array.shape):
raise ValueError("Shapes of the labels must be the same")

Check warning on line 78 in etna/metrics/functional_metrics.py

View check run for this annotation

Codecov / codecov/patch

etna/metrics/functional_metrics.py#L78

Added line #L78 was not covered by tests

axis = _get_axis_by_multioutput(multioutput)
with warnings.catch_warnings():
# this helps to prevent warning in case of all nans
warnings.filterwarnings(
message="Mean of empty slice",
action="ignore",
)
result = np.nanmean((y_true_array - y_pred_array) ** 2, axis=axis)
return result


def mape(y_true: ArrayLike, y_pred: ArrayLike, eps: float = 1e-15, multioutput: str = "joint") -> ArrayLike:
"""Mean absolute percentage error.

Expand Down
24 changes: 18 additions & 6 deletions etna/metrics/intervals_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class Coverage(Metric, _IntervalsMetricMixin):
def __init__(
self,
quantiles: Optional[Tuple[float, float]] = None,
mode: str = MetricAggregationMode.per_segment,
mode: str = MetricAggregationMode.per_segment.value,
upper_name: Optional[str] = None,
lower_name: Optional[str] = None,
**kwargs,
Expand All @@ -67,8 +67,14 @@ def __init__(
----------
quantiles:
lower and upper quantiles
mode: 'macro' or 'per-segment'
metrics aggregation mode
mode:
"macro" or "per-segment", way to aggregate metric values over segments:

* if "macro" computes average value

* if "per-segment" -- does not aggregate metrics

See :py:class:`~etna.metrics.base.MetricAggregationMode`.
upper_name:
name of column with upper border of the interval
lower_name:
Expand Down Expand Up @@ -169,7 +175,7 @@ class Width(Metric, _IntervalsMetricMixin):
def __init__(
self,
quantiles: Optional[Tuple[float, float]] = None,
mode: str = MetricAggregationMode.per_segment,
mode: str = MetricAggregationMode.per_segment.value,
upper_name: Optional[str] = None,
lower_name: Optional[str] = None,
**kwargs,
Expand All @@ -180,8 +186,14 @@ def __init__(
----------
quantiles:
lower and upper quantiles
mode: 'macro' or 'per-segment'
metrics aggregation mode
mode:
"macro" or "per-segment", way to aggregate metric values over segments:

* if "macro" computes average value

* if "per-segment" -- does not aggregate metrics

See :py:class:`~etna.metrics.base.MetricAggregationMode`.
upper_name:
name of column with upper border of the interval
lower_name:
Expand Down
Loading
Loading