Skip to content

Commit

Permalink
AMX lowering improvements (#194)
Browse files Browse the repository at this point in the history
* Improve AMX lowering to minimize loads and stores.

Signed-off-by: Ilya Enkovich <[email protected]>

* Support bfloat16 in CPU matmul tutorials.

Signed-off-by: Ilya Enkovich <[email protected]>

---------

Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich authored Dec 12, 2024
1 parent 3f11034 commit feb95c3
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 242 deletions.
4 changes: 2 additions & 2 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@
import os

DTYPE = getattr(torch, (os.getenv("DTYPE", "float32")))
# Chosse block size depending on dtype. We have more register
# Choose block size depending on dtype. We have more register
# capacity for bfloat16/float16 compared to float32.
BLOCK_SIZE_M = 8 if DTYPE == torch.float32 else 32
BLOCK_SIZE_N = 32
Expand Down Expand Up @@ -306,7 +306,7 @@ def matmul_preprocess_input(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, n
K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size"
if c is None:
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
else:
assert c.shape == (M, N), "Incompatible dimensions"

Expand Down
58 changes: 29 additions & 29 deletions python/tutorials/cpu-blocked-matmul-fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@

import triton
import triton.language as tl
import os

BLOCK_SIZE_M = 8
DTYPE = os.getenv("DTYPE", "float32")
# Choose block size depending on dtype. We have more register
# capacity for bfloat16/float16 compared to float32.
BLOCK_SIZE_M = 8 if DTYPE == "float32" else 32
BLOCK_SIZE_N = 32
BLOCK_SIZE_K = 8
BLOCK_SIZE_K = 8 if DTYPE == "float32" else 32
GROUP_SIZE_M = 8


Expand All @@ -39,13 +43,10 @@
# used by Triton CPU backend which processes RHS block row-by-row and LHS
# block column-by-column.
@triton.jit
def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr,
TRANSPOSED_BLOCK_A: tl.constexpr, BLOCKED_B: tl.constexpr,
def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M, N, K, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr, BLOCKED_B: tl.constexpr,
TRANSPOSED_B: tl.constexpr):
tl.static_assert(M % BLOCK_SIZE_M == 0)
tl.static_assert(N % BLOCK_SIZE_N == 0)
tl.static_assert(BLOCKED_A or not TRANSPOSED_BLOCK_A)
tl.static_assert(BLOCKED_B or not TRANSPOSED_B)
pid = tl.program_id(axis=0)
Expand All @@ -62,10 +63,10 @@ def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M: tl.constexpr, N
a_out_block_m = in_block_m
A_OUT_BLOCK_SIZE_M: tl.constexpr = BLOCK_SIZE_K if TRANSPOSED_BLOCK_A else BLOCK_SIZE_M
A_OUT_BLOCK_SIZE_K: tl.constexpr = BLOCK_SIZE_M if TRANSPOSED_BLOCK_A else BLOCK_SIZE_K
A_OUT_BLOCKS_M: tl.constexpr = M // BLOCK_SIZE_M
A_OUT_BLOCKS_K: tl.constexpr = K // BLOCK_SIZE_K
A_OUT_BLOCKS_M = M // BLOCK_SIZE_M
A_OUT_BLOCKS_K = K // BLOCK_SIZE_K
A_OUT_STRIDE_M: tl.constexpr = A_OUT_BLOCK_SIZE_K
A_OUT_STRIDE_BLOCK_M: tl.constexpr = BLOCK_SIZE_M * K
A_OUT_STRIDE_BLOCK_M = BLOCK_SIZE_M * K
A_OUT_STRIDE_BLOCK_K: tl.constexpr = BLOCK_SIZE_M * BLOCK_SIZE_K
for in_block_k in tl.range(in_block_n, A_OUT_BLOCKS_K, N // BLOCK_SIZE_N):
a_out_block_k = in_block_k
Expand All @@ -84,10 +85,10 @@ def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M: tl.constexpr, N
tl.store(a_out_ptr, val)

if BLOCKED_B:
B_OUT_BLOCKS_K: tl.constexpr = N // BLOCK_SIZE_N if TRANSPOSED_B else K // BLOCK_SIZE_K
B_OUT_BLOCKS_N: tl.constexpr = K // BLOCK_SIZE_K if TRANSPOSED_B else N // BLOCK_SIZE_N
B_OUT_BLOCKS_K = N // BLOCK_SIZE_N if TRANSPOSED_B else K // BLOCK_SIZE_K
B_OUT_BLOCKS_N = K // BLOCK_SIZE_K if TRANSPOSED_B else N // BLOCK_SIZE_N
B_OUT_STRIDE_K: tl.constexpr = BLOCK_SIZE_N
B_OUT_STRIDE_BLOCK_K: tl.constexpr = (K * BLOCK_SIZE_N if TRANSPOSED_B else BLOCK_SIZE_K * N)
B_OUT_STRIDE_BLOCK_K = (K * BLOCK_SIZE_N if TRANSPOSED_B else BLOCK_SIZE_K * N)
B_OUT_STRIDE_BLOCK_N: tl.constexpr = BLOCK_SIZE_K * BLOCK_SIZE_N
for in_block_k in tl.range(in_block_m, K // BLOCK_SIZE_K, M // BLOCK_SIZE_M):
b_out_block_k = in_block_n if TRANSPOSED_B else in_block_k
Expand Down Expand Up @@ -120,11 +121,11 @@ def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M: tl.constexpr, N
# Reshape is used to remove the heading (1, 1) dimensions, but CPU backend folds it with the load
# operation and it doesn't prevent direct vector loads from the input memory.
@triton.jit
def matmul_kernel_fma(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
# number of blocks in a group
GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr,
BLOCKED_B: tl.constexpr, TRANSPOSED_B: tl.constexpr):
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
# number of blocks in a group
GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr,
BLOCKED_B: tl.constexpr, TRANSPOSED_B: tl.constexpr):
# TRANSPOSED_BLOCK_A means that each block in A is transposed.
# It is allowed only for blocked input.
assert (BLOCKED_A or not TRANSPOSED_BLOCK_A)
Expand Down Expand Up @@ -188,8 +189,8 @@ def matmul_kernel_fma(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr,
tl.store(c_block_ptr, c)


def matmul_fma(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, bb: torch.Tensor, M, N, K,
PREPACKED, BLOCKED_A, TRANSPOSED_BLOCK_A, BLOCKED_B, TRANSPOSED_B, num_threads=0):
def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, bb: torch.Tensor, M, N, K, PREPACKED,
BLOCKED_A, TRANSPOSED_BLOCK_A, BLOCKED_B, TRANSPOSED_B, num_threads=0):
#TODO: Currently masked load is not supported yet.
assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (
K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size"
Expand All @@ -207,7 +208,7 @@ def matmul_fma(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tens
a = ab
if BLOCKED_B:
b = bb
matmul_kernel_fma[grid](
matmul_kernel[grid](
a, b, c, #
M, N, K, #
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, #
Expand All @@ -233,14 +234,14 @@ def matmul_fma(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tens
rtol = 0
a_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_M) * (512 // BLOCK_SIZE_K) * 64), device='cpu', dtype=torch.float32)
b_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_K) * (512 // BLOCK_SIZE_N) * 64), device='cpu', dtype=torch.float32)
triton_output = matmul_fma(a, b, c, a_tmp, b_tmp, 512, 512, 512, True, False, False, False, False, False)
triton_output = matmul(a, b, c, a_tmp, b_tmp, 512, 512, 512, True, False, False, False, False, False)
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
print("✅ TritonCPU and TorchCPU match")
else:
print("❌ TritonCPU and TorchCPU differ, the maximum difference is "
f'{torch.max(torch.abs(triton_output - torch_output))}')
assert False
triton_output = matmul_fma(a, b, c, a_tmp, b_tmp, 512, 512, 512, False, True, True, True, True, True)
triton_output = matmul(a, b, c, a_tmp, b_tmp, 512, 512, 512, False, True, True, True, True, True)
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
print("✅ TritonCPU pre-packed and TorchCPU match")
else:
Expand Down Expand Up @@ -289,7 +290,7 @@ def decode_provider(provider):
BLOCK_TRANSPOSE_B_OPTS = [(True, True), (False, False)]
PREPACK_OPTS = [False, True]
SINGLE_THREAD_OPTS = [False]
DTYPE_OPTS = ['float32']
DTYPE_OPTS = [DTYPE]
LINE_VALS = [
encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, prepack, single_thread, dtype)
for single_thread in SINGLE_THREAD_OPTS
Expand All @@ -316,7 +317,7 @@ def decode_provider(provider):
ylabel='GFLOPS', # Label name for the y-axis.
plot_name=
# Name for the plot. Used also as a file name for saving the plot.
f'matmul-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})',
f'matmul-performance-{DTYPE} (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})',
args={}, # Values for function arguments not in `x_names` and `y_name`.
))
def benchmark(M, N, K, provider):
Expand Down Expand Up @@ -360,9 +361,8 @@ def benchmark(M, N, K, provider):
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles)
elif backend == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: matmul_fma(a, b, c, a_tmp, b_tmp, M, N, K, prepack, blocked_a, transposed_a, blocked_b,
transposed_b, num_threads=int(single_thread)), quantiles=quantiles,
measure_time_with_hooks=True, rep=1000)
lambda: matmul(a, b, c, a_tmp, b_tmp, M, N, K, prepack, blocked_a, transposed_a, blocked_b, transposed_b,
num_threads=int(single_thread)), quantiles=quantiles, measure_time_with_hooks=True, rep=1000)
perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)

Expand Down
11 changes: 11 additions & 0 deletions third_party/cpu/include/TritonCPUTransforms/OptCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,14 @@ inline Value shapeCast(Location loc, Value in,
} // namespace mlir

#define int_cst(ty, val) intCst(loc, ty, val, rewriter)
#define index_cst(val) rewriter.create<arith::ConstantIndexOp>(loc, val)
#define cst_like(src, val) cstLike(loc, src, val, rewriter)

#define op_addi(lhs, rhs) rewriter.create<arith::AddIOp>(loc, lhs, rhs)
#define op_addf(lhs, rhs) rewriter.create<arith::AddFOp>(loc, lhs, rhs)
#define op_subi(lhs, rhs) rewriter.create<arith::SubIOp>(loc, lhs, rhs)
#define op_subf(lhs, rhs) rewriter.create<arith::SubFOp>(loc, lhs, rhs)
#define op_muli(lhs, rhs) rewriter.create<arith::MulIOp>(loc, lhs, rhs)
#define op_mulf(lhs, rhs) rewriter.create<arith::MulFOp>(loc, lhs, rhs)
#define op_bitcast(ty, val) rewriter.create<arith::BitcastOp>(loc, ty, val)
#define op_lshr(lhs, rhs) rewriter.create<arith::ShRUIOp>(loc, lhs, rhs)
Expand All @@ -146,6 +148,15 @@ inline Value shapeCast(Location loc, Value in,
rewriter.create<arith::SelectOp>(loc, cond, val, other)
#define op_sitofp(ty, val) rewriter.create<arith::SIToFPOp>(loc, ty, val)
#define op_fptosi(ty, val) rewriter.create<arith::FPToSIOp>(loc, ty, val)
#define op_read(ty, memRef, indices) \
rewriter.create<vector::TransferReadOp>(loc, ty, memRef, indices)
#define op_write(val, memRef, indices) \
rewriter.create<vector::TransferWriteOp>(loc, val, memRef, indices)
#define op_interleave(lhs, rhs) \
rewriter.create<vector::InterleaveOp>(loc, lhs, rhs)
#define op_extract(vec, idx) rewriter.create<vector::ExtractOp>(loc, vec, idx)
#define op_store(val, mem, idx) \
rewriter.create<vector::StoreOp>(loc, val, mem, idx)

#define op_icmp_eq(lhs, rhs) \
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, lhs, rhs)
Expand Down
Loading

0 comments on commit feb95c3

Please sign in to comment.