Skip to content

Commit

Permalink
Fix review comments.
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich committed Dec 11, 2024
1 parent f08f8b0 commit d19ce8c
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 29 deletions.
47 changes: 34 additions & 13 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@
CACHE_PADDING = os.getenv("CACHE_PADDING", "0") != "0"
PREPACKED = os.getenv("PREPACKED", "0") != "0"
PAD_B_ONLY = True
USE_BLOCK_POINTERS = os.getenv("USE_BLOCK_POINTERS", "1") != "0"
GROUP_SIZE_M = 8
USE_GPU = False

Expand Down Expand Up @@ -197,6 +198,7 @@ def matmul_kernel(
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
USE_BLOCK_POINTERS: tl.constexpr, #
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
Expand All @@ -219,12 +221,19 @@ def matmul_kernel(
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
block_offset_m = pid_m * BLOCK_SIZE_M
block_offset_n = pid_n * BLOCK_SIZE_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(1, 0))
if USE_BLOCK_POINTERS:
block_offset_m = pid_m * BLOCK_SIZE_M
block_offset_n = pid_n * BLOCK_SIZE_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(1, 0))
else:
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_tile_ptr = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_tile_ptr = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
Expand All @@ -241,18 +250,29 @@ def matmul_kernel(
# We accumulate along the K dimension.
accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32)
# Advance the ptrs to the next K block.
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K])
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0])
if USE_BLOCK_POINTERS:
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K])
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0])
else:
a_tile_ptr += BLOCK_SIZE_K * stride_ak
b_tile_ptr += BLOCK_SIZE_K * stride_bk


# Convert the accumulator to the output matrix C's type if needed.
c = accumulator

# -----------------------------------------------------------
# Write back the block of the output matrix C.
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N),
order=(1, 0))
tl.store(c_block_ptr, c)
if USE_BLOCK_POINTERS:
c_tile_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N),
order=(1, 0))
tl.store(c_tile_ptr, c)
else:
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_tile_ptr = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(c_tile_ptr, c)


# %%
Expand Down Expand Up @@ -306,6 +326,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, M: int, N: int, K:
c.stride(0), c.stride(1), #
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, #
GROUP_SIZE_M=GROUP_SIZE_M, #
USE_BLOCK_POINTERS=USE_BLOCK_POINTERS, #
num_threads=num_threads, #
)
return c
Expand Down Expand Up @@ -393,7 +414,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, M: int, N: int, K:
ylabel='GFLOPS', # Label name for the y-axis.
plot_name=
# Name for the plot. Used also as a file name for saving the plot.
f'matmul-performance-{DTYPE} (CACHE_PADDING={CACHE_PADDING} PREPACKED={PREPACKED} PAD_B_ONLY={PAD_B_ONLY} GROUP_SIZE_M={GROUP_SIZE_M})',
f'matmul-performance-{DTYPE} (USE_BLOCK_POINTERS={USE_BLOCK_POINTERS} CACHE_PADDING={CACHE_PADDING} PREPACKED={PREPACKED} PAD_B_ONLY={PAD_B_ONLY} GROUP_SIZE_M={GROUP_SIZE_M})',
args={}, # Values for function arguments not in `x_names` and `y_name`.
))
def benchmark(M, N, K, provider):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ struct MemBuffer {
};

// Check if accumulator value is updated in a loop and has no other
// usages than a dot op, that updates it. Tile loads/stores and casts
// for such accumulators can be done outside of the loop.
// usages than a dot op that updates it. Loads, stores, and casts
// for such accumulator can be done outside of the loop.
bool isLoopCarriedAcc(Value acc);

// Get initial value for a loop-carried accumulator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ using namespace mlir::triton::cpu;

namespace {

// This structure is used to hold candidates for conversion to AMX
// Mul[F|I]Op operations.
// This structure is used to hold candidates for conversion to FMA operations.
struct FmaDotOpCandidate {
// Operation to convert.
cpu::DotOp op;
Expand All @@ -48,10 +47,10 @@ struct FmaDotOpCandidate {
MemBuffer rhsBuf;
};

// Check if input and output types can be handled by AMX (possibly, using
// additional casts for input/output). Returns true if AMX usage is possible.
// In this case, tile element type fields of the candidate structure are
// filled with actual types to be used in lowering.
// Check if input and output types can be handled by FMA (possibly, using
// additional casts for input/output). Returns true if FMA lowering is possible.
// In this case, element type fields of the candidate structure are filled
// with actual types to be used in lowering.
bool checkElemTypes(Type lhsElemTy, Type rhsElemTy, Type accElemTy,
Type resElemTy, FmaDotOpCandidate &candidate) {
MLIRContext *ctx = lhsElemTy.getContext();
Expand Down Expand Up @@ -88,7 +87,7 @@ bool checkInputShapes(VectorType lhsTy, VectorType resTy) {
return true;
}

// Check if specified ContractionOp can be lowered to AMX operations.
// Check if specified ContractionOp can be lowered to FMA operations.
// If conversion is possible, then true is returned and candidate
// structure is filled with detailed transformation info.
bool isFmaCandidate(cpu::DotOp op, FmaDotOpCandidate &candidate) {
Expand Down Expand Up @@ -314,9 +313,9 @@ LogicalResult convertCandidate(FmaDotOpCandidate &candidate,
if (k != lhsTy.getDimSize(1) - 1)
nextRhsVec = loadRow(loc, rhsVecTy, rhsBuf, k + 1, rewriter);

// Prefetch RHS to L2 cache.
// Prefetch RHS to LLC cache.
if (!rhsPrefetchIndices.empty())
prefetch(loc, candidate.rhsBuf, k, 0, rhsPrefetchIndices, 3, rewriter);
prefetch(loc, candidate.rhsBuf, k, 0, rhsPrefetchIndices, 1, rewriter);

Value nextLhsBroadcasted =
broadcastElem(loc, accVecTy, lhsBuf, 0, k, rewriter);
Expand All @@ -328,10 +327,13 @@ LogicalResult convertCandidate(FmaDotOpCandidate &candidate,
nextLhsBroadcasted =
broadcastElem(loc, accVecTy, lhsBuf, m + 1, k, rewriter);

// Prefetch LHS to L0 cache.
if (!lhsPrefetchIndices.empty() &&
((k * candidate.accRows + m) % 16 == 0))
prefetch(loc, candidate.lhsBuf, m, k, lhsPrefetchIndices, 1, rewriter);
// Prefetch LHS to L1 cache.
if (!lhsPrefetchIndices.empty()) {
if ((candidate.lhsBuf.transposed && (m % 8 == 0)) ||
(!candidate.lhsBuf.transposed && (k % 8 == 0)))
prefetch(loc, candidate.lhsBuf, m, k, lhsPrefetchIndices, 3,
rewriter);
}

accVecs[m] = rewriter.create<vector::FMAOp>(loc, rhsVec, lhsBroadcasted,
accVecs[m]);
Expand All @@ -342,7 +344,7 @@ LogicalResult convertCandidate(FmaDotOpCandidate &candidate,
// In this case we have the whole accumulator/result on tiles. Loop
// carried dependencies are not in place yet and should be added.
// After the loop, resulting tiles should either be stored to the
// output buffer, or moved to a vector though a temporary buffer.
// output buffer, or moved to a vector through a temporary buffer.

// We don't need the original accumulator and contraction op anymore.
// Directly yield orig accumulator value, so it would be later removed
Expand Down

0 comments on commit d19ce8c

Please sign in to comment.