From 7325d6fb7de98fd0c80c025ececa2ac4c024fa48 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 6 Sep 2023 16:19:34 +0200 Subject: [PATCH] Fix bootstrapping with few samples (#2052) * fix * test * changelog (cherry picked from commit 8d82db077b32fc2b68a4584a0f260cd0d845ee37) --- CHANGELOG.md | 3 +++ src/torchmetrics/wrappers/bootstrapping.py | 19 +++++++++++-------- .../unittests/wrappers/test_bootstrapping.py | 19 ++++++++++++++++++- 3 files changed, 32 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 119127df7b6..814578e512b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed performance issues in `RecallAtFixedPrecision` for large batch sizes ([#2042](https://github.com/Lightning-AI/torchmetrics/pull/2042)) +- Fixed bug in `BootStrapper` when very few samples were evaluated that could lead to crash ([#2052](https://github.com/Lightning-AI/torchmetrics/pull/2052)) + + - Fixed bug when creating multiple plots that lead to not all plots being shown ([#2060](https://github.com/Lightning-AI/torchmetrics/pull/2060)) diff --git a/src/torchmetrics/wrappers/bootstrapping.py b/src/torchmetrics/wrappers/bootstrapping.py index 0157ce99f71..dd01a0c5c35 100644 --- a/src/torchmetrics/wrappers/bootstrapping.py +++ b/src/torchmetrics/wrappers/bootstrapping.py @@ -127,16 +127,19 @@ def update(self, *args: Any, **kwargs: Any) -> None: Any tensor passed in will be bootstrapped along dimension 0. """ + args_sizes = apply_to_collection(args, Tensor, len) + kwargs_sizes = list(apply_to_collection(kwargs, Tensor, len)) + if len(args_sizes) > 0: + size = args_sizes[0] + elif len(kwargs_sizes) > 0: + size = kwargs_sizes[0] + else: + raise ValueError("None of the input contained tensors, so could not determine the sampling size") + for idx in range(self.num_bootstraps): - args_sizes = apply_to_collection(args, Tensor, len) - kwargs_sizes = list(apply_to_collection(kwargs, Tensor, len)) - if len(args_sizes) > 0: - size = args_sizes[0] - elif len(kwargs_sizes) > 0: - size = kwargs_sizes[0] - else: - raise ValueError("None of the input contained tensors, so could not determine the sampling size") sample_idx = _bootstrap_sampler(size, sampling_strategy=self.sampling_strategy).to(self.device) + if sample_idx.numel() == 0: + continue new_args = apply_to_collection(args, Tensor, torch.index_select, dim=0, index=sample_idx) new_kwargs = apply_to_collection(kwargs, Tensor, torch.index_select, dim=0, index=sample_idx) self.metrics[idx].update(*new_args, **new_kwargs) diff --git a/tests/unittests/wrappers/test_bootstrapping.py b/tests/unittests/wrappers/test_bootstrapping.py index 7d38d5728ac..2bbb4da8592 100644 --- a/tests/unittests/wrappers/test_bootstrapping.py +++ b/tests/unittests/wrappers/test_bootstrapping.py @@ -21,7 +21,7 @@ from lightning_utilities import apply_to_collection from sklearn.metrics import mean_squared_error, precision_score, recall_score from torch import Tensor -from torchmetrics.classification import MulticlassPrecision, MulticlassRecall +from torchmetrics.classification import MulticlassF1Score, MulticlassPrecision, MulticlassRecall from torchmetrics.regression import MeanSquaredError from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler @@ -123,3 +123,20 @@ def test_bootstrap(device, sampling_strategy, metric, ref_metric): assert np.allclose(output["mean"].cpu(), np.mean(sk_scores)) assert np.allclose(output["std"].cpu(), np.std(sk_scores, ddof=1)) assert np.allclose(output["raw"].cpu(), sk_scores) + + +@pytest.mark.parametrize("sampling_strategy", ["poisson", "multinomial"]) +def test_low_sample_amount(sampling_strategy): + """Test that the metric works with very little data. + + In this case it is very likely that no samples from a current batch should be included in one of the bootstraps, + but this should still not crash the metric. + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2048 + + """ + preds = torch.randn(3, 3).softmax(dim=-1) + target = torch.LongTensor([0, 0, 0]) + bootstrap_f1 = BootStrapper( + MulticlassF1Score(num_classes=3, average=None), num_bootstraps=20, sampling_strategy=sampling_strategy + ) + assert bootstrap_f1(preds, target) # does not work