diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index ded0daf1b27..b97649b700e 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -223,8 +223,8 @@ def plot( >>> from torch import randn >>> # Example plotting a single value - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import NormalizedRootMeanSquaredError + >>> metric = NormalizedRootMeanSquaredError() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot() @@ -233,8 +233,8 @@ def plot( >>> from torch import randn >>> # Example plotting multiple values - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import NormalizedRootMeanSquaredError + >>> metric = NormalizedRootMeanSquaredError() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 465ed2d55e5..add2b78b1b8 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -130,6 +130,7 @@ MeanSquaredError, MeanSquaredLogError, MinkowskiDistance, + NormalizedRootMeanSquaredError, PearsonCorrCoef, R2Score, RelativeSquaredError, @@ -476,6 +477,7 @@ pytest.param(MeanAbsoluteError, _rand_input, _rand_input, id="mean absolute error"), pytest.param(MeanAbsolutePercentageError, _rand_input, _rand_input, id="mean absolute percentage error"), pytest.param(partial(MinkowskiDistance, p=3), _rand_input, _rand_input, id="minkowski distance"), + pytest.param(NormalizedRootMeanSquaredError, _rand_input, _rand_input, id="normalized root mean squared error"), pytest.param(PearsonCorrCoef, _rand_input, _rand_input, id="pearson corr coef"), pytest.param(R2Score, _rand_input, _rand_input, id="r2 score"), pytest.param(RelativeSquaredError, _rand_input, _rand_input, id="relative squared error"),