From ea54b8c26b7427be71e5713f76be4f594f4563d5 Mon Sep 17 00:00:00 2001
From: IdoAmosIBM <141218370+IdoAmosIBM@users.noreply.github.com>
Date: Tue, 10 Sep 2024 16:41:44 +0300
Subject: [PATCH] added new metrics for regression tasks (#364)

* added new metrics for regression tasks

* refreactored regression metrics to new dir

---------

Co-authored-by: Ido Amos Ido.Amos@ibm.com <idoamos@cccxc432.pok.ibm.com>
Co-authored-by: Ido Amos Ido.Amos@ibm.com <idoamos@cccxc428.pok.ibm.com>
---
 fuse/eval/examples/examples_stats.py          |   2 +-
 fuse/eval/metrics/libs/stat.py                |  56 +++++++-
 fuse/eval/metrics/regression/__init__.py      |   0
 fuse/eval/metrics/regression/metrics.py       | 131 ++++++++++++++++++
 fuse/eval/metrics/stat/metrics_stat_common.py |  16 +--
 5 files changed, 188 insertions(+), 17 deletions(-)
 create mode 100644 fuse/eval/metrics/regression/__init__.py
 create mode 100644 fuse/eval/metrics/regression/metrics.py

diff --git a/fuse/eval/examples/examples_stats.py b/fuse/eval/examples/examples_stats.py
index 782d18fd2..d0fe36388 100644
--- a/fuse/eval/examples/examples_stats.py
+++ b/fuse/eval/examples/examples_stats.py
@@ -17,7 +17,7 @@
 
 """
 
-from fuse.eval.metrics.stat.metrics_stat_common import MetricPearsonCorrelation
+from fuse.eval.metrics.regression.metrics import MetricPearsonCorrelation
 import numpy as np
 import pandas as pd
 from collections import OrderedDict
diff --git a/fuse/eval/metrics/libs/stat.py b/fuse/eval/metrics/libs/stat.py
index 149b02b41..d51f21a8b 100644
--- a/fuse/eval/metrics/libs/stat.py
+++ b/fuse/eval/metrics/libs/stat.py
@@ -1,6 +1,6 @@
 import numpy as np
 from typing import Sequence, Union
-from scipy.stats import pearsonr
+from scipy.stats import pearsonr, spearmanr
 
 
 class Stat:
@@ -55,3 +55,57 @@ def pearson_correlation(
         results["statistic"] = statistic
         results["p_value"] = p_value
         return results
+
+    @staticmethod
+    def spearman_correlation(
+        pred: Union[np.ndarray, Sequence],
+        target: Union[np.ndarray, Sequence],
+        mask: Union[np.ndarray, Sequence, None] = None,
+    ) -> dict:
+        """
+        Spearman correlation coefficient measuring the monotonic relationship between two datasets/vectors.
+        :param pred: prediction values
+        :param target: target values
+        :param mask: optional boolean mask. if it is provided, the metric will be applied only to the masked samples
+        """
+        if 0 == len(pred):
+            return dict(statistic=float("nan"), p_value=float("nan"))
+
+        if isinstance(pred, Sequence):
+            if np.isscalar(pred[0]):
+                pred = np.array(pred)
+            else:
+                pred = np.concatenate(pred)
+        if isinstance(target, Sequence):
+            if np.isscalar(target[0]):
+                target = np.array(target)
+            else:
+                target = np.concatenate(target)
+        if isinstance(mask, Sequence):
+            if np.isscalar(mask[0]):
+                mask = np.array(mask).astype("bool")
+            else:
+                mask = np.concatenate(mask).astype("bool")
+        if mask is not None:
+            pred = pred[mask]
+            target = target[mask]
+
+        pred = pred.squeeze()
+        target = target.squeeze()
+        if len(pred.shape) > 1 or len(target.shape) > 1:
+            raise ValueError(
+                f"expected 1D vectors. got pred shape: {pred.shape}, target shape: {target.shape}"
+            )
+
+        assert len(pred) == len(
+            target
+        ), f"Spearman corr expected to get pred and target with same length but got pred={len(pred)} - target={len(target)}"
+
+        statistic, p_value = spearmanr(
+            pred, target, nan_policy="propagate"
+        )  # nans will result in nan outputs
+
+        results = {}
+        results["statistic"] = statistic
+        results["p_value"] = p_value
+        return results
diff --git a/fuse/eval/metrics/regression/__init__.py b/fuse/eval/metrics/regression/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/fuse/eval/metrics/regression/metrics.py b/fuse/eval/metrics/regression/metrics.py
new file mode 100644
index 000000000..d5c771a71
--- /dev/null
+++ b/fuse/eval/metrics/regression/metrics.py
@@ -0,0 +1,131 @@
+from typing import List, Optional, Union
+from fuse.eval.metrics.libs.stat import Stat
+from fuse.eval.metrics.metrics_common import MetricDefault
+import numpy as np
+from sklearn.metrics import mean_absolute_error, mean_squared_error
+
+
+class MetricPearsonCorrelation(MetricDefault):
+    def __init__(
+        self, pred: str, target: str, mask: Optional[str] = None, **kwargs: dict
+    ) -> None:
+        super().__init__(
+            pred=pred,
+            target=target,
+            mask=mask,
+            metric_func=Stat.pearson_correlation,
+            **kwargs,
+        )
+
+
+class MetricSpearmanCorrelation(MetricDefault):
+    def __init__(
+        self, pred: str, target: str, mask: Optional[str] = None, **kwargs: dict
+    ) -> None:
+        super().__init__(
+            pred=pred,
+            target=target,
+            mask=mask,
+            metric_func=Stat.spearman_correlation,
+            **kwargs,
+        )
+
+
+class MetricMAE(MetricDefault):
+    def __init__(
+        self,
+        pred: str,
+        target: str,
+        **kwargs: dict,
+    ) -> None:
+        """
+        See MetricDefault for the missing params
+        :param pred: scalar predictions
+        :param target: ground truth scalar labels
+        :param threshold: threshold to apply to both pred and target
+        :param balanced: optionally to use balanced accuracy (from sklearn) instead of regular accuracy.
+        """
+        super().__init__(
+            pred=pred,
+            target=target,
+            metric_func=self.mae,
+            **kwargs,
+        )
+
+    def mae(
+        self,
+        pred: Union[List, np.ndarray],
+        target: Union[List, np.ndarray],
+        **kwargs: dict,
+    ) -> float:
+        return mean_absolute_error(y_true=target, y_pred=pred)
+
+
+class MetricMSE(MetricDefault):
+    def __init__(
+        self,
+        pred: str,
+        target: str,
+        **kwargs: dict,
+    ) -> None:
+        """
+        Our implementation of standard MSE, current version of scikit dones't support it as a metric.
+        See MetricDefault for the missing params
+        :param pred: scalar predictions
+        :param target: ground truth scalar labels
+        :param threshold: threshold to apply to both pred and target
+        :param balanced: optionally to use balanced accuracy (from sklearn) instead of regular accuracy.
+        """
+        super().__init__(
+            pred=pred,
+            target=target,
+            metric_func=self.mse,
+            **kwargs,
+        )
+
+    def mse(
+        self,
+        pred: Union[List, np.ndarray],
+        target: Union[List, np.ndarray],
+        **kwargs: dict,
+    ) -> float:
+        return mean_squared_error(y_true=target, y_pred=pred)
+
+
+class MetricRMSE(MetricDefault):
+    def __init__(
+        self,
+        pred: str,
+        target: str,
+        **kwargs: dict,
+    ) -> None:
+        """
+        See MetricDefault for the missing params
+        :param pred: scalar predictions
+        :param target: ground truth scalar labels
+        :param threshold: threshold to apply to both pred and target
+        :param balanced: optionally to use balanced accuracy (from sklearn) instead of regular accuracy.
+        """
+        super().__init__(
+            pred=pred,
+            target=target,
+            metric_func=self.mse,
+            **kwargs,
+        )
+
+    def mse(
+        self,
+        pred: Union[List, np.ndarray],
+        target: Union[List, np.ndarray],
+        **kwargs: dict,
+    ) -> float:
+
+        pred = np.array(pred).flatten()
+        target = np.array(target).flatten()
+
+        assert len(pred) == len(
+            target
+        ), f"Expected pred and target to have the dimensions but found: {len(pred)} elements in pred and {len(target)} in target"
+
+        squared_diff = (pred - target) ** 2
+        return squared_diff.mean()
diff --git a/fuse/eval/metrics/stat/metrics_stat_common.py b/fuse/eval/metrics/stat/metrics_stat_common.py
index 15d475505..7cf3eaaa5 100644
--- a/fuse/eval/metrics/stat/metrics_stat_common.py
+++ b/fuse/eval/metrics/stat/metrics_stat_common.py
@@ -1,7 +1,6 @@
 from typing import Any, Dict, Hashable, Optional, Sequence
 from collections import Counter
-from fuse.eval.metrics.metrics_common import MetricDefault, MetricWithCollectorBase
-from fuse.eval.metrics.libs.stat import Stat
+from fuse.eval.metrics.metrics_common import MetricWithCollectorBase
 
 
 class MetricUniqueValues(MetricWithCollectorBase):
@@ -20,16 +19,3 @@ def eval(
         counter = Counter(values)
 
         return list(counter.items())
-
-
-class MetricPearsonCorrelation(MetricDefault):
-    def __init__(
-        self, pred: str, target: str, mask: Optional[str] = None, **kwargs: dict
-    ) -> None:
-        super().__init__(
-            pred=pred,
-            target=target,
-            mask=mask,
-            metric_func=Stat.pearson_correlation,
-            **kwargs
-        )