From 53c65d5d298e2267b4223eae70e38813e72de612 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 28 Aug 2023 22:09:17 +0200 Subject: [PATCH] Fix `MetricCollection` when input are metrics that return dicts with same keywords (#2027) Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Jirka (cherry picked from commit 58fc9c68ac89b3adf4ba953944e590d7b337e4db) --- CHANGELOG.md | 9 ++++--- src/torchmetrics/collections.py | 18 ++++++++++--- src/torchmetrics/utilities/data.py | 13 ++++++--- tests/unittests/bases/test_collections.py | 30 ++++++++++++++++++--- tests/unittests/utilities/test_utilities.py | 5 ++-- 5 files changed, 58 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fbd5aa645b..1bd786f566d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,16 +26,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- Fixed bug in `PearsonCorrCoef` is updated on single samples at a time ([#2019](https://github.com/Lightning-AI/torchmetrics/pull/2019) +- Fixed bug in `PearsonCorrCoef` is updated on single samples at a time ([#2019](https://github.com/Lightning-AI/torchmetrics/pull/2019)) -- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017) +- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017)) + + +- Fixed bug in `MetricCollection` when used with multiple metrics that return dicts with same keys ([#2027](https://github.com/Lightning-AI/torchmetrics/pull/2027)) - Fixed bug in detection intersection metrics when `class_metrics=True` resulting in wrong values ([#1924](https://github.com/Lightning-AI/torchmetrics/pull/1924)) -- Fixed missing attributes `higher_is_better`, `is_differentiable` for some metrics ([#2028](https://github.com/Lightning-AI/torchmetrics/pull/2028) +- Fixed missing attributes `higher_is_better`, `is_differentiable` for some metrics ([#2028](https://github.com/Lightning-AI/torchmetrics/pull/2028)) ## [1.1.0] - 2023-08-22 diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index d82646d7ce7..c7077f537f5 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -23,7 +23,7 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.data import allclose +from torchmetrics.utilities.data import _flatten_dict, allclose from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val @@ -334,17 +334,27 @@ def _compute_and_reduce( res = m(*args, **m._filter_kwargs(**kwargs)) else: raise ValueError("method_name should be either 'compute' or 'forward', but got {method_name}") + result[k] = res + _, duplicates = _flatten_dict(result) + + flattened_results = {} + for k, res in result.items(): if isinstance(res, dict): for key, v in res.items(): + # if duplicates of keys we need to add unique prefix to each key + if duplicates: + stripped_k = k.replace(getattr(m, "prefix", ""), "") + stripped_k = stripped_k.replace(getattr(m, "postfix", ""), "") + key = f"{stripped_k}_{key}" if hasattr(m, "prefix") and m.prefix is not None: key = f"{m.prefix}{key}" if hasattr(m, "postfix") and m.postfix is not None: key = f"{key}{m.postfix}" - result[key] = v + flattened_results[key] = v else: - result[k] = res - return {self._set_name(k): v for k, v in result.items()} + flattened_results[k] = res + return {self._set_name(k): v for k, v in flattened_results.items()} def reset(self) -> None: """Call reset for each metric sequentially.""" diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index ebb81679a02..8e818a144f7 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch from lightning_utilities import apply_to_collection @@ -60,16 +60,21 @@ def _flatten(x: Sequence) -> list: return [item for sublist in x for item in sublist] -def _flatten_dict(x: Dict) -> Dict: - """Flatten dict of dicts into single dict.""" +def _flatten_dict(x: Dict) -> Tuple[Dict, bool]: + """Flatten dict of dicts into single dict and checking for duplicates in keys along the way.""" new_dict = {} + duplicates = False for key, value in x.items(): if isinstance(value, dict): for k, v in value.items(): + if k in new_dict: + duplicates = True new_dict[k] = v else: + if key in new_dict: + duplicates = True new_dict[key] = value - return new_dict + return new_dict, duplicates def to_onehot( diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index fafef8fb14d..834f764ff81 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -614,11 +614,33 @@ def test_nested_collections(input_collections): assert "valmetrics/micro_MulticlassPrecision" in val -def test_double_nested_collections(): +@pytest.mark.parametrize( + ("base_metrics", "expected"), + [ + ( + DummyMetricMultiOutputDict(), + ( + "prefix2_prefix1_output1_postfix1_postfix2", + "prefix2_prefix1_output2_postfix1_postfix2", + ), + ), + ( + {"metric1": DummyMetricMultiOutputDict(), "metric2": DummyMetricMultiOutputDict()}, + ( + "prefix2_prefix1_metric1_output1_postfix1_postfix2", + "prefix2_prefix1_metric1_output2_postfix1_postfix2", + "prefix2_prefix1_metric2_output1_postfix1_postfix2", + "prefix2_prefix1_metric2_output2_postfix1_postfix2", + ), + ), + ], +) +def test_double_nested_collections(base_metrics, expected): """Test that double nested collections gets flattened to a single collection.""" - collection1 = MetricCollection([DummyMetricMultiOutputDict()], prefix="prefix1_", postfix="_postfix1") + collection1 = MetricCollection(base_metrics, prefix="prefix1_", postfix="_postfix1") collection2 = MetricCollection([collection1], prefix="prefix2_", postfix="_postfix2") x = torch.randn(10).sum() val = collection2(x) - assert "prefix2_prefix1_output1_postfix1_postfix2" in val - assert "prefix2_prefix1_output2_postfix1_postfix2" in val + + for key in val: + assert key in expected diff --git a/tests/unittests/utilities/test_utilities.py b/tests/unittests/utilities/test_utilities.py index d0e38abadfb..ca05ce5f75b 100644 --- a/tests/unittests/utilities/test_utilities.py +++ b/tests/unittests/utilities/test_utilities.py @@ -113,8 +113,9 @@ def test_flatten_list(): def test_flatten_dict(): """Check that _flatten_dict utility function works as expected.""" inp = {"a": {"b": 1, "c": 2}, "d": 3} - out = _flatten_dict(inp) - assert out == {"b": 1, "c": 2, "d": 3} + out_dict, out_dup = _flatten_dict(inp) + assert out_dict == {"b": 1, "c": 2, "d": 3} + assert out_dup is False @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu")