Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce FMA lowering for DotOp. #193

Merged
merged 3 commits into from
Dec 12, 2024

Conversation

ienkovich
Copy link
Collaborator

This PR adds lowering for DotOp through vector FMA and broadcast operations. The lowering is quite simple and doesn't work well for all block sizes, so a block size fitting register file is preferred. It provides much better performance compared to the contraction operation lowering.

The patch also adds some fixes to the matmul tutorial to provide more measuring options. First, it utilizes block pointers to improve analysis and simplify code generation. Second, it introduces a padding feature to improve caching by avoiding power of two strides.

Here are the performance results for the current lowering through vector contraction op:

matmul-performance-torch.float32 (CACHE_PADDING=False PREPACKED=False PAD_B_ONLY=True GROUP_SIZE_M=8):
         M       N       K  TritonCPU 1   TritonCPU  TorchCPU (native)
0    256.0   256.0   256.0     7.008444  208.756250         266.680477
1    384.0   384.0   384.0     7.037777  269.310043        1143.445958
2    512.0   512.0   512.0     7.038876  277.117546         609.675895
3    640.0   640.0   640.0     7.070787  286.993461        1855.767899
4    768.0   768.0   768.0     7.009782  290.195857        1537.823776
5    896.0   896.0   896.0     7.028581  293.512130        2740.863129
6   1024.0  1024.0  1024.0     6.481916  275.566048        2517.403857
7   1152.0  1152.0  1152.0     7.030157  293.743336        3999.967758
8   1280.0  1280.0  1280.0     6.964409  292.484953        3277.685933
9   1408.0  1408.0  1408.0     7.021666  295.316153        3505.679013
10  1536.0  1536.0  1536.0     6.502502  277.653762        4338.997662
11  1664.0  1664.0  1664.0     6.970083  294.638160        4488.973249
12  1792.0  1792.0  1792.0     6.760696  287.476124        4696.579705
13  1920.0  1920.0  1920.0     6.970649  293.270995        4989.509636
14  2048.0  2048.0  2048.0     6.453451  275.972042        5102.363950
15  2176.0  2176.0  2176.0     6.892996  292.064724        5242.598378
16  2304.0  2304.0  2304.0     6.516752  278.238428        5339.505081
17  2432.0  2432.0  2432.0     6.830341  289.582060        5389.065772
18  2560.0  2560.0  2560.0     6.374564  272.799380        5427.722433

Here are results with the FMA pass used:

matmul-performance-torch.float32 (CACHE_PADDING=False PREPACKED=False PAD_B_ONLY=True GROUP_SIZE_M=8):
         M       N       K  TritonCPU 1    TritonCPU  TorchCPU (native)
0    256.0   256.0   256.0    84.234192   622.682483         272.395227
1    384.0   384.0   384.0    85.787907  1039.313328        1138.971759
2    512.0   512.0   512.0    89.306110  1495.700062         611.135505
3    640.0   640.0   640.0    87.683793  1757.461846        1860.038651
4    768.0   768.0   768.0    86.769409  1974.581834        1527.227452
5    896.0   896.0   896.0    87.748022  2170.535765        2760.432819
6   1024.0  1024.0  1024.0    53.942684  1996.266476        2493.590160
7   1152.0  1152.0  1152.0    87.103365  2573.812399        4030.042541
8   1280.0  1280.0  1280.0    78.312309  2557.204496        3291.480726
9   1408.0  1408.0  1408.0    86.944922  2821.794459        3539.799941
10  1536.0  1536.0  1536.0    54.837263  2019.201287        4284.861570
11  1664.0  1664.0  1664.0    85.833454  2835.290641        4504.741425
12  1792.0  1792.0  1792.0    64.847520  2287.982110        4701.276260
13  1920.0  1920.0  1920.0    85.060913  2816.037518        5005.366543
14  2048.0  2048.0  2048.0    54.181541  1970.398142        5108.159571
15  2176.0  2176.0  2176.0    81.200119  2779.553123        5230.722931
16  2304.0  2304.0  2304.0    55.733256  2139.537010        5310.082279
17  2432.0  2432.0  2432.0    76.065209  2727.850897        5401.544195
18  2560.0  2560.0  2560.0    52.850840  1965.135308        5448.324440

These are FMA results with padding enabled. Here we can avoid performance drops on some sizes:

matmul-performance-torch.float32 (CACHE_PADDING=True PREPACKED=False PAD_B_ONLY=True GROUP_SIZE_M=8):
         M       N       K  TritonCPU 1    TritonCPU  TorchCPU (native)
