Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: stream-k decomposition for MMQ #8018

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

This PR implements "stream-k decomposition" as described in this paper for the MMQ kernels. The idea is that instead of dividing the output tensor into tiles and assigning those tiles to the streaming multiprocessors in waves you instead assign partial tiles to the streaming multiprocessors that you then combine with an additional "fixup" kernel. Notably the number of tiles that you need to fix is constant and equal to the number of streaming multiprocessors (vs. e.g. splitting all tiles).

Performance vs. master MMQ
GPU Model Microbatch size Test t/s master t/s cuda-mmq-stream-k-2 Speedup
RTX 4090 llama 8B Q2_K_M 16 pp2048 1845.03 1974.88 1.07
RTX 4090 llama 8B Q2_K_M 32 pp2048 2704.27 3052.64 1.13
RTX 4090 llama 8B Q2_K_M 63 pp2048 3974.88 4310.72 1.08
RTX 4090 llama 8B Q2_K_M 128 pp2048 5464.22 5528.70 1.01
RTX 4090 llama 8B Q2_K_M 256 pp2048 6815.42 7109.27 1.04
RTX 4090 llama 8B Q2_K_M 512 pp2048 7416.65 7558.46 1.02
RTX 4090 llama 8B Q2_K_M 1024 pp2048 7537.06 7469.47 0.99
RTX 4090 llama 8B Q2_K_M 2048 pp2048 7004.78 6950.96 0.99
RTX 4090 llama 8B Q3_K_S 16 pp2048 1657.97 1766.89 1.07
RTX 4090 llama 8B Q3_K_S 32 pp2048 2615.77 2986.33 1.14
RTX 4090 llama 8B Q3_K_S 63 pp2048 4055.24 4484.16 1.11
RTX 4090 llama 8B Q3_K_S 128 pp2048 5920.58 6369.46 1.08
RTX 4090 llama 8B Q3_K_S 256 pp2048 7733.07 8454.78 1.09
RTX 4090 llama 8B Q3_K_S 512 pp2048 8676.65 8966.30 1.03
RTX 4090 llama 8B Q3_K_S 1024 pp2048 8708.10 8828.17 1.01
RTX 4090 llama 8B Q3_K_S 2048 pp2048 8053.46 8111.16 1.01
RTX 4090 llama 8B Q4_0 16 pp2048 1873.62 1989.96 1.06
RTX 4090 llama 8B Q4_0 32 pp2048 3269.12 3456.99 1.06
RTX 4090 llama 8B Q4_0 63 pp2048 5104.59 5284.82 1.04
RTX 4090 llama 8B Q4_0 128 pp2048 7114.91 7227.59 1.02
RTX 4090 llama 8B Q4_0 256 pp2048 9318.58 9514.91 1.02
RTX 4090 llama 8B Q4_0 512 pp2048 10198.79 10317.21 1.01
RTX 4090 llama 8B Q4_0 1024 pp2048 10158.53 10065.18 0.99
RTX 4090 llama 8B Q4_0 2048 pp2048 9241.77 9119.39 0.99
RTX 4090 llama 8B Q4_1 16 pp2048 1858.86 1880.17 1.01
RTX 4090 llama 8B Q4_1 32 pp2048 3179.30 2999.65 0.94
RTX 4090 llama 8B Q4_1 63 pp2048 5412.60 5084.00 0.94
RTX 4090 llama 8B Q4_1 128 pp2048 6581.02 6982.92 1.06
RTX 4090 llama 8B Q4_1 256 pp2048 8990.72 9070.09 1.01
RTX 4090 llama 8B Q4_1 512 pp2048 9667.64 9783.97 1.01
RTX 4090 llama 8B Q4_1 1024 pp2048 9930.18 9568.42 0.96
RTX 4090 llama 8B Q4_1 2048 pp2048 9154.70 8782.85 0.96
RTX 4090 llama 8B Q4_K_S 16 pp2048 1937.52 2004.38 1.03
RTX 4090 llama 8B Q4_K_S 32 pp2048 3342.11 3399.52 1.02
RTX 4090 llama 8B Q4_K_S 63 pp2048 5154.91 5080.20 0.99
RTX 4090 llama 8B Q4_K_S 128 pp2048 6874.51 6758.06 0.98
RTX 4090 llama 8B Q4_K_S 256 pp2048 8434.04 8750.48 1.04
RTX 4090 llama 8B Q4_K_S 512 pp2048 9195.77 9421.38 1.02
RTX 4090 llama 8B Q4_K_S 1024 pp2048 9169.23 9306.70 1.01
RTX 4090 llama 8B Q4_K_S 2048 pp2048 8629.07 8523.98 0.99
RTX 4090 llama 8B Q5_0 16 pp2048 1480.94 1673.97 1.13
RTX 4090 llama 8B Q5_0 32 pp2048 2509.59 3051.97 1.22
RTX 4090 llama 8B Q5_0 63 pp2048 4162.96 4628.07 1.11
RTX 4090 llama 8B Q5_0 128 pp2048 6207.95 6629.96 1.07
RTX 4090 llama 8B Q5_0 256 pp2048 8006.74 8867.56 1.11
RTX 4090 llama 8B Q5_0 512 pp2048 9223.85 9702.41 1.05
RTX 4090 llama 8B Q5_0 1024 pp2048 9277.39 9535.97 1.03
RTX 4090 llama 8B Q5_0 2048 pp2048 8710.12 8780.70 1.01
RTX 4090 llama 8B Q5_1 16 pp2048 1619.67 1580.42 0.98
RTX 4090 llama 8B Q5_1 32 pp2048 2584.57 2547.53 0.99
RTX 4090 llama 8B Q5_1 63 pp2048 4297.61 4222.33 0.98
RTX 4090 llama 8B Q5_1 128 pp2048 5975.51 6523.30 1.09
RTX 4090 llama 8B Q5_1 256 pp2048 7744.80 8451.57 1.09
RTX 4090 llama 8B Q5_1 512 pp2048 8823.13 9127.40 1.03
RTX 4090 llama 8B Q5_1 1024 pp2048 8931.16 9039.22 1.01
RTX 4090 llama 8B Q5_1 2048 pp2048 8335.92 8360.64 1.00
RTX 4090 llama 8B Q5_K_S 16 pp2048 1669.68 1705.77 1.02
RTX 4090 llama 8B Q5_K_S 32 pp2048 2776.29 2777.08 1.00
RTX 4090 llama 8B Q5_K_S 63 pp2048 4419.73 4356.69 0.99
RTX 4090 llama 8B Q5_K_S 128 pp2048 5981.22 6426.45 1.07
RTX 4090 llama 8B Q5_K_S 256 pp2048 7698.50 8275.37 1.07
RTX 4090 llama 8B Q5_K_S 512 pp2048 8615.20 8939.98 1.04
RTX 4090 llama 8B Q5_K_S 1024 pp2048 8703.88 8884.14 1.02
RTX 4090 llama 8B Q5_K_S 2048 pp2048 8158.43 8242.38 1.01
RTX 4090 llama 8B Q6_K 16 pp2048 1410.39 1455.31 1.03
RTX 4090 llama 8B Q6_K 32 pp2048 2452.60 2754.07 1.12
RTX 4090 llama 8B Q6_K 63 pp2048 4258.11 4430.69 1.04
RTX 4090 llama 8B Q6_K 128 pp2048 6049.71 6309.51 1.04
RTX 4090 llama 8B Q6_K 256 pp2048 7968.27 8277.67 1.04
RTX 4090 llama 8B Q6_K 512 pp2048 8751.79 8968.23 1.02
RTX 4090 llama 8B Q6_K 1024 pp2048 8837.98 8852.53 1.00
RTX 4090 llama 8B Q6_K 2048 pp2048 8155.85 8044.97 0.99
RTX 4090 llama 8B Q8_0 16 pp2048 1123.00 1276.92 1.14
RTX 4090 llama 8B Q8_0 32 pp2048 2178.68 2374.20 1.09
RTX 4090 llama 8B Q8_0 63 pp2048 3911.79 4065.28 1.04
RTX 4090 llama 8B Q8_0 128 pp2048 6343.39 6563.01 1.03
RTX 4090 llama 8B Q8_0 256 pp2048 8656.65 9411.70 1.09
RTX 4090 llama 8B Q8_0 512 pp2048 9933.38 10591.37 1.07
RTX 4090 llama 8B Q8_0 1024 pp2048 10310.62 10554.57 1.02
RTX 4090 llama 8B Q8_0 2048 pp2048 9672.82 9544.11 0.99
RTX 3090 llama 8B Q2_K_M 16 pp2048 860.08 983.08 1.14
RTX 3090 llama 8B Q2_K_M 32 pp2048 1191.51 1425.70 1.20
RTX 3090 llama 8B Q2_K_M 63 pp2048 1575.38 1893.31 1.20
RTX 3090 llama 8B Q2_K_M 128 pp2048 2007.22 2262.78 1.13
RTX 3090 llama 8B Q2_K_M 256 pp2048 2487.20 2536.22 1.02
RTX 3090 llama 8B Q2_K_M 512 pp2048 2592.80 2647.31 1.02
RTX 3090 llama 8B Q2_K_M 1024 pp2048 2646.01 2639.46 1.00
RTX 3090 llama 8B Q2_K_M 2048 pp2048 2604.77 2608.64 1.00
RTX 3090 llama 8B Q3_K_S 16 pp2048 809.98 920.54 1.14
RTX 3090 llama 8B Q3_K_S 32 pp2048 1139.97 1485.39 1.30
RTX 3090 llama 8B Q3_K_S 63 pp2048 1644.27 2087.77 1.27
RTX 3090 llama 8B Q3_K_S 128 pp2048 2331.84 2736.05 1.17
RTX 3090 llama 8B Q3_K_S 256 pp2048 2906.75 3057.93 1.05
RTX 3090 llama 8B Q3_K_S 512 pp2048 3053.04 3191.71 1.05
RTX 3090 llama 8B Q3_K_S 1024 pp2048 3173.62 3236.57 1.02
RTX 3090 llama 8B Q3_K_S 2048 pp2048 3092.63 3172.58 1.03
RTX 3090 llama 8B Q4_0 16 pp2048 1076.43 1173.33 1.09
RTX 3090 llama 8B Q4_0 32 pp2048 1556.26 1788.35 1.15
RTX 3090 llama 8B Q4_0 63 pp2048 2144.28 2534.95 1.18
RTX 3090 llama 8B Q4_0 128 pp2048 2830.76 3166.94 1.12
RTX 3090 llama 8B Q4_0 256 pp2048 3438.60 3599.03 1.05
RTX 3090 llama 8B Q4_0 512 pp2048 3701.21 3784.69 1.02
RTX 3090 llama 8B Q4_0 1024 pp2048 3832.97 3795.66 0.99
RTX 3090 llama 8B Q4_0 2048 pp2048 3643.18 3634.57 1.00
RTX 3090 llama 8B Q4_1 16 pp2048 1179.37 1268.37 1.08
RTX 3090 llama 8B Q4_1 32 pp2048 1614.90 1652.84 1.02
RTX 3090 llama 8B Q4_1 63 pp2048 2031.74 2382.21 1.17
RTX 3090 llama 8B Q4_1 128 pp2048 2755.17 2966.05 1.08
RTX 3090 llama 8B Q4_1 256 pp2048 3232.95 3322.46 1.03
RTX 3090 llama 8B Q4_1 512 pp2048 3579.61 3469.01 0.97
RTX 3090 llama 8B Q4_1 1024 pp2048 3718.61 3526.51 0.95
RTX 3090 llama 8B Q4_1 2048 pp2048 3624.65 3442.58 0.95
RTX 3090 llama 8B Q4_K_S 16 pp2048 1125.44 1242.52 1.10
RTX 3090 llama 8B Q4_K_S 32 pp2048 1585.72 1693.02 1.07
RTX 3090 llama 8B Q4_K_S 63 pp2048 2032.91 2312.02 1.14
RTX 3090 llama 8B Q4_K_S 128 pp2048 2564.13 2812.79 1.10
RTX 3090 llama 8B Q4_K_S 256 pp2048 3073.94 3181.44 1.03
RTX 3090 llama 8B Q4_K_S 512 pp2048 3313.49 3325.57 1.00
RTX 3090 llama 8B Q4_K_S 1024 pp2048 3444.24 3376.12 0.98
RTX 3090 llama 8B Q4_K_S 2048 pp2048 3250.36 3291.11 1.01
RTX 3090 llama 8B Q5_0 16 pp2048 776.15 946.39 1.22
RTX 3090 llama 8B Q5_0 32 pp2048 1217.41 1621.45 1.33
RTX 3090 llama 8B Q5_0 63 pp2048 1687.45 2198.64 1.30
RTX 3090 llama 8B Q5_0 128 pp2048 2404.10 2902.89 1.21
RTX 3090 llama 8B Q5_0 256 pp2048 3087.93 3296.47 1.07
RTX 3090 llama 8B Q5_0 512 pp2048 3219.13 3451.26 1.07
RTX 3090 llama 8B Q5_0 1024 pp2048 3363.65 3484.96 1.04
RTX 3090 llama 8B Q5_0 2048 pp2048 3298.89 3401.04 1.03
RTX 3090 llama 8B Q5_1 16 pp2048 930.03 987.19 1.06
RTX 3090 llama 8B Q5_1 32 pp2048 1365.85 1423.36 1.04
RTX 3090 llama 8B Q5_1 63 pp2048 1727.76 2091.52 1.21
RTX 3090 llama 8B Q5_1 128 pp2048 2362.19 2726.86 1.15
RTX 3090 llama 8B Q5_1 256 pp2048 2978.09 3068.63 1.03
RTX 3090 llama 8B Q5_1 512 pp2048 3090.98 3242.92 1.05
RTX 3090 llama 8B Q5_1 1024 pp2048 3216.80 3274.03 1.02
RTX 3090 llama 8B Q5_1 2048 pp2048 3187.65 3211.42 1.01
RTX 3090 llama 8B Q5_K_S 16 pp2048 903.89 1016.46 1.12
RTX 3090 llama 8B Q5_K_S 32 pp2048 1296.51 1448.98 1.12
RTX 3090 llama 8B Q5_K_S 63 pp2048 1715.60 2032.78 1.18
RTX 3090 llama 8B Q5_K_S 128 pp2048 2326.05 2652.85 1.14
RTX 3090 llama 8B Q5_K_S 256 pp2048 2894.26 3004.17 1.04
RTX 3090 llama 8B Q5_K_S 512 pp2048 3061.19 3129.73 1.02
RTX 3090 llama 8B Q5_K_S 1024 pp2048 3163.91 3190.35 1.01
RTX 3090 llama 8B Q5_K_S 2048 pp2048 3097.12 3135.12 1.01
RTX 3090 llama 8B Q6_K 16 pp2048 813.79 926.20 1.14
RTX 3090 llama 8B Q6_K 32 pp2048 1251.80 1552.21 1.24
RTX 3090 llama 8B Q6_K 63 pp2048 1771.66 2216.34 1.25
RTX 3090 llama 8B Q6_K 128 pp2048 2452.64 2705.40 1.10
RTX 3090 llama 8B Q6_K 256 pp2048 3003.60 3084.47 1.03
RTX 3090 llama 8B Q6_K 512 pp2048 3186.07 3211.52 1.01
RTX 3090 llama 8B Q6_K 1024 pp2048 3283.36 3254.43 0.99
RTX 3090 llama 8B Q6_K 2048 pp2048 3252.64 3165.70 0.97
RTX 3090 llama 8B Q8_0 16 pp2048 766.31 935.77 1.22
RTX 3090 llama 8B Q8_0 32 pp2048 1314.52 1622.77 1.23
RTX 3090 llama 8B Q8_0 63 pp2048 1955.76 2426.60 1.24
RTX 3090 llama 8B Q8_0 128 pp2048 2800.84 3208.42 1.15
RTX 3090 llama 8B Q8_0 256 pp2048 3525.69 3732.82 1.06
RTX 3090 llama 8B Q8_0 512 pp2048 3758.75 3876.41 1.03
RTX 3090 llama 8B Q8_0 1024 pp2048 3890.72 3932.16 1.01
RTX 3090 llama 8B Q8_0 2048 pp2048 3820.42 3822.87 1.00
Performance vs. master cuBLAS
GPU Model Microbatch size Test t/s master t/s cuda-mmq-stream-k-2 Speedup
RTX 4090 llama 8B Q2_K_M 16 pp2048 1849.42 1961.85 1.06
RTX 4090 llama 8B Q2_K_M 32 pp2048 2702.03 3033.47 1.12
RTX 4090 llama 8B Q2_K_M 63 pp2048 3978.71 4287.27 1.08
RTX 4090 llama 8B Q2_K_M 128 pp2048 3648.30 5499.50 1.51
RTX 4090 llama 8B Q2_K_M 256 pp2048 5905.04 7096.81 1.20
RTX 4090 llama 8B Q2_K_M 512 pp2048 7778.63 7527.15 0.97
RTX 4090 llama 8B Q2_K_M 1024 pp2048 9022.73 7470.52 0.83
RTX 4090 llama 8B Q2_K_M 2048 pp2048 8976.75 6932.79 0.77
RTX 4090 llama 8B Q3_K_S 16 pp2048 1653.20 1758.20 1.06
RTX 4090 llama 8B Q3_K_S 32 pp2048 2612.20 2975.96 1.14
RTX 4090 llama 8B Q3_K_S 63 pp2048 4051.31 4482.28 1.11
RTX 4090 llama 8B Q3_K_S 128 pp2048 3541.55 6370.11 1.80
RTX 4090 llama 8B Q3_K_S 256 pp2048 5774.83 8420.01 1.46
RTX 4090 llama 8B Q3_K_S 512 pp2048 7685.85 8980.66 1.17
RTX 4090 llama 8B Q3_K_S 1024 pp2048 9014.36 8813.25 0.98
RTX 4090 llama 8B Q3_K_S 2048 pp2048 8885.37 8100.26 0.91
RTX 4090 llama 8B Q4_0 16 pp2048 1876.38 1987.65 1.06
RTX 4090 llama 8B Q4_0 32 pp2048 3268.51 3451.60 1.06
RTX 4090 llama 8B Q4_0 63 pp2048 5112.46 5256.57 1.03
RTX 4090 llama 8B Q4_0 128 pp2048 3472.33 7181.49 2.07
RTX 4090 llama 8B Q4_0 256 pp2048 5737.60 9492.50 1.65
RTX 4090 llama 8B Q4_0 512 pp2048 7773.21 10256.46 1.32
RTX 4090 llama 8B Q4_0 1024 pp2048 9080.33 10055.02 1.11
RTX 4090 llama 8B Q4_0 2048 pp2048 8988.66 9095.03 1.01
RTX 4090 llama 8B Q4_1 16 pp2048 1855.97 1877.27 1.01
RTX 4090 llama 8B Q4_1 32 pp2048 3160.56 2991.37 0.95
RTX 4090 llama 8B Q4_1 63 pp2048 5428.70 5071.75 0.93
RTX 4090 llama 8B Q4_1 128 pp2048 3462.05 6973.87 2.01
RTX 4090 llama 8B Q4_1 256 pp2048 5685.58 9015.68 1.59
RTX 4090 llama 8B Q4_1 512 pp2048 7639.11 9708.85 1.27
RTX 4090 llama 8B Q4_1 1024 pp2048 8947.06 9524.67 1.06
RTX 4090 llama 8B Q4_1 2048 pp2048 8868.91 8692.50 0.98
RTX 4090 llama 8B Q4_K_S 16 pp2048 1938.40 2004.50 1.03
RTX 4090 llama 8B Q4_K_S 32 pp2048 3339.09 3399.47 1.02
RTX 4090 llama 8B Q4_K_S 63 pp2048 5190.26 5088.44 0.98
RTX 4090 llama 8B Q4_K_S 128 pp2048 3478.24 6745.80 1.94
RTX 4090 llama 8B Q4_K_S 256 pp2048 5745.87 8736.13 1.52
RTX 4090 llama 8B Q4_K_S 512 pp2048 7658.49 9375.75 1.22
RTX 4090 llama 8B Q4_K_S 1024 pp2048 8967.41 9290.54 1.04
RTX 4090 llama 8B Q4_K_S 2048 pp2048 8946.40 8517.78 0.95
RTX 4090 llama 8B Q5_0 16 pp2048 1472.88 1673.13 1.14
RTX 4090 llama 8B Q5_0 32 pp2048 2501.81 3049.99 1.22
RTX 4090 llama 8B Q5_0 63 pp2048 4154.21 4622.38 1.11
RTX 4090 llama 8B Q5_0 128 pp2048 3390.37 6613.48 1.95
RTX 4090 llama 8B Q5_0 256 pp2048 5620.71 8840.13 1.57
RTX 4090 llama 8B Q5_0 512 pp2048 7589.73 9679.36 1.28
RTX 4090 llama 8B Q5_0 1024 pp2048 8914.41 9532.62 1.07
RTX 4090 llama 8B Q5_0 2048 pp2048 8969.12 8679.60 0.97
RTX 4090 llama 8B Q5_1 16 pp2048 1622.59 1575.96 0.97
RTX 4090 llama 8B Q5_1 32 pp2048 2590.62 2539.50 0.98
RTX 4090 llama 8B Q5_1 63 pp2048 4292.55 4213.50 0.98
RTX 4090 llama 8B Q5_1 128 pp2048 3427.44 6527.37 1.90
RTX 4090 llama 8B Q5_1 256 pp2048 5601.15 8404.73 1.50
RTX 4090 llama 8B Q5_1 512 pp2048 7524.07 9037.00 1.20
RTX 4090 llama 8B Q5_1 1024 pp2048 8825.99 8971.98 1.02
RTX 4090 llama 8B Q5_1 2048 pp2048 8949.14 8316.25 0.93
RTX 4090 llama 8B Q5_K_S 16 pp2048 1672.64 1701.35 1.02
RTX 4090 llama 8B Q5_K_S 32 pp2048 2755.45 2780.01 1.01
RTX 4090 llama 8B Q5_K_S 63 pp2048 4430.13 4369.05 0.99
RTX 4090 llama 8B Q5_K_S 128 pp2048 3455.75 6424.47 1.86
RTX 4090 llama 8B Q5_K_S 256 pp2048 5679.23 8255.02 1.45
RTX 4090 llama 8B Q5_K_S 512 pp2048 7603.42 8940.02 1.18
RTX 4090 llama 8B Q5_K_S 1024 pp2048 8911.92 8860.79 0.99
RTX 4090 llama 8B Q5_K_S 2048 pp2048 9002.91 8208.66 0.91
RTX 4090 llama 8B Q6_K 16 pp2048 1407.55 1454.38 1.03
RTX 4090 llama 8B Q6_K 32 pp2048 2451.52 2755.87 1.12
RTX 4090 llama 8B Q6_K 63 pp2048 4250.79 4425.15 1.04
RTX 4090 llama 8B Q6_K 128 pp2048 3379.81 6300.29 1.86
RTX 4090 llama 8B Q6_K 256 pp2048 5569.55 8264.67 1.48
RTX 4090 llama 8B Q6_K 512 pp2048 7487.04 8990.42 1.20
RTX 4090 llama 8B Q6_K 1024 pp2048 8725.02 8816.08 1.01
RTX 4090 llama 8B Q6_K 2048 pp2048 8608.72 8085.01 0.94
RTX 4090 llama 8B Q8_0 16 pp2048 1122.75 1279.34 1.14
RTX 4090 llama 8B Q8_0 32 pp2048 2178.11 2375.24 1.09
RTX 4090 llama 8B Q8_0 63 pp2048 3924.24 4062.17 1.04
RTX 4090 llama 8B Q8_0 128 pp2048 3304.34 6570.09 1.99
RTX 4090 llama 8B Q8_0 256 pp2048 5469.02 9388.70 1.72
RTX 4090 llama 8B Q8_0 512 pp2048 7388.73 10572.18 1.43
RTX 4090 llama 8B Q8_0 1024 pp2048 8741.50 10490.88 1.20
RTX 4090 llama 8B Q8_0 2048 pp2048 8834.26 9645.92 1.09
RTX 3090 llama 8B Q2_K_M 16 pp2048 876.71 980.01 1.12
RTX 3090 llama 8B Q2_K_M 32 pp2048 1204.85 1421.12 1.18
RTX 3090 llama 8B Q2_K_M 63 pp2048 1592.05 1855.78 1.17
RTX 3090 llama 8B Q2_K_M 128 pp2048 2085.35 2216.94 1.06
RTX 3090 llama 8B Q2_K_M 256 pp2048 3072.04 2480.35 0.81
RTX 3090 llama 8B Q2_K_M 512 pp2048 3630.75 2584.63 0.71
RTX 3090 llama 8B Q2_K_M 1024 pp2048 4259.26 2630.49 0.62
RTX 3090 llama 8B Q2_K_M 2048 pp2048 4282.13 2602.09 0.61
RTX 3090 llama 8B Q3_K_S 16 pp2048 820.72 919.68 1.12
RTX 3090 llama 8B Q3_K_S 32 pp2048 1143.83 1472.94 1.29
RTX 3090 llama 8B Q3_K_S 63 pp2048 1652.83 2081.18 1.26
RTX 3090 llama 8B Q3_K_S 128 pp2048 1920.37 2706.25 1.41
RTX 3090 llama 8B Q3_K_S 256 pp2048 2907.74 3046.21 1.05
RTX 3090 llama 8B Q3_K_S 512 pp2048 3489.15 3146.05 0.90
RTX 3090 llama 8B Q3_K_S 1024 pp2048 4142.09 3195.93 0.77
RTX 3090 llama 8B Q3_K_S 2048 pp2048 4240.70 3149.05 0.74
RTX 3090 llama 8B Q4_0 16 pp2048 1106.45 1145.99 1.04
RTX 3090 llama 8B Q4_0 32 pp2048 1634.97 1784.24 1.09
RTX 3090 llama 8B Q4_0 63 pp2048 2259.05 2485.26 1.10
RTX 3090 llama 8B Q4_0 128 pp2048 2204.20 3092.43 1.40
RTX 3090 llama 8B Q4_0 256 pp2048 3232.61 3500.21 1.08
RTX 3090 llama 8B Q4_0 512 pp2048 3803.77 3683.54 0.97
RTX 3090 llama 8B Q4_0 1024 pp2048 4434.49 3729.28 0.84
RTX 3090 llama 8B Q4_0 2048 pp2048 4463.80 3628.15 0.81
RTX 3090 llama 8B Q4_1 16 pp2048 1219.23 1276.05 1.05
RTX 3090 llama 8B Q4_1 32 pp2048 1682.41 1652.77 0.98
RTX 3090 llama 8B Q4_1 63 pp2048 2078.54 2366.28 1.14
RTX 3090 llama 8B Q4_1 128 pp2048 2146.32 2931.50 1.37
RTX 3090 llama 8B Q4_1 256 pp2048 3149.25 3271.09 1.04
RTX 3090 llama 8B Q4_1 512 pp2048 3694.63 3403.65 0.92
RTX 3090 llama 8B Q4_1 1024 pp2048 4349.90 3472.35 0.80
RTX 3090 llama 8B Q4_1 2048 pp2048 4411.96 3381.65 0.77
RTX 3090 llama 8B Q4_K_S 16 pp2048 1134.98 1235.53 1.09
RTX 3090 llama 8B Q4_K_S 32 pp2048 1592.72 1706.50 1.07
RTX 3090 llama 8B Q4_K_S 63 pp2048 2036.81 2307.37 1.13
RTX 3090 llama 8B Q4_K_S 128 pp2048 2110.53 2821.77 1.34
RTX 3090 llama 8B Q4_K_S 256 pp2048 3107.55 3172.70 1.02
RTX 3090 llama 8B Q4_K_S 512 pp2048 3586.36 3325.16 0.93
RTX 3090 llama 8B Q4_K_S 1024 pp2048 4164.70 3379.66 0.81
RTX 3090 llama 8B Q4_K_S 2048 pp2048 4271.58 3321.58 0.78
RTX 3090 llama 8B Q5_0 16 pp2048 802.70 948.99 1.18
RTX 3090 llama 8B Q5_0 32 pp2048 1261.06 1624.05 1.29
RTX 3090 llama 8B Q5_0 63 pp2048 1742.79 2191.72 1.26
RTX 3090 llama 8B Q5_0 128 pp2048 2041.06 2895.32 1.42
RTX 3090 llama 8B Q5_0 256 pp2048 3016.43 3273.40 1.09
RTX 3090 llama 8B Q5_0 512 pp2048 3638.16 3439.15 0.95
RTX 3090 llama 8B Q5_0 1024 pp2048 4263.78 3472.91 0.81
RTX 3090 llama 8B Q5_0 2048 pp2048 4350.71 3380.50 0.78
RTX 3090 llama 8B Q5_1 16 pp2048 957.84 985.70 1.03
RTX 3090 llama 8B Q5_1 32 pp2048 1392.74 1412.94 1.01
RTX 3090 llama 8B Q5_1 63 pp2048 1753.05 2091.18 1.19
RTX 3090 llama 8B Q5_1 128 pp2048 2027.25 2727.02 1.35
RTX 3090 llama 8B Q5_1 256 pp2048 2998.97 3069.79 1.02
RTX 3090 llama 8B Q5_1 512 pp2048 3619.35 3220.80 0.89
RTX 3090 llama 8B Q5_1 1024 pp2048 4230.62 3253.31 0.77
RTX 3090 llama 8B Q5_1 2048 pp2048 4290.47 3189.35 0.74
RTX 3090 llama 8B Q5_K_S 16 pp2048 913.16 1007.55 1.10
RTX 3090 llama 8B Q5_K_S 32 pp2048 1306.56 1449.53 1.11
RTX 3090 llama 8B Q5_K_S 63 pp2048 1722.17 2050.22 1.19
RTX 3090 llama 8B Q5_K_S 128 pp2048 2056.14 2679.60 1.30
RTX 3090 llama 8B Q5_K_S 256 pp2048 3048.91 3022.22 0.99
RTX 3090 llama 8B Q5_K_S 512 pp2048 3558.73 3169.98 0.89
RTX 3090 llama 8B Q5_K_S 1024 pp2048 4161.86 3211.52 0.77
RTX 3090 llama 8B Q5_K_S 2048 pp2048 4225.06 3153.45 0.75
RTX 3090 llama 8B Q6_K 16 pp2048 814.28 919.77 1.13
RTX 3090 llama 8B Q6_K 32 pp2048 1242.22 1532.72 1.23
RTX 3090 llama 8B Q6_K 63 pp2048 1760.44 2197.74 1.25
RTX 3090 llama 8B Q6_K 128 pp2048 2081.49 2668.01 1.28
RTX 3090 llama 8B Q6_K 256 pp2048 3048.52 3032.60 0.99
RTX 3090 llama 8B Q6_K 512 pp2048 3583.21 3196.87 0.89
RTX 3090 llama 8B Q6_K 1024 pp2048 4172.73 3208.81 0.77
RTX 3090 llama 8B Q6_K 2048 pp2048 4198.13 3147.56 0.75
RTX 3090 llama 8B Q8_0 16 pp2048 766.11 934.73 1.22
RTX 3090 llama 8B Q8_0 32 pp2048 1324.08 1587.96 1.20
RTX 3090 llama 8B Q8_0 63 pp2048 1970.85 2371.24 1.20
RTX 3090 llama 8B Q8_0 128 pp2048 2062.85 3159.14 1.53
RTX 3090 llama 8B Q8_0 256 pp2048 3058.36 3675.12 1.20
RTX 3090 llama 8B Q8_0 512 pp2048 3644.68 3868.97 1.06
RTX 3090 llama 8B Q8_0 1024 pp2048 4276.15 3900.21 0.91
RTX 3090 llama 8B Q8_0 2048 pp2048 4334.50 3824.55 0.88

The performance increase for small batch sizes (where you suffer more from tail effects) is quite noticeable, for large batch sizes it's only a few percent. You could potentially get more performance if you were to make the implementation nondeterministic with atomic adds (on my desktop with the RTX 3090 the matrix multiplication kernels need 70-760 µs, the fixup kernel needs ~20 µs). For my P40 and RX 6800 the stream-k kernel was slower so there is no change for those cards.

@ggerganov IIRC you have an RTX 2060, can you check the performance of this PR vs. master (both with LLAMA_CUDA_FORCE_MMQ)? Compared to my GPUs that GPU has significantly fewer SMs so I would expect this PR to provide comparatively less benefit (but still some overhead from the fixup kernel which may outweigh the speedup).

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jun 19, 2024
@JohannesGaessler JohannesGaessler added the Review Complexity : High Generally require indepth knowledge of LLMs or GPUs label Jun 19, 2024
@ggerganov
Copy link
Owner

ggerganov commented Jun 20, 2024

Here are the results on the RTX 2060, data for 1B and 7B models:

LLAMA_CUDA=1 ./scripts/compare-commits.sh master da1db13d6aaabd0889a1bc1c7755a462621bd1d6 \
  -m models/llama-7b-v2/ggml-model-q4_0.gguf \
  -m models/llama-7b-v2/ggml-model-q4_k.gguf \
  -m models/tinyllama-1b/ggml-model-q4_0.gguf \
  -m models/tinyllama-1b/ggml-model-q8_0.gguf \
  -p 2048 -ub 16,32,64,128,256,512,1024,2048 -ngl 99 -n 0
