Skip to content

Commit

Permalink
fix plotting code + test
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Oct 11, 2024
1 parent b967aa0 commit 4d57398
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/torchmetrics/regression/nrmse.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ def plot(
>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import MeanSquaredError
>>> metric = MeanSquaredError()
>>> from torchmetrics.regression import NormalizedRootMeanSquaredError
>>> metric = NormalizedRootMeanSquaredError()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
Expand All @@ -233,8 +233,8 @@ def plot(
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import MeanSquaredError
>>> metric = MeanSquaredError()
>>> from torchmetrics.regression import NormalizedRootMeanSquaredError
>>> metric = NormalizedRootMeanSquaredError()
>>> values = []
>>> for _ in range(10):
... values.append(metric(randn(10,), randn(10,)))
Expand Down
2 changes: 2 additions & 0 deletions tests/unittests/utilities/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
MeanSquaredError,
MeanSquaredLogError,
MinkowskiDistance,
NormalizedRootMeanSquaredError,
PearsonCorrCoef,
R2Score,
RelativeSquaredError,
Expand Down Expand Up @@ -476,6 +477,7 @@
pytest.param(MeanAbsoluteError, _rand_input, _rand_input, id="mean absolute error"),
pytest.param(MeanAbsolutePercentageError, _rand_input, _rand_input, id="mean absolute percentage error"),
pytest.param(partial(MinkowskiDistance, p=3), _rand_input, _rand_input, id="minkowski distance"),
pytest.param(NormalizedRootMeanSquaredError, _rand_input, _rand_input, id="normalized root mean squared error"),
pytest.param(PearsonCorrCoef, _rand_input, _rand_input, id="pearson corr coef"),
pytest.param(R2Score, _rand_input, _rand_input, id="r2 score"),
pytest.param(RelativeSquaredError, _rand_input, _rand_input, id="relative squared error"),
Expand Down

0 comments on commit 4d57398

Please sign in to comment.