From 6f7c821f83f65a9ce6c42f6697083cc6ede2af62 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 11 Oct 2024 12:03:02 +0200 Subject: [PATCH] fix implementation --- src/torchmetrics/regression/nrmse.py | 58 ++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index ddd232195f0..60f3f717bfd 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -33,40 +33,58 @@ def _final_aggregation( min_val: Tensor, max_val: Tensor, mean_val: Tensor, - std_val: Tensor, + var_val: Tensor, + target_squared: Tensor, total: Tensor, - normalization: Literal["mean", "range", "std"] = "mean", + normalization: Literal["mean", "range", "std", "l2"] = "mean", ) -> Tensor: + """In the case of multiple devices we need to aggregate the statistics from the different devices.""" if len(min_val) == 1: if normalization == "mean": return mean_val[0] if normalization == "range": return max_val[0] - min_val[0] if normalization == "std": - return std_val[0] - - min_val_1, max_val_1, mean_val_1, std_val_1, total_1 = min_val[0], max_val[0], mean_val[0], std_val[0], total[0] + return var_val[0] + if normalization == "l2": + return target_squared[0] + + min_val_1, max_val_1, mean_val_1, var_val_1, target_squared_1, total_1 = ( + min_val[0], + max_val[0], + mean_val[0], + var_val[0], + target_squared[0], + total[0], + ) for i in range(1, len(min_val)): - min_val_2, max_val_2, mean_val_2, std_val_2, total_2 = min_val[i], max_val[i], mean_val[i], std_val[i], total[i] + min_val_2, max_val_2, mean_val_2, var_val_2, target_squared_2, total_2 = ( + min_val[i], + max_val[i], + mean_val[i], + var_val[i], + target_squared[i], + total[i], + ) total = total_1 + total_2 mean = (total_1 * mean_val_1 + total_2 * mean_val_2) / total - std = torch.sqrt( - ( - std_val_1**2 * (total_1 - 1) - + std_val_2**2 * (total_2 - 1) - + (mean_val_1 - mean) ** 2 * total_1 - + (mean_val_2 - mean) ** 2 * total_2 - ) - / (total - 1) - ) + var = ( + (total_1 - 1) * var_val_1 + + (total_2 - 1) * var_val_2 + + ((mean_val_1 - mean) ** 2) * total_1 + + ((mean_val_2 - mean) ** 2) * total_2 + ) / (total - 1) min_val = torch.min(min_val_1, min_val_2) max_val = torch.max(max_val_1, max_val_2) + target_squared = target_squared_1 + target_squared_2 if normalization == "mean": return mean if normalization == "range": return max_val - min_val - return std + if normalization == "std": + return var + return target_squared class NormalizedRootMeanSquaredError(Metric): @@ -193,7 +211,13 @@ def compute(self) -> Tensor: """ if (self.num_outputs == 1 and self.mean_val.numel() > 1) or (self.num_outputs > 1 and self.mean_val.ndim > 1): denom = _final_aggregation( - self.min_val, self.max_val, self.mean_val, self.var_val, self.total, self.normalization + min_val=self.min_val, + max_val=self.max_val, + mean_val=self.mean_val, + var_val=self.var_val, + target_squared=self.target_squared, + total=self.total, + normalization=self.normalization, ) else: if self.normalization == "mean":