Skip to content

Commit

Permalink
Regression metrics patch (#382)
Browse files Browse the repository at this point in the history
* adding mask option to mae, rmse, mse metrics

* added optional masking for regression metrics

* added optional mask for regression tasks

* cleaned comments

---------

Co-authored-by: Ido Amos [email protected] <[email protected]>
Co-authored-by: Ido Amos [email protected] <[email protected]>
  • Loading branch information
3 people authored Nov 20, 2024
1 parent a156da0 commit 4b58fcb
Showing 1 changed file with 77 additions and 1 deletion.
78 changes: 77 additions & 1 deletion fuse/eval/metrics/regression/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Union
from typing import List, Optional, Union, Sequence
from fuse.eval.metrics.libs.stat import Stat
from fuse.eval.metrics.metrics_common import MetricDefault
import numpy as np
Expand Down Expand Up @@ -41,6 +41,7 @@ def __init__(
self,
pred: str,
target: str,
mask: Optional[str] = None,
**kwargs: dict,
) -> None:
"""
Expand All @@ -53,6 +54,7 @@ def __init__(
super().__init__(
pred=pred,
target=target,
mask=mask,
metric_func=self.mae,
**kwargs,
)
Expand All @@ -61,8 +63,31 @@ def mae(
self,
pred: Union[List, np.ndarray],
target: Union[List, np.ndarray],
mask: Optional[np.ndarray] = None,
**kwargs: dict,
) -> float:

if mask is not None:

if isinstance(pred, Sequence):
if np.isscalar(pred[0]):
pred = np.array(pred)
else:
pred = np.concatenate(pred)
if isinstance(target, Sequence):
if np.isscalar(target[0]):
target = np.array(target)
else:
target = np.concatenate(target)
if isinstance(mask, Sequence):
if np.isscalar(mask[0]):
mask = np.array(mask).astype("bool")
else:
mask = np.concatenate(mask).astype("bool")

target = target[mask]
pred = pred[mask]

return mean_absolute_error(y_true=target, y_pred=pred)


Expand All @@ -71,6 +96,7 @@ def __init__(
self,
pred: str,
target: str,
mask: Optional[str] = None,
**kwargs: dict,
) -> None:
"""
Expand All @@ -84,6 +110,7 @@ def __init__(
super().__init__(
pred=pred,
target=target,
mask=mask,
metric_func=self.mse,
**kwargs,
)
Expand All @@ -92,8 +119,31 @@ def mse(
self,
pred: Union[List, np.ndarray],
target: Union[List, np.ndarray],
mask: Optional[np.ndarray] = None,
**kwargs: dict,
) -> float:

if mask is not None:

if isinstance(pred, Sequence):
if np.isscalar(pred[0]):
pred = np.array(pred)
else:
pred = np.concatenate(pred)
if isinstance(target, Sequence):
if np.isscalar(target[0]):
target = np.array(target)
else:
target = np.concatenate(target)
if isinstance(mask, Sequence):
if np.isscalar(mask[0]):
mask = np.array(mask).astype("bool")
else:
mask = np.concatenate(mask).astype("bool")

target = target[mask]
pred = pred[mask]

return mean_squared_error(y_true=target, y_pred=pred)


Expand All @@ -102,6 +152,7 @@ def __init__(
self,
pred: str,
target: str,
mask: Optional[str] = None,
**kwargs: dict,
) -> None:
"""
Expand All @@ -112,6 +163,7 @@ def __init__(
super().__init__(
pred=pred,
target=target,
mask=mask,
metric_func=self.rmse,
**kwargs,
)
Expand All @@ -120,10 +172,34 @@ def rmse(
self,
pred: Union[List, np.ndarray],
target: Union[List, np.ndarray],
mask: Optional[np.ndarray] = None,
**kwargs: dict,
) -> float:

if mask is not None:

if isinstance(pred, Sequence):
if np.isscalar(pred[0]):
pred = np.array(pred)
else:
pred = np.concatenate(pred)
if isinstance(target, Sequence):
if np.isscalar(target[0]):
target = np.array(target)
else:
target = np.concatenate(target)
if isinstance(mask, Sequence):
if np.isscalar(mask[0]):
mask = np.array(mask).astype("bool")
else:
mask = np.concatenate(mask).astype("bool")

target = target[mask]
pred = pred[mask]

pred = np.array(pred).flatten()
target = np.array(target).flatten()

assert len(pred) == len(
target
), f"Expected pred and target to have the dimensions but found: {len(pred)} elements in pred and {len(target)} in target"
Expand Down

0 comments on commit 4b58fcb

Please sign in to comment.