0    256.0   256.0   256.0    72.949955   558.867600         268.741773
1    384.0   384.0   384.0    77.125823   929.806476        1137.175635
2    512.0   512.0   512.0    83.999143  1401.244439         614.273587
3    640.0   640.0   640.0    84.664957  1524.495666        1864.343529
4    768.0   768.0   768.0    84.604859  1801.298949        1553.238187
5    896.0   896.0   896.0    84.444551  1878.818698        2723.251686
6   1024.0  1024.0  1024.0    85.769258  2103.116325        2480.075410
7   1152.0  1152.0  1152.0    86.848168  2320.548872        4011.834207
8   1280.0  1280.0  1280.0    86.634951  2402.316923        3278.833805
9   1408.0  1408.0  1408.0    86.930059  2469.336265        3523.858020
10  1536.0  1536.0  1536.0    87.198629  2506.812047        4335.929244
11  1664.0  1664.0  1664.0    87.281008  2527.454260        4488.361247
12  1792.0  1792.0  1792.0    86.853180  2601.301105        4718.554866
13  1920.0  1920.0  1920.0    86.117222  2621.744387        4981.533116
14  2048.0  2048.0  2048.0    82.660039  2676.202552        5038.594435
15  2176.0  2176.0  2176.0    84.022500  2645.374144        5233.303856
16  2304.0  2304.0  2304.0    82.671846  2719.971484        5307.135532
17  2432.0  2432.0  2432.0    82.197134  2682.406711        5399.380956
18  2560.0  2560.0  2560.0    81.888011  2718.030500        5184.562840

If we are interested in performance when we can ignore padding costs (e.g. process weights only once for inference), we can use PREPACKED option to ignore padding costs:

matmul-performance-torch.float32 (CACHE_PADDING=True PREPACKED=True PAD_B_ONLY=True GROUP_SIZE_M=8):
         M       N       K  TritonCPU 1    TritonCPU  TorchCPU (native)
0    256.0   256.0   256.0    79.610036   653.770435         267.452286
1    384.0   384.0   384.0    85.079409  1005.631517        1134.799678
2    512.0   512.0   512.0    87.532636  1527.301199         619.642050
3    640.0   640.0   640.0    86.998497  1747.420024        1871.767418
4    768.0   768.0   768.0    85.796267  1952.906215        1509.585868
5    896.0   896.0   896.0    87.414424  2157.243564        2737.679073
6   1024.0  1024.0  1024.0    87.174831  2480.578106        2460.009481
7   1152.0  1152.0  1152.0    88.315111  2617.867193        4003.993667
8   1280.0  1280.0  1280.0    87.280672  2724.798712        3259.025916
9   1408.0  1408.0  1408.0    88.526457  2850.948546        3531.726690
10  1536.0  1536.0  1536.0    87.764679  2749.347624        4315.063324
11  1664.0  1664.0  1664.0    88.147403  2856.626217        4476.378033
12  1792.0  1792.0  1792.0    87.558864  2915.255715        4344.708642
13  1920.0  1920.0  1920.0    86.867476  2837.949546        4756.190472
14  2048.0  2048.0  2048.0    85.014205  2915.286143        5110.188233
15  2176.0  2176.0  2176.0    84.818367  2936.145875        5211.733589
16  2304.0  2304.0  2304.0    83.467489  2972.995680        5315.340594
17  2432.0  2432.0  2432.0    82.483407  2943.624111        5383.422364
18  2560.0  2560.0  2560.0    82.622524  2940.045080        5115.043590

To evaluate possibilities to improve results through kernel modifications, I added a new blocked matmul tutorial. There we optionally change the layout of input data for better data locality. It supports multiple options for pre-processing both LHS and RHS and allows to compare their performance. Here are the results when we use a blocked layout for RHS (prepack in column name means layout change price was ignored).

Single thread:

matmul-performance-bf16 (BLOCK_SIZE_M=8, BLOCK_SIZE_N=32, BLOCK_SIZE_K=8, GROUP_SIZE_M=8):
         M       N       K  triton-cpu-bb-tb-st-float32  triton-cpu-bb-tb-prepack-st-float32  triton-cpu-st-float32  torch-cpu-native-st-float32
0    256.0   256.0   256.0                    81.696607                            84.464895              70.210052                    79.359020
1    384.0   384.0   384.0                   108.524852                           111.976063              86.270084                   113.624339
2    512.0   512.0   512.0                   123.694884                           128.112250              91.915599                   129.039819
3    640.0   640.0   640.0                   128.713663                           132.301042              93.096064                   135.867300
4    768.0   768.0   768.0                   127.903844                           131.282430              90.376665                   140.786211
5    896.0   896.0   896.0                   130.744368                           132.967943              91.334762                   141.910865
6   1024.0  1024.0  1024.0                   131.381638                           133.511512              54.326675                   142.349877
7   1152.0  1152.0  1152.0                   132.352811                           134.581067              90.521744                   145.063204
8   1280.0  1280.0  1280.0                   133.143446                           135.645516              83.127549                   145.091582
9   1408.0  1408.0  1408.0                   133.636818                           136.120497              88.607280                   146.154894
10  1536.0  1536.0  1536.0                   134.109253                           136.471275              54.921499                   147.217501
11  1664.0  1664.0  1664.0                   134.017928                           136.431090              86.824743                   145.044889
12  1792.0  1792.0  1792.0                   134.419619                           136.573808              63.146663                   147.900299
13  1920.0  1920.0  1920.0                   134.614521                           136.579132              85.544668                   148.478316
14  2048.0  2048.0  2048.0                   133.922156                           135.539775              54.563150                   147.964276
15  2176.0  2176.0  2176.0                   134.651667                           136.622749              79.855244                   148.696436
16  2304.0  2304.0  2304.0                   134.211121                           136.606094              56.165119                   149.060219
17  2432.0  2432.0  2432.0                   134.335383                           136.488440              75.453513                   148.931097
18  2560.0  2560.0  2560.0                   133.849690                           135.728972              52.652949                   149.094247

Multiple threads:

matmul-performance-bf16 (BLOCK_SIZE_M=8, BLOCK_SIZE_N=32, BLOCK_SIZE_K=8, GROUP_SIZE_M=8):
         M       N       K  triton-cpu-bb-tb-float32  triton-cpu-bb-tb-prepack-float32  triton-cpu-float32  torch-cpu-native-float32
0    256.0   256.0   256.0                561.478249                        736.707959          589.087239                257.708635
1    384.0   384.0   384.0               1050.244640                       1267.870611         1056.065136               1109.414015
2    512.0   512.0   512.0               1593.944323                       1813.041868         1532.087005                612.623454
3    640.0   640.0   640.0               1906.903928                       2319.499051         1798.642694               1878.252881
4    768.0   768.0   768.0               2255.633723                       2416.360401         1978.372343               1551.976881
5    896.0   896.0   896.0               2334.537906                       2890.740335         2160.359976               2764.163797
6   1024.0  1024.0  1024.0               2666.057098                       3040.612394         2062.177442               2424.649109
7   1152.0  1152.0  1152.0               2945.359195                       3501.236723         2347.677064               3988.522005
8   1280.0  1280.0  1280.0               3256.701333                       4012.643627         2438.030721               3165.916235
9   1408.0  1408.0  1408.0               3499.052301                       4361.914879         2534.732206               3499.218586
10  1536.0  1536.0  1536.0               3320.661094                       3933.713772         2054.186745               4327.383445
11  1664.0  1664.0  1664.0               3834.574392                       4681.549549         2570.068571               4456.700709
12  1792.0  1792.0  1792.0               3833.164000                       4765.670018         2267.101573               4706.207026
13  1920.0  1920.0  1920.0               4006.826231                       4814.920158         2584.381985               4988.130529
14  2048.0  2048.0  2048.0               4056.631820                       4909.244017         2042.732060               4909.326631
15  2176.0  2176.0  2176.0               4284.935721                       5140.160445         2602.996309               5202.643604
16  2304.0  2304.0  2304.0               4344.113066                       5101.554051         2119.188147               5337.437308
17  2432.0  2432.0  2432.0               4474.744440                       5312.884287         2597.644915               5384.322708
18  2560.0  2560.0  2560.0               4584.410834                       5401.233470         2025.822987               5311.491207

All results are measured on 48-core Intel Platinum 8468V. For better stability, the frequency was fixed, hyperthreading disabled, Intel OpenMP was used with KMP_AFFINITY to pin threads.

@ienkovich ienkovich requested review from int3, minjang and Devjiu December 11, 2024 03:55
@ienkovich ienkovich requested a review from ptillet as a code owner December 11, 2024 03:55
Copy link
Collaborator

@minjang minjang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so MUCH for doing this work! I'm going to test on my environment soon. I quickly skimmed through. Mostly looking good. Just some questions and nits.

