Skip to content

Commit

Permalink
Fix formatting.
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich committed Jun 17, 2024
1 parent 02db84c commit b0689c5
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@
import triton
import triton.language as tl


BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 32
BLOCK_SIZE_K = 32
GROUP_SIZE_M = 8
USE_GPU = True


@triton.jit
def matmul_kernel(
# Pointers to matrices
Expand Down Expand Up @@ -227,7 +227,7 @@ def matmul_kernel(
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

# Convert the accumulator to the output matrix C's type if needed.
c = accumulator

Expand All @@ -236,14 +236,13 @@ def matmul_kernel(
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]

#TODO: Currently masked load is not supported yet.
#c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
#tl.store(c_ptrs, c, mask=c_mask)
tl.store(c_ptrs, c)



# %%
# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.
Expand All @@ -256,9 +255,10 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
M, K = a.shape
K, N = b.shape
#TODO: Currently masked load is not supported yet.
assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size"
assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (
K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size"
if c is None:
# Allocates output.
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
else:
assert c.shape == (M, N), "Incompatible dimensions"
Expand All @@ -270,9 +270,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K, #
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, #
GROUP_SIZE_M=GROUP_SIZE_M, #
)
return c
Expand All @@ -298,7 +296,8 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
print("✅ TritonCPU and TorchCPU match")
else:
print("❌ TritonCPU and TorchCPU differ, the maximum difference is "f'{torch.max(torch.abs(triton_output - torch_output))}')
print("❌ TritonCPU and TorchCPU differ, the maximum difference is "
f'{torch.max(torch.abs(triton_output - torch_output))}')

# %%
# Benchmark
Expand Down Expand Up @@ -326,13 +325,13 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
print("✅ TritonGPU and TorchGPU match")
else:
print("❌ TritonGPU and TorchGPU differ, the maximum difference is "f'{torch.max(torch.abs(triton_output - torch_output))}')
print("❌ TritonGPU and TorchGPU differ, the maximum difference is "
f'{torch.max(torch.abs(triton_output - torch_output))}')

LINE_VALS += ['triton-gpu', 'torch-gpu']
LINE_NAMES += ['TritonGPU', 'TorchGPU']
LINE_STYLES += [('yellow', '-'), ('red', '-')]


# %%
# Seems like we're good to go!

Expand All @@ -359,7 +358,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
f'matmul-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})',
args={}, # Values for function arguments not in `x_names` and `y_name`.
))

def benchmark(M, N, K, provider):
import os

Expand All @@ -383,7 +381,8 @@ def benchmark(M, N, K, provider):
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles)
elif provider == 'torch-cpu':
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles, is_cpu=True)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles,
is_cpu=True)
elif provider == 'triton-cpu-single':
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True)
Expand Down

0 comments on commit b0689c5

Please sign in to comment.