Fix TransformerEncoderLayer Full Mask UT Failure on XPU #2336
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The Problem Solved
The assertion failed when the nn.TransformerEncoderLayer's test_transformer_encoder_layer test ran on an XPU device.
#2015
Root Cause: The test is designed to check the Transformer's behavior when the Key Padding Mask completely masks the input (Sequence Length=1, Mask=[[True]]).
The MHA module on XPU devices (lacking a specific Fast Path optimization) falls back to the Non Fast Path execution, where the actual calculated result is a non-NaN finite value (this is the mathematically robust result of X + Attention(0)).
However, by default, the test incorrectly entered the Fast Path assertion branch in XPU/Non CrossRef mode, which expected a NaN result. The actual non-NaN result did not match the expected NaN assertion, causing the test to fail.
Why This Solution
To validate the core semantics of TransformerEncoderLayer on XPU, we need to use the numerically most robust Non Fast Path logic as the benchmark.Non Fast Path Semantics: When fully masked, the attention output is 0, and the result is a finite value (Non-NaN) after text LayerNorm(X).
Solution: By forcing the test into the non fast path branch, its assertion is made to match the actual non-NaN result of the XPU device, ensuring the Unit Test (UT) passes.
The Implementation Method
The fix involves manually setting TEST_WITH_CROSSREF=1 within the affected test function in the relevant test file.
This forces the test method to follow the Non Fast Path branch. This switches the test's assertion from the incorrect self.assertTrue(np.isnan(result).all()) to the correct self.assertTrue(not np.isnan(result).any()), thereby fixing the test.
TIPS:
Given that PYTORCH_TEST_WITH_CROSSREF is a variable that can affect global test semantics and performance, this fix was chosen to be effective within the minimal scope to avoid unintended performance degradation or logic changes in other Unit Tests.