Skip to content

Commit

Permalink
fix implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Oct 11, 2024
1 parent 2d26828 commit 6f7c821
Showing 1 changed file with 41 additions and 17 deletions.
58 changes: 41 additions & 17 deletions src/torchmetrics/regression/nrmse.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,40 +33,58 @@ def _final_aggregation(
min_val: Tensor,
max_val: Tensor,
mean_val: Tensor,
std_val: Tensor,
var_val: Tensor,
target_squared: Tensor,
total: Tensor,
normalization: Literal["mean", "range", "std"] = "mean",
normalization: Literal["mean", "range", "std", "l2"] = "mean",
) -> Tensor:
"""In the case of multiple devices we need to aggregate the statistics from the different devices."""
if len(min_val) == 1:
if normalization == "mean":
return mean_val[0]
if normalization == "range":
return max_val[0] - min_val[0]
if normalization == "std":
return std_val[0]

min_val_1, max_val_1, mean_val_1, std_val_1, total_1 = min_val[0], max_val[0], mean_val[0], std_val[0], total[0]
return var_val[0]
if normalization == "l2":
return target_squared[0]

min_val_1, max_val_1, mean_val_1, var_val_1, target_squared_1, total_1 = (
min_val[0],
max_val[0],
mean_val[0],
var_val[0],
target_squared[0],
total[0],
)
for i in range(1, len(min_val)):
min_val_2, max_val_2, mean_val_2, std_val_2, total_2 = min_val[i], max_val[i], mean_val[i], std_val[i], total[i]
min_val_2, max_val_2, mean_val_2, var_val_2, target_squared_2, total_2 = (
min_val[i],
max_val[i],
mean_val[i],
var_val[i],
target_squared[i],
total[i],
)
total = total_1 + total_2
mean = (total_1 * mean_val_1 + total_2 * mean_val_2) / total
std = torch.sqrt(
(
std_val_1**2 * (total_1 - 1)
+ std_val_2**2 * (total_2 - 1)
+ (mean_val_1 - mean) ** 2 * total_1
+ (mean_val_2 - mean) ** 2 * total_2
)
/ (total - 1)
)
var = (
(total_1 - 1) * var_val_1
+ (total_2 - 1) * var_val_2
+ ((mean_val_1 - mean) ** 2) * total_1
+ ((mean_val_2 - mean) ** 2) * total_2
) / (total - 1)
min_val = torch.min(min_val_1, min_val_2)
max_val = torch.max(max_val_1, max_val_2)
target_squared = target_squared_1 + target_squared_2

if normalization == "mean":
return mean
if normalization == "range":
return max_val - min_val
return std
if normalization == "std":
return var
return target_squared


class NormalizedRootMeanSquaredError(Metric):
Expand Down Expand Up @@ -193,7 +211,13 @@ def compute(self) -> Tensor:
"""
if (self.num_outputs == 1 and self.mean_val.numel() > 1) or (self.num_outputs > 1 and self.mean_val.ndim > 1):
denom = _final_aggregation(
self.min_val, self.max_val, self.mean_val, self.var_val, self.total, self.normalization
min_val=self.min_val,
max_val=self.max_val,
mean_val=self.mean_val,
var_val=self.var_val,
target_squared=self.target_squared,
total=self.total,
normalization=self.normalization,
)
else:
if self.normalization == "mean":
Expand Down

0 comments on commit 6f7c821

Please sign in to comment.