Skip to content

Commit

Permalink
Fix bootstrapping with few samples (#2052)
Browse files Browse the repository at this point in the history
* fix

* test

* changelog

(cherry picked from commit 8d82db0)
  • Loading branch information
SkafteNicki authored and Borda committed Sep 11, 2023
1 parent 10243d7 commit 7325d6f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed performance issues in `RecallAtFixedPrecision` for large batch sizes ([#2042](https://github.com/Lightning-AI/torchmetrics/pull/2042))


- Fixed bug in `BootStrapper` when very few samples were evaluated that could lead to crash ([#2052](https://github.com/Lightning-AI/torchmetrics/pull/2052))


- Fixed bug when creating multiple plots that lead to not all plots being shown ([#2060](https://github.com/Lightning-AI/torchmetrics/pull/2060))


Expand Down
19 changes: 11 additions & 8 deletions src/torchmetrics/wrappers/bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,19 @@ def update(self, *args: Any, **kwargs: Any) -> None:
Any tensor passed in will be bootstrapped along dimension 0.
"""
args_sizes = apply_to_collection(args, Tensor, len)
kwargs_sizes = list(apply_to_collection(kwargs, Tensor, len))
if len(args_sizes) > 0:
size = args_sizes[0]
elif len(kwargs_sizes) > 0:
size = kwargs_sizes[0]
else:
raise ValueError("None of the input contained tensors, so could not determine the sampling size")

for idx in range(self.num_bootstraps):
args_sizes = apply_to_collection(args, Tensor, len)
kwargs_sizes = list(apply_to_collection(kwargs, Tensor, len))
if len(args_sizes) > 0:
size = args_sizes[0]
elif len(kwargs_sizes) > 0:
size = kwargs_sizes[0]
else:
raise ValueError("None of the input contained tensors, so could not determine the sampling size")
sample_idx = _bootstrap_sampler(size, sampling_strategy=self.sampling_strategy).to(self.device)
if sample_idx.numel() == 0:
continue
new_args = apply_to_collection(args, Tensor, torch.index_select, dim=0, index=sample_idx)
new_kwargs = apply_to_collection(kwargs, Tensor, torch.index_select, dim=0, index=sample_idx)
self.metrics[idx].update(*new_args, **new_kwargs)
Expand Down
19 changes: 18 additions & 1 deletion tests/unittests/wrappers/test_bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from lightning_utilities import apply_to_collection
from sklearn.metrics import mean_squared_error, precision_score, recall_score
from torch import Tensor
from torchmetrics.classification import MulticlassPrecision, MulticlassRecall
from torchmetrics.classification import MulticlassF1Score, MulticlassPrecision, MulticlassRecall
from torchmetrics.regression import MeanSquaredError
from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler

Expand Down Expand Up @@ -123,3 +123,20 @@ def test_bootstrap(device, sampling_strategy, metric, ref_metric):
assert np.allclose(output["mean"].cpu(), np.mean(sk_scores))
assert np.allclose(output["std"].cpu(), np.std(sk_scores, ddof=1))
assert np.allclose(output["raw"].cpu(), sk_scores)


@pytest.mark.parametrize("sampling_strategy", ["poisson", "multinomial"])
def test_low_sample_amount(sampling_strategy):
"""Test that the metric works with very little data.
In this case it is very likely that no samples from a current batch should be included in one of the bootstraps,
but this should still not crash the metric.
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2048
"""
preds = torch.randn(3, 3).softmax(dim=-1)
target = torch.LongTensor([0, 0, 0])
bootstrap_f1 = BootStrapper(
MulticlassF1Score(num_classes=3, average=None), num_bootstraps=20, sampling_strategy=sampling_strategy
)
assert bootstrap_f1(preds, target) # does not work

0 comments on commit 7325d6f

Please sign in to comment.