From faee8b7d38facdd36ba9d75bcf20262051b4ec79 Mon Sep 17 00:00:00 2001 From: Sivan Ravid <12941495+sivanravidos@users.noreply.github.com> Date: Sun, 8 Dec 2024 12:09:41 +0200 Subject: [PATCH] Allow skipping stats in GroupAnalysis metric (#383) allow skipping stats in GroupAnalysis metric Co-authored-by: Sivan Ravid <sivanra@il.ibm.com> --- fuse/eval/metrics/metrics_common.py | 59 +++++++++++++++++------------ 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/fuse/eval/metrics/metrics_common.py b/fuse/eval/metrics/metrics_common.py index c29d760e1..54852c986 100644 --- a/fuse/eval/metrics/metrics_common.py +++ b/fuse/eval/metrics/metrics_common.py @@ -688,19 +688,27 @@ def eval( class GroupAnalysis(MetricWithCollectorBase): """ - Evaluate a metric per group and compute basic statistics about the different per group results. + Evaluate a metric per group and compute basic statistics over the different per group results. eval() method returns a dictionary of the following format: {'mean': <>, 'std': <>, 'median': <>, <group 0>: <>, <group 1>: <>, ...} """ - def __init__(self, metric: MetricBase, group: str, **super_kwargs: Any) -> None: + def __init__( + self, + metric: MetricBase, + group: str, + compute_group_stats: bool = True, + **super_kwargs: Any, + ) -> None: """ :param metric: metric to analyze :param group: key to extract the group from + :compute_group_stats: wether to compute stats such as mean, std, median over the per group results :param super_kwargs: additional arguments for super class (MetricWithCollectorBase) constructor """ super().__init__(group=group, **super_kwargs) self._metric = metric + self._compute_group_stats = compute_group_stats def collect(self, batch: Dict) -> None: "See super class" @@ -718,7 +726,9 @@ def reset(self) -> None: return super().reset() def eval( - self, results: Dict[str, Any] = None, ids: Optional[Sequence[Hashable]] = None + self, + results: Dict[str, Any] = None, + ids: Optional[Sequence[Hashable]] = None, ) -> Dict[str, Any]: """ See super class @@ -745,31 +755,32 @@ def eval( ) # compute stats - group_results_list = list(group_analysis_results.values()) - if isinstance(group_results_list[0], dict): # multiple values - # get all keys - all_keys = set() - for group_result in group_results_list: - all_keys |= set(group_result.keys()) - - for key in all_keys: - values = [group_result[key] for group_result in group_results_list] + if self._compute_group_stats: + group_results_list = list(group_analysis_results.values()) + if isinstance(group_results_list[0], dict): # multiple values + # get all keys + all_keys = set() + for group_result in group_results_list: + all_keys |= set(group_result.keys()) + + for key in all_keys: + values = [group_result[key] for group_result in group_results_list] + try: + group_analysis_results[f"{key}.mean"] = np.mean(values) + group_analysis_results[f"{key}.std"] = np.std(values) + group_analysis_results[f"{key}.median"] = np.median(values) + except: + # do nothing + pass + else: # single value + values = [group_result for group_result in group_results_list] try: - group_analysis_results[f"{key}.mean"] = np.mean(values) - group_analysis_results[f"{key}.std"] = np.std(values) - group_analysis_results[f"{key}.median"] = np.median(values) + group_analysis_results["mean"] = np.mean(values) + group_analysis_results["std"] = np.std(values) + group_analysis_results["median"] = np.median(values) except: # do nothing pass - else: # single value - values = [group_result for group_result in group_results_list] - try: - group_analysis_results["mean"] = np.mean(values) - group_analysis_results["std"] = np.std(values) - group_analysis_results["median"] = np.median(values) - except: - # do nothing - pass return group_analysis_results