-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce FMA lowering for DotOp. #193
Conversation
ccc1d72
to
f08f8b0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so MUCH for doing this work! I'm going to test on my environment soon. I quickly skimmed through. Mostly looking good. Just some questions and nits.
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | ||
block_offset_m = pid_m * BLOCK_SIZE_M | ||
block_offset_n = pid_n * BLOCK_SIZE_N | ||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we keep this matmul_kernel
as is? Instead, what about having matmul_kernel_block_ptr
or something. It'd be good to keep the baseline implementations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can leave both options in a single tutorial and choose by a flag. Having them in different tutorials would make the comparison of these two options less reliable.
# a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | ||
# b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I forgot to remove these restrictions. We can handle masks. So, let's remove them.
# c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) | ||
# tl.store(c_ptrs, c, mask=c_mask) | ||
tl.store(c_ptrs, c) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto. Masks work.
a = a_scratch | ||
pad_kernel[(K // BLOCK_SIZE_K, )](b, b_scratch, N, BLOCK_SIZE_K, BLOCK_SIZE_N, 32, num_threads=num_threads) | ||
b = b_scratch | ||
|
||
#TODO: Currently masked load is not supported yet. | ||
assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and ( | ||
K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be removed.
a = tl.load(a_tile_ptr) | ||
b = tl.load(b_tile_ptr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, could you add the masking? It works, and I believe masking will also work with this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Masks are not related to my changes. And I don't want to return them yet (at least unconditionally). We have masks optimizations working for the 1D case, but I'm not sure it can work for the 2D case and the masked variant can be much slower.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, then let me (or anyone) update the masking in a separate PR. This restriction was placed in old time, and now it should be removed. It's perfectly fine not to have the best performance, but masking is a must for matmul.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it stops you from getting the best performance when you don't need masks then I'm not sure. We see that padding is profitable anyway, so it can be used to both improve cache hits and avoid masks by making sure we never access a part of a block. And padding doesn't require modifications of the matmul kernel.
third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToFMA.cpp
Outdated
Show resolved
Hide resolved
third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToFMA.cpp
Outdated
Show resolved
Hide resolved
third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToFMA.cpp
Outdated
Show resolved
Hide resolved
third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h
Outdated
Show resolved
Hide resolved
auto newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields( | ||
rewriter, newInitOperands, true, | ||
[&newYieldedValues](OpBuilder &b, Location loc, | ||
ArrayRef<BlockArgument> newBBArgs) { | ||
return newYieldedValues; | ||
})); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that I didn't fully understand this approach in AMX and here.
Is keepAccOnRegs = true
a normal case as in a typical tutorial example?
accumulator = tl.zeros(...)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
accumulator = tl.dot(a, b, accumulator)
...
accumulator
has loop-carried dependencies. So I guess keepAccOnRegs
is true.
So, why do we generate scf::for here? Looks more confusing and complex to me. I see that this is mostly based on the current AMX approach. But naively thinking, I think we can avoid scf::for here, just emitting FMA
...
Anyhow, we have good perf numbers :) There're multiple ways to do it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keepAccOnRegs
means we want to load the accumulator to registers before the loop and keep it this way for the whole loop. In this case accumulator is represented as a set of 1D vectors - 1 vector per each accumulator's row. Then those accumulator rows are used in FMA operations and the results of FMA operations go to yield operation to form the accumulator's loop dependencies. Adding new values to the yield operation is done via a call to replaceWithAdditionalYields
. That method replaces the original ForOp with an extended one. The original loop body is moved to a new operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be clear: this for-loop is the original loop written by a user. We don't create new loops, just modify the existing one.
# 1D launch kernel where each block gets its own program. | ||
grid = ((M // BLOCK_SIZE_M) * (N // BLOCK_SIZE_N), ) | ||
if (BLOCKED_A or BLOCKED_B) and not PREPACKED: | ||
block_transpose_combined_kernel[grid]( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, if you have evaluated performance impact of doing the packing as threads consume the tiles? As opposed to materializing full packed tensors before doing the matmul.
Not sure how grid:thread mapping would work given there would be some coordination across threads required I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't explore such an option. It should be possible to organize through atomics. Do you think it would give additional performance gains?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure TBH, I don't have useful data handly.
I would imagine perf gains would depend on the size of the packed tensors, number of threads, and if we can leverage cache-locality better.
It should be possible to organize through atomics
Are you talking about triton atomic_
ops? Interesting... yeah I would guess so. I am not too sure how would the input packing calls, and gemm call orchestration might look like for cpu setup.
Signed-off-by: Ilya Enkovich <[email protected]>
d19ce8c
to
fe8dc83
Compare
@minjang Thanks for the review. I've made changes to fix all found issues. |
This PR adds lowering for DotOp through vector FMA and broadcast operations. The lowering is quite simple and doesn't work well for all block sizes, so a block size fitting register file is preferred. It provides much better performance compared to the contraction operation lowering.
The patch also adds some fixes to the matmul tutorial to provide more measuring options. First, it utilizes block pointers to improve analysis and simplify code generation. Second, it introduces a padding feature to improve caching by avoiding power of two strides.
Here are the performance results for the current lowering through vector contraction op:
Here are results with the FMA pass used:
These are FMA results with padding enabled. Here we can avoid performance drops on some sizes:
If we are interested in performance when we can ignore padding costs (e.g. process weights only once for inference), we can use PREPACKED option to ignore padding costs:
To evaluate possibilities to improve results through kernel modifications, I added a new blocked matmul tutorial. There we optionally change the layout of input data for better data locality. It supports multiple options for pre-processing both LHS and RHS and allows to compare their performance. Here are the results when we use a blocked layout for RHS (prepack in column name means layout change price was ignored).
Single thread:
Multiple threads:
All results are measured on 48-core Intel Platinum 8468V. For better stability, the frequency was fixed, hyperthreading disabled, Intel OpenMP was used with KMP_AFFINITY to pin threads.