From dd5e1d1cf290f946eaaa99043474cf4087aad27f Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Fri, 10 May 2024 06:02:20 +0000 Subject: [PATCH] Fix F16 -> F32 upcasting to support Pascal GPUs Based on code from openai/triton#2780 provided by @wkpark --- .../lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index a87fc936d47b..e4ad66ceaca9 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -476,6 +476,15 @@ struct FpToFpOpConversion return outVals; } + if (srcElementType.isF16() && dstElementType.isF32()) { + SmallVector outVals; + for (Value v : operands[0]) { + outVals.push_back( + convertFp16ToFp32(loc, rewriter, v)); + } + return outVals; + } + if (srcElementType.isF32() && dstElementType.isBF16()) { assert(roundingMode.has_value() && "rounding mode must be specified for fp32->bf16 conversion");