Fix race condition in _safe_divide
by creating tensor directly on device
#3284
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes a race condition in
_safe_divide
that could lead to uninitialized values when using non-blocking tensor transfers, particularly affecting MPS devices.Closes #3095
The Problem
The previous implementation created a tensor on CPU and then transferred it to the target device:
This caused a race condition when
non_blocking=True
:.to()
call returns immediately without waiting for the memory copy to completetorch.where()
before the copy finishesIssue reporter experienced "sometimes correct default (0.0) but sometimes uninitialized numbers" on MPS devices.
The Solution
Create the tensor directly on the target device:
This eliminates the race condition by:
Benefits
torch.tensor(..., device=device)
doesn't cause CUDA synchronizationTesting
Added comprehensive test in
tests/unittests/utilities/test_utilities.py
that verifies:zero_division
valuesAll existing tests pass including:
_safe_divide
Related
This is similar to the approach in #3094 which was initially closed due to concerns about CUDA synchronization. However, creating tensors directly on device with
torch.tensor(..., device=device)
does not cause synchronization, unlike using.to(device)
. PR #3101 attempted to fix this by disablingnon_blocking
for MPS, but the race condition could still occur. This PR properly fixes the root cause.Original prompt
✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.
📚 Documentation preview 📚: https://torchmetrics--3284.org.readthedocs.build/en/3284/