-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce FMA lowering for DotOp. #193
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -153,14 +153,36 @@ | |
|
||
import triton | ||
import triton.language as tl | ||
import os | ||
|
||
BLOCK_SIZE_M = 32 | ||
DTYPE = getattr(torch, (os.getenv("DTYPE", "float32"))) | ||
# Chosse block size depending on dtype. We have more register | ||
# capacity for bfloat16/float16 compared to float32. | ||
BLOCK_SIZE_M = 8 if DTYPE == torch.float32 else 32 | ||
BLOCK_SIZE_N = 32 | ||
BLOCK_SIZE_K = 32 | ||
BLOCK_SIZE_K = 8 if DTYPE == torch.float32 else 32 | ||
CACHE_PADDING = os.getenv("CACHE_PADDING", "0") != "0" | ||
PREPACKED = os.getenv("PREPACKED", "0") != "0" | ||
PAD_B_ONLY = True | ||
USE_BLOCK_POINTERS = os.getenv("USE_BLOCK_POINTERS", "1") != "0" | ||
GROUP_SIZE_M = 8 | ||
USE_GPU = False | ||
|
||
|
||
@triton.jit | ||
def pad_kernel(in_ptr, out_ptr, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, PADDING: tl.constexpr): | ||
in_offset = tl.program_id(axis=0) * N * BLOCK_SIZE_M | ||
out_offset = tl.program_id(axis=0) * (N + PADDING) * BLOCK_SIZE_M | ||
for row in tl.range(0, BLOCK_SIZE_M): | ||
for block in tl.range(0, N // BLOCK_SIZE_N): | ||
val = tl.load(in_ptr + in_offset + block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) | ||
tl.store(out_ptr + out_offset + block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N), val) | ||
zero = tl.full((PADDING, ), 0, dtype=in_ptr.type.element_ty) | ||
tl.store(out_ptr + out_offset + N + tl.arange(0, PADDING), zero) | ||
in_offset += N | ||
out_offset += N + PADDING | ||
|
||
|
||
@triton.jit | ||
def matmul_kernel( | ||
# Pointers to matrices | ||
|
@@ -176,6 +198,7 @@ def matmul_kernel( | |
# Meta-parameters | ||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # | ||
GROUP_SIZE_M: tl.constexpr, # | ||
USE_BLOCK_POINTERS: tl.constexpr, # | ||
): | ||
"""Kernel for computing the matmul C = A x B. | ||
A has shape (M, K), B has shape (K, N) and C has shape (M, N) | ||
|
@@ -198,14 +221,21 @@ def matmul_kernel( | |
# Create pointers for the first blocks of A and B. | ||
# We will advance this pointer as we move in the K direction | ||
# and accumulate | ||
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers | ||
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers | ||
# See above `Pointer Arithmetic` section for details | ||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | ||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | ||
offs_k = tl.arange(0, BLOCK_SIZE_K) | ||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) | ||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | ||
if USE_BLOCK_POINTERS: | ||
block_offset_m = pid_m * BLOCK_SIZE_M | ||
block_offset_n = pid_n * BLOCK_SIZE_N | ||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), | ||
offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), | ||
order=(1, 0)) | ||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), | ||
offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), | ||
order=(1, 0)) | ||
else: | ||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | ||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | ||
offs_k = tl.arange(0, BLOCK_SIZE_K) | ||
a_tile_ptr = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) | ||
b_tile_ptr = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | ||
|
||
# ----------------------------------------------------------- | ||
# Iterate to compute a block of the C matrix. | ||
|
@@ -217,43 +247,60 @@ def matmul_kernel( | |
# Load the next block of A and B, generate a mask by checking the K dimension. | ||
# If it is out of bounds, set it to 0. | ||
|
||
# TODO: Currently masked load is not supported yet. | ||
# a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | ||
# b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) | ||
a = tl.load(a_ptrs) | ||
b = tl.load(b_ptrs) | ||
a = tl.load(a_tile_ptr) | ||
b = tl.load(b_tile_ptr) | ||
Comment on lines
+250
to
+251
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, could you add the masking? It works, and I believe masking will also work with this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Masks are not related to my changes. And I don't want to return them yet (at least unconditionally). We have masks optimizations working for the 1D case, but I'm not sure it can work for the 2D case and the masked variant can be much slower. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, then let me (or anyone) update the masking in a separate PR. This restriction was placed in old time, and now it should be removed. It's perfectly fine not to have the best performance, but masking is a must for matmul. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it stops you from getting the best performance when you don't need masks then I'm not sure. We see that padding is profitable anyway, so it can be used to both improve cache hits and avoid masks by making sure we never access a part of a block. And padding doesn't require modifications of the matmul kernel. |
||
# We accumulate along the K dimension. | ||
accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32) | ||
# Advance the ptrs to the next K block. | ||
a_ptrs += BLOCK_SIZE_K * stride_ak | ||
b_ptrs += BLOCK_SIZE_K * stride_bk | ||
if USE_BLOCK_POINTERS: | ||
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K]) | ||
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0]) | ||
else: | ||
a_tile_ptr += BLOCK_SIZE_K * stride_ak | ||
b_tile_ptr += BLOCK_SIZE_K * stride_bk | ||
|
||
# Convert the accumulator to the output matrix C's type if needed. | ||
c = accumulator | ||
|
||
# ----------------------------------------------------------- | ||
# Write back the block of the output matrix C with masks. | ||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | ||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] | ||
|
||
# TODO: Currently masked load is not supported yet. | ||
# c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) | ||
# tl.store(c_ptrs, c, mask=c_mask) | ||
tl.store(c_ptrs, c) | ||
Comment on lines
-241
to
-243
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto. Masks work. |
||
# Write back the block of the output matrix C. | ||
if USE_BLOCK_POINTERS: | ||
c_tile_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), | ||
offsets=(block_offset_m, block_offset_n), | ||
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) | ||
tl.store(c_tile_ptr, c) | ||
else: | ||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | ||
c_tile_ptr = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] | ||
tl.store(c_tile_ptr, c) | ||
|
||
|
||
# %% | ||
# We can now create a convenience wrapper function that only takes two input tensors, | ||
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. | ||
|
||
a_scratch = torch.empty((), dtype=DTYPE) | ||
b_scratch = torch.empty((), dtype=DTYPE) | ||
|
||
def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0): | ||
|
||
def matmul_preprocess_input(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0): | ||
# Check constraints. | ||
assert a.shape[1] == b.shape[0], "Incompatible dimensions" | ||
assert a.is_contiguous(), "Matrix A must be contiguous" | ||
M, K = a.shape | ||
K, N = b.shape | ||
|
||
# TODO: Check if padding is needed at all. | ||
if CACHE_PADDING: | ||
a_scratch.resize_(M, K + 32) | ||
b_scratch.resize_(K, N + 32) | ||
if not PAD_B_ONLY: | ||
pad_kernel[(M // BLOCK_SIZE_M, )](a, a_scratch, K, BLOCK_SIZE_M, BLOCK_SIZE_K, 32, num_threads=num_threads) | ||
a = a_scratch | ||
pad_kernel[(K // BLOCK_SIZE_K, )](b, b_scratch, N, BLOCK_SIZE_K, BLOCK_SIZE_N, 32, num_threads=num_threads) | ||
b = b_scratch | ||
|
||
#TODO: Currently masked load is not supported yet. | ||
assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and ( | ||
K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be removed. |
||
|
@@ -262,6 +309,14 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0): | |
c = torch.empty((M, N), device=a.device, dtype=a.dtype) | ||
else: | ||
assert c.shape == (M, N), "Incompatible dimensions" | ||
|
||
return a, b, c | ||
|
||
|
||
def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, M: int, N: int, K: int, num_threads=0): | ||
if not PREPACKED: | ||
a, b, c = matmul_preprocess_input(a, b, c, num_threads=num_threads) | ||
|
||
# 1D launch kernel where each block gets its own program. | ||
grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), ) | ||
matmul_kernel[grid]( | ||
|
@@ -272,6 +327,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0): | |
c.stride(0), c.stride(1), # | ||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # | ||
GROUP_SIZE_M=GROUP_SIZE_M, # | ||
USE_BLOCK_POINTERS=USE_BLOCK_POINTERS, # | ||
num_threads=num_threads, # | ||
) | ||
return c | ||
|
@@ -287,10 +343,13 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0): | |
|
||
triton.runtime.driver.set_active_to_cpu() | ||
|
||
a = torch.randn((512, 512), device='cpu', dtype=torch.float32) | ||
b = torch.randn((512, 512), device='cpu', dtype=torch.float32) | ||
triton_output = matmul(a, b, None) | ||
torch_output = torch.matmul(a, b) | ||
a = torch.randn((512, 512), device='cpu', dtype=DTYPE) | ||
b = torch.randn((512, 512), device='cpu', dtype=DTYPE) | ||
c = None | ||
torch_output = torch.matmul(a.to(torch.float32), b.to(torch.float32)) | ||
if PREPACKED: | ||
a, b, c = matmul_preprocess_input(a, b, c) | ||
triton_output = matmul(a, b, c, 512, 512, 512) | ||
print(f"triton_cpu_output_with_{a.dtype}_inputs={triton_output}") | ||
print(f"torch_cpu_output_with_{a.dtype}_inputs={torch_output}") | ||
rtol = 0 | ||
|
@@ -310,9 +369,9 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0): | |
# We can now compare the performance of our kernel against that of Pytorch. Here we focus on square matrices, | ||
# but feel free to arrange this script as you wish to benchmark any other matrix shape. | ||
|
||
LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu-native', 'torch-cpu-compile'] | ||
LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU (native)', 'TorchCPU (compile)'] | ||
LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '--'), ('green', '-')] | ||
LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu-native'] | ||
LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU (native)'] | ||
LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '--')] | ||
|
||
if USE_GPU and triton.runtime.driver.get_active_gpus(): | ||
triton.runtime.driver.set_active_to_gpu() | ||
|
@@ -356,36 +415,47 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0): | |
ylabel='GFLOPS', # Label name for the y-axis. | ||
plot_name= | ||
# Name for the plot. Used also as a file name for saving the plot. | ||
f'matmul-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})', | ||
f'matmul-performance-{DTYPE} (USE_BLOCK_POINTERS={USE_BLOCK_POINTERS} CACHE_PADDING={CACHE_PADDING} PREPACKED={PREPACKED} PAD_B_ONLY={PAD_B_ONLY} GROUP_SIZE_M={GROUP_SIZE_M})', | ||
args={}, # Values for function arguments not in `x_names` and `y_name`. | ||
)) | ||
def benchmark(M, N, K, provider): | ||
|
||
device = 'cpu' if 'cpu' in provider else 'cuda' | ||
a = torch.randn((M, K), device=device, dtype=torch.float32) | ||
b = torch.randn((K, N), device=device, dtype=torch.float32) | ||
a = torch.randn((M, K), device=device, dtype=DTYPE) | ||
b = torch.randn((K, N), device=device, dtype=DTYPE) | ||
|
||
if device == 'cpu': | ||
c = torch.empty((M, N), device=a.device, dtype=a.dtype) | ||
if 'triton-cpu' in provider: | ||
c = torch.zeros((M, N), device=a.device, dtype=torch.float32) | ||
else: | ||
c = torch.zeros((M, N), device=a.device, dtype=a.dtype) | ||
triton.runtime.driver.set_active_to_cpu() | ||
else: | ||
c = None | ||
triton.runtime.driver.set_active_to_gpu() | ||
|
||
if PREPACKED: | ||
triton_a, triton_b, triton_c = matmul_preprocess_input(a, b, c) | ||
else: | ||
triton_a, triton_b, triton_c = a, b, c | ||
|
||
quantiles = [0.5, 0.2, 0.8] | ||
if provider == 'torch-gpu': | ||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) | ||
elif provider == 'triton-gpu': | ||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles) | ||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(triton_a, triton_b, None), quantiles=quantiles) | ||
elif provider == 'torch-cpu-native': | ||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles) | ||
elif provider == 'torch-cpu-compile': | ||
compiled = torch.compile(torch.matmul) | ||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles) | ||
elif provider == 'triton-cpu-single': | ||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c, num_threads=1), quantiles=quantiles) | ||
ms, min_ms, max_ms = triton.testing.do_bench( | ||
lambda: matmul(triton_a, triton_b, triton_c, M, N, K, num_threads=1), quantiles=quantiles, | ||
measure_time_with_hooks=True) | ||
elif provider == 'triton-cpu': | ||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles) | ||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(triton_a, triton_b, triton_c, M, N, K), | ||
quantiles=quantiles, measure_time_with_hooks=True) | ||
perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) | ||
return perf(ms), perf(max_ms), perf(min_ms) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I forgot to remove these restrictions. We can handle masks. So, let's remove them.