Enable cross entropy loss for xla autocast with FP32 precision #7992
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There are many operators in XLA autocast that have been commented, but these operators are casted in the GPU, in order to maintain consistency, we need to support these operators as well. For cross_entropy_loss, it is currently commented in the xla autocast, so there will be no casting occuring, and it will execute based on its input’s dtype.
The output type is bf16, which is expected because linear layer is specified in xla autocast. loss dtype is fp32, which is correct, but there’s a catch, there was no autocasting done for the crossEntropyLoss, the reason the dtype is FP32 is because of the target’s dtype, which is FP32. There is a multiplication which happens in crossentropyloss between the generated output and the target, all the exponentiation/log etc. is done in BF16, but only because of the final multiplication, we get the result in FP32, because it casts to the higher precision (FP32). This is not the expected behavior, all the exponentiation/logs i.e. all ops related to crossentropyloss should be executed in FP32, the reason it is not happening is because crossentropyloss is not specified in xla autocast. This finding is based after detailed analysis of the HLO outputs which is attached below.
Before this change:
Exp1.
HLO:
Exp2.
The target dtype if also bf16 in this case. This experiment was done to prove that the dtype of the target was the true cause of the FP32 output as shown below.
Code change from previous experiment:
target = torch.randn(16, 10).to(torch.bfloat16).to(device)
HLO
After uncommenting the CrossEntropyLoss in the XLA autocast as done in this PR:
Exp3:
The input and target are in FP32, so the output of the linear layer will be in BF16, and then it should be upcasted to FP32 for the Crossentropyloss as seen in the HLO.
The executions of the operators are as expected.