b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
block_offset_m = pid_m * BLOCK_SIZE_M
block_offset_n = pid_n * BLOCK_SIZE_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep this matmul_kernel as is? Instead, what about having matmul_kernel_block_ptr or something. It'd be good to keep the baseline implementations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can leave both options in a single tutorial and choose by a flag. Having them in different tutorials would make the comparison of these two options less reliable.

Comment on lines -221 to -222
# a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I forgot to remove these restrictions. We can handle masks. So, let's remove them.

Comment on lines -241 to -243
# c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
# tl.store(c_ptrs, c, mask=c_mask)
tl.store(c_ptrs, c)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto. Masks work.

a = a_scratch
pad_kernel[(K // BLOCK_SIZE_K, )](b, b_scratch, N, BLOCK_SIZE_K, BLOCK_SIZE_N, 32, num_threads=num_threads)
b = b_scratch

#TODO: Currently masked load is not supported yet.
assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (
K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed.

Comment on lines +239 to +240
a = tl.load(a_tile_ptr)
b = tl.load(b_tile_ptr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, could you add the masking? It works, and I believe masking will also work with this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Masks are not related to my changes. And I don't want to return them yet (at least unconditionally). We have masks optimizations working for the 1D case, but I'm not sure it can work for the 2D case and the masked variant can be much slower.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, then let me (or anyone) update the masking in a separate PR. This restriction was placed in old time, and now it should be removed. It's perfectly fine not to have the best performance, but masking is a must for matmul.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it stops you from getting the best performance when you don't need masks then I'm not sure. We see that padding is profitable anyway, so it can be used to both improve cache hits and avoid masks by making sure we never access a part of a block. And padding doesn't require modifications of the matmul kernel.

Comment on lines +365 to +370
auto newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
rewriter, newInitOperands, true,
[&newYieldedValues](OpBuilder &b, Location loc,
ArrayRef<BlockArgument> newBBArgs) {
return newYieldedValues;
}));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I didn't fully understand this approach in AMX and here.

Is keepAccOnRegs = true a normal case as in a typical tutorial example?

    accumulator = tl.zeros(...)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = tl.load(a_ptrs)
        b = tl.load(b_ptrs)
        accumulator = tl.dot(a, b, accumulator)
        ...

accumulator has loop-carried dependencies. So I guess keepAccOnRegs is true.

So, why do we generate scf::for here? Looks more confusing and complex to me. I see that this is mostly based on the current AMX approach. But naively thinking, I think we can avoid scf::for here, just emitting FMA...

Anyhow, we have good perf numbers :) There're multiple ways to do it!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keepAccOnRegs means we want to load the accumulator to registers before the loop and keep it this way for the whole loop. In this case accumulator is represented as a set of 1D vectors - 1 vector per each accumulator's row. Then those accumulator rows are used in FMA operations and the results of FMA operations go to yield operation to form the accumulator's loop dependencies. Adding new values to the yield operation is done via a call to replaceWithAdditionalYields. That method replaces the original ForOp with an extended one. The original loop body is moved to a new operation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be clear: this for-loop is the original loop written by a user. We don't create new loops, just modify the existing one.

# 1D launch kernel where each block gets its own program.
grid = ((M // BLOCK_SIZE_M) * (N // BLOCK_SIZE_N), )
if (BLOCKED_A or BLOCKED_B) and not PREPACKED:
block_transpose_combined_kernel[grid](
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, if you have evaluated performance impact of doing the packing as threads consume the tiles? As opposed to materializing full packed tensors before doing the matmul.

Not sure how grid:thread mapping would work given there would be some coordination across threads required I guess.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't explore such an option. It should be possible to organize through atomics. Do you think it would give additional performance gains?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure TBH, I don't have useful data handly.

I would imagine perf gains would depend on the size of the packed tensors, number of threads, and if we can leverage cache-locality better.

It should be possible to organize through atomics

Are you talking about triton atomic_ ops? Interesting... yeah I would guess so. I am not too sure how would the input packing calls, and gemm call orchestration might look like for cpu setup.

python/tutorials/cpu-blocked-matmul-fp32.py Show resolved Hide resolved
Signed-off-by: Ilya Enkovich <[email protected]>
@ienkovich
Copy link
Collaborator Author

@minjang Thanks for the review. I've made changes to fix all found issues.

@ienkovich ienkovich merged commit 3f11034 into triton-lang:main Dec 12, 2024
2 checks passed
@ienkovich ienkovich deleted the ienkovich/cpu/fma branch December 12, 2024 00:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants