Skip to content

Conversation

@rajkthakur
Copy link

@rajkthakur rajkthakur commented Nov 4, 2025

Commit 2a9138a removed .use_opmathtype_for_compute() from element-wise 'mul' operation, this breaks mixed-precision accumulation behavior expected by the Neuron compiler that traces/compile on CPU and later execute the binary on neuron hardwares, causing significant accuracy degradation in:

  • Llama 3.1 70B models (16.7% throughput drop, accuracy failures)
  • Mixtral 8x22B models (accuracy test failures)
  • Other transformer models using mixed-precision compilation

Reverts: commit 2a9138a, other changes are result of rebase from r2.9
Fixes: Model accuracy failures with mixed-precision accumulation #9699

@rajkthakur
Copy link
Author

Is the torchax test failure expected?

Copy link
Collaborator

@ysiraichi ysiraichi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • The TorchAX CI could be fixed by this change
  • Instead of reverting the commit, I think it's better to make it backend specific, and land it on master and cherry-pick it to r2.9?

What do you think?

@rajkthakur
Copy link
Author

rajkthakur commented Nov 6, 2025

While I understand the suggestion from #8545 to make this backend-specific, I believe a full revert is more appropriate atleast for release branch r2.9:

  1. The change affects the fundamental numerical behavior of torch.mul, which is called extensively throughout any model. Even small precision differences compound in:
    • Attention mechanisms with large sequence lengths
    • Gradient accumulation/reduction over many steps and likely XLA based compilers need to explicitly handle this scenario.

  2. Risk vs. benefit:
    • Risk: Our testing shows concrete accuracy regressions and are blocking for production workloads(mixtral, llama) on current compiler, which I guess might impact other users too.
    • Benefit: The original issue [Q][GPU][BF16] torch.mul is lowered to HLO as an f32 multiply #8545 was a question about the seeming unnecessary upcast/downcast that also appears in PyTorch CUDA. The upcast/downcast can be already removed by using PyTorch autocast so there's no need to fix [Q][GPU][BF16] torch.mul is lowered to HLO as an f32 multiply #8545 for r2.9. We can keep this in master branch for further investigation.

Let me know your thoughts?

@ysiraichi
Copy link
Collaborator

ysiraichi commented Nov 6, 2025

Thank you for your detailed analysis.

I'm sorry, but I still think that the best way to go about this is to make it backend specific. Actually, now that I'm thinking about it, I think it would make more sense to do it in the lowering step.

That said, since we don't really want to introduce a lot of changes here, and the infrastructure is already in place, I would say that we should make that use_opmathtype_for_compute() call backend specific.

It's a simple change (something like this condition for TPU) that fixes this problem for the Neuron backend, while leaving the other backends with the "backend-independent" choice of not doing that.

  1. The original issue [Q][GPU][BF16] torch.mul is lowered to HLO as an f32 multiply #8545 was a question about the seeming unnecessary upcast/downcast that also appears in PyTorch CUDA. The upcast/downcast can be already removed by using PyTorch autocast so there's no need to fix [Q][GPU][BF16] torch.mul is lowered to HLO as an f32 multiply #8545 for r2.9. We can keep this in master branch for further investigation.

That "seeming unnecessary upcast/downcast" is exactly what we are talking about, here. And, no, PyTorch autocast does not solve this issue (see this comment).

@rajkthakur rajkthakur changed the title Revert - "mul: remove opmath cast sequence (#9663)" mul: add opmath cast sequence for Neuron/CPU Nov 8, 2025
@jeffhataws jeffhataws requested review from qihqi and ysiraichi November 8, 2025 05:28
Copy link
Collaborator

@ysiraichi ysiraichi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR.
Let me know if you have any questions.

@rajkthakur rajkthakur changed the title mul: add opmath cast sequence for Neuron/CPU Revert "mul: remove opmath cast sequence (pytorch#9663)" Nov 10, 2025
@ysiraichi
Copy link
Collaborator

Could you rebase this PR, so that there are only your commits?

@rajkthakur
Copy link
Author

Could you rebase this PR, so that there are only your commits?

updated

Copy link
Collaborator

@ysiraichi ysiraichi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
Let's just wait for the CI.

@rajkthakur
Copy link
Author

@ysiraichi CI is completed.

@jeffhataws jeffhataws merged commit 66f8859 into pytorch:r2.9 Nov 12, 2025
24 checks passed
@jeffhataws
Copy link
Collaborator

Thanks @rajkthakur @ysiraichi . It is merged now and ready to go. @bhavya01 will make a final 2.9 release candidate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants