Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Condition to achieve linear speedup? #15

Open
jiwonsong-dev opened this issue Sep 12, 2024 · 18 comments
Open

Condition to achieve linear speedup? #15

jiwonsong-dev opened this issue Sep 12, 2024 · 18 comments

Comments

@jiwonsong-dev
Copy link

I tested latency of QuantLinear forward with various sizes of input and feature sizes.
But for token counts from 1 to 1024, I cannot see any speedup compared to AWQ W4A16 kernel and the results were suboptimal to pytorch FP16 Linear in most cases.
I tested weight sizes (4096, 4096), (5120, 5120), (6656, 6656), (8192, 8192) which match linear sizes of LLaMA model family on A6000 and RTX3090 GPU.
I see the experiments in the paper was taken on A100 GPU.
Is there any specific setting or condition to see the speedup aligns with the results on paper?

@jiwonsong-dev
Copy link
Author

Overhead of activation quantization using simple PyTorch operation is substantial but the kernel itself is slower than nn.Linear for most cases.

@HandH1998
Copy link
Owner

HandH1998 commented Sep 13, 2024

@jiwonsong-dev There is online activation quantization using simple PyTorch in QuantLinear, which is very slow. The GEMM speedup in our paper is evaluated without activation quantization. If you want to reproduce the speedup, please refer to #2 (comment). By the way, the activation quantizaiton is fused into element-wise kernel like rmsnorm in our vllm PR, and it will not affect the inference speed much.

@jiwonsong-dev
Copy link
Author

Is the kernel integrated to vLLM is the same one in the repo?
I see the QuantLinear slower than nn.Linear for M from 1 to 1024 when N,K are fixed to 4096 even with the quantization overhead not considered.

@HandH1998
Copy link
Owner

@jiwonsong-dev The kernel is the same with that in vLLM. If there is no other operations like dtype conversion and reshape in your modified QuantLinear, the QuantLinear should deliver the similar performance with directly using the gemm kernel. Generally, the QuantLinear is only used for the simple inference in our repo. I recommend you to try vLLM for practical inference.

@jiwonsong-dev
Copy link
Author

I checked your fork of Marlin repository and saw actual speedup via benchmark codes. Thank you for kind response!

@jiwonsong-dev
Copy link
Author

Is there any specific reason why permutation is different when packing channel quantized weights? Per group follows original Marlin format.

@HandH1998
Copy link
Owner

@brisker
Copy link

brisker commented Sep 29, 2024

@HandH1998
I have tried QQQ-w4a8-no-group version on internVL-20B on my own task, the embarrassing thing is that, compared to w8a8, the w4a8 is faster on decoding speed indeed as expected, but slower on first-token generation.
But due to the tradeoff between first token and decoding, the final speed of w4a8 is even slightly slower than w8a8.

The puzzle from me is that, I am already using the w4a8-per-channel version, with no group, why is w4a8 first-token is still such slow?
According to your paper, the decoding process from w4 into w8 to do w8a8 gemm, is simply multiply 16, which should not be so slow.

Have you ever analyzed details like this for your w4a8-no-group kernel? Any further advice to optimize the kernel?

@HandH1998
Copy link
Owner

@brisker It it normal that w4a8 first-token is slower than w8a8, since the additional dequant operation (on slower cuda core) of w4a8 slows down tha main loop, even though the dequant overhead is small. In my experiments, if your case has a couple of decoding iteration, the final w4a8 speed is always falser than w8a8 for better decoding speed. Here we provide the detailed results.

input length=1024
output length=128
vLLM
llama-2-series

TTFT(ms)

bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128
7b
1 71.63 46.71 104.29 71.94 78.35 53.44 65.99
4 270.68 171.24 293.73 285.65 313.15 208.25 261.54
16 274.95 175.09 299.09 290.64 318.39 212.36 266.04
64 294.64 198.12 318.90 315.41 343.35 238.14 290.36
13b
1 133.48 78.37 204.17 132.94 146.03 90.93 117.30
4 241.39 155.24 312.20 265.77 293.79 180.91 234.62
16 245.51 158.43 316.91 269.22 297.37 184.73 238.52
64 285.86 180.74 337.22 289.07 317.10 204.47 257.47
70b
1 - 356.51 992.45 662.55 756.93 417.54 571.21
4 - 1400.66 2766.41 2627.26 3010.08 1674.70 2292.46
16 - 1402.62 2775.86 2635.86 3016.53 1682.51 2296.78
64 - 1425.73 2807.03 2661.51 3023.35 1712.25 2326.96

TPOT(ms)

bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128
7b
1 11.70 15.81 9.32 6.33 6.47 6.34 6.59
4 12.78 17.33 10.51 7.39 7.45 7.47 7.74
16 24.62 23.34 23.64 19.59 20.40 17.86 19.37
64 71.74 57.75 82.38 70.03 74.03 62.28 69.45
13b
1 20.10 18.29 14.39 8.99 9.18 8.94 9.31
4 23.76 22.12 18.77 12.85 13.24 12.18 12.90
16 43.33 34.32 42.85 33.15 34.98 28.45 31.72
64 146.53 92.54 151.86 117.41 125.15 96.60 111.25
70b
1 - 54.27 50.79 29.29 30.20 28.88 30.48
4 - 61.20 54.16 32.06 32.84 31.22 32.98
16 - 160.92 135.19 104.32 114.00 80.27 96.51
64 - 526.42 546.04 408.20 453.59 283.48 363.82

@brisker
Copy link

brisker commented Sep 29, 2024

@HandH1998
What does TTFT(ms) and TPOT(ms) actually mean in your chart?

@HandH1998
Copy link
Owner

HandH1998 commented Sep 29, 2024

TTFT: Time To First Token
TPOT: Time Per decoding Output Token

@brisker
Copy link

brisker commented Sep 29, 2024

@HandH1998
For sq-w8a8 in your chart, which specific kernel are you refering?
In my experiments, I used the official w8a8 kernel from vLLM(cutlass backend).

@HandH1998
Copy link
Owner

cublas w8a8 gemm from vllm-project/vllm#1508. But cublas and cutlass should have similar performance.

@brisker
Copy link

brisker commented Sep 29, 2024

TPOT: Time Per decoding Output Token

TPOT has already includes the first decoding time? or you have excluded first token time away?

@HandH1998
Copy link
Owner

It doesn't include the first token.

@brisker
Copy link

brisker commented Oct 9, 2024

@brisker It it normal that w4a8 first-token is slower than w8a8, since the additional dequant operation (on slower cuda core) of w4a8 slows down tha main loop, even though the dequant overhead is small. In my experiments, if your case has a couple of decoding iteration, the final w4a8 speed is always falser than w8a8 for better decoding speed. Here we provide the detailed results.

input length=1024 output length=128 vLLM llama-2-series

TTFT(ms)

bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128
7b
1 71.63 46.71 104.29 71.94 78.35 53.44 65.99
4 270.68 171.24 293.73 285.65 313.15 208.25 261.54
16 274.95 175.09 299.09 290.64 318.39 212.36 266.04
64 294.64 198.12 318.90 315.41 343.35 238.14 290.36
13b
1 133.48 78.37 204.17 132.94 146.03 90.93 117.30
4 241.39 155.24 312.20 265.77 293.79 180.91 234.62
16 245.51 158.43 316.91 269.22 297.37 184.73 238.52
64 285.86 180.74 337.22 289.07 317.10 204.47 257.47
70b
1 - 356.51 992.45 662.55 756.93 417.54 571.21
4 - 1400.66 2766.41 2627.26 3010.08 1674.70 2292.46
16 - 1402.62 2775.86 2635.86 3016.53 1682.51 2296.78
64 - 1425.73 2807.03 2661.51 3023.35 1712.25 2326.96
TPOT(ms)

bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128
7b
1 11.70 15.81 9.32 6.33 6.47 6.34 6.59
4 12.78 17.33 10.51 7.39 7.45 7.47 7.74
16 24.62 23.34 23.64 19.59 20.40 17.86 19.37
64 71.74 57.75 82.38 70.03 74.03 62.28 69.45
13b
1 20.10 18.29 14.39 8.99 9.18 8.94 9.31
4 23.76 22.12 18.77 12.85 13.24 12.18 12.90
16 43.33 34.32 42.85 33.15 34.98 28.45 31.72
64 146.53 92.54 151.86 117.41 125.15 96.60 111.25
70b
1 - 54.27 50.79 29.29 30.20 28.88 30.48
4 - 61.20 54.16 32.06 32.84 31.22 32.98
16 - 160.92 135.19 104.32 114.00 80.27 96.51
64 - 526.42 546.04 408.20 453.59 283.48 363.82

considering this sheet,

  1. what is the gpu are you using?
  2. for the TTFT( Time To First Token), why is llama2-13b even faster than llama2-7b?
  3. for the TTFT( Time To First Token), is the time-data for all batch, or for single sample? for example, 245.51ms for batchsize=16, does this mean the first token takes about 0.2*16=3.2 seconds for llama2-13b first token (batchsize=16)?

@HandH1998

@HandH1998
Copy link
Owner

@brisker

  1. A100-80G.
  2. I think it is because that matrix multiplication of this shape in llama2-13b can achieve greater acceleration than llama2-7b.
  3. No, it mens that 245.51ms is for all the first tokens of bsz=16.

@Andy0422
Copy link

Andy0422 commented Oct 31, 2024

@brisker It it normal that w4a8 first-token is slower than w8a8, since the additional dequant operation (on slower cuda core) of w4a8 slows down tha main loop, even though the dequant overhead is small. In my experiments, if your case has a couple of decoding iteration, the final w4a8 speed is always falser than w8a8 for better decoding speed. Here we provide the detailed results.

input length=1024 output length=128 vLLM llama-2-series

TTFT(ms)

bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128
7b
1 71.63 46.71 104.29 71.94 78.35 53.44 65.99
4 270.68 171.24 293.73 285.65 313.15 208.25 261.54
16 274.95 175.09 299.09 290.64 318.39 212.36 266.04
64 294.64 198.12 318.90 315.41 343.35 238.14 290.36
13b
1 133.48 78.37 204.17 132.94 146.03 90.93 117.30
4 241.39 155.24 312.20 265.77 293.79 180.91 234.62
16 245.51 158.43 316.91 269.22 297.37 184.73 238.52
64 285.86 180.74 337.22 289.07 317.10 204.47 257.47
70b
1 - 356.51 992.45 662.55 756.93 417.54 571.21
4 - 1400.66 2766.41 2627.26 3010.08 1674.70 2292.46
16 - 1402.62 2775.86 2635.86 3016.53 1682.51 2296.78
64 - 1425.73 2807.03 2661.51 3023.35 1712.25 2326.96
TPOT(ms)

bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128
7b
1 11.70 15.81 9.32 6.33 6.47 6.34 6.59
4 12.78 17.33 10.51 7.39 7.45 7.47 7.74
16 24.62 23.34 23.64 19.59 20.40 17.86 19.37
64 71.74 57.75 82.38 70.03 74.03 62.28 69.45
13b
1 20.10 18.29 14.39 8.99 9.18 8.94 9.31
4 23.76 22.12 18.77 12.85 13.24 12.18 12.90
16 43.33 34.32 42.85 33.15 34.98 28.45 31.72
64 146.53 92.54 151.86 117.41 125.15 96.60 111.25
70b
1 - 54.27 50.79 29.29 30.20 28.88 30.48
4 - 61.20 54.16 32.06 32.84 31.22 32.98
16 - 160.92 135.19 104.32 114.00 80.27 96.51
64 - 526.42 546.04 408.20 453.59 283.48 363.82

@HandH1998 I just confused why w4a8 is faster than w8a8 on 70B model? It seem that cannot meet the theoretic roofline model, the figure in Qserve...
I think at bs=64, it still fall into the memory bound, meanwhile it will OOM soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants