Skip to content

Commit

Permalink
tests: move reset seed to fixture [2/2] (#2706)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Aug 30, 2024
1 parent c233f36 commit 7fc4177
Show file tree
Hide file tree
Showing 36 changed files with 208 additions and 257 deletions.
47 changes: 21 additions & 26 deletions src/torchmetrics/image/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
class _ErrorRelativeGlobalDimensionlessSynthesis(ErrorRelativeGlobalDimensionlessSynthesis):
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> from torch import rand
>>> preds = rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> ergas = _ErrorRelativeGlobalDimensionlessSynthesis()
>>> torch.round(ergas(preds, target))
>>> ergas(preds, target).round()
tensor(10.)
"""
Expand All @@ -39,8 +39,8 @@ def __init__(
class _MultiScaleStructuralSimilarityIndexMeasure(MultiScaleStructuralSimilarityIndexMeasure):
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
>>> from torch import rand
>>> preds = rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> ms_ssim = _MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
>>> ms_ssim(preds, target)
Expand Down Expand Up @@ -103,13 +103,12 @@ def __init__(
class _RelativeAverageSpectralError(RelativeAverageSpectralError):
"""Wrapper for deprecated import.
>>> import torch
>>> g = torch.manual_seed(22)
>>> preds = torch.rand(4, 3, 16, 16)
>>> target = torch.rand(4, 3, 16, 16)
>>> from torch import rand
>>> preds = rand(4, 3, 16, 16)
>>> target = rand(4, 3, 16, 16)
>>> rase = _RelativeAverageSpectralError()
>>> rase(preds, target)
tensor(5114.66...)
tensor(5326.40...)
"""

Expand All @@ -125,13 +124,12 @@ def __init__(
class _RootMeanSquaredErrorUsingSlidingWindow(RootMeanSquaredErrorUsingSlidingWindow):
"""Wrapper for deprecated import.
>>> import torch
>>> g = torch.manual_seed(22)
>>> preds = torch.rand(4, 3, 16, 16)
>>> target = torch.rand(4, 3, 16, 16)
>>> from torch import rand
>>> preds = rand(4, 3, 16, 16)
>>> target = rand(4, 3, 16, 16)
>>> rmse_sw = RootMeanSquaredErrorUsingSlidingWindow()
>>> rmse_sw(preds, target)
tensor(0.3999)
tensor(0.4158)
"""

Expand All @@ -147,10 +145,9 @@ def __init__(
class _SpectralAngleMapper(SpectralAngleMapper):
"""Wrapper for deprecated import.
>>> import torch
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16], generator=gen)
>>> target = torch.rand([16, 3, 16, 16], generator=gen)
>>> from torch import rand
>>> preds = rand([16, 3, 16, 16])
>>> target = rand([16, 3, 16, 16])
>>> sam = _SpectralAngleMapper()
>>> sam(preds, target)
tensor(0.5914)
Expand All @@ -169,10 +166,9 @@ def __init__(
class _SpectralDistortionIndex(SpectralDistortionIndex):
"""Wrapper for deprecated import.
>>> import torch
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> from torch import rand
>>> preds = rand([16, 3, 16, 16])
>>> target = rand([16, 3, 16, 16])
>>> sdi = _SpectralDistortionIndex()
>>> sdi(preds, target)
tensor(0.0234)
Expand Down Expand Up @@ -229,10 +225,9 @@ def __init__(
class _TotalVariation(TotalVariation):
"""Wrapper for deprecated import.
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import rand
>>> tv = _TotalVariation()
>>> img = torch.rand(5, 3, 28, 28)
>>> img = rand(5, 3, 28, 28)
>>> tv(img)
tensor(7546.8018)
Expand Down
21 changes: 9 additions & 12 deletions src/torchmetrics/image/d_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ class SpectralDistortionIndex(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import rand
>>> from torchmetrics.image import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> preds = rand([16, 3, 16, 16])
>>> target = rand([16, 3, 16, 16])
>>> sdi = SpectralDistortionIndex()
>>> sdi(preds, target)
tensor(0.0234)
Expand Down Expand Up @@ -126,11 +125,10 @@ def plot(
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import rand
>>> from torchmetrics.image import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> preds = rand([16, 3, 16, 16])
>>> target = rand([16, 3, 16, 16])
>>> metric = SpectralDistortionIndex()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
Expand All @@ -139,11 +137,10 @@ def plot(
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import rand
>>> from torchmetrics.image import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> preds = rand([16, 3, 16, 16])
>>> target = rand([16, 3, 16, 16])
>>> metric = SpectralDistortionIndex()
>>> values = [ ]
>>> for _ in range(10):
Expand Down
27 changes: 12 additions & 15 deletions src/torchmetrics/image/d_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,12 @@ class SpatialDistortionIndex(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import rand
>>> from torchmetrics.image import SpatialDistortionIndex
>>> preds = torch.rand([16, 3, 32, 32])
>>> preds = rand([16, 3, 32, 32])
>>> target = {
... 'ms': torch.rand([16, 3, 16, 16]),
... 'pan': torch.rand([16, 3, 32, 32]),
... 'ms': rand([16, 3, 16, 16]),
... 'pan': rand([16, 3, 32, 32]),
... }
>>> sdi = SpatialDistortionIndex()
>>> sdi(preds, target)
Expand Down Expand Up @@ -191,13 +190,12 @@ def plot(
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import rand
>>> from torchmetrics.image import SpatialDistortionIndex
>>> preds = torch.rand([16, 3, 32, 32])
>>> preds = rand([16, 3, 32, 32])
>>> target = {
... 'ms': torch.rand([16, 3, 16, 16]),
... 'pan': torch.rand([16, 3, 32, 32]),
... 'ms': rand([16, 3, 16, 16]),
... 'pan': rand([16, 3, 32, 32]),
... }
>>> metric = SpatialDistortionIndex()
>>> metric.update(preds, target)
Expand All @@ -207,13 +205,12 @@ def plot(
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import rand
>>> from torchmetrics.image import SpatialDistortionIndex
>>> preds = torch.rand([16, 3, 32, 32])
>>> preds = rand([16, 3, 32, 32])
>>> target = {
... 'ms': torch.rand([16, 3, 16, 16]),
... 'pan': torch.rand([16, 3, 32, 32]),
... 'ms': rand([16, 3, 16, 16]),
... 'pan': rand([16, 3, 32, 32]),
... }
>>> metric = SpatialDistortionIndex()
>>> values = [ ]
Expand Down
14 changes: 7 additions & 7 deletions src/torchmetrics/image/ergas.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ class ErrorRelativeGlobalDimensionlessSynthesis(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> import torch
>>> from torch import rand
>>> from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> preds = rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> ergas = ErrorRelativeGlobalDimensionlessSynthesis()
>>> torch.round(ergas(preds, target))
>>> ergas(preds, target).round()
tensor(10.)
"""
Expand Down Expand Up @@ -131,9 +131,9 @@ def plot(
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torch import rand
>>> from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> preds = rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> metric = ErrorRelativeGlobalDimensionlessSynthesis()
>>> metric.update(preds, target)
Expand All @@ -143,9 +143,9 @@ def plot(
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torch import rand
>>> from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> preds = rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> metric = ErrorRelativeGlobalDimensionlessSynthesis()
>>> values = [ ]
Expand Down
5 changes: 2 additions & 3 deletions src/torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ class FrechetInceptionDistance(Metric):
If ``reset_real_features`` is not an ``bool``
Example:
>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torch import rand
>>> from torchmetrics.image.fid import FrechetInceptionDistance
>>> fid = FrechetInceptionDistance(feature=64)
>>> # generate two slightly overlapping image intensity distributions
Expand All @@ -275,7 +274,7 @@ class FrechetInceptionDistance(Metric):
>>> fid.update(imgs_dist1, real=True)
>>> fid.update(imgs_dist2, real=False)
>>> fid.compute()
tensor(12.7202)
tensor(12.6388)
"""

Expand Down
5 changes: 2 additions & 3 deletions src/torchmetrics/image/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,14 @@ class InceptionScore(Metric):
If ``feature`` is not an ``str``, ``int`` or ``torch.nn.Module``
Example:
>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torch import rand
>>> from torchmetrics.image.inception import InceptionScore
>>> inception = InceptionScore()
>>> # generate some images
>>> imgs = torch.randint(0, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> inception.update(imgs)
>>> inception.compute()
(tensor(1.0544), tensor(0.0117))
(tensor(1.0549), tensor(0.0121))
"""

Expand Down
9 changes: 4 additions & 5 deletions src/torchmetrics/image/kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,16 @@ class KernelInceptionDistance(Metric):
If ``reset_real_features`` is not an ``bool``
Example:
>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torch import randint
>>> from torchmetrics.image.kid import KernelInceptionDistance
>>> kid = KernelInceptionDistance(subset_size=50)
>>> # generate two slightly overlapping image intensity distributions
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist1 = randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> kid.update(imgs_dist1, real=True)
>>> kid.update(imgs_dist2, real=False)
>>> kid.compute()
(tensor(0.0337), tensor(0.0023))
(tensor(0.0312), tensor(0.0025))
"""

Expand Down
9 changes: 4 additions & 5 deletions src/torchmetrics/image/lpip.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,14 @@ class LearnedPerceptualImagePatchSimilarity(Metric):
If ``reduction`` is not one of ``"mean"`` or ``"sum"``
Example:
>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torch import rand
>>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
>>> lpips = LearnedPerceptualImagePatchSimilarity(net_type='squeeze')
>>> # LPIPS needs the images to be in the [-1, 1] range.
>>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> img1 = (rand(10, 3, 100, 100) * 2) - 1
>>> img2 = (rand(10, 3, 100, 100) * 2) - 1
>>> lpips(img1, img2)
tensor(0.1046)
tensor(0.1024)
"""

Expand Down
7 changes: 3 additions & 4 deletions src/torchmetrics/image/mifid.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,12 @@ class MemorizationInformedFrechetInceptionDistance(Metric):
If ``reset_real_features`` is not an ``bool``
Example::
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torch import randint
>>> from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
>>> mifid = MemorizationInformedFrechetInceptionDistance(feature=64)
>>> # generate two slightly overlapping image intensity distributions
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist1 = randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> mifid.update(imgs_dist1, real=True)
>>> mifid.update(imgs_dist2, real=False)
>>> mifid.compute()
Expand Down
12 changes: 4 additions & 8 deletions src/torchmetrics/image/perceptual_path_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
perceptual_path_length,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCHVISION_AVAILABLE

if not _TORCHVISION_AVAILABLE:
if not _TORCHVISION_AVAILABLE or not _TORCH_GREATER_EQUAL_2_0:
__doctest_skip__ = ["PerceptualPathLength"]


Expand Down Expand Up @@ -98,9 +98,7 @@ class PerceptualPathLength(Metric):
If ``upper_discard`` is not a float between 0 and 1 or None.
Example::
>>> from torchmetrics.image import PerceptualPathLength
>>> import torch
>>> _ = torch.manual_seed(42)
>>> class DummyGenerator(torch.nn.Module):
... def __init__(self, z_size) -> None:
... super().__init__()
Expand All @@ -112,10 +110,8 @@ class PerceptualPathLength(Metric):
... return torch.randn(num_samples, self.z_size)
>>> generator = DummyGenerator(2)
>>> ppl = PerceptualPathLength(num_samples=10)
>>> ppl(generator) # doctest: +SKIP
(tensor(0.2371),
tensor(0.1763),
tensor([0.3502, 0.1362, 0.2535, 0.0902, 0.1784, 0.0769, 0.5871, 0.0691, 0.3921]))
>>> ppl(generator)
(tensor(...), tensor(...), tensor([...]))
"""

Expand Down
8 changes: 3 additions & 5 deletions src/torchmetrics/image/psnrb.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,10 @@ class PeakSignalNoiseRatioWithBlockedEffect(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> import torch
>>> from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect
>>> from torch import rand
>>> metric = PeakSignalNoiseRatioWithBlockedEffect()
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand(2, 1, 10, 10)
>>> target = torch.rand(2, 1, 10, 10)
>>> preds = rand(2, 1, 10, 10)
>>> target = rand(2, 1, 10, 10)
>>> metric(preds, target)
tensor(7.2893)
Expand Down
Loading

0 comments on commit 7fc4177

Please sign in to comment.