From 2ab44e8eb063134f98d4c2c343feeef6acdb04de Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 6 Sep 2023 12:01:32 +0200 Subject: [PATCH] Fix plot splitter (#2060) * fix + tests * changelog (cherry picked from commit acaf4cc191149eb0454046025721dbcaad569fe9) --- CHANGELOG.md | 4 +++- src/torchmetrics/utilities/plot.py | 4 ++-- tests/unittests/utilities/test_plot.py | 12 ++++++++++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b8e03ddacf4..119127df7b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fixed performance issues in `RecallAtFixedPrecision` for large batch sizes ([#2042](https://github.com/Lightning-AI/torchmetrics/pull/2042)) -- + + +- Fixed bug when creating multiple plots that lead to not all plots being shown ([#2060](https://github.com/Lightning-AI/torchmetrics/pull/2060)) ## [1.1.1] - 2023-08-29 diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 0bbd23d50ed..476be4bd0e8 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -172,9 +172,9 @@ def plot_single_or_multi_val( def _get_col_row_split(n: int) -> Tuple[int, int]: """Split `n` figures into `rows` x `cols` figures.""" nsq = sqrt(n) - if nsq * nsq == n: + if int(nsq) == nsq: # square number return int(nsq), int(nsq) - if floor(nsq) * ceil(nsq) > n: + if floor(nsq) * ceil(nsq) >= n: return floor(nsq), ceil(nsq) return ceil(nsq), ceil(nsq) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index b47b06da7f8..cd452af8633 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -167,6 +167,7 @@ _TORCH_GREATER_EQUAL_1_10, _TORCHAUDIO_GREATER_EQUAL_0_10, ) +from torchmetrics.utilities.plot import _get_col_row_split from torchmetrics.wrappers import ( BootStrapper, ClasswiseWrapper, @@ -789,6 +790,17 @@ def test_plot_methods_retrieval(metric_class, preds, target, indexes, num_vals): plt.close(fig) +@pytest.mark.parametrize( + ("n", "expected_row", "expected_col"), + [(1, 1, 1), (2, 1, 2), (3, 2, 2), (4, 2, 2), (5, 2, 3), (6, 2, 3), (7, 3, 3), (8, 3, 3), (9, 3, 3), (10, 3, 4)], +) +def test_row_col_splitter(n, expected_row, expected_col): + """Test the row col splitter function works as expected.""" + row, col = _get_col_row_split(n) + assert row == expected_row + assert col == expected_col + + @pytest.mark.parametrize( ("metric_class", "preds", "target", "labels"), [