From faee8b7d38facdd36ba9d75bcf20262051b4ec79 Mon Sep 17 00:00:00 2001
From: Sivan Ravid <12941495+sivanravidos@users.noreply.github.com>
Date: Sun, 8 Dec 2024 12:09:41 +0200
Subject: [PATCH] Allow skipping stats in GroupAnalysis metric (#383)

allow skipping stats in GroupAnalysis metric

Co-authored-by: Sivan Ravid <sivanra@il.ibm.com>
---
 fuse/eval/metrics/metrics_common.py | 59 +++++++++++++++++------------
 1 file changed, 35 insertions(+), 24 deletions(-)

diff --git a/fuse/eval/metrics/metrics_common.py b/fuse/eval/metrics/metrics_common.py
index c29d760e1..54852c986 100644
--- a/fuse/eval/metrics/metrics_common.py
+++ b/fuse/eval/metrics/metrics_common.py
@@ -688,19 +688,27 @@ def eval(
 
 class GroupAnalysis(MetricWithCollectorBase):
     """
-    Evaluate a metric per group and compute basic statistics about the different per group results.
+    Evaluate a metric per group and compute basic statistics over the different per group results.
     eval() method returns a dictionary of the following format:
     {'mean': <>, 'std': <>, 'median': <>, <group 0>: <>, <group 1>: <>, ...}
     """
 
-    def __init__(self, metric: MetricBase, group: str, **super_kwargs: Any) -> None:
+    def __init__(
+        self,
+        metric: MetricBase,
+        group: str,
+        compute_group_stats: bool = True,
+        **super_kwargs: Any,
+    ) -> None:
         """
         :param metric: metric to analyze
         :param group: key to extract the group from
+        :compute_group_stats: wether to compute stats such as mean, std, median over the per group results
         :param super_kwargs: additional arguments for super class (MetricWithCollectorBase) constructor
         """
         super().__init__(group=group, **super_kwargs)
         self._metric = metric
+        self._compute_group_stats = compute_group_stats
 
     def collect(self, batch: Dict) -> None:
         "See super class"
@@ -718,7 +726,9 @@ def reset(self) -> None:
         return super().reset()
 
     def eval(
-        self, results: Dict[str, Any] = None, ids: Optional[Sequence[Hashable]] = None
+        self,
+        results: Dict[str, Any] = None,
+        ids: Optional[Sequence[Hashable]] = None,
     ) -> Dict[str, Any]:
         """
         See super class
@@ -745,31 +755,32 @@ def eval(
             )
 
         # compute stats
-        group_results_list = list(group_analysis_results.values())
-        if isinstance(group_results_list[0], dict):  # multiple values
-            # get all keys
-            all_keys = set()
-            for group_result in group_results_list:
-                all_keys |= set(group_result.keys())
-
-            for key in all_keys:
-                values = [group_result[key] for group_result in group_results_list]
+        if self._compute_group_stats:
+            group_results_list = list(group_analysis_results.values())
+            if isinstance(group_results_list[0], dict):  # multiple values
+                # get all keys
+                all_keys = set()
+                for group_result in group_results_list:
+                    all_keys |= set(group_result.keys())
+
+                for key in all_keys:
+                    values = [group_result[key] for group_result in group_results_list]
+                    try:
+                        group_analysis_results[f"{key}.mean"] = np.mean(values)
+                        group_analysis_results[f"{key}.std"] = np.std(values)
+                        group_analysis_results[f"{key}.median"] = np.median(values)
+                    except:
+                        # do nothing
+                        pass
+            else:  # single value
+                values = [group_result for group_result in group_results_list]
                 try:
-                    group_analysis_results[f"{key}.mean"] = np.mean(values)
-                    group_analysis_results[f"{key}.std"] = np.std(values)
-                    group_analysis_results[f"{key}.median"] = np.median(values)
+                    group_analysis_results["mean"] = np.mean(values)
+                    group_analysis_results["std"] = np.std(values)
+                    group_analysis_results["median"] = np.median(values)
                 except:
                     # do nothing
                     pass
-        else:  # single value
-            values = [group_result for group_result in group_results_list]
-            try:
-                group_analysis_results["mean"] = np.mean(values)
-                group_analysis_results["std"] = np.std(values)
-                group_analysis_results["median"] = np.median(values)
-            except:
-                # do nothing
-                pass
 
         return group_analysis_results