Skip to content

Commit

Permalink
fix part of tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Oct 11, 2024
1 parent 4d57398 commit 2d26828
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 14 deletions.
9 changes: 7 additions & 2 deletions src/torchmetrics/regression/nrmse.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class NormalizedRootMeanSquaredError(Metric):
total: Tensor
min_val: Tensor
max_val: Tensor
target_squared: Tensor
mean_val: Tensor
var_val: Tensor

Expand Down Expand Up @@ -160,6 +161,7 @@ def __init__(
self.add_state("max_val", default=-float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None)
self.add_state("mean_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)
self.add_state("var_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)
self.add_state("target_squared", default=torch.zeros(self.num_outputs), dist_reduce_fx=None)

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets.
Expand All @@ -171,9 +173,10 @@ def update(self, preds: Tensor, target: Tensor) -> None:
self.sum_squared_error += sum_squared_error
target = target.view(-1) if self.num_outputs == 1 else target

# Update min and max
# Update min and max and target squared
self.min_val = torch.minimum(target.min(dim=0).values, self.min_val)
self.max_val = torch.maximum(target.max(dim=0).values, self.max_val)
self.target_squared += (target**2).sum(dim=0)

# Update mean and variance
new_mean = (self.total * self.mean_val + target.sum(dim=0)) / (self.total + num_obs)
Expand All @@ -197,8 +200,10 @@ def compute(self) -> Tensor:
denom = self.mean_val
elif self.normalization == "range":
denom = self.max_val - self.min_val
else:
elif self.normalization == "std":
denom = torch.sqrt(self.var_val / self.total)
else:
denom = torch.sqrt(self.target_squared)
return _normalized_root_mean_squared_error_compute(self.sum_squared_error, self.total, denom)

def plot(
Expand Down
53 changes: 41 additions & 12 deletions tests/unittests/regression/test_mean_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,12 @@ def _reference_normalized_root_mean_squared_error(
if num_outputs == 1:
y_true = y_true.flatten()
y_pred = y_pred.flatten()
evaluator = RegressionMetric(y_true, y_pred) if normalization == "range" else RegressionMetric(y_pred, y_true)
arg_mapping = {"mean": 1, "range": 2, "std": 4}
return evaluator.normalized_root_mean_square_error(model=arg_mapping[normalization])
if normalization != "l2":
evaluator = RegressionMetric(y_true, y_pred) if normalization == "range" else RegressionMetric(y_pred, y_true)
arg_mapping = {"mean": 1, "range": 2, "std": 4}
return evaluator.normalized_root_mean_square_error(model=arg_mapping[normalization])
# for l2 normalization we do not have a reference implementation
return np.sqrt(np.mean(np.square(y_true - y_pred), axis=0)) / np.linalg.norm(y_true, axis=0)


def _reference_weighted_mean_abs_percentage_error(target, preds):
Expand Down Expand Up @@ -172,24 +175,50 @@ def _multi_target_ref_wrapper(preds, target, sk_fn, metric_args):
@pytest.mark.parametrize(
("metric_class", "metric_functional", "sk_fn", "metric_args"),
[
(MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True}),
(MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": False}),
(MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True, "num_outputs": NUM_TARGETS}),
(MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}),
(MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {"num_outputs": NUM_TARGETS}),
(MeanAbsolutePercentageError, mean_absolute_percentage_error, sk_mean_abs_percentage_error, {}),
(
pytest.param(
MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True}, id="mse_singleoutput"
),
pytest.param(
MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": False}, id="rmse_singleoutput"
),
pytest.param(
MeanSquaredError,
mean_squared_error,
sk_mean_squared_error,
{"squared": True, "num_outputs": NUM_TARGETS},
id="mse_multioutput",
),
pytest.param(MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}, id="mae_singleoutput"),
pytest.param(
MeanAbsoluteError,
mean_absolute_error,
sk_mean_absolute_error,
{"num_outputs": NUM_TARGETS},
id="mae_multioutput",
),
pytest.param(
MeanAbsolutePercentageError,
mean_absolute_percentage_error,
sk_mean_abs_percentage_error,
{},
id="mape_singleoutput",
),
pytest.param(
SymmetricMeanAbsolutePercentageError,
symmetric_mean_absolute_percentage_error,
_reference_symmetric_mape,
{},
id="symmetric_mean_absolute_percentage_error",
),
(MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}),
(
pytest.param(
MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}, id="mean_squared_log_error"
),
pytest.param(
WeightedMeanAbsolutePercentageError,
weighted_mean_absolute_percentage_error,
_reference_weighted_mean_abs_percentage_error,
{},
id="weighted_mean_absolute_percentage_error",
),
pytest.param(
NormalizedRootMeanSquaredError,
Expand Down

0 comments on commit 2d26828

Please sign in to comment.