Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[kernel] introduce fused_moe_triton_splitk to sglang #2333

Closed
wants to merge 0 commits into from

Conversation

BBuf
Copy link
Contributor

@BBuf BBuf commented Dec 3, 2024

Background

I got the idea from: blog and code

The author said that the performance of the fused_moe_triton_splitk is better than the original fused_moe_triton kernel on A100 and H100 when deploying mixtral-moe.

Therefore, in this PR, I introduced fused_moe_triton_splitk to sglang, which was derived from the original fused_moe_triton with some modifications. Due to atomic_add operations causing random values when writing to the output tensor, I can't get correct result for a long time , but I resolved the bug through debugging today. I performed development, accuracy testing, and performance validation on a GTX 4090, where I didn't see performance gains. However, I believe that newer GPUs like the A100 or H100 would show much performance improvements as claimed by Meta.

Accuracy

fused_moe_triton

serving command:

python3 -m sglang.launch_server --model-path /mnt/bbuf/benchmark_data/Qwen2-57B-A14B-Instruct-FP8   --context-length 4096  --chunked-prefill-size 512 --schedule-policy lpm --load-balance-method round_robin --enable-mixed-chunk --kv-cache-dtype auto --port 8000 --tp 4 --dp 2 --host 0.0.0.0 --disable-custom-all-reduce --mem-fraction-static 0.8 --disable-cuda-graph --dtype float16 --host 0.0.0.0 

eval command:

lm_eval --model local-completions --tasks lambada_openai,hellaswag --model_args model=/mnt/bbuf/benchmark_data/Qwen2-57B-A14B-Instruct-FP8/,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=1,max_retries=3,tokenized_requests=False --batch_size 64 --output_path sglang-fp8-output/

result:

Tasks Version Filter n-shot Metric Value Stderr
hellaswag 1 none 0 acc 0.6513 ± 0.0048
none 0 acc_norm 0.8396 ± 0.0037
lambada_openai 1 none 0 acc 0.7607 ± 0.0059
none 0 perplexity 2.9287 ± 0.0574

fused_moe_triton_split

serving command:

SGLANG_FUSED_MOE_BACKEND=GEMM_SPLITK python3 -m sglang.launch_server --model-path /mnt/bbuf/benchmark_data/Qwen2-57B-A14B-Instruct-FP8   --context-length 4096  --chunked-prefill-size 512 --schedule-policy lpm --load-balance-method round_robin --enable-mixed-chunk --kv-cache-dtype auto --port 8000 --tp 4 --dp 2 --host 0.0.0.0 --disable-custom-all-reduce --mem-fraction-static 0.8 --disable-cuda-graph --dtype float16 --host 0.0.0.0 

eval command:

lm_eval --model local-completions --tasks lambada_openai,hellaswag --model_args model=/mnt/bbuf/benchmark_data/Qwen2-57B-A14B-Instruct-FP8/,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=1,max_retries=3,tokenized_requests=False --batch_size 64 --output_path sglang-fp8-output/

result:

Tasks Version Filter n-shot Metric Value Stderr
hellaswag 1 none 0 acc 0.6498 ± 0.0048
none 0 acc_norm 0.8383 ± 0.0037
lambada_openai 1 none 0 acc 0.7632 ± 0.0059
none 0 perplexity 2.9276 ± 0.0575

GTX 4090 Performance

图片

I didn't see performance gains in GTX 4090 with tp4 fp8_w8a8 in Qwen2-57B-A14B model. However, I believe that newer GPUs like the A100 or H100 would show much performance improvements as claimed by Meta.

Limitation

It's worth noting that since Triton doesn't support atomic_add operations on BF16, you need to specify dtype as float16 when deploying the model. Please refer triton-lang/triton#2708 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant