Skip to content

Commit ddfe243

Browse files
committed
Add opmath cast sequence for CPU or Neuron
1 parent 611a5cc commit ddfe243

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2529,13 +2529,26 @@ at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output,
25292529
at::Tensor XLANativeFunctions::mul(const at::Tensor& self,
25302530
const at::Tensor& other) {
25312531
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
2532-
using FnType = XLATensorPtr(const XLATensorPtr&, const XLATensorPtr&,
2533-
std::optional<at::ScalarType>);
2534-
return OpConfig::From(static_cast<FnType*>(tensor_methods::mul))
2535-
.add_input(self)
2536-
.add_input(other)
2537-
.cast_inputs_to_common_dtype()
2538-
.run();
2532+
2533+
// Check device type to determine if we need opmathtype for mixed-precision
2534+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
2535+
XlaDeviceType hw_type =
2536+
static_cast<XlaDeviceType>(xla_self->GetDevice().type());
2537+
2538+
auto config =
2539+
OpConfig([](const XLAInputVector& inputs, at::ScalarType dtype) {
2540+
return tensor_methods::mul(inputs[0], inputs[1], dtype);
2541+
})
2542+
.add_input(self)
2543+
.add_input(other)
2544+
.cast_inputs_to_common_dtype();
2545+
2546+
// Only use opmathtype for CPU or Neuron backend
2547+
if (hw_type == XlaDeviceType::CPU || hw_type == XlaDeviceType::NEURON) {
2548+
config.use_opmathtype_for_compute();
2549+
}
2550+
2551+
return config.run();
25392552
}
25402553

25412554
at::Tensor XLANativeFunctions::mul(const at::Tensor& self,

0 commit comments

Comments
 (0)