diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index b97649b700e..ddd232195f0 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -133,6 +133,7 @@ class NormalizedRootMeanSquaredError(Metric): total: Tensor min_val: Tensor max_val: Tensor + target_squared: Tensor mean_val: Tensor var_val: Tensor @@ -160,6 +161,7 @@ def __init__( self.add_state("max_val", default=-float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None) self.add_state("mean_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) self.add_state("var_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("target_squared", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets. @@ -171,9 +173,10 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.sum_squared_error += sum_squared_error target = target.view(-1) if self.num_outputs == 1 else target - # Update min and max + # Update min and max and target squared self.min_val = torch.minimum(target.min(dim=0).values, self.min_val) self.max_val = torch.maximum(target.max(dim=0).values, self.max_val) + self.target_squared += (target**2).sum(dim=0) # Update mean and variance new_mean = (self.total * self.mean_val + target.sum(dim=0)) / (self.total + num_obs) @@ -197,8 +200,10 @@ def compute(self) -> Tensor: denom = self.mean_val elif self.normalization == "range": denom = self.max_val - self.min_val - else: + elif self.normalization == "std": denom = torch.sqrt(self.var_val / self.total) + else: + denom = torch.sqrt(self.target_squared) return _normalized_root_mean_squared_error_compute(self.sum_squared_error, self.total, denom) def plot( diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index d738f7e3124..e0df6fc84ec 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -124,9 +124,12 @@ def _reference_normalized_root_mean_squared_error( if num_outputs == 1: y_true = y_true.flatten() y_pred = y_pred.flatten() - evaluator = RegressionMetric(y_true, y_pred) if normalization == "range" else RegressionMetric(y_pred, y_true) - arg_mapping = {"mean": 1, "range": 2, "std": 4} - return evaluator.normalized_root_mean_square_error(model=arg_mapping[normalization]) + if normalization != "l2": + evaluator = RegressionMetric(y_true, y_pred) if normalization == "range" else RegressionMetric(y_pred, y_true) + arg_mapping = {"mean": 1, "range": 2, "std": 4} + return evaluator.normalized_root_mean_square_error(model=arg_mapping[normalization]) + # for l2 normalization we do not have a reference implementation + return np.sqrt(np.mean(np.square(y_true - y_pred), axis=0)) / np.linalg.norm(y_true, axis=0) def _reference_weighted_mean_abs_percentage_error(target, preds): @@ -172,24 +175,50 @@ def _multi_target_ref_wrapper(preds, target, sk_fn, metric_args): @pytest.mark.parametrize( ("metric_class", "metric_functional", "sk_fn", "metric_args"), [ - (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True}), - (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": False}), - (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True, "num_outputs": NUM_TARGETS}), - (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}), - (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {"num_outputs": NUM_TARGETS}), - (MeanAbsolutePercentageError, mean_absolute_percentage_error, sk_mean_abs_percentage_error, {}), - ( + pytest.param( + MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True}, id="mse_singleoutput" + ), + pytest.param( + MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": False}, id="rmse_singleoutput" + ), + pytest.param( + MeanSquaredError, + mean_squared_error, + sk_mean_squared_error, + {"squared": True, "num_outputs": NUM_TARGETS}, + id="mse_multioutput", + ), + pytest.param(MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}, id="mae_singleoutput"), + pytest.param( + MeanAbsoluteError, + mean_absolute_error, + sk_mean_absolute_error, + {"num_outputs": NUM_TARGETS}, + id="mae_multioutput", + ), + pytest.param( + MeanAbsolutePercentageError, + mean_absolute_percentage_error, + sk_mean_abs_percentage_error, + {}, + id="mape_singleoutput", + ), + pytest.param( SymmetricMeanAbsolutePercentageError, symmetric_mean_absolute_percentage_error, _reference_symmetric_mape, {}, + id="symmetric_mean_absolute_percentage_error", ), - (MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}), - ( + pytest.param( + MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}, id="mean_squared_log_error" + ), + pytest.param( WeightedMeanAbsolutePercentageError, weighted_mean_absolute_percentage_error, _reference_weighted_mean_abs_percentage_error, {}, + id="weighted_mean_absolute_percentage_error", ), pytest.param( NormalizedRootMeanSquaredError,