Skip to content

Commit 66f8859

Browse files
authored
Revert "mul: remove opmath cast sequence (#9663)" (#9701)
Commit 2a9138a removed `.use_opmathtype_for_compute()` from element-wise 'mul' operation, this breaks mixed-precision accumulation behavior expected by the Neuron compiler that traces/compile on CPU and later execute the binary on neuron hardwares, causing accuracy degradation transformer models using mixed-precision compilation Reverts: commit 2a9138a, other changes are result of rebase from r2.9 Fixes: Model accuracy failures with mixed-precision accumulation #9699
1 parent be33668 commit 66f8859

File tree

2 files changed

+1
-16
lines changed

2 files changed

+1
-16
lines changed

test/test_operations_hlo.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,6 @@ def test_dropout_by_u8_mask(self):
6767
hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([b])
6868
assert 'u8' in hlo_text
6969

70-
def test_bfloat16_mul_not_upcast(self):
71-
a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla')
72-
b = torch.rand(5, 5, dtype=torch.bfloat16).to('xla')
73-
c = a * b
74-
hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([c])
75-
# Check that the output is not upcasted to float32
76-
assert 'f32' not in hlo_text
77-
78-
def test_bfloat16_float32_mul_upcast(self):
79-
a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla')
80-
b = torch.rand(5, 5, dtype=torch.float32).to('xla')
81-
c = a * b
82-
hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([c])
83-
# Check that the output is upcasted to float32
84-
assert 'f32' in hlo_text
85-
8670

8771
if __name__ == '__main__':
8872
torch.set_default_dtype(torch.float32)

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2535,6 +2535,7 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self,
25352535
.add_input(self)
25362536
.add_input(other)
25372537
.cast_inputs_to_common_dtype()
2538+
.use_opmathtype_for_compute()
25382539
.run();
25392540
}
25402541

0 commit comments

Comments
 (0)