Skip to content

Commit

Permalink
Use fast math function for tl.math.log as exp (#4723)
Browse files Browse the repository at this point in the history
We were using precise log op by mistake.
To get high precision user can use libdevice directly. Also clean up
special case for math.exp
  • Loading branch information
ThomasRaoux authored Sep 13, 2024
1 parent df26ec6 commit 84fe9da
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 41 deletions.
3 changes: 1 addition & 2 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,7 @@ struct ElementwiseOpConversion
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
return {rewriter.create<DestOp>(loc, elemTy, operands[0],
adaptor.getAttributes().getValue())};
return {rewriter.create<DestOp>(loc, elemTy, operands[0], op->getAttrs())};
}
};

Expand Down
19 changes: 12 additions & 7 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <pybind11/functional.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Expand Down Expand Up @@ -135,6 +135,11 @@ void outputWarning(Location loc, const std::string &msg) {
/*stack_level=*/2);
}

template <typename OpTy> OpTy approxMath(OpTy op) {
op.setFastmath(arith::FastMathFlags::afn);
return op;
}

} // anonymous namespace

/*****************************************************************************/
Expand Down Expand Up @@ -1447,27 +1452,27 @@ void init_triton_ir(py::module &&m) {
})
.def("create_exp",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::ExpOp>(val);
return approxMath(self.create<math::ExpOp>(val));
})
.def("create_exp2",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::Exp2Op>(val);
return approxMath(self.create<math::Exp2Op>(val));
})
.def("create_cos",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::CosOp>(val);
return approxMath(self.create<math::CosOp>(val));
})
.def("create_sin",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::SinOp>(val);
return approxMath(self.create<math::SinOp>(val));
})
.def("create_log",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::LogOp>(val);
return approxMath(self.create<math::LogOp>(val));
})
.def("create_log2",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::Log2Op>(val);
return approxMath(self.create<math::Log2Op>(val));
})
.def("create_erf",
[](TritonOpBuilder &self, Value &val) -> Value {
Expand Down
7 changes: 6 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4376,9 +4376,14 @@ def kernel(X, Y, BLOCK: tl.constexpr):
x = torch.max(x, torch.tensor(1e-6, dtype=torch.float32, device=device))
y = torch.zeros(shape, dtype=torch.float32, device=device)

kernel[(1, )](x, y, BLOCK=shape[0])
k = kernel[(1, )](x, y, BLOCK=shape[0])
torch.allclose(getattr(torch, func_str)(x), y, rtol=1e-3)

if func_str in ['log', 'log2'] and is_cuda():
assert 'lg2.approx.ftz.f32' in k.asm['ptx']
if func_str in ['exp', 'exp2'] and is_cuda():
assert 'ex2.approx.ftz.f32' in k.asm['ptx']


# -----------------------
# test inline asm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -755,32 +755,6 @@ struct TruncFOpConversion
}
};

struct ExpOpConversionApprox
: ElementwiseOpConversionBase<math::ExpOp, ExpOpConversionApprox> {
using Base = ElementwiseOpConversionBase<math::ExpOp, ExpOpConversionApprox>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

SmallVector<Value> createDestOps(math::ExpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
// For non-FP32 input, call __nv_expf for higher-precision calculation
if (elemTy.getIntOrFloatBitWidth() != 32)
return {};

const double log2e = 1.4426950408889634;
Value prod = fmul(f32_ty, operands[0][0], f32_val(log2e));

PTXBuilder ptxBuilder;
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
auto output = ptxBuilder.newOperand("=f");
auto input = ptxBuilder.newOperand(prod, "f");
exp2(output, input);
return {ptxBuilder.launch(rewriter, loc, f32_ty, false)};
}
};

struct ClampFOpConversion
: ElementwiseOpConversionBase<ClampFOp, ClampFOpConversion> {
using Base = ElementwiseOpConversionBase<ClampFOp, ClampFOpConversion>;
Expand Down Expand Up @@ -951,11 +925,6 @@ void mlir::triton::NVIDIA::populateElementwiseOpToLLVMPatterns(
patterns.add<FpToFpOpConversion>(typeConverter, axisInfoAnalysis,
computeCapability, benefit);

// ExpOpConversionApprox will try using ex2.approx if the input type is
// FP32. For other input types, ExpOpConversionApprox will return failure and
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
// __nv_expf for higher-precision calculation
patterns.add<ExpOpConversionApprox>(typeConverter, axisInfoAnalysis, benefit);
bool hwNanPropagationSupported = computeCapability >= 80;
mlir::triton::populateMinMaxFOpToLLVMPattern(
typeConverter, patterns, axisInfoAnalysis, hwNanPropagationSupported,
Expand Down

0 comments on commit 84fe9da

Please sign in to comment.