Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce FMA lowering for DotOp. #193

Merged
merged 3 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 89 additions & 41 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,35 @@

import triton
import triton.language as tl
import os

BLOCK_SIZE_M = 32
DTYPE = getattr(torch, (os.getenv("DTYPE", "float32")))
# Chosse block size depending on dtype. We have more register
# capacity for bfloat16/float16 compared to float32.
BLOCK_SIZE_M = 8 if DTYPE == torch.float32 else 32
BLOCK_SIZE_N = 32
BLOCK_SIZE_K = 32
BLOCK_SIZE_K = 8 if DTYPE == torch.float32 else 32
CACHE_PADDING = os.getenv("CACHE_PADDING", "0") != "0"
PREPACKED = os.getenv("PREPACKED", "0") != "0"
PAD_B_ONLY = True
GROUP_SIZE_M = 8
USE_GPU = False


@triton.jit
def pad_kernel(in_ptr, out_ptr, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, PADDING: tl.constexpr):
in_offset = tl.program_id(axis=0) * N * BLOCK_SIZE_M
out_offset = tl.program_id(axis=0) * (N + PADDING) * BLOCK_SIZE_M
for row in tl.range(0, BLOCK_SIZE_M):
for block in tl.range(0, N // BLOCK_SIZE_N):
val = tl.load(in_ptr + in_offset + block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))
tl.store(out_ptr + out_offset + block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N), val)
zero = tl.full((PADDING, ), 0, dtype=in_ptr.type.element_ty)
tl.store(out_ptr + out_offset + N + tl.arange(0, PADDING), zero)
in_offset += N
out_offset += N + PADDING


