Skip to content

Commit

Permalink
Reduce memory usage for certain image metrics (#2089)
Browse files Browse the repository at this point in the history
(cherry picked from commit 51439e6)
  • Loading branch information
SkafteNicki authored and Borda committed Dec 1, 2023
1 parent 99ca092 commit 7cf5d09
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 30 deletions.
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

-


### Changed

-
- Change default state of `SpectralAngleMapper` and `UniversalImageQualityIndex` to be tensors ([#2089](https://github.com/Lightning-AI/torchmetrics/pull/2089))


### Removed
Expand Down
45 changes: 31 additions & 14 deletions src/torchmetrics/image/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, List, Optional, Sequence, Union

from torch import Tensor
from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics.functional.image.sam import _sam_compute, _sam_update
Expand Down Expand Up @@ -75,33 +75,50 @@ class SpectralAngleMapper(Metric):

preds: List[Tensor]
target: List[Tensor]
sum_sam: Tensor
numel: Tensor

def __init__(
self,
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
rank_zero_warn(
"Metric `SpectralAngleMapper` will save all targets and predictions in the buffer."
" For large datasets, this may lead to a large memory footprint."
)

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
if reduction not in ("elementwise_mean", "sum", "none", None):
raise ValueError(
f"The `reduction` {reduction} is not valid. Valid options are `elementwise_mean`, `sum`, `none`, None."
)
if reduction == "none" or reduction is None:
rank_zero_warn(
"Metric `SpectralAngleMapper` will save all targets and predictions in the buffer when using"
"`reduction=None` or `reduction='none'. For large datasets, this may lead to a large memory footprint."
)
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
else:
self.add_state("sum_sam", tensor(0.0), dist_reduce_fx="sum")
self.add_state("numel", tensor(0), dist_reduce_fx="sum")
self.reduction = reduction

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
preds, target = _sam_update(preds, target)
self.preds.append(preds)
self.target.append(target)
if self.reduction == "none" or self.reduction is None:
self.preds.append(preds)
self.target.append(target)
else:
sam_score = _sam_compute(preds, target, reduction="sum")
self.sum_sam += sam_score
p_shape = preds.shape
self.numel += p_shape[0] * p_shape[2] * p_shape[3]

def compute(self) -> Tensor:
"""Compute spectra over state."""
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _sam_compute(preds, target, self.reduction)
if self.reduction == "none" or self.reduction is None:
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _sam_compute(preds, target, self.reduction)
return self.sum_sam / self.numel if self.reduction == "elementwise_mean" else self.sum_sam

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
44 changes: 30 additions & 14 deletions src/torchmetrics/image/uqi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, List, Optional, Sequence, Union

from torch import Tensor
from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics.functional.image.uqi import _uqi_compute, _uqi_update
Expand Down Expand Up @@ -73,6 +73,8 @@ class UniversalImageQualityIndex(Metric):

preds: List[Tensor]
target: List[Tensor]
sum_uqi: Tensor
numel: Tensor

def __init__(
self,
Expand All @@ -82,29 +84,43 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
rank_zero_warn(
"Metric `UniversalImageQualityIndex` will save all targets and"
" predictions in buffer. For large datasets this may lead"
" to large memory footprint."
)

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
if reduction not in ("elementwise_mean", "sum", "none", None):
raise ValueError(
f"The `reduction` {reduction} is not valid. Valid options are `elementwise_mean`, `sum`, `none`, None."
)
if reduction is None or reduction == "none":
rank_zero_warn(
"Metric `UniversalImageQualityIndex` will save all targets and predictions in the buffer when using"
"`reduction=None` or `reduction='none'. For large datasets, this may lead to a large memory footprint."
)
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
else:
self.add_state("sum_uqi", tensor(0.0), dist_reduce_fx="sum")
self.add_state("numel", tensor(0), dist_reduce_fx="sum")
self.kernel_size = kernel_size
self.sigma = sigma
self.reduction = reduction

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
preds, target = _uqi_update(preds, target)
self.preds.append(preds)
self.target.append(target)
if self.reduction is None or self.reduction == "none":
self.preds.append(preds)
self.target.append(target)
else:
uqi_score = _uqi_compute(preds, target, self.kernel_size, self.sigma, reduction="sum")
self.sum_uqi += uqi_score
ps = preds.shape
self.numel += ps[0] * ps[1] * (ps[2] - self.kernel_size[0] + 1) * (ps[3] - self.kernel_size[1] + 1)

def compute(self) -> Tensor:
"""Compute explained variance over state."""
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _uqi_compute(preds, target, self.kernel_size, self.sigma, self.reduction)
if self.reduction == "none" or self.reduction is None:
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _uqi_compute(preds, target, self.kernel_size, self.sigma, self.reduction)
return self.sum_uqi / self.numel if self.reduction == "elementwise_mean" else self.sum_uqi

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down

0 comments on commit 7cf5d09

Please sign in to comment.