Skip to content

Commit

Permalink
Allow skipping stats in GroupAnalysis metric (#383)
Browse files Browse the repository at this point in the history
allow skipping stats in GroupAnalysis metric

Co-authored-by: Sivan Ravid <[email protected]>
  • Loading branch information
sivanravidos and Sivan Ravid authored Dec 8, 2024
1 parent 4b58fcb commit faee8b7
Showing 1 changed file with 35 additions and 24 deletions.
59 changes: 35 additions & 24 deletions fuse/eval/metrics/metrics_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,19 +688,27 @@ def eval(

class GroupAnalysis(MetricWithCollectorBase):
"""
Evaluate a metric per group and compute basic statistics about the different per group results.
Evaluate a metric per group and compute basic statistics over the different per group results.
eval() method returns a dictionary of the following format:
{'mean': <>, 'std': <>, 'median': <>, <group 0>: <>, <group 1>: <>, ...}
"""

def __init__(self, metric: MetricBase, group: str, **super_kwargs: Any) -> None:
def __init__(
self,
metric: MetricBase,
group: str,
compute_group_stats: bool = True,
**super_kwargs: Any,
) -> None:
"""
:param metric: metric to analyze
:param group: key to extract the group from
:compute_group_stats: wether to compute stats such as mean, std, median over the per group results
:param super_kwargs: additional arguments for super class (MetricWithCollectorBase) constructor
"""
super().__init__(group=group, **super_kwargs)
self._metric = metric
self._compute_group_stats = compute_group_stats

def collect(self, batch: Dict) -> None:
"See super class"
Expand All @@ -718,7 +726,9 @@ def reset(self) -> None:
return super().reset()

def eval(
self, results: Dict[str, Any] = None, ids: Optional[Sequence[Hashable]] = None
self,
results: Dict[str, Any] = None,
ids: Optional[Sequence[Hashable]] = None,
) -> Dict[str, Any]:
"""
See super class
Expand All @@ -745,31 +755,32 @@ def eval(
)

# compute stats
group_results_list = list(group_analysis_results.values())
if isinstance(group_results_list[0], dict): # multiple values
# get all keys
all_keys = set()
for group_result in group_results_list:
all_keys |= set(group_result.keys())

for key in all_keys:
values = [group_result[key] for group_result in group_results_list]
if self._compute_group_stats:
group_results_list = list(group_analysis_results.values())
if isinstance(group_results_list[0], dict): # multiple values
# get all keys
all_keys = set()
for group_result in group_results_list:
all_keys |= set(group_result.keys())

for key in all_keys:
values = [group_result[key] for group_result in group_results_list]
try:
group_analysis_results[f"{key}.mean"] = np.mean(values)
group_analysis_results[f"{key}.std"] = np.std(values)
group_analysis_results[f"{key}.median"] = np.median(values)
except:
# do nothing
pass
else: # single value
values = [group_result for group_result in group_results_list]
try:
group_analysis_results[f"{key}.mean"] = np.mean(values)
group_analysis_results[f"{key}.std"] = np.std(values)
group_analysis_results[f"{key}.median"] = np.median(values)
group_analysis_results["mean"] = np.mean(values)
group_analysis_results["std"] = np.std(values)
group_analysis_results["median"] = np.median(values)
except:
# do nothing
pass
else: # single value
values = [group_result for group_result in group_results_list]
try:
group_analysis_results["mean"] = np.mean(values)
group_analysis_results["std"] = np.std(values)
group_analysis_results["median"] = np.median(values)
except:
# do nothing
pass

return group_analysis_results

Expand Down

0 comments on commit faee8b7

Please sign in to comment.