Performance vs. master cuBLAS

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 2060 SUPER, compute capability 7.5, VMM: yes

Model [GiB] Microbatch size Test t/s master t/s da1db13 Speedup
llama 1B Q4_0 0.59 16 pp2048 1181.60 1146.10 0.97
llama 1B Q4_0 0.59 32 pp2048 2544.40 2376.44 0.93
llama 1B Q4_0 0.59 64 pp2048 3302.75 3231.90 0.98
llama 1B Q4_0 0.59 128 pp2048 3330.94 3327.80 1.00
llama 1B Q4_0 0.59 256 pp2048 4179.32 4177.13 1.00
llama 1B Q4_0 0.59 512 pp2048 4597.28 4595.68 1.00
llama 1B Q4_0 0.59 1024 pp2048 4587.57 4581.55 1.00
llama 1B Q4_0 0.59 2048 pp2048 4218.85 4211.63 1.00
llama 1B Q8_0 1.09 16 pp2048 957.18 1028.21 1.07
llama 1B Q8_0 1.09 32 pp2048 2055.63 2281.86 1.11
llama 1B Q8_0 1.09 64 pp2048 2814.76 3077.22 1.09
llama 1B Q8_0 1.09 128 pp2048 3662.31 3659.32 1.00
llama 1B Q8_0 1.09 256 pp2048 4428.42 4427.45 1.00
llama 1B Q8_0 1.09 512 pp2048 4745.80 4746.90 1.00
llama 1B Q8_0 1.09 1024 pp2048 4665.55 4660.40 1.00
llama 1B Q8_0 1.09 2048 pp2048 4255.26 4247.57 1.00
llama 7B Q4_0 3.56 16 pp2048 429.42 405.31 0.94
llama 7B Q4_0 3.56 32 pp2048 707.97 675.67 0.95
llama 7B Q4_0 3.56 64 pp2048 962.51 935.46 0.97
llama 7B Q4_0 3.56 128 pp2048 797.75 797.07 1.00
llama 7B Q4_0 3.56 256 pp2048 1097.33 1096.56 1.00
llama 7B Q4_0 3.56 512 pp2048 1330.73 1329.57 1.00
llama 7B Q4_0 3.56 1024 pp2048 1440.01 1438.79 1.00
llama 7B Q4_0 3.56 2048 pp2048 1458.11 1456.14 1.00
llama 7B Q4_K_M 3.80 16 pp2048 414.12 388.30 0.94
llama 7B Q4_K_M 3.80 32 pp2048 716.33 726.15 1.01
llama 7B Q4_K_M 3.80 64 pp2048 952.33 899.34 0.94
llama 7B Q4_K_M 3.80 128 pp2048 803.76 803.57 1.00
llama 7B Q4_K_M 3.80 256 pp2048 1103.59 1102.60 1.00
llama 7B Q4_K_M 3.80 512 pp2048 1334.25 1333.16 1.00
llama 7B Q4_K_M 3.80 1024 pp2048 1442.62 1440.66 1.00
llama 7B Q4_K_M 3.80 2048 pp2048 1458.86 1457.32 1.00
LLAMA_CUDA_FORCE_MMQ=1 LLAMA_CUDA=1 ./scripts/compare-commits.sh master da1db13d6aaabd0889a1bc1c7755a462621bd1d6 \
  -m models/llama-7b-v2/ggml-model-q4_0.gguf \
  -m models/llama-7b-v2/ggml-model-q4_k.gguf \
  -m models/tinyllama-1b/ggml-model-q4_0.gguf \
  -m models/tinyllama-1b/ggml-model-q8_0.gguf \
  -p 2048 -ub 16,32,64,128,256,512,1024,2048 -ngl 99 -n 0
Performance vs. master MMQ

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: yes
ggml_cuda_init: CUDA_USE_TENSOR_CORES: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 2060 SUPER, compute capability 7.5, VMM: yes

Model [GiB] Microbatch size Test t/s master t/s da1db13 Speedup
llama 1B Q4_0 0.59 16 pp2048 1182.36 1145.77 0.97
llama 1B Q4_0 0.59 32 pp2048 2543.57 2364.53 0.93
llama 1B Q4_0 0.59 64 pp2048 3300.52 3229.23 0.98
llama 1B Q4_0 0.59 128 pp2048 3824.02 3710.79 0.97
llama 1B Q4_0 0.59 256 pp2048 4388.32 4247.36 0.97
llama 1B Q4_0 0.59 512 pp2048 4280.59 4312.24 1.01
llama 1B Q4_0 0.59 1024 pp2048 4108.44 4111.93 1.00
llama 1B Q4_0 0.59 2048 pp2048 3755.14 3735.72 0.99
llama 1B Q8_0 1.09 16 pp2048 956.76 1027.92 1.07
llama 1B Q8_0 1.09 32 pp2048 2054.94 2278.54 1.11
llama 1B Q8_0 1.09 64 pp2048 2814.72 3077.43 1.09
llama 1B Q8_0 1.09 128 pp2048 3568.13 3609.68 1.01
llama 1B Q8_0 1.09 256 pp2048 4169.34 4041.25 0.97
llama 1B Q8_0 1.09 512 pp2048 4127.56 4104.11 0.99
llama 1B Q8_0 1.09 1024 pp2048 3989.79 3912.26 0.98
llama 1B Q8_0 1.09 2048 pp2048 3663.29 3574.50 0.98
llama 7B Q4_0 3.56 16 pp2048 429.39 405.35 0.94
llama 7B Q4_0 3.56 32 pp2048 707.94 676.26 0.96
llama 7B Q4_0 3.56 64 pp2048 962.28 936.64 0.97
llama 7B Q4_0 3.56 128 pp2048 1133.00 1139.91 1.01
llama 7B Q4_0 3.56 256 pp2048 1135.68 1179.10 1.04
llama 7B Q4_0 3.56 512 pp2048 1164.37 1188.67 1.02
llama 7B Q4_0 3.56 1024 pp2048 1153.37 1164.52 1.01
llama 7B Q4_0 3.56 2048 pp2048 1112.35 1117.58 1.00
llama 7B Q4_K_M 3.80 16 pp2048 414.09 388.46 0.94
llama 7B Q4_K_M 3.80 32 pp2048 716.26 726.54 1.01
llama 7B Q4_K_M 3.80 64 pp2048 952.00 900.44 0.95
llama 7B Q4_K_M 3.80 128 pp2048 1072.28 976.05 0.91
llama 7B Q4_K_M 3.80 256 pp2048 1068.74 1076.58 1.01
llama 7B Q4_K_M 3.80 512 pp2048 1113.67 1106.75 0.99
llama 7B Q4_K_M 3.80 1024 pp2048 1105.77 1104.62 1.00
llama 7B Q4_K_M 3.80 2048 pp2048 1077.17 1072.15 1.00