@triton.jit
def matmul_kernel(
# Pointers to matrices
Expand Down Expand Up @@ -198,14 +219,12 @@ def matmul_kernel(
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetic` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
block_offset_m = pid_m * BLOCK_SIZE_M
block_offset_n = pid_n * BLOCK_SIZE_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep this matmul_kernel as is? Instead, what about having matmul_kernel_block_ptr or something. It'd be good to keep the baseline implementations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can leave both options in a single tutorial and choose by a flag. Having them in different tutorials would make the comparison of these two options less reliable.

offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(1, 0))

# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
Expand All @@ -217,43 +236,50 @@ def matmul_kernel(
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.

# TODO: Currently masked load is not supported yet.
# a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
Comment on lines -221 to -222
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I forgot to remove these restrictions. We can handle masks. So, let's remove them.

a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
a = tl.load(a_tile_ptr)
b = tl.load(b_tile_ptr)
Comment on lines +250 to +251
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, could you add the masking? It works, and I believe masking will also work with this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Masks are not related to my changes. And I don't want to return them yet (at least unconditionally). We have masks optimizations working for the 1D case, but I'm not sure it can work for the 2D case and the masked variant can be much slower.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, then let me (or anyone) update the masking in a separate PR. This restriction was placed in old time, and now it should be removed. It's perfectly fine not to have the best performance, but masking is a must for matmul.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it stops you from getting the best performance when you don't need masks then I'm not sure. We see that padding is profitable anyway, so it can be used to both improve cache hits and avoid masks by making sure we never access a part of a block. And padding doesn't require modifications of the matmul kernel.

# We accumulate along the K dimension.
accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K])
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0])

# Convert the accumulator to the output matrix C's type if needed.
c = accumulator

# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]

# TODO: Currently masked load is not supported yet.
# c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
# tl.store(c_ptrs, c, mask=c_mask)
tl.store(c_ptrs, c)
Comment on lines -241 to -243
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto. Masks work.

# Write back the block of the output matrix C.
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N),
order=(1, 0))
tl.store(c_block_ptr, c)


# %%
# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.

a_scratch = torch.empty((), dtype=DTYPE)
b_scratch = torch.empty((), dtype=DTYPE)

def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0):

def matmul_preprocess_input(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape

# TODO: Check if padding is needed at all.
if CACHE_PADDING:
a_scratch.resize_(M, K + 32)
b_scratch.resize_(K, N + 32)
if not PAD_B_ONLY:
pad_kernel[(M // BLOCK_SIZE_M, )](a, a_scratch, K, BLOCK_SIZE_M, BLOCK_SIZE_K, 32, num_threads=num_threads)
a = a_scratch
pad_kernel[(K // BLOCK_SIZE_K, )](b, b_scratch, N, BLOCK_SIZE_K, BLOCK_SIZE_N, 32, num_threads=num_threads)
b = b_scratch

#TODO: Currently masked load is not supported yet.
assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (
K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed.

Expand All @@ -262,6 +288,14 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0):
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
else:
assert c.shape == (M, N), "Incompatible dimensions"

return a, b, c


def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, M: int, N: int, K: int, num_threads=0):
if not PREPACKED:
a, b, c = matmul_preprocess_input(a, b, c, num_threads=num_threads)

# 1D launch kernel where each block gets its own program.
grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), )
matmul_kernel[grid](
Expand All @@ -287,10 +321,13 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0):

triton.runtime.driver.set_active_to_cpu()

a = torch.randn((512, 512), device='cpu', dtype=torch.float32)
b = torch.randn((512, 512), device='cpu', dtype=torch.float32)
triton_output = matmul(a, b, None)
torch_output = torch.matmul(a, b)
a = torch.randn((512, 512), device='cpu', dtype=DTYPE)
b = torch.randn((512, 512), device='cpu', dtype=DTYPE)
c = None
torch_output = torch.matmul(a.to(torch.float32), b.to(torch.float32))
if PREPACKED:
a, b, c = matmul_preprocess_input(a, b, c)
triton_output = matmul(a, b, c, 512, 512, 512)
print(f"triton_cpu_output_with_{a.dtype}_inputs={triton_output}")
print(f"torch_cpu_output_with_{a.dtype}_inputs={torch_output}")
rtol = 0
Expand All @@ -310,9 +347,9 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0):
# We can now compare the performance of our kernel against that of Pytorch. Here we focus on square matrices,
# but feel free to arrange this script as you wish to benchmark any other matrix shape.

LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu-native', 'torch-cpu-compile']
LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU (native)', 'TorchCPU (compile)']
LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '--'), ('green', '-')]
LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu-native']
LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU (native)']
LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '--')]

if USE_GPU and triton.runtime.driver.get_active_gpus():
triton.runtime.driver.set_active_to_gpu()
Expand Down Expand Up @@ -356,36 +393,47 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0):
ylabel='GFLOPS', # Label name for the y-axis.
plot_name=
# Name for the plot. Used also as a file name for saving the plot.
f'matmul-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})',
f'matmul-performance-{DTYPE} (CACHE_PADDING={CACHE_PADDING} PREPACKED={PREPACKED} PAD_B_ONLY={PAD_B_ONLY} GROUP_SIZE_M={GROUP_SIZE_M})',
args={}, # Values for function arguments not in `x_names` and `y_name`.
))
def benchmark(M, N, K, provider):

device = 'cpu' if 'cpu' in provider else 'cuda'
a = torch.randn((M, K), device=device, dtype=torch.float32)
b = torch.randn((K, N), device=device, dtype=torch.float32)
a = torch.randn((M, K), device=device, dtype=DTYPE)
b = torch.randn((K, N), device=device, dtype=DTYPE)

if device == 'cpu':
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
if 'triton-cpu' in provider:
c = torch.zeros((M, N), device=a.device, dtype=torch.float32)
else:
c = torch.zeros((M, N), device=a.device, dtype=a.dtype)
triton.runtime.driver.set_active_to_cpu()
else:
c = None
triton.runtime.driver.set_active_to_gpu()

if PREPACKED:
triton_a, triton_b, triton_c = matmul_preprocess_input(a, b, c)
else:
triton_a, triton_b, triton_c = a, b, c

quantiles = [0.5, 0.2, 0.8]
if provider == 'torch-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
elif provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(triton_a, triton_b, None), quantiles=quantiles)
elif provider == 'torch-cpu-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles)
elif provider == 'torch-cpu-compile':
compiled = torch.compile(torch.matmul)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles)
elif provider == 'triton-cpu-single':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c, num_threads=1), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: matmul(triton_a, triton_b, triton_c, M, N, K, num_threads=1), quantiles=quantiles,
measure_time_with_hooks=True)
elif provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(triton_a, triton_b, triton_c, M, N, K),
quantiles=quantiles, measure_time_with_hooks=True)
perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)

Expand Down
Loading