Skip to content

Commit

Permalink
Fix issue with shared state of MetricCollection compute group when …
Browse files Browse the repository at this point in the history
…using `DiceScore(average="weighted")` (#2848)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
  • Loading branch information
3 people authored Dec 11, 2024
1 parent a968ebe commit cd24d2b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed issue with shared state in metric collection when using dice score ([#2848](https://github.com/PyTorchLightning/metrics/pull/2848))


---
Expand Down
3 changes: 1 addition & 2 deletions src/torchmetrics/segmentation/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:
)
self.numerator.append(numerator)
self.denominator.append(denominator)
if self.average == "weighted":
self.support.append(support)
self.support.append(support)

def compute(self) -> Tensor:
"""Computes the Dice Score."""
Expand Down
30 changes: 30 additions & 0 deletions tests/unittests/segmentation/test_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest
import torch
from sklearn.metrics import f1_score
from torchmetrics import MetricCollection
from torchmetrics.functional.segmentation.dice import dice_score
from torchmetrics.segmentation.dice import DiceScore

Expand Down Expand Up @@ -106,3 +107,32 @@ def test_dice_score_functional(self, preds, target, input_format, include_backgr
"input_format": input_format,
},
)


@pytest.mark.parametrize("compute_groups", [True, False])
def test_dice_score_metric_collection(compute_groups: bool, num_batches: int = 4):
"""Test that the metric works within a metric collection with and without compute groups."""
metric_collection = MetricCollection(
metrics={
"DiceScore (micro)": DiceScore(
num_classes=NUM_CLASSES,
average="micro",
),
"DiceScore (macro)": DiceScore(
num_classes=NUM_CLASSES,
average="macro",
),
"DiceScore (weighted)": DiceScore(
num_classes=NUM_CLASSES,
average="weighted",
),
},
compute_groups=compute_groups,
)

for _ in range(num_batches):
metric_collection.update(_inputs1.preds, _inputs1.target)
result = metric_collection.compute()

assert isinstance(result, dict)
assert len(set(metric_collection.keys()) - set(result.keys())) == 0

0 comments on commit cd24d2b

Please sign in to comment.