@JohannesGaessler
Copy link
Collaborator Author

Thank you! That looks fine to me.

@JohannesGaessler
Copy link
Collaborator Author

I cannot reproduce the failing server test from the CI on my local machine. But the CI server tests are run with the CPU backend anyways, right? So this PR should have no effect on them.

@slaren
Copy link
Collaborator

slaren commented Jun 20, 2024

It's not related, the error was first seen in #7993.

@slaren
Copy link
Collaborator

slaren commented Jun 20, 2024

I got these errors with test-backend-ops:

  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]): [MUL_MAT] NaN at index 20 (CUDA0=nan CPU=6.696930) FAIL
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1]): [MUL_MAT] NaN at index 4 (CUDA0=nan CPU=-2.965016) FAIL
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1]): [MUL_MAT] NaN at index 1 (CUDA0=nan CPU=-0.442069) FAIL
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1]): [MUL_MAT] NaN at index 9 (CUDA0=nan CPU=0.075346) FAIL

@JohannesGaessler
Copy link
Collaborator Author

The issue should be fixed now. The TLDR is that I had not considered an edge case for very small matrices.

More specifically, for very small matrices there were instances where the fixup kernel assumed an MMQ CUDA block had written something to the fixup buffer when in reality it had simply returned because the k slice assigned to it had a length of zero (because there just wasn't enough work to distribute). The fixup kernel then read undefined memory from the pool buffer and added it to the final result. If the memory was zero that caused no issues. If the memory contained random garbage the result was NaN. So whether or not the defect manifested as a bug depended heavily on the exact setup that the tests were run with. I think it would make sense to add a test or debug mode that fills the pool buffers with NaN before returning them. That way, if any kernel uses any memory that was not previously written to it should become detectable (unless there already exists a tool to detect this kind of problem that I am not aware of).

@slaren
Copy link
Collaborator

slaren commented Jun 20, 2024

Anything that improves the testing capabilities would be good. I plan to do some work on test-backend-ops to allow testing random values for the parameters of the ops (fuzzing), that should make it easier to find edge cases in ops.

I tested the performance with the 3090 Ti and it is similar to your results with the 3090 as usual. I also tried with a 3080 and it is not always an improvement. However this GPU is my display device and the results are not as repeatable as with the 3090 Ti.

GPU Model Model Size [GiB] Microbatch size Test t/s master t/s cuda-mmq-stream-k-2 Speedup
RTX 3080 llama 7B Q4_0 3.56 16 pp2048 892.93 817.34 0.92
RTX 3080 llama 7B Q4_0 3.56 32 pp2048 1545.34 1502.87 0.97
RTX 3080 llama 7B Q4_0 3.56 64 pp2048 2103.63 2111.72 1.00
RTX 3080 llama 7B Q4_0 3.56 128 pp2048 2679.05 2657.88 0.99
RTX 3080 llama 7B Q4_0 3.56 256 pp2048 3159.96 2935.99 0.93
RTX 3080 llama 7B Q4_0 3.56 512 pp2048 3144.34 3035.00 0.97
RTX 3080 llama 7B Q4_0 3.56 1024 pp2048 3121.95 2937.57 0.94
RTX 3080 llama 7B Q4_0 3.56 2048 pp2048 2986.47 2825.06 0.95
RTX 3080 llama 7B Q8_0 6.67 16 pp2048 675.32 697.10 1.03
RTX 3080 llama 7B Q8_0 6.67 32 pp2048 1312.00 1368.82 1.04
RTX 3080 llama 7B Q8_0 6.67 64 pp2048 1931.47 2065.59 1.07
RTX 3080 llama 7B Q8_0 6.67 128 pp2048 2596.44 2714.64 1.05
RTX 3080 llama 7B Q8_0 6.67 256 pp2048 3139.19 2995.61 0.95
RTX 3080 llama 7B Q8_0 6.67 512 pp2048 3149.73 3132.67 0.99
RTX 3080 llama 7B Q8_0 6.67 1024 pp2048 3137.40 3086.56 0.98
RTX 3080 llama 7B Q8_0 6.67 2048 pp2048 3016.30 2938.36 0.97

@JohannesGaessler
Copy link
Collaborator Author

There seems to be compute-sanitizer --tool=memcheck but that did not detect this particular issue (both with/without VMM).

@JohannesGaessler JohannesGaessler merged commit d50f889 into ggerganov:master Jun 20, 2024
59 of 64 checks passed
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Jun 30, 2024
* CUDA: stream-k decomposition for MMQ

* fix undefined memory reads for small matrices
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants