Skip to content

Commit

Permalink
Fix math tests for armv8 (#178)
Browse files Browse the repository at this point in the history
We only use Sleef on !x86 platforms. Sleef APIs are not fully agnostic
of the underlying architecture. For example, `Sleef_sinf8_u10` does not
exist on Arm.

This PR, makes the `MathToVecLibPass` aware of the
CPU SIMD architecture by accepting `cpu_features` as new optional
argument.

No change is expected on x86 side.
  • Loading branch information
digantdesai authored and int3 committed Dec 6, 2024
1 parent b02e487 commit 273f93a
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 43 deletions.
18 changes: 18 additions & 0 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,24 @@ void init_triton_llvm(py::module &&m) {
if (f.second)
res.insert(f.first().str());
}

// Likely something went wrong with the LLVM feature detection.
if (!res.size()) {
std::string triple = llvm::sys::getProcessTriple();
// e.g. arm64-apple-darwin24.1.0
// ^^^^^
std::size_t pos = triple.find('-');
if (pos == std::string::npos) {
return res;
}

std::string arch = triple.substr(0, pos);
if (arch == "aarch64" || arch == "arm64") {
// Safe because NEON is a mandatory feature for aarch64.
res.insert("neon"); // For math tests
}
}

return res;
});
}
Expand Down
23 changes: 20 additions & 3 deletions python/test/unit/cpu/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,23 @@

import triton
import triton.language as tl
from triton._C.libtriton import llvm
from triton.language.extra import libdevice
from itertools import chain, product


def get_native_vector_size_in_bits():
"""
Returns the native vector size of the CPU.
Assuming x86 always uses "auto dispatch" with 512-bit vectors for Sleef.
"""
cpu_features = llvm.get_cpu_features()
# TODO support for arm sve w/ VLA
if "neon" in cpu_features:
return 128
return 512


def is_interpreter():
return os.environ.get('TRITON_INTERPRET', '0') == '1'

Expand All @@ -34,9 +47,13 @@ def check_num_vec_calls(meta, vec_lib, dtype_str, size, is_always_extern=False):
# FP16 and BF16 are cast to FP32 for math ops
elem_size = 8 if dtype_str == "float64" else 4
data_size = size * elem_size
if data_size > 64:
num_vec_calls = data_size // 64
elif data_size >= 16:

vec_size = get_native_vector_size_in_bits() / 8 # bytes
# 128-bit vector is the smallest supported by Sleef for both x86 and arm
smallest_vec_size = 128 / 8 # bytes
if data_size > vec_size:
num_vec_calls = data_size // vec_size
elif data_size >= smallest_vec_size:
num_vec_calls = 1
else:
num_vec_calls = 1 if is_always_extern else 0
Expand Down
5 changes: 3 additions & 2 deletions third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def make_tttcir(self, mod, metadata, opt):
pm.enable_debug()
cpu.passes.ttcpuir.add_optimize_masks(pm)
passes.common.add_canonicalizer(pm)
convert_bf16_dot_product = self.cpu_arch == "aarch64" and 'fp-armv8' in self.cpu_features and 'neon' in self.cpu_features
convert_bf16_dot_product = ((self.cpu_arch == "aarch64" or self.cpu_arch == "armv8")
and 'fp-armv8' in self.cpu_features and 'neon' in self.cpu_features)
if convert_bf16_dot_product:
use_horizontal_sum = os.getenv("TRITON_CPU_DOT_PROD_HORIZ_SUM", "1") == "1"
cpu.passes.ttcpuir.add_convert_dot_product(pm, use_horizontal_sum)
Expand Down Expand Up @@ -215,7 +216,7 @@ def make_llir(self, src, metadata, options):
VecLib.libmvec: {"avx512f"},
}
if (vec_lib := options.get_vec_lib()) and vec_lib_requirements[vec_lib] & self.cpu_features:
cpu.passes.ttcpuir.add_math_to_vec_lib(pm, vec_lib)
cpu.passes.ttcpuir.add_math_to_vec_lib(pm, vec_lib, self.cpu_features)

passes.convert.add_math_to_llvmir(pm)
cpu.passes.ttcpuir.add_math_to_libm(pm)
Expand Down
3 changes: 2 additions & 1 deletion third_party/cpu/include/TritonCPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ std::unique_ptr<OperationPass<triton::FuncOp>> createLowerMultiReductionPass();
std::unique_ptr<OperationPass<ModuleOp>> createAtomicOpsToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>> createDebugOpsToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>>
createMathToVecLibPass(VecLib lib = VecLib::Sleef);
createMathToVecLibPass(VecLib lib = VecLib::Sleef,
std::set<std::string> cpu_features = {});

#define GEN_PASS_REGISTRATION
#include "cpu/include/TritonCPUToLLVM/Passes.h.inc"
Expand Down
115 changes: 80 additions & 35 deletions third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,17 @@ template <typename OpT> struct VecOpToFp32 : public OpRewritePattern<OpT> {
};

// Decompose vector operation to single-dimensional vector operations
// with a native AVX512 vector size.
// with a AVX512 for x86 or NEON for ARM.
template <typename OpT>
struct DecomposeToNativeVecs : public OpRewritePattern<OpT> {
public:
using OpRewritePattern<OpT>::OpRewritePattern;
// CPU SIMD vector size in bits
size_t vec_bits;

DecomposeToNativeVecs(MLIRContext *context)
: OpRewritePattern<OpT>(context) {}
DecomposeToNativeVecs(MLIRContext *context,
size_t native_vec_size_in_bits = 512)
: OpRewritePattern<OpT>(context), vec_bits(native_vec_size_in_bits) {}

LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Expand All @@ -83,7 +86,7 @@ struct DecomposeToNativeVecs : public OpRewritePattern<OpT> {
// vector size.
auto shape = vecTy.getShape();
SmallVector<int64_t> newShape(1, 1);
int64_t elemsPerVec = 512 / elemTy.getIntOrFloatBitWidth();
int64_t elemsPerVec = vec_bits / elemTy.getIntOrFloatBitWidth();
for (int64_t i = shape.size() - 1; i >= 0; --i) {
int64_t size = shape[i];
if (newShape.size() > 1) {
Expand Down Expand Up @@ -330,18 +333,39 @@ struct ExternElementwiseOpConversion

template <typename OpTy>
void populatePatternsForOp(RewritePatternSet &patterns,
GetVecFnNameFn getVecFnName) {
GetVecFnNameFn getVecFnName,
size_t vec_size_in_bits = 512) {
patterns.add<VecOpToFp32<OpTy>>(patterns.getContext());
patterns.add<DecomposeToNativeVecs<OpTy>>(patterns.getContext());
patterns.add<DecomposeToNativeVecs<OpTy>>(patterns.getContext(),
vec_size_in_bits);
patterns.add<VecOpToVecLibConversion<OpTy>>(patterns.getContext(),
getVecFnName);
}

struct MathToVecLibPass
: public mlir::triton::cpu::impl::MathToVecLibBase<MathToVecLibPass> {
MathToVecLibPass() = default;
size_t vec_size_in_bits;

explicit MathToVecLibPass(VecLib lib) { this->lib = lib; }
explicit MathToVecLibPass(VecLib lib, std::set<std::string> cpu_features) {
this->lib = lib;
update_vec_size(cpu_features);
}

void update_vec_size(std::set<std::string> &cpu_features) {
// TODO:
// Refactor this as an independent function.
// And improve this to support other x86 SIMD ISAs and also for arm SVE
// (VLA)
vec_size_in_bits = 512;
for (auto feature : cpu_features) {
// Arm NEON is fixed 128-bit SIMD ISA.
if (feature == "neon") {
vec_size_in_bits = 128;
break;
}
}
}

void runOnOperation() override {
Operation *op = getOperation();
Expand All @@ -356,20 +380,20 @@ struct MathToVecLibPass
}
case VecLib::Sleef: {
populateCommonPatterns<SleefNameGenerator>(patterns);
populatePatternsForOp<math::ExpM1Op>(patterns,
SleefNameGenerator("expm1"));
populatePatternsForOp<math::ExpM1Op>(
patterns, SleefNameGenerator("expm1"), vec_size_in_bits);
populatePatternsForOp<math::FloorOp>(
patterns, SleefNameGenerator("floor", /*ulp=*/0));
patterns, SleefNameGenerator("floor", /*ulp=*/0), vec_size_in_bits);
populatePatternsForOp<math::SqrtOp>(
patterns, SleefNameGenerator("sqrt", /*ulp=*/5));
patterns, SleefNameGenerator("sqrt", /*ulp=*/5), vec_size_in_bits);
populatePatternsForOp<math::TruncOp>(
patterns, SleefNameGenerator("trunc", /*ulp=*/0));
patterns, SleefNameGenerator("trunc", /*ulp=*/0), vec_size_in_bits);
break;
}
}

patterns.add<DecomposeToNativeVecs<ExternElementwiseOp>>(
patterns.getContext());
patterns.getContext(), vec_size_in_bits);
patterns.add<PadSmallVecsForSleef>(patterns.getContext());
patterns.add<ExternElementwiseOpConversion>(patterns.getContext());

Expand All @@ -379,26 +403,46 @@ struct MathToVecLibPass

template <typename VecFnNameGenerator>
void populateCommonPatterns(RewritePatternSet &patterns) const {
populatePatternsForOp<math::AcosOp>(patterns, VecFnNameGenerator("acos"));
populatePatternsForOp<math::AcoshOp>(patterns, VecFnNameGenerator("acosh"));
populatePatternsForOp<math::AsinOp>(patterns, VecFnNameGenerator("asin"));
populatePatternsForOp<math::AsinhOp>(patterns, VecFnNameGenerator("asinh"));
populatePatternsForOp<math::AtanOp>(patterns, VecFnNameGenerator("atan"));
populatePatternsForOp<math::AtanhOp>(patterns, VecFnNameGenerator("atanh"));
populatePatternsForOp<math::CbrtOp>(patterns, VecFnNameGenerator("cbrt"));
populatePatternsForOp<math::CosOp>(patterns, VecFnNameGenerator("cos"));
populatePatternsForOp<math::CoshOp>(patterns, VecFnNameGenerator("cosh"));
populatePatternsForOp<math::ErfOp>(patterns, VecFnNameGenerator("erf"));
populatePatternsForOp<math::ExpOp>(patterns, VecFnNameGenerator("exp"));
populatePatternsForOp<math::Exp2Op>(patterns, VecFnNameGenerator("exp2"));
populatePatternsForOp<math::LogOp>(patterns, VecFnNameGenerator("log"));
populatePatternsForOp<math::Log2Op>(patterns, VecFnNameGenerator("log2"));
populatePatternsForOp<math::Log10Op>(patterns, VecFnNameGenerator("log10"));
populatePatternsForOp<math::Log1pOp>(patterns, VecFnNameGenerator("log1p"));
populatePatternsForOp<math::SinOp>(patterns, VecFnNameGenerator("sin"));
populatePatternsForOp<math::SinhOp>(patterns, VecFnNameGenerator("sinh"));
populatePatternsForOp<math::TanOp>(patterns, VecFnNameGenerator("tan"));
populatePatternsForOp<math::TanhOp>(patterns, VecFnNameGenerator("tanh"));
populatePatternsForOp<math::AcosOp>(patterns, VecFnNameGenerator("acos"),
vec_size_in_bits);
populatePatternsForOp<math::AcoshOp>(patterns, VecFnNameGenerator("acosh"),
vec_size_in_bits);
populatePatternsForOp<math::AsinOp>(patterns, VecFnNameGenerator("asin"),
vec_size_in_bits);
populatePatternsForOp<math::AsinhOp>(patterns, VecFnNameGenerator("asinh"),
vec_size_in_bits);
populatePatternsForOp<math::AtanOp>(patterns, VecFnNameGenerator("atan"),
vec_size_in_bits);
populatePatternsForOp<math::AtanhOp>(patterns, VecFnNameGenerator("atanh"),
vec_size_in_bits);
populatePatternsForOp<math::CbrtOp>(patterns, VecFnNameGenerator("cbrt"),
vec_size_in_bits);
populatePatternsForOp<math::CosOp>(patterns, VecFnNameGenerator("cos"),
vec_size_in_bits);
populatePatternsForOp<math::CoshOp>(patterns, VecFnNameGenerator("cosh"),
vec_size_in_bits);
populatePatternsForOp<math::ErfOp>(patterns, VecFnNameGenerator("erf"),
vec_size_in_bits);
populatePatternsForOp<math::ExpOp>(patterns, VecFnNameGenerator("exp"),
vec_size_in_bits);
populatePatternsForOp<math::Exp2Op>(patterns, VecFnNameGenerator("exp2"),
vec_size_in_bits);
populatePatternsForOp<math::LogOp>(patterns, VecFnNameGenerator("log"),
vec_size_in_bits);
populatePatternsForOp<math::Log2Op>(patterns, VecFnNameGenerator("log2"),
vec_size_in_bits);
populatePatternsForOp<math::Log10Op>(patterns, VecFnNameGenerator("log10"),
vec_size_in_bits);
populatePatternsForOp<math::Log1pOp>(patterns, VecFnNameGenerator("log1p"),
vec_size_in_bits);
populatePatternsForOp<math::SinOp>(patterns, VecFnNameGenerator("sin"),
vec_size_in_bits);
populatePatternsForOp<math::SinhOp>(patterns, VecFnNameGenerator("sinh"),
vec_size_in_bits);
populatePatternsForOp<math::TanOp>(patterns, VecFnNameGenerator("tan"),
vec_size_in_bits);
populatePatternsForOp<math::TanhOp>(patterns, VecFnNameGenerator("tanh"),
vec_size_in_bits);
}
};

Expand All @@ -408,8 +452,9 @@ namespace mlir {
namespace triton {
namespace cpu {

std::unique_ptr<OperationPass<ModuleOp>> createMathToVecLibPass(VecLib lib) {
return std::make_unique<MathToVecLibPass>(lib);
std::unique_ptr<OperationPass<ModuleOp>>
createMathToVecLibPass(VecLib lib, std::set<std::string> cpu_features) {
return std::make_unique<MathToVecLibPass>(lib, cpu_features);
}

} // namespace cpu
Expand Down
5 changes: 3 additions & 2 deletions third_party/cpu/triton_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) {
m.def("add_memref_to_llvmir", [](mlir::PassManager &pm) {
pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass());
});
m.def("add_math_to_vec_lib", [](mlir::PassManager &pm, cpu::VecLib lib) {
pm.addPass(mlir::triton::cpu::createMathToVecLibPass(lib));
m.def("add_math_to_vec_lib", [](mlir::PassManager &pm, cpu::VecLib lib,
std::set<std::string> cpu_features) {
pm.addPass(mlir::triton::cpu::createMathToVecLibPass(lib, cpu_features));
});
m.def("add_math_to_libm", [](mlir::PassManager &pm) {
pm.addPass(mlir::createConvertMathToLibmPass());
Expand Down

0 comments on commit 273f93a

Please sign in to comment.