Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests: move reset seed to fixture [2/2] #2706

Merged
merged 13 commits into from
Aug 30, 2024
47 changes: 21 additions & 26 deletions src/torchmetrics/image/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
class _ErrorRelativeGlobalDimensionlessSynthesis(ErrorRelativeGlobalDimensionlessSynthesis):
"""Wrapper for deprecated import.

>>> import torch
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> from torch import rand
>>> preds = rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> ergas = _ErrorRelativeGlobalDimensionlessSynthesis()
>>> torch.round(ergas(preds, target))
>>> ergas(preds, target).round()
tensor(10.)

"""
Expand All @@ -39,8 +39,8 @@ def __init__(
class _MultiScaleStructuralSimilarityIndexMeasure(MultiScaleStructuralSimilarityIndexMeasure):
"""Wrapper for deprecated import.

>>> import torch
>>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
>>> from torch import rand
>>> preds = rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> ms_ssim = _MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
>>> ms_ssim(preds, target)
Expand Down Expand Up @@ -103,13 +103,12 @@ def __init__(
class _RelativeAverageSpectralError(RelativeAverageSpectralError):
"""Wrapper for deprecated import.

>>> import torch
>>> g = torch.manual_seed(22)
>>> preds = torch.rand(4, 3, 16, 16)
>>> target = torch.rand(4, 3, 16, 16)
>>> from torch import rand
>>> preds = rand(4, 3, 16, 16)
>>> target = rand(4, 3, 16, 16)
>>> rase = _RelativeAverageSpectralError()
>>> rase(preds, target)
tensor(5114.66...)
tensor(5326.40...)

"""

Expand All @@ -125,13 +124,12 @@ def __init__(
class _RootMeanSquaredErrorUsingSlidingWindow(RootMeanSquaredErrorUsingSlidingWindow):
"""Wrapper for deprecated import.

>>> import torch
>>> g = torch.manual_seed(22)
>>> preds = torch.rand(4, 3, 16, 16)
>>> target = torch.rand(4, 3, 16, 16)
>>> from torch import rand
>>> preds = rand(4, 3, 16, 16)
>>> target = rand(4, 3, 16, 16)
>>> rmse_sw = RootMeanSquaredErrorUsingSlidingWindow()
>>> rmse_sw(preds, target)
tensor(0.3999)
tensor(0.4158)

"""

Expand All @@ -147,10 +145,9 @@ def __init__(
class _SpectralAngleMapper(SpectralAngleMapper):
"""Wrapper for deprecated import.

>>> import torch
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16], generator=gen)
>>> target = torch.rand([16, 3, 16, 16], generator=gen)
>>> from torch import rand
>>> preds = rand([16, 3, 16, 16])
>>> target = rand([16, 3, 16, 16])
>>> sam = _SpectralAngleMapper()
>>> sam(preds, target)
tensor(0.5914)
Expand All @@ -169,10 +166,9 @@ def __init__(
class _SpectralDistortionIndex(SpectralDistortionIndex):
"""Wrapper for deprecated import.

>>> import torch
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> from torch import rand
>>> preds = rand([16, 3, 16, 16])
>>> target = rand([16, 3, 16, 16])
>>> sdi = _SpectralDistortionIndex()
>>> sdi(preds, target)
tensor(0.0234)
Expand Down Expand Up @@ -229,10 +225,9 @@ def __init__(
class _TotalVariation(TotalVariation):
"""Wrapper for deprecated import.

>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import rand
>>> tv = _TotalVariation()
>>> img = torch.rand(5, 3, 28, 28)
>>> img = rand(5, 3, 28, 28)
>>> tv(img)
tensor(7546.8018)

Expand Down
21 changes: 9 additions & 12 deletions src/torchmetrics/image/d_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ class SpectralDistortionIndex(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import rand
>>> from torchmetrics.image import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> preds = rand([16, 3, 16, 16])
>>> target = rand([16, 3, 16, 16])
>>> sdi = SpectralDistortionIndex()
>>> sdi(preds, target)
tensor(0.0234)
Expand Down Expand Up @@ -126,11 +125,10 @@ def plot(
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import rand
>>> from torchmetrics.image import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> preds = rand([16, 3, 16, 16])
>>> target = rand([16, 3, 16, 16])
>>> metric = SpectralDistortionIndex()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
Expand All @@ -139,11 +137,10 @@ def plot(
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import rand
>>> from torchmetrics.image import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> preds = rand([16, 3, 16, 16])
>>> target = rand([16, 3, 16, 16])
>>> metric = SpectralDistortionIndex()
>>> values = [ ]
>>> for _ in range(10):
Expand Down
27 changes: 12 additions & 15 deletions src/torchmetrics/image/d_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,12 @@ class SpatialDistortionIndex(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import rand
>>> from torchmetrics.image import SpatialDistortionIndex
>>> preds = torch.rand([16, 3, 32, 32])
>>> preds = rand([16, 3, 32, 32])
>>> target = {
... 'ms': torch.rand([16, 3, 16, 16]),
... 'pan': torch.rand([16, 3, 32, 32]),
... 'ms': rand([16, 3, 16, 16]),
... 'pan': rand([16, 3, 32, 32]),
... }
>>> sdi = SpatialDistortionIndex()
>>> sdi(preds, target)
Expand Down Expand Up @@ -191,13 +190,12 @@ def plot(
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import rand
>>> from torchmetrics.image import SpatialDistortionIndex
>>> preds = torch.rand([16, 3, 32, 32])
>>> preds = rand([16, 3, 32, 32])
>>> target = {
... 'ms': torch.rand([16, 3, 16, 16]),
... 'pan': torch.rand([16, 3, 32, 32]),
... 'ms': rand([16, 3, 16, 16]),
... 'pan': rand([16, 3, 32, 32]),
... }
>>> metric = SpatialDistortionIndex()
>>> metric.update(preds, target)
Expand All @@ -207,13 +205,12 @@ def plot(
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import rand
>>> from torchmetrics.image import SpatialDistortionIndex
>>> preds = torch.rand([16, 3, 32, 32])
>>> preds = rand([16, 3, 32, 32])
>>> target = {
... 'ms': torch.rand([16, 3, 16, 16]),
... 'pan': torch.rand([16, 3, 32, 32]),
... 'ms': rand([16, 3, 16, 16]),
... 'pan': rand([16, 3, 32, 32]),
... }
>>> metric = SpatialDistortionIndex()
>>> values = [ ]
Expand Down
14 changes: 7 additions & 7 deletions src/torchmetrics/image/ergas.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ class ErrorRelativeGlobalDimensionlessSynthesis(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> import torch
>>> from torch import rand
>>> from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> preds = rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> ergas = ErrorRelativeGlobalDimensionlessSynthesis()
>>> torch.round(ergas(preds, target))
>>> ergas(preds, target).round()
tensor(10.)

"""
Expand Down Expand Up @@ -131,9 +131,9 @@ def plot(
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torch import rand
>>> from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> preds = rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> metric = ErrorRelativeGlobalDimensionlessSynthesis()
>>> metric.update(preds, target)
Expand All @@ -143,9 +143,9 @@ def plot(
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torch import rand
>>> from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> preds = rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> metric = ErrorRelativeGlobalDimensionlessSynthesis()
>>> values = [ ]
Expand Down
5 changes: 2 additions & 3 deletions src/torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ class FrechetInceptionDistance(Metric):
If ``reset_real_features`` is not an ``bool``

Example:
>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torch import rand
>>> from torchmetrics.image.fid import FrechetInceptionDistance
>>> fid = FrechetInceptionDistance(feature=64)
>>> # generate two slightly overlapping image intensity distributions
Expand All @@ -275,7 +274,7 @@ class FrechetInceptionDistance(Metric):
>>> fid.update(imgs_dist1, real=True)
>>> fid.update(imgs_dist2, real=False)
>>> fid.compute()
tensor(12.7202)
tensor(12.6388)

"""

Expand Down
5 changes: 2 additions & 3 deletions src/torchmetrics/image/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,14 @@ class InceptionScore(Metric):
If ``feature`` is not an ``str``, ``int`` or ``torch.nn.Module``

Example:
>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torch import rand
>>> from torchmetrics.image.inception import InceptionScore
>>> inception = InceptionScore()
>>> # generate some images
>>> imgs = torch.randint(0, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> inception.update(imgs)
>>> inception.compute()
(tensor(1.0544), tensor(0.0117))
(tensor(1.0549), tensor(0.0121))

"""

Expand Down
9 changes: 4 additions & 5 deletions src/torchmetrics/image/kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,16 @@ class KernelInceptionDistance(Metric):
If ``reset_real_features`` is not an ``bool``

Example:
>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torch import randint
>>> from torchmetrics.image.kid import KernelInceptionDistance
>>> kid = KernelInceptionDistance(subset_size=50)
>>> # generate two slightly overlapping image intensity distributions
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist1 = randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> kid.update(imgs_dist1, real=True)
>>> kid.update(imgs_dist2, real=False)
>>> kid.compute()
(tensor(0.0337), tensor(0.0023))
(tensor(0.0312), tensor(0.0025))

"""

Expand Down
9 changes: 4 additions & 5 deletions src/torchmetrics/image/lpip.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,14 @@ class LearnedPerceptualImagePatchSimilarity(Metric):
If ``reduction`` is not one of ``"mean"`` or ``"sum"``

Example:
>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torch import rand
>>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
>>> lpips = LearnedPerceptualImagePatchSimilarity(net_type='squeeze')
>>> # LPIPS needs the images to be in the [-1, 1] range.
>>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> img1 = (rand(10, 3, 100, 100) * 2) - 1
>>> img2 = (rand(10, 3, 100, 100) * 2) - 1
>>> lpips(img1, img2)
tensor(0.1046)
tensor(0.1024)

"""

Expand Down
7 changes: 3 additions & 4 deletions src/torchmetrics/image/mifid.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,12 @@ class MemorizationInformedFrechetInceptionDistance(Metric):
If ``reset_real_features`` is not an ``bool``

Example::
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import randint
>>> from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
>>> mifid = MemorizationInformedFrechetInceptionDistance(feature=64)
>>> # generate two slightly overlapping image intensity distributions
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist1 = randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> mifid.update(imgs_dist1, real=True)
>>> mifid.update(imgs_dist2, real=False)
>>> mifid.compute()
Expand Down
12 changes: 4 additions & 8 deletions src/torchmetrics/image/perceptual_path_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
perceptual_path_length,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCHVISION_AVAILABLE

if not _TORCHVISION_AVAILABLE:
if not _TORCHVISION_AVAILABLE or not _TORCH_GREATER_EQUAL_2_0:
__doctest_skip__ = ["PerceptualPathLength"]


Expand Down Expand Up @@ -98,9 +98,7 @@ class PerceptualPathLength(Metric):
If ``upper_discard`` is not a float between 0 and 1 or None.

Example::
>>> from torchmetrics.image import PerceptualPathLength
>>> import torch
>>> _ = torch.manual_seed(42)
>>> class DummyGenerator(torch.nn.Module):
... def __init__(self, z_size) -> None:
... super().__init__()
Expand All @@ -112,10 +110,8 @@ class PerceptualPathLength(Metric):
... return torch.randn(num_samples, self.z_size)
>>> generator = DummyGenerator(2)
>>> ppl = PerceptualPathLength(num_samples=10)
>>> ppl(generator) # doctest: +SKIP
(tensor(0.2371),
tensor(0.1763),
tensor([0.3502, 0.1362, 0.2535, 0.0902, 0.1784, 0.0769, 0.5871, 0.0691, 0.3921]))
>>> ppl(generator)
(tensor(...), tensor(...), tensor([...]))

"""

Expand Down
8 changes: 3 additions & 5 deletions src/torchmetrics/image/psnrb.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,10 @@ class PeakSignalNoiseRatioWithBlockedEffect(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> import torch
>>> from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect
>>> from torch import rand
>>> metric = PeakSignalNoiseRatioWithBlockedEffect()
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand(2, 1, 10, 10)
>>> target = torch.rand(2, 1, 10, 10)
>>> preds = rand(2, 1, 10, 10)
>>> target = rand(2, 1, 10, 10)
>>> metric(preds, target)
tensor(7.2893)

Expand Down
Loading
Loading