From feb95c36dd4b60913a518f1306fa1a2e83b9248a Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Thu, 12 Dec 2024 16:31:18 -0600 Subject: [PATCH] AMX lowering improvements (#194) * Improve AMX lowering to minimize loads and stores. Signed-off-by: Ilya Enkovich * Support bfloat16 in CPU matmul tutorials. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich --- .../tutorials/03-matrix-multiplication-cpu.py | 4 +- python/tutorials/cpu-blocked-matmul-fp32.py | 58 +-- .../include/TritonCPUTransforms/OptCommon.h | 11 + .../ConvertDotOp/ConvertDotToAMX.cpp | 430 +++++++++--------- 4 files changed, 261 insertions(+), 242 deletions(-) diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index f2ee03dfadc2..3b44a30bf7ad 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -156,7 +156,7 @@ import os DTYPE = getattr(torch, (os.getenv("DTYPE", "float32"))) -# Chosse block size depending on dtype. We have more register +# Choose block size depending on dtype. We have more register # capacity for bfloat16/float16 compared to float32. BLOCK_SIZE_M = 8 if DTYPE == torch.float32 else 32 BLOCK_SIZE_N = 32 @@ -306,7 +306,7 @@ def matmul_preprocess_input(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, n K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" if c is None: # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=a.dtype) + c = torch.empty((M, N), device=a.device, dtype=torch.float32) else: assert c.shape == (M, N), "Incompatible dimensions" diff --git a/python/tutorials/cpu-blocked-matmul-fp32.py b/python/tutorials/cpu-blocked-matmul-fp32.py index 0df8f41b9b2b..8f0f0ebce41a 100644 --- a/python/tutorials/cpu-blocked-matmul-fp32.py +++ b/python/tutorials/cpu-blocked-matmul-fp32.py @@ -15,10 +15,14 @@ import triton import triton.language as tl +import os -BLOCK_SIZE_M = 8 +DTYPE = os.getenv("DTYPE", "float32") +# Choose block size depending on dtype. We have more register +# capacity for bfloat16/float16 compared to float32. +BLOCK_SIZE_M = 8 if DTYPE == "float32" else 32 BLOCK_SIZE_N = 32 -BLOCK_SIZE_K = 8 +BLOCK_SIZE_K = 8 if DTYPE == "float32" else 32 GROUP_SIZE_M = 8 @@ -39,13 +43,10 @@ # used by Triton CPU backend which processes RHS block row-by-row and LHS # block column-by-column. @triton.jit -def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr, - TRANSPOSED_BLOCK_A: tl.constexpr, BLOCKED_B: tl.constexpr, +def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M, N, K, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr, BLOCKED_B: tl.constexpr, TRANSPOSED_B: tl.constexpr): - tl.static_assert(M % BLOCK_SIZE_M == 0) - tl.static_assert(N % BLOCK_SIZE_N == 0) tl.static_assert(BLOCKED_A or not TRANSPOSED_BLOCK_A) tl.static_assert(BLOCKED_B or not TRANSPOSED_B) pid = tl.program_id(axis=0) @@ -62,10 +63,10 @@ def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M: tl.constexpr, N a_out_block_m = in_block_m A_OUT_BLOCK_SIZE_M: tl.constexpr = BLOCK_SIZE_K if TRANSPOSED_BLOCK_A else BLOCK_SIZE_M A_OUT_BLOCK_SIZE_K: tl.constexpr = BLOCK_SIZE_M if TRANSPOSED_BLOCK_A else BLOCK_SIZE_K - A_OUT_BLOCKS_M: tl.constexpr = M // BLOCK_SIZE_M - A_OUT_BLOCKS_K: tl.constexpr = K // BLOCK_SIZE_K + A_OUT_BLOCKS_M = M // BLOCK_SIZE_M + A_OUT_BLOCKS_K = K // BLOCK_SIZE_K A_OUT_STRIDE_M: tl.constexpr = A_OUT_BLOCK_SIZE_K - A_OUT_STRIDE_BLOCK_M: tl.constexpr = BLOCK_SIZE_M * K + A_OUT_STRIDE_BLOCK_M = BLOCK_SIZE_M * K A_OUT_STRIDE_BLOCK_K: tl.constexpr = BLOCK_SIZE_M * BLOCK_SIZE_K for in_block_k in tl.range(in_block_n, A_OUT_BLOCKS_K, N // BLOCK_SIZE_N): a_out_block_k = in_block_k @@ -84,10 +85,10 @@ def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M: tl.constexpr, N tl.store(a_out_ptr, val) if BLOCKED_B: - B_OUT_BLOCKS_K: tl.constexpr = N // BLOCK_SIZE_N if TRANSPOSED_B else K // BLOCK_SIZE_K - B_OUT_BLOCKS_N: tl.constexpr = K // BLOCK_SIZE_K if TRANSPOSED_B else N // BLOCK_SIZE_N + B_OUT_BLOCKS_K = N // BLOCK_SIZE_N if TRANSPOSED_B else K // BLOCK_SIZE_K + B_OUT_BLOCKS_N = K // BLOCK_SIZE_K if TRANSPOSED_B else N // BLOCK_SIZE_N B_OUT_STRIDE_K: tl.constexpr = BLOCK_SIZE_N - B_OUT_STRIDE_BLOCK_K: tl.constexpr = (K * BLOCK_SIZE_N if TRANSPOSED_B else BLOCK_SIZE_K * N) + B_OUT_STRIDE_BLOCK_K = (K * BLOCK_SIZE_N if TRANSPOSED_B else BLOCK_SIZE_K * N) B_OUT_STRIDE_BLOCK_N: tl.constexpr = BLOCK_SIZE_K * BLOCK_SIZE_N for in_block_k in tl.range(in_block_m, K // BLOCK_SIZE_K, M // BLOCK_SIZE_M): b_out_block_k = in_block_n if TRANSPOSED_B else in_block_k @@ -120,11 +121,11 @@ def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M: tl.constexpr, N # Reshape is used to remove the heading (1, 1) dimensions, but CPU backend folds it with the load # operation and it doesn't prevent direct vector loads from the input memory. @triton.jit -def matmul_kernel_fma(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - # number of blocks in a group - GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr, - BLOCKED_B: tl.constexpr, TRANSPOSED_B: tl.constexpr): +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + # number of blocks in a group + GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr, + BLOCKED_B: tl.constexpr, TRANSPOSED_B: tl.constexpr): # TRANSPOSED_BLOCK_A means that each block in A is transposed. # It is allowed only for blocked input. assert (BLOCKED_A or not TRANSPOSED_BLOCK_A) @@ -188,8 +189,8 @@ def matmul_kernel_fma(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, tl.store(c_block_ptr, c) -def matmul_fma(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, bb: torch.Tensor, M, N, K, - PREPACKED, BLOCKED_A, TRANSPOSED_BLOCK_A, BLOCKED_B, TRANSPOSED_B, num_threads=0): +def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, bb: torch.Tensor, M, N, K, PREPACKED, + BLOCKED_A, TRANSPOSED_BLOCK_A, BLOCKED_B, TRANSPOSED_B, num_threads=0): #TODO: Currently masked load is not supported yet. assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and ( K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" @@ -207,7 +208,7 @@ def matmul_fma(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tens a = ab if BLOCKED_B: b = bb - matmul_kernel_fma[grid]( + matmul_kernel[grid]( a, b, c, # M, N, K, # BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # @@ -233,14 +234,14 @@ def matmul_fma(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tens rtol = 0 a_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_M) * (512 // BLOCK_SIZE_K) * 64), device='cpu', dtype=torch.float32) b_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_K) * (512 // BLOCK_SIZE_N) * 64), device='cpu', dtype=torch.float32) -triton_output = matmul_fma(a, b, c, a_tmp, b_tmp, 512, 512, 512, True, False, False, False, False, False) +triton_output = matmul(a, b, c, a_tmp, b_tmp, 512, 512, 512, True, False, False, False, False, False) if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): print("✅ TritonCPU and TorchCPU match") else: print("❌ TritonCPU and TorchCPU differ, the maximum difference is " f'{torch.max(torch.abs(triton_output - torch_output))}') assert False -triton_output = matmul_fma(a, b, c, a_tmp, b_tmp, 512, 512, 512, False, True, True, True, True, True) +triton_output = matmul(a, b, c, a_tmp, b_tmp, 512, 512, 512, False, True, True, True, True, True) if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): print("✅ TritonCPU pre-packed and TorchCPU match") else: @@ -289,7 +290,7 @@ def decode_provider(provider): BLOCK_TRANSPOSE_B_OPTS = [(True, True), (False, False)] PREPACK_OPTS = [False, True] SINGLE_THREAD_OPTS = [False] -DTYPE_OPTS = ['float32'] +DTYPE_OPTS = [DTYPE] LINE_VALS = [ encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, prepack, single_thread, dtype) for single_thread in SINGLE_THREAD_OPTS @@ -316,7 +317,7 @@ def decode_provider(provider): ylabel='GFLOPS', # Label name for the y-axis. plot_name= # Name for the plot. Used also as a file name for saving the plot. - f'matmul-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})', + f'matmul-performance-{DTYPE} (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})', args={}, # Values for function arguments not in `x_names` and `y_name`. )) def benchmark(M, N, K, provider): @@ -360,9 +361,8 @@ def benchmark(M, N, K, provider): ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles) elif backend == 'triton-cpu': ms, min_ms, max_ms = triton.testing.do_bench( - lambda: matmul_fma(a, b, c, a_tmp, b_tmp, M, N, K, prepack, blocked_a, transposed_a, blocked_b, - transposed_b, num_threads=int(single_thread)), quantiles=quantiles, - measure_time_with_hooks=True, rep=1000) + lambda: matmul(a, b, c, a_tmp, b_tmp, M, N, K, prepack, blocked_a, transposed_a, blocked_b, transposed_b, + num_threads=int(single_thread)), quantiles=quantiles, measure_time_with_hooks=True, rep=1000) perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) diff --git a/third_party/cpu/include/TritonCPUTransforms/OptCommon.h b/third_party/cpu/include/TritonCPUTransforms/OptCommon.h index a2e94f894caf..09e7ec65595d 100644 --- a/third_party/cpu/include/TritonCPUTransforms/OptCommon.h +++ b/third_party/cpu/include/TritonCPUTransforms/OptCommon.h @@ -125,12 +125,14 @@ inline Value shapeCast(Location loc, Value in, } // namespace mlir #define int_cst(ty, val) intCst(loc, ty, val, rewriter) +#define index_cst(val) rewriter.create(loc, val) #define cst_like(src, val) cstLike(loc, src, val, rewriter) #define op_addi(lhs, rhs) rewriter.create(loc, lhs, rhs) #define op_addf(lhs, rhs) rewriter.create(loc, lhs, rhs) #define op_subi(lhs, rhs) rewriter.create(loc, lhs, rhs) #define op_subf(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_muli(lhs, rhs) rewriter.create(loc, lhs, rhs) #define op_mulf(lhs, rhs) rewriter.create(loc, lhs, rhs) #define op_bitcast(ty, val) rewriter.create(loc, ty, val) #define op_lshr(lhs, rhs) rewriter.create(loc, lhs, rhs) @@ -146,6 +148,15 @@ inline Value shapeCast(Location loc, Value in, rewriter.create(loc, cond, val, other) #define op_sitofp(ty, val) rewriter.create(loc, ty, val) #define op_fptosi(ty, val) rewriter.create(loc, ty, val) +#define op_read(ty, memRef, indices) \ + rewriter.create(loc, ty, memRef, indices) +#define op_write(val, memRef, indices) \ + rewriter.create(loc, val, memRef, indices) +#define op_interleave(lhs, rhs) \ + rewriter.create(loc, lhs, rhs) +#define op_extract(vec, idx) rewriter.create(loc, vec, idx) +#define op_store(val, mem, idx) \ + rewriter.create(loc, val, mem, idx) #define op_icmp_eq(lhs, rhs) \ rewriter.create(loc, arith::CmpIPredicate::eq, lhs, rhs) diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp index 23f16944de41..1b6dd9269ac1 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp @@ -1,4 +1,4 @@ -#include "cpu/include/TritonCPUTransforms/OptCommon.h" +#include "ConvertDotCommon.h" #include "cpu/include/TritonCPUTransforms/Passes.h" @@ -24,24 +24,12 @@ namespace cpu { } // namespace triton } // namespace mlir -#define DEBUG_TYPE "triton-cpu-dot-to-amx" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") - using namespace mlir; using namespace mlir::triton; using namespace mlir::triton::cpu; namespace { -// This struct describes buffers used to load/store AMX tiles. -struct AmxBuffer { - Value memRef; - SmallVector indices; - - bool empty() const { return !memRef; } -}; - // This structure is used to hold candidates for conversion to AMX // Mul[F|I]Op operations. struct AmxDotOpCandidate { @@ -75,7 +63,7 @@ struct AmxDotOpCandidate { // If resulting tiles are not required to be trasfered to vectors and can be // directly stored to the output memory instead, then this field holds a // buffer to use. - AmxBuffer outBuf; + MemBuffer outBuf; // If output buffer is used then keep the original vector store here. Operation *origStore = nullptr; }; @@ -182,50 +170,6 @@ bool checkInputShapes(VectorType lhsTy, VectorType resTy) { return true; } -// Check if accumulator value is updated in a loop and has no other -// usages than a dot op, that updates it. Tile loads/stores and casts -// for such accumulators can be done outside of the loop. -bool isLoopCarriedAcc(Value acc) { - LDBG("Check if accumulator can be held in tiles: " << acc); - if (!acc.hasOneUse()) { - LDBG(" No. Has multiple uses."); - for (auto op : acc.getUsers()) - LDBG(" " << *op); - return false; - } - - auto blockArg = dyn_cast(acc); - if (!blockArg) { - LDBG(" No. Not a block argument."); - return false; - } - - auto forOp = dyn_cast(blockArg.getOwner()->getParentOp()); - if (!forOp) { - LDBG(" No. Not in a for-loop."); - return false; - } - - blockArg.getArgNumber(); - - Value updAcc = acc.getUsers().begin()->getResult(0); - if (!updAcc.hasOneUse()) { - LDBG(" No. Has multiple uses."); - return false; - } - - auto &updAccUse = *updAcc.getUses().begin(); - if (!isa(updAccUse.getOwner()) || - updAccUse.getOperandNumber() != - (blockArg.getArgNumber() - forOp.getNumInductionVars())) { - LDBG(" No. Loop carried dependency not detected."); - return false; - } - - LDBG(" Yes."); - return true; -} - // Return a value that holds the resulting loop carried accumulator value. // It's one of ForOp's results. Value getResValueForLoopCarriedAcc(cpu::DotOp op) { @@ -239,11 +183,11 @@ Value getResValueForLoopCarriedAcc(cpu::DotOp op) { // by input shapes and types. Block sizes are chosen to minimize number of // tile loads/stores including tile register spills. void setupBlockAndTileSizes(ArrayRef lhsShape, - ArrayRef rhsShape, + ArrayRef resShape, AmxDotOpCandidate &candidate) { - int64_t m = lhsShape[0]; - int64_t n = rhsShape[1]; - int64_t k = rhsShape[0]; + int64_t m = resShape[0]; + int64_t n = resShape[1]; + int64_t k = lhsShape[1]; int64_t tileM = std::min(m, (int64_t)16); int64_t tileN = std::min(n, (int64_t)16); int64_t tileK = std::min( @@ -288,7 +232,7 @@ void findOutputBuffer(Value val, AmxDotOpCandidate &candidate) { if (val.hasOneUse()) { auto store = dyn_cast(*val.user_begin()); if (store && !hasMaskOrBoundsCheck(store)) - candidate.outBuf = AmxBuffer{store.getSource(), store.getIndices()}; + candidate.outBuf = MemBuffer{store.getSource(), store.getIndices()}; candidate.origStore = store; } } @@ -319,15 +263,16 @@ bool isAmxCandidate(cpu::DotOp op, bool supportInt8, bool supportFp16, return false; candidate.op = op; - setupBlockAndTileSizes(lhsTy.getShape(), rhsTy.getShape(), candidate); + setupBlockAndTileSizes(lhsTy.getShape(), resTy.getShape(), candidate); candidate.keepAccOnTiles = isLoopCarriedAcc(op.getC()); // Can't keep acc in a tile the whole loop right now: // https://github.com/llvm/llvm-project/issues/109481 if (candidate.keepAccOnTiles) { - // We might not have enough tiles to hold accumulator. In this case - // keep it in a bufffer. - if (candidate.tilesInBlockM * candidate.tilesInBlockN > 1) { + // We might not have enough tiles to hold the whole accumulator. If we + // have more than one block, keep it in a bufffer. + if (candidate.tilesInBlockM * candidate.tileM < resTy.getDimSize(0) || + candidate.tilesInBlockN * candidate.tileN < resTy.getDimSize(1)) { LDBG("Accumulator is too big to keep on tiles. Keep it bufferized " "insterad."); candidate.keepAccOnTiles = false; @@ -335,14 +280,6 @@ bool isAmxCandidate(cpu::DotOp op, bool supportInt8, bool supportFp16, } else { findOutputBuffer(getResValueForLoopCarriedAcc(op), candidate); } - - // TODO: fix LLVM bug and remove this code. - LDBG("Avoid accumulator on tiles due to LLVM bug: " - "https://github.com/llvm/llvm-project/issues/109481."); - LDBG("Keep accumulator bufferized instead."); - candidate.keepAccOnTiles = false; - candidate.keepAccInBuf = true; - candidate.outBuf = AmxBuffer{}; } else { findOutputBuffer(op.getResult(), candidate); } @@ -350,35 +287,6 @@ bool isAmxCandidate(cpu::DotOp op, bool supportInt8, bool supportFp16, return true; } -// Cast vector to a specified element type using ext or trunc -// operations. Return the original value if it already matches -// the required element type. -Value maybeCast(Location loc, Value val, Type dstElemTy, - PatternRewriter &rewriter) { - VectorType srcTy = cast(val.getType()); - if (srcTy.getElementType() == dstElemTy) - return val; - - VectorType dstTy = srcTy.cloneWith(std::nullopt, dstElemTy); - if (srcTy.getElementType().isInteger()) { - if (srcTy.getElementTypeBitWidth() < dstTy.getElementTypeBitWidth()) - return rewriter.create(loc, dstTy, val); - return rewriter.create(loc, dstTy, val); - } - - if (srcTy.getElementTypeBitWidth() < dstTy.getElementTypeBitWidth()) - return rewriter.create(loc, dstTy, val); - return rewriter.create(loc, dstTy, val); -} - -// Get initial value for a loop-carried accumulator. -Value getInitAccValue(Value val) { - auto blockArg = cast(val); - auto forOp = cast(blockArg.getOwner()->getParentOp()); - int initValIdx = blockArg.getArgNumber() - forOp.getNumInductionVars(); - return forOp.getInitArgs()[initValIdx]; -} - template T getSwizzledRhsTileType(T origTileType) { int64_t rowsPerGroup = 32 / origTileType.getElementTypeBitWidth(); SmallVector shape({origTileType.getDimSize(0) / rowsPerGroup, @@ -386,18 +294,6 @@ template T getSwizzledRhsTileType(T origTileType) { return origTileType.cloneWith(shape, origTileType.getElementType()); } -AmxBuffer allocateTmpBuffer(Location loc, VectorType vecTy, - Operation *allocaPoint, PatternRewriter &rewriter) { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(allocaPoint); - auto memRefTy = MemRefType::get(vecTy.getShape(), vecTy.getElementType()); - Value memRef = rewriter.create( - loc, memRefTy, rewriter.getIntegerAttr(rewriter.getI64Type(), 64)); - Value zeroIdx = rewriter.create(loc, 0); - SmallVector indices(2, zeroIdx); - return {memRef, indices}; -} - // In AMX, element values shoud be packed to 32-bit groups that would be // multiplied elementwise with following accumulation. It means that RHS // needs to be pre-packed. E.g. for the following input @@ -418,27 +314,80 @@ void interleaveAndStore(Location loc, Value val, Value buf, int64_t rowsPerGroup = 32 / valTy.getElementTypeBitWidth(); assert(rowsPerGroup == 2 || rowsPerGroup == 4); assert(valTy.getDimSize(0) % rowsPerGroup == 0); - Value zeroIdx = rewriter.create(loc, 0); + Value zeroIdx = index_cst(0); for (int64_t i = 0; i < valTy.getDimSize(0); i += rowsPerGroup) { Value row1, row2; if (rowsPerGroup == 2) { - row1 = rewriter.create(loc, val, i); - row2 = rewriter.create(loc, val, i + 1); + row1 = op_extract(val, i); + row2 = op_extract(val, i + 1); } else { - row1 = rewriter.create( - loc, rewriter.create(loc, val, i), - rewriter.create(loc, val, i + 2)); - row2 = rewriter.create( - loc, rewriter.create(loc, val, i + 1), - rewriter.create(loc, val, i + 3)); + row1 = op_interleave(op_extract(val, i), op_extract(val, i + 2)); + row2 = op_interleave(op_extract(val, i + 1), op_extract(val, i + 3)); } - Value shuffled = rewriter.create(loc, row1, row2); - Value idx = rewriter.create(loc, i / rowsPerGroup); - rewriter.create(loc, shuffled, buf, - SmallVector({idx, zeroIdx})); + Value shuffled = op_interleave(row1, row2); + Value idx = index_cst(i / rowsPerGroup); + op_store(shuffled, buf, SmallVector({idx, zeroIdx})); } } +Value loadWithPrefetch(Location loc, VectorType ty, Value memRef, + ArrayRef indices, ArrayRef step, + PatternRewriter &rewriter) { + Value res = op_read(ty, memRef, indices); + if (!step.empty()) { + SmallVector prefetchIndices; + for (int64_t i = 0; i < indices.size(); ++i) { + prefetchIndices.push_back( + op_addi(indices[i], rewriter.create( + loc, rewriter.getIndexType(), step[i]))); + } + rewriter.create(loc, memRef, prefetchIndices, false, 1, + true); + } + return res; +} + +// Copy tensor with packing using for-loop. See interleaveAndStore for more +// details. +void copyWithInterleave(Location loc, VectorType srcTy, const MemBuffer &src, + const MemBuffer &dst, PatternRewriter &rewriter) { + int64_t rowsPerGroup = 32 / srcTy.getElementTypeBitWidth(); + Value lower = index_cst(0); + Value upper = index_cst(srcTy.getDimSize(0) / rowsPerGroup); + Value one = index_cst(1); + Value rowsPerGroupVal = index_cst(rowsPerGroup); + VectorType srcVecTy = + VectorType::get({srcTy.getDimSize(1)}, srcTy.getElementType()); + auto forOp = rewriter.create(loc, lower, upper, one); + Value ivVal = forOp.getInductionVar(); + rewriter.setInsertionPointToStart(forOp.getBody()); + SmallVector srcIndices = src.indices; + int64_t mDimIdx = srcIndices.size() - 2; + Value scaledM = op_muli(ivVal, rowsPerGroupVal); + srcIndices[mDimIdx] = op_addi(srcIndices[mDimIdx], scaledM); + Value row1 = loadWithPrefetch(loc, srcVecTy, src.memRef, srcIndices, src.step, + rewriter); + srcIndices[mDimIdx] = op_addi(srcIndices[mDimIdx], one); + Value row2 = loadWithPrefetch(loc, srcVecTy, src.memRef, srcIndices, src.step, + rewriter); + if (rowsPerGroup == 4) { + srcIndices[mDimIdx] = op_addi(srcIndices[mDimIdx], one); + Value row3 = loadWithPrefetch(loc, srcVecTy, src.memRef, srcIndices, + src.step, rewriter); + srcIndices[mDimIdx] = op_addi(srcIndices[mDimIdx], one); + Value row4 = loadWithPrefetch(loc, srcVecTy, src.memRef, srcIndices, + src.step, rewriter); + row1 = op_interleave(row1, row3); + row2 = op_interleave(row2, row4); + } + Value shuffled = op_interleave(row1, row2); + SmallVector dstIndices = dst.indices; + dstIndices[dstIndices.size() - 2] = + op_addi(dstIndices[dstIndices.size() - 2], ivVal); + op_write(shuffled, dst.memRef, dstIndices); + rewriter.setInsertionPointAfter(forOp); +} + // Prepare temporary buffers to be used for tile loads. If the original // value can be directly loaded to tiles from its original memory, then // use it instead. Return empty buffer if source value is all zeros and @@ -446,18 +395,25 @@ void interleaveAndStore(Location loc, Value val, Value buf, // // If interleave flag is set, then pre-pack RHS before store. See // interleaveAndStore for more details. -AmxBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, +MemBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, bool skipForZeros, bool readOnly, Operation *allocaPoint, PatternRewriter &rewriter) { LDBG("Preparing buffer (interleave=" << interleave << ") for a vector: " << val); - auto valLoad = val.getDefiningOp(); - if (valLoad && !interleave && readOnly && !hasMaskOrBoundsCheck(valLoad)) { - Value memRef = valLoad.getSource(); - ValueRange indices = valLoad.getIndices(); - LDBG(" Reusing the original memref for a buffer: " << memRef); - return {memRef, indices}; + auto vecTy = cast(val.getType()); + MemBuffer inputBuf = findInputBuffer(val); + if (!inputBuf.empty()) { + if (interleave) { + LDBG(" Copying from the original memref with interleave: " + << inputBuf.memRef); + auto tmpBuf = allocateTmpBuffer(loc, getSwizzledRhsTileType(vecTy), + allocaPoint, rewriter); + copyWithInterleave(loc, vecTy, inputBuf, tmpBuf, rewriter); + return tmpBuf; + } + LDBG(" Reusing the original memref for a buffer: " << inputBuf.memRef); + return inputBuf; } if (skipForZeros && isZeroConst(val)) { @@ -465,15 +421,14 @@ AmxBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, return {}; } - auto vecTy = cast(val.getType()); if (interleave) vecTy = getSwizzledRhsTileType(vecTy); - AmxBuffer buf = allocateTmpBuffer(loc, vecTy, allocaPoint, rewriter); + MemBuffer buf = allocateTmpBuffer(loc, vecTy, allocaPoint, rewriter); if (interleave) { interleaveAndStore(loc, val, buf.memRef, rewriter); } else { - rewriter.create(loc, val, buf.memRef, buf.indices); + op_write(val, buf.memRef, buf.indices); } return buf; @@ -482,8 +437,8 @@ AmxBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, // Return a buffer where the final result should be stored. If result can // be directly stored to the output memory, then it is used as an output // buffer. Otherwise, re-use accumulator buffer or create a new one. -AmxBuffer prepareResultBuffer(Location loc, Value val, const AmxBuffer &accBuf, - const AmxBuffer &outBuf, Operation *allocaPoint, +MemBuffer prepareResultBuffer(Location loc, Value val, const MemBuffer &accBuf, + const MemBuffer &outBuf, Operation *allocaPoint, PatternRewriter &rewriter) { if (!outBuf.empty()) { LDBG("Output memory will be used for direct tile stores."); @@ -500,37 +455,22 @@ AmxBuffer prepareResultBuffer(Location loc, Value val, const AmxBuffer &accBuf, rewriter); } -Value shiftIndex(Location loc, Value index, int64_t offs, - PatternRewriter &rewriter) { - if (!offs) - return index; - - // Do constant folding right away here for better code readability - // after the pass. - auto cstOp = dyn_cast(index.getDefiningOp()); - if (cstOp) { - int64_t oldVal = cast(cstOp.getValue()).getInt(); - return rewriter.create(loc, oldVal + offs); - } - - Value offsVal = rewriter.create(loc, offs); - return rewriter.create(loc, index.getType(), index, offsVal); -} - -SmallVector shiftIndices(Location loc, ArrayRef indices, - amx::TileType tileTy, int64_t tilesInBlockM, - int64_t tilesInBlockN, int64_t blockM, - int64_t blockN, int64_t tileM, int64_t tileN, - PatternRewriter &rewriter) { +SmallVector shiftIndices(Location loc, ArrayRef indices, + amx::TileType tileTy, int64_t tilesInBlockM, + int64_t tilesInBlockN, int64_t blockM, + int64_t blockN, int64_t tileM, int64_t tileN, + PatternRewriter &rewriter) { int64_t blockOffsM = blockM * tilesInBlockM * tileTy.getDimSize(0); int64_t blockOffsN = blockN * tilesInBlockN * tileTy.getDimSize(1); int64_t tileOffsM = blockOffsM + tileM * tileTy.getDimSize(0); int64_t tileOffsN = blockOffsN + tileN * tileTy.getDimSize(1); - return {shiftIndex(loc, indices[0], tileOffsM, rewriter), - shiftIndex(loc, indices[1], tileOffsN, rewriter)}; + SmallVector res(indices.begin(), indices.end() - 2); + res.push_back(shiftIndex(loc, *(indices.end() - 2), tileOffsM, rewriter)); + res.push_back(shiftIndex(loc, *(indices.end() - 1), tileOffsN, rewriter)); + return res; } -Value loadTile(Location loc, amx::TileType tileTy, const AmxBuffer &buf, +Value loadTile(Location loc, amx::TileType tileTy, const MemBuffer &buf, int64_t tilesInBlockM, int64_t tilesInBlockN, int64_t blockM, int64_t blockN, int64_t tileM, int64_t tileN, PatternRewriter &rewriter) { @@ -541,7 +481,7 @@ Value loadTile(Location loc, amx::TileType tileTy, const AmxBuffer &buf, } void storeTile(Location loc, amx::TileType tileTy, Value val, - const AmxBuffer &buf, int64_t tilesInBlockM, + const MemBuffer &buf, int64_t tilesInBlockM, int64_t tilesInBlockN, int64_t blockM, int64_t blockN, int64_t tileM, int64_t tileN, PatternRewriter &rewriter) { auto indices = @@ -551,7 +491,7 @@ void storeTile(Location loc, amx::TileType tileTy, Value val, } SmallVector> -loadBlockTiles(Location loc, amx::TileType tileTy, const AmxBuffer &buf, +loadBlockTiles(Location loc, amx::TileType tileTy, const MemBuffer &buf, int64_t tilesInBlockM, int64_t tilesInBlockN, int64_t blockM, int64_t blockN, PatternRewriter &rewriter) { SmallVector> res(tilesInBlockM); @@ -567,22 +507,18 @@ loadBlockTiles(Location loc, amx::TileType tileTy, const AmxBuffer &buf, return res; } -// Move acc to a tile for the whole loop. It might be loads from memory or -// zero tiles. -SmallVector> -moveLoopAccToTiles(Location loc, amx::TileType tileTy, const AmxBuffer &buf, - int64_t tilesInBlockM, int64_t tilesInBlockN, - PatternRewriter &rewriter) { - LDBG("Loading accumulator to tiles before the loop."); - auto res = loadBlockTiles(loc, tileTy, buf, tilesInBlockM, tilesInBlockN, 0, - 0, rewriter); - - // TODO: add new block args into ForOp and return them instead. - // Yield directly uses them for now and will be patched after mul - // ops generation. - llvm_unreachable("Not yet supported."); - - return res; +void storeBlockTiles(Location loc, amx::TileType tileTy, const MemBuffer &buf, + int64_t blockM, int64_t blockN, + const SmallVector> &tiles, + PatternRewriter &rewriter) { + int64_t tilesInBlockM = tiles.size(); + int64_t tilesInBlockN = tiles[0].size(); + for (int64_t m = 0; m < tilesInBlockM; ++m) { + for (int64_t n = 0; n < tilesInBlockN; ++n) { + storeTile(loc, tileTy, tiles[m][n], buf, tilesInBlockM, tilesInBlockN, + blockM, blockN, m, n, rewriter); + } + } } // Multiply two blocks. LHS block is preloaded to tiles with the following @@ -590,8 +526,8 @@ moveLoopAccToTiles(Location loc, amx::TileType tileTy, const AmxBuffer &buf, // Optionally, results can also be stored to accBuf. void multiplyBlocksPreloadLhs(Location loc, amx::TileType lhsTileTy, amx::TileType rhsTileTy, amx::TileType accTileTy, - const AmxBuffer &lhsBuf, const AmxBuffer &rhsBuf, - const AmxBuffer &accBuf, int64_t blockM, + const MemBuffer &lhsBuf, const MemBuffer &rhsBuf, + const MemBuffer &accBuf, int64_t blockM, int64_t blockN, int64_t blockK, int64_t tilesInBlockM, int64_t tilesInBlockN, SmallVector> &accTiles, @@ -626,8 +562,8 @@ void multiplyBlocksPreloadLhs(Location loc, amx::TileType lhsTileTy, // Similar to multiplyBlocksPreloadLhs but here RHS is preloaded to tiles. void multiplyBlocksPreloadRhs(Location loc, amx::TileType lhsTileTy, amx::TileType rhsTileTy, amx::TileType accTileTy, - const AmxBuffer &lhsBuf, const AmxBuffer &rhsBuf, - const AmxBuffer &accBuf, int64_t blockM, + const MemBuffer &lhsBuf, const MemBuffer &rhsBuf, + const MemBuffer &accBuf, int64_t blockM, int64_t blockN, int64_t blockK, int64_t tilesInBlockM, int64_t tilesInBlockN, SmallVector> &accTiles, @@ -691,11 +627,11 @@ LogicalResult convertCandidate(AmxDotOpCandidate &candidate, // Cast input data if required and prepare input buffer. It might be temporary // buffers with stored vectors or the original input memory. Value lhs = maybeCast(loc, op.getA(), candidate.lhsTileElemTy, rewriter); - AmxBuffer lhsBuf = + MemBuffer lhsBuf = prepareTensorBuffer(loc, lhs, false, false, true, allocaPoint, rewriter); Value rhs = maybeCast(loc, op.getB(), candidate.rhsTileElemTy, rewriter); - AmxBuffer rhsBuf = + MemBuffer rhsBuf = prepareTensorBuffer(loc, rhs, true, false, true, allocaPoint, rewriter); Value acc = maybeCast(loc, op.getC(), candidate.accTileElemTy, rewriter); @@ -705,7 +641,7 @@ LogicalResult convertCandidate(AmxDotOpCandidate &candidate, forOp = cast(op->getParentOp()); accToStore = getInitAccValue(acc); } - AmxBuffer accBuf; + MemBuffer accBuf; { // If accumulator is bufferized then we should move initial values before // the loop. @@ -717,14 +653,24 @@ LogicalResult convertCandidate(AmxDotOpCandidate &candidate, false, allocaPoint, rewriter); } - AmxBuffer resBuf = prepareResultBuffer( + MemBuffer resBuf = prepareResultBuffer( loc, op.getResult(), accBuf, candidate.outBuf, allocaPoint, rewriter); SmallVector> accTiles; - if (candidate.keepAccOnTiles) - accTiles = - moveLoopAccToTiles(loc, accTileTy, accBuf, candidate.tilesInBlockM, - candidate.tilesInBlockN, rewriter); + SmallVector> accInitTiles; + if (candidate.keepAccOnTiles) { + // Initial tile values are loaded before the loop and then directly + // used within the loop. Later, new iter values will be added to + // add loop carried-dependencies for accumulator tiles and accInitTiles + // will be used as initializers for them. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(forOp); + LDBG("Loading accumulator to tiles before the loop."); + accInitTiles = + loadBlockTiles(loc, accTileTy, accBuf, candidate.tilesInBlockM, + candidate.tilesInBlockN, 0, 0, rewriter); + accTiles = accInitTiles; + } int64_t blocksInAccM = accTy.getDimSize(0) / candidate.tileM / candidate.tilesInBlockM; @@ -743,6 +689,7 @@ LogicalResult convertCandidate(AmxDotOpCandidate &candidate, // TODO: enable forward store for acc kept in tiles. bool storeAcc = !candidate.keepAccOnTiles && (blocK == (tilesInVectorK - 1)); + // We need to choose which block (LHS or RHS) to keep on tiles. // E.g. for ACC block 4x1 tiles, LHS block is also 4 tiles, so // we would use all tile registers trying to keep both ACC and @@ -762,37 +709,98 @@ LogicalResult convertCandidate(AmxDotOpCandidate &candidate, } } - // TODO: For keepAccOnTiles fix YieldOp to use mul results. - // TODO: For keepAccOnTiles move all new forOp results to vector through a - // buffer. - if (candidate.keepAccOnTiles) - llvm_unreachable("Not yet supported."); + if (candidate.keepAccOnTiles) { + // In this case we have the whole accumulator/result on tiles. Loop + // carried dependencies are not in place yet and should be added. + // After the loop, resulting tiles should either be stored to the + // output buffer, or moved to a vector though a temporary buffer. + + // We don't need the original accumulator and contraction op anymore. + // Directly yield orig accumulator value, so it would be later removed + // as unused. The original contraction can be removed right away. + int64_t origResIdx = op.getResult().getUses().begin()->getOperandNumber(); + rewriter.replaceOp(op, op.getC()); + + // Now, replace the loop with a new one to add loop carried dependency for + // accumulator tiles. + LDBG("Rewrite loop to introduce loop carried dependencies for accumulator " + "tiles."); + SmallVector newInitOperands; + SmallVector newYieldedValues; + for (int64_t m = 0; m < candidate.tilesInBlockM; ++m) + for (int64_t n = 0; n < candidate.tilesInBlockN; ++n) { + LDBG("Initial value\n " << accInitTiles[m][n] + << "\nis combined with\n " << accTiles[m][n]); + newInitOperands.push_back(accInitTiles[m][n]); + newYieldedValues.push_back(accTiles[m][n]); + } + auto newForOp = cast(*forOp.replaceWithAdditionalYields( + rewriter, newInitOperands, true, + [&newYieldedValues](OpBuilder &b, Location loc, + ArrayRef newBBArgs) { + return newYieldedValues; + })); + + // The resulting tiles are now in the new loop results. + auto resTiles = newForOp.getResults().take_back(newYieldedValues.size()); + for (int64_t m = 0; m < candidate.tilesInBlockM; ++m) + for (int64_t n = 0; n < candidate.tilesInBlockN; ++n) { + accTiles[m][n] = resTiles[m * candidate.tilesInBlockN + n]; + } - if (candidate.keepAccInBuf) { - int resIdx = op.getResult().getUses().begin()->getOperandNumber(); - Value loopRes = forOp.getResult(resIdx); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(newForOp); + if (candidate.outBuf.empty()) { + // Move tiles to a vector through a temporary buffer and use it instead + // of the original one. + LDBG("Moving resulting tiles to a vector through memory."); + VectorType resTy = accTy.cloneWith(std::nullopt, candidate.accTileElemTy); + storeBlockTiles(loc, accTileTy, resBuf, 0, 0, accTiles, rewriter); + Value newVal = op_read(resTy, resBuf.memRef, resBuf.indices); + // We might need to cast back to the original type. + newVal = maybeCast(loc, newVal, accTy.getElementType(), rewriter); + rewriter.replaceAllUsesWith(newForOp.getResult(origResIdx), newVal); + } else { + // Store tiles directly to the output buffer and remove the original + // store. + LDBG("Storing resulting tiles to the output memory."); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(candidate.origStore); + storeBlockTiles(loc, accTileTy, candidate.outBuf, 0, 0, accTiles, + rewriter); + rewriter.eraseOp(candidate.origStore); + } + } else if (candidate.keepAccInBuf) { + // The result is in the buffer. We should load it and replace one of the + // loop results. The original contraction op can be removed. + // TODO: should we try to store to the output buffer on the last iteration? + Value loopRes = forOp.getTiedLoopResult(cast(op.getC())); LDBG( "Loading buffererized accumulator to a vector to replace loop result."); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(forOp); - Value newVal = rewriter.create( - loc, cast(acc.getType()), resBuf.memRef, resBuf.indices); + Value newVal = + op_read(cast(acc.getType()), resBuf.memRef, resBuf.indices); // We might need to cast back to the original type. newVal = maybeCast(loc, newVal, accTy.getElementType(), rewriter); rewriter.replaceAllUsesWith(loopRes, newVal); - // For now, just use init value for unused ForOp result instead of - // its removal. + // Directly yield orig accumulator iter value. It will be removed as unused + // later. rewriter.replaceOp(op, op.getC()); } else if (candidate.outBuf.empty()) { + // The result is in the buffer. We should load it and replace the original + // constraction result. LDBG("Loading the result to a vector to replace orig op result."); - Value newVal = rewriter.create( - loc, cast(acc.getType()), resBuf.memRef, resBuf.indices); + Value newVal = + op_read(cast(acc.getType()), resBuf.memRef, resBuf.indices); // We might need to cast back to the original type. newVal = maybeCast(loc, newVal, accTy.getElementType(), rewriter); rewriter.replaceOp(op, newVal); } else { + // The result is already in the output buffer. We just need to remove the + // original contraction and store operation. LDBG("Removing original operation and its use."); - rewriter.eraseOp(*op.getResult().user_begin()); + rewriter.eraseOp(candidate.origStore); rewriter.eraseOp(op); }