Skip to content

Commit

Permalink
refactor usage of manual seed (Lightning-AI#2007)
Browse files Browse the repository at this point in the history
* refactor usage of manual seed

* fix outputs
  • Loading branch information
Borda authored Aug 21, 2023
1 parent 29f3289 commit 1cf4de9
Show file tree
Hide file tree
Showing 15 changed files with 74 additions and 55 deletions.
16 changes: 10 additions & 6 deletions examples/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ def error_relative_global_dimensionless_synthesis() -> tuple:
"""Plot error relative global dimensionless synthesis example."""
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis

p = lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
t = lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
gen = torch.manual_seed(42)
p = lambda: torch.rand([16, 1, 16, 16], generator=gen)
t = lambda: torch.rand([16, 1, 16, 16], generator=gen)

# plot single value
metric = ErrorRelativeGlobalDimensionlessSynthesis()
Expand Down Expand Up @@ -286,8 +287,9 @@ def spectral_angle_mapper() -> tuple:
"""Plot spectral angle mapper example."""
from torchmetrics.image.sam import SpectralAngleMapper

p = lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
t = lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
gen = torch.manual_seed(42)
p = lambda: torch.rand([16, 3, 16, 16], generator=gen)
t = lambda: torch.rand([16, 3, 16, 16], generator=gen)

# plot single value
metric = SpectralAngleMapper()
Expand All @@ -306,7 +308,8 @@ def structural_similarity_index_measure() -> tuple:
"""Plot structural similarity index measure example."""
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure

p = lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
gen = torch.manual_seed(42)
p = lambda: torch.rand([3, 3, 256, 256], generator=gen)
t = lambda: p() * 0.75

# plot single value
Expand All @@ -326,7 +329,8 @@ def multiscale_structural_similarity_index_measure() -> tuple:
"""Plot multiscale structural similarity index measure example."""
from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure

p = lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
gen = torch.manual_seed(42)
p = lambda: torch.rand([3, 3, 256, 256], generator=gen)
t = lambda: p() * 0.75

# plot single value
Expand Down
25 changes: 14 additions & 11 deletions src/torchmetrics/functional/image/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def _error_relative_global_dimensionless_synthesis(
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 1, 16, 16], generator=gen)
>>> target = preds * 0.75
>>> ergds = _error_relative_global_dimensionless_synthesis(preds, target)
>>> torch.round(ergds)
Expand Down Expand Up @@ -105,9 +106,9 @@ def _relative_average_spectral_error(preds: Tensor, target: Tensor, window_size:
"""Wrapper for deprecated import.
>>> import torch
>>> g = torch.manual_seed(22)
>>> preds = torch.rand(4, 3, 16, 16)
>>> target = torch.rand(4, 3, 16, 16)
>>> gen = torch.manual_seed(22)
>>> preds = torch.rand(4, 3, 16, 16, generator=gen)
>>> target = torch.rand(4, 3, 16, 16, generator=gen)
>>> _relative_average_spectral_error(preds, target)
tensor(5114.6641)
Expand All @@ -122,9 +123,9 @@ def _root_mean_squared_error_using_sliding_window(
"""Wrapper for deprecated import.
>>> import torch
>>> g = torch.manual_seed(22)
>>> preds = torch.rand(4, 3, 16, 16)
>>> target = torch.rand(4, 3, 16, 16)
>>> gen = torch.manual_seed(22)
>>> preds = torch.rand(4, 3, 16, 16, generator=gen)
>>> target = torch.rand(4, 3, 16, 16, generator=gen)
>>> _root_mean_squared_error_using_sliding_window(preds, target)
tensor(0.3999)
Expand All @@ -143,10 +144,11 @@ def _spectral_angle_mapper(
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16], generator=gen)
>>> target = torch.rand([16, 3, 16, 16], generator=gen)
>>> _spectral_angle_mapper(preds, target)
tensor(0.5943)
tensor(0.5914)
"""
_deprecated_root_import_func("spectral_angle_mapper", "image")
Expand All @@ -169,7 +171,8 @@ def _multiscale_structural_similarity_index_measure(
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([3, 3, 256, 256], generator=gen)
>>> target = preds * 0.75
>>> _multiscale_structural_similarity_index_measure(preds, target, data_range=1.0)
tensor(0.9627)
Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/functional/image/ergas.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def _ergas_compute(
- ``'none'`` or ``None``: no reduction will be applied
Example:
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 1, 16, 16], generator=gen)
>>> target = preds * 0.75
>>> preds, target = _ergas_update(preds, target)
>>> torch.round(_ergas_compute(preds, target))
Expand Down Expand Up @@ -111,7 +112,8 @@ def error_relative_global_dimensionless_synthesis(
Example:
>>> from torchmetrics.functional.image import error_relative_global_dimensionless_synthesis
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 1, 16, 16], generator=gen)
>>> target = preds * 0.75
>>> ergds = error_relative_global_dimensionless_synthesis(preds, target)
>>> torch.round(ergds)
Expand Down
14 changes: 8 additions & 6 deletions src/torchmetrics/functional/image/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,12 @@ def _sam_compute(
- ``'none'`` or ``None``: no reduction will be applied
Example:
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16], generator=gen)
>>> target = torch.rand([16, 3, 16, 16], generator=gen)
>>> preds, target = _sam_update(preds, target)
>>> _sam_compute(preds, target)
tensor(0.5943)
tensor(0.5914)
"""
dot_product = (preds * target).sum(dim=1)
Expand Down Expand Up @@ -106,10 +107,11 @@ def spectral_angle_mapper(
Example:
>>> from torchmetrics.functional.image import spectral_angle_mapper
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16], generator=gen)
>>> target = torch.rand([16, 3, 16, 16], generator=gen)
>>> spectral_angle_mapper(preds, target)
tensor(0.5943)
tensor(0.5914)
References:
[1] Roberta H. Yuhas, Alexander F. H. Goetz and Joe W. Boardman, "Discrimination among semi-arid
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/functional/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,8 @@ def multiscale_structural_similarity_index_measure(
Example:
>>> from torchmetrics.functional.image import multiscale_structural_similarity_index_measure
>>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([3, 3, 256, 256], generator=gen)
>>> target = preds * 0.75
>>> multiscale_structural_similarity_index_measure(preds, target, data_range=1.0)
tensor(0.9627)
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def clip_score(
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.functional.multimodal import clip_score
>>> score = clip_score(torch.randint(255, (3, 224, 224)), "a photo of a cat", "openai/clip-vit-base-patch16")
>>> print(score.detach())
>>> score.detach()
tensor(24.4255)
"""
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/functional/text/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,12 @@ def _perplexity(preds: Tensor, target: Tensor, ignore_index: Optional[int] = Non
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))
>>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand(2, 8, 5, generator=gen)
>>> target = torch.randint(5, (2, 8), generator=gen)
>>> target[0, 6:] = -100
>>> _perplexity(preds, target, ignore_index=-100)
tensor(5.2545)
tensor(5.8540)
"""
_deprecated_root_import_func("perplexity", "text")
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/functional/text/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,12 @@ def perplexity(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None
Examples:
>>> import torch
>>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))
>>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand(2, 8, 5, generator=gen)
>>> target = torch.randint(5, (2, 8), generator=gen)
>>> target[0, 6:] = -100
>>> perplexity(preds, target, ignore_index=-100)
tensor(5.2545)
tensor(5.8540)
"""
total, count = _perplexity_update(preds, target, ignore_index)
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/image/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,12 @@ class _SpectralAngleMapper(SpectralAngleMapper):
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16], generator=gen)
>>> target = torch.rand([16, 3, 16, 16], generator=gen)
>>> sam = _SpectralAngleMapper()
>>> sam(preds, target)
tensor(0.5943)
tensor(0.5914)
"""

Expand Down
3 changes: 1 addition & 2 deletions src/torchmetrics/image/kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ class KernelInceptionDistance(Metric):
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> kid.update(imgs_dist1, real=True)
>>> kid.update(imgs_dist2, real=False)
>>> kid_mean, kid_std = kid.compute()
>>> print((kid_mean, kid_std))
>>> kid.compute()
(tensor(0.0337), tensor(0.0023))
"""
Expand Down
17 changes: 10 additions & 7 deletions src/torchmetrics/image/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@ class SpectralAngleMapper(Metric):
Example:
>>> import torch
>>> from torchmetrics.image import SpectralAngleMapper
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16], generator=gen)
>>> target = torch.rand([16, 3, 16, 16], generator=gen)
>>> sam = SpectralAngleMapper()
>>> sam(preds, target)
tensor(0.5943)
tensor(0.5914)
"""

Expand Down Expand Up @@ -125,8 +126,9 @@ def plot(
>>> # Example plotting single value
>>> import torch
>>> from torchmetrics.image import SpectralAngleMapper
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16], generator=gen)
>>> target = torch.rand([16, 3, 16, 16], generator=gen)
>>> metric = SpectralAngleMapper()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
Expand All @@ -137,8 +139,9 @@ def plot(
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image import SpectralAngleMapper
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16], generator=gen)
>>> target = torch.rand([16, 3, 16, 16], generator=gen)
>>> metric = SpectralAngleMapper()
>>> values = [ ]
>>> for _ in range(10):
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ class MultiScaleStructuralSimilarityIndexMeasure(Metric):
Example:
>>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
>>> import torch
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
Expand Down
7 changes: 3 additions & 4 deletions src/torchmetrics/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,11 @@ class CLIPScore(Metric):
Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.multimodal.clip_score import CLIPScore
>>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")
>>> score = metric(torch.randint(255, (3, 224, 224)), "a photo of a cat")
>>> print(score.detach())
tensor(24.7691)
>>> score = metric(torch.randint(255, (3, 224, 224), generator=torch.manual_seed(42)), "a photo of a cat")
>>> score.detach()
tensor(24.4255)
"""

Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/text/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,13 @@ class _Perplexity(Perplexity):
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))
>>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand(2, 8, 5, generator=gen)
>>> target = torch.randint(5, (2, 8), generator=gen)
>>> target[0, 6:] = -100
>>> perp = _Perplexity(ignore_index=-100)
>>> perp(preds, target)
tensor(5.2545)
tensor(5.8540)
"""

Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/text/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ class Perplexity(Metric):
Examples:
>>> from torchmetrics.text import Perplexity
>>> import torch
>>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))
>>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand(2, 8, 5, generator=gen)
>>> target = torch.randint(5, (2, 8), generator=gen)
>>> target[0, 6:] = -100
>>> perp = Perplexity(ignore_index=-100)
>>> perp(preds, target)
tensor(5.2545)
tensor(5.8540)
"""
is_differentiable = True
Expand Down

0 comments on commit 1cf4de9

Please sign in to comment.