forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Kernel] Update fused_moe tuning script for FP8 (vllm-project#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo. All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens. Before this PR (with static activation scaling): qps = 1: 9.8 ms ITL, 0.49s e2e latency qps = 2: 9.7 ms ITL, 0.49s e2e latency qps = 4: 10.1 ms ITL, 0.52s e2e latency qps = 6: 11.9 ms ITL, 0.59s e2e latency qps = 8: 14.0 ms ITL, 0.70s e2e latency qps = 10: 15.7 ms ITL, 0.79s e2e latency After this PR (with static activation scaling): qps = 1: 9.8 ms ITL, 0.49s e2e latency qps = 2: 9.7 ms ITL, 0.49s e2e latency qps = 4: 10.2 ms ITL, 0.53s e2e latency qps = 6: 11.9 ms ITL, 0.59s e2e latency qps = 8: 11.9 ms ITL, 0.59s e2e latency qps = 10: 12.1 ms ITL, 0.61s e2e latency
- Loading branch information
Showing
2 changed files
with
211 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
140 changes: 140 additions & 0 deletions
140
...r/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
{ | ||
"1": { | ||
"BLOCK_SIZE_M": 16, | ||
"BLOCK_SIZE_N": 32, | ||
"BLOCK_SIZE_K": 64, | ||
"GROUP_SIZE_M": 1 | ||
}, | ||
"2": { | ||
"BLOCK_SIZE_M": 16, | ||
"BLOCK_SIZE_N": 32, | ||
"BLOCK_SIZE_K": 64, | ||
"GROUP_SIZE_M": 1 | ||
}, | ||
"4": { | ||
"BLOCK_SIZE_M": 16, | ||
"BLOCK_SIZE_N": 32, | ||
"BLOCK_SIZE_K": 64, | ||
"GROUP_SIZE_M": 1 | ||
}, | ||
"8": { | ||
"BLOCK_SIZE_M": 64, | ||
"BLOCK_SIZE_N": 64, | ||
"BLOCK_SIZE_K": 256, | ||
"GROUP_SIZE_M": 1, | ||
"num_warps": 8, | ||
"num_stages": 5 | ||
}, | ||
"16": { | ||
"BLOCK_SIZE_M": 64, | ||
"BLOCK_SIZE_N": 128, | ||
"BLOCK_SIZE_K": 64, | ||
"GROUP_SIZE_M": 16, | ||
"num_warps": 4, | ||
"num_stages": 5 | ||
}, | ||
"24": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 64, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 1, | ||
"num_warps": 8, | ||
"num_stages": 5 | ||
}, | ||
"32": { | ||
"BLOCK_SIZE_M": 64, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 16, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"48": { | ||
"BLOCK_SIZE_M": 64, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 64, | ||
"num_warps": 8, | ||
"num_stages": 3 | ||
}, | ||
"64": { | ||
"BLOCK_SIZE_M": 64, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 64, | ||
"GROUP_SIZE_M": 64, | ||
"num_warps": 4, | ||
"num_stages": 4 | ||
}, | ||
"96": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 64, | ||
"BLOCK_SIZE_K": 256, | ||
"GROUP_SIZE_M": 32, | ||
"num_warps": 8, | ||
"num_stages": 2 | ||
}, | ||
"128": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 64, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 32, | ||
"num_warps": 4, | ||
"num_stages": 3 | ||
}, | ||
"256": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 32, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 16, | ||
"num_warps": 4, | ||
"num_stages": 5 | ||
}, | ||
"512": { | ||
"BLOCK_SIZE_M": 64, | ||
"BLOCK_SIZE_N": 128, | ||
"BLOCK_SIZE_K": 256, | ||
"GROUP_SIZE_M": 64, | ||
"num_warps": 4, | ||
"num_stages": 2 | ||
}, | ||
"1024": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 64, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"1536": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 32, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"2048": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 64, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"3072": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 32, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"4096": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 16, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
} | ||
} |