[kernel] introduce fused_moe_triton_splitk to sglang #2333
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Background
I got the idea from: blog and code
The author said that the performance of the
fused_moe_triton_splitk
is better than the originalfused_moe_triton
kernel on A100 and H100 when deploying mixtral-moe.Therefore, in this PR, I introduced
fused_moe_triton_splitk
to sglang, which was derived from the originalfused_moe_triton
with some modifications. Due toatomic_add
operations causing random values when writing to the output tensor, I can't get correct result for a long time , but I resolved the bug through debugging today. I performed development, accuracy testing, and performance validation on a GTX 4090, where I didn't see performance gains. However, I believe that newer GPUs like the A100 or H100 would show much performance improvements as claimed by Meta.Accuracy
fused_moe_triton
serving command:
eval command:
result:
fused_moe_triton_split
serving command:
eval command:
result:
GTX 4090 Performance
I didn't see performance gains in GTX 4090 with tp4 fp8_w8a8 in Qwen2-57B-A14B model. However, I believe that newer GPUs like the A100 or H100 would show much performance improvements as claimed by Meta.
Limitation
It's worth noting that since Triton doesn't support atomic_add operations on BF16, you need to specify dtype as float16 when deploying the model. Please refer triton-lang/triton#2708 .