Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bootstrapping with few samples #2052

Merged
merged 6 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed performance issues in `RecallAtFixedPrecision` for large batch sizes ([#2042](https://github.com/Lightning-AI/torchmetrics/pull/2042))


- Fixed bug in `BootStrapper` when very few samples were evaluated that could lead to crash ([#2052](https://github.com/Lightning-AI/torchmetrics/pull/2052))


- Fixed bug when creating multiple plots that lead to not all plots being shown ([#2060](https://github.com/Lightning-AI/torchmetrics/pull/2060))


Expand Down
19 changes: 11 additions & 8 deletions src/torchmetrics/wrappers/bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,19 @@ def update(self, *args: Any, **kwargs: Any) -> None:
Any tensor passed in will be bootstrapped along dimension 0.

"""
args_sizes = apply_to_collection(args, Tensor, len)
kwargs_sizes = list(apply_to_collection(kwargs, Tensor, len))
if len(args_sizes) > 0:
size = args_sizes[0]
elif len(kwargs_sizes) > 0:
size = kwargs_sizes[0]
else:
raise ValueError("None of the input contained tensors, so could not determine the sampling size")

for idx in range(self.num_bootstraps):
args_sizes = apply_to_collection(args, Tensor, len)
kwargs_sizes = list(apply_to_collection(kwargs, Tensor, len))
if len(args_sizes) > 0:
size = args_sizes[0]
elif len(kwargs_sizes) > 0:
size = kwargs_sizes[0]
else:
raise ValueError("None of the input contained tensors, so could not determine the sampling size")
sample_idx = _bootstrap_sampler(size, sampling_strategy=self.sampling_strategy).to(self.device)
if sample_idx.numel() == 0:
continue
new_args = apply_to_collection(args, Tensor, torch.index_select, dim=0, index=sample_idx)
new_kwargs = apply_to_collection(kwargs, Tensor, torch.index_select, dim=0, index=sample_idx)
self.metrics[idx].update(*new_args, **new_kwargs)
Expand Down
19 changes: 18 additions & 1 deletion tests/unittests/wrappers/test_bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from lightning_utilities import apply_to_collection
from sklearn.metrics import mean_squared_error, precision_score, recall_score
from torch import Tensor
from torchmetrics.classification import MulticlassPrecision, MulticlassRecall
from torchmetrics.classification import MulticlassF1Score, MulticlassPrecision, MulticlassRecall
from torchmetrics.regression import MeanSquaredError
from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler

Expand Down Expand Up @@ -123,3 +123,20 @@ def test_bootstrap(device, sampling_strategy, metric, ref_metric):
assert np.allclose(output["mean"].cpu(), np.mean(sk_scores))
assert np.allclose(output["std"].cpu(), np.std(sk_scores, ddof=1))
assert np.allclose(output["raw"].cpu(), sk_scores)


@pytest.mark.parametrize("sampling_strategy", ["poisson", "multinomial"])
def test_low_sample_amount(sampling_strategy):
"""Test that the metric works with very little data.

In this case it is very likely that no samples from a current batch should be included in one of the bootstraps,
but this should still not crash the metric.
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2048

"""
preds = torch.randn(3, 3).softmax(dim=-1)
target = torch.LongTensor([0, 0, 0])
bootstrap_f1 = BootStrapper(
MulticlassF1Score(num_classes=3, average=None), num_bootstraps=20, sampling_strategy=sampling_strategy
)
assert bootstrap_f1(preds, target) # does not work
Loading