Skip to content

Commit e3ccca0

Browse files
authored
test: reduce redundant test cases for TRTLLM Gen FP8 MoE (#5845)
Signed-off-by: Dom Brown <[email protected]>
1 parent bb5b16f commit e3ccca0

File tree

1 file changed

+126
-90
lines changed

1 file changed

+126
-90
lines changed

tests/unittest/_torch/thop/test_moe.py

Lines changed: 126 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import os
1717
import sys
18+
from typing import Tuple
1819

1920
import pytest
2021
import torch
@@ -570,100 +571,135 @@ def quant_dequant_per_tensor_fp8(a):
570571
reason="The kernel only supports Blackwell. Current SM is %d." %
571572
getSMVersion(),
572573
)
573-
@pytest.mark.parametrize("num_tokens", [16, 64, 1024, 4096])
574-
@pytest.mark.parametrize("expert_info", [(32, 8, 4, 8), (32, 1, 1, 5),
575-
(72, 1, 1, 6), (256, 8, 4, 8)])
576-
@pytest.mark.parametrize("hidden_size", [512])
577-
@pytest.mark.parametrize("intermediate_size", [512])
578-
@pytest.mark.parametrize("use_autotune", [True, False],
579-
ids=["autotune", "no_autotune"])
580-
def test_moe_fp8(num_tokens, expert_info, hidden_size, intermediate_size,
581-
use_autotune):
582-
torch.random.manual_seed(0)
583-
584-
#
585-
# Data Generation
586-
#
587-
num_experts, n_groups, top_k_groups, top_k = expert_info
588-
padding = 8
589-
routed_scaling = 2.5
590-
routing_method_type = RoutingMethodType.DeepSeekV3
591-
tile_tokens_dim = 8 if num_tokens < 1024 else 32
592-
593-
assert top_k <= num_experts
594-
assert top_k <= 8
595-
assert top_k_groups <= 4
596-
assert num_experts > n_groups
597-
assert num_experts % n_groups == 0
598-
assert num_experts % 4 == 0
599-
assert top_k < (top_k_groups * num_experts / n_groups)
600-
601-
expert_logits = torch.randn((num_tokens, num_experts),
602-
device='cuda').to(torch.float)
603-
routing_bias = torch.randn(num_experts, device='cuda', dtype=torch.bfloat16)
604-
605-
hidden_states = torch.randn((num_tokens, hidden_size),
606-
device='cuda').to(torch.float8_e4m3fn)
607-
hidden_states_scale = 2 * torch.rand(
608-
(hidden_size // 128, num_tokens), device='cuda').to(torch.float)
609-
610-
gemm1_weights = torch.randn(
611-
(num_experts, 2 * intermediate_size, hidden_size),
612-
device='cuda').to(torch.float8_e4m3fn)
613-
gemm1_scales = 2 * torch.rand(
614-
(num_experts, 2 * intermediate_size // 128, hidden_size // 128),
615-
device='cuda').to(torch.float)
616-
gemm2_weights = torch.randn((num_experts, hidden_size, intermediate_size),
617-
device='cuda').to(torch.float8_e4m3fn)
618-
gemm2_scales = 2 * torch.rand(
619-
(num_experts, hidden_size // 128, intermediate_size // 128),
620-
device='cuda').to(torch.float)
621-
622-
permute_info, scores = routing_reference_no_aux(expert_logits, routing_bias,
623-
top_k, n_groups,
624-
top_k_groups,
625-
routed_scaling, padding)
626-
627-
args = moe_args(num_tokens, num_experts, hidden_size, intermediate_size,
628-
top_k, padding, hidden_states, hidden_states_scale, None,
629-
scores, gemm1_weights, gemm1_scales, None, gemm2_weights,
630-
gemm2_scales, None, permute_info, False)
574+
class TestMoeFP8:
575+
"""
576+
Test the FP8 MoE. As autotune also covers the actual MoE, we can run the test
577+
with autotune by default. We add a separate test for no autotune to ensure that
578+
the default tactic selection works. This reduces unnecessary test runs for CI
579+
"""
580+
581+
@pytest.mark.parametrize("num_tokens", [16, 64, 1024, 4096])
582+
@pytest.mark.parametrize("expert_info", [(32, 8, 4, 8), (32, 1, 1, 5),
583+
(72, 1, 1, 6), (256, 8, 4, 8)])
584+
@pytest.mark.parametrize("hidden_size", [512])
585+
@pytest.mark.parametrize("intermediate_size", [512])
586+
def test_autotune(self, num_tokens: int, expert_info: Tuple[int, int, int,
587+
int],
588+
hidden_size: int, intermediate_size: int):
589+
590+
self.run_moe_fp8_test(num_tokens,
591+
expert_info,
592+
hidden_size,
593+
intermediate_size,
594+
use_autotune=True)
595+
596+
@pytest.mark.parametrize("num_tokens", [16])
597+
@pytest.mark.parametrize("expert_info", [(32, 8, 4, 8)])
598+
@pytest.mark.parametrize("hidden_size", [512])
599+
@pytest.mark.parametrize("intermediate_size", [512])
600+
def test_no_autotune(self, num_tokens: int, expert_info: Tuple[int, int,
601+
int, int],
602+
hidden_size: int, intermediate_size: int):
603+
604+
self.run_moe_fp8_test(num_tokens,
605+
expert_info,
606+
hidden_size,
607+
intermediate_size,
608+
use_autotune=False)
609+
610+
def run_moe_fp8_test(self, num_tokens: int, expert_info: Tuple[int, int,
611+
int, int],
612+
hidden_size: int, intermediate_size: int,
613+
use_autotune: bool):
614+
torch.random.manual_seed(0)
615+
616+
#
617+
# Data Generation
618+
#
619+
num_experts, n_groups, top_k_groups, top_k = expert_info
620+
padding = 8
621+
routed_scaling = 2.5
622+
routing_method_type = RoutingMethodType.DeepSeekV3
623+
tile_tokens_dim = 8 if num_tokens < 1024 else 32
624+
625+
assert top_k <= num_experts
626+
assert top_k <= 8
627+
assert top_k_groups <= 4
628+
assert num_experts > n_groups
629+
assert num_experts % n_groups == 0
630+
assert num_experts % 4 == 0
631+
assert top_k < (top_k_groups * num_experts / n_groups)
631632

632-
with autotune(use_autotune):
633-
output = torch.ops.trtllm.fp8_block_scale_moe_runner(
634-
expert_logits, routing_bias, hidden_states, hidden_states_scale,
635-
gemm1_weights, gemm1_scales, gemm2_weights, gemm2_scales,
636-
num_experts, top_k, n_groups, top_k_groups, intermediate_size, 0,
637-
num_experts, routed_scaling, tile_tokens_dim, routing_method_type)
633+
expert_logits = torch.randn((num_tokens, num_experts),
634+
device='cuda').to(torch.float)
635+
routing_bias = torch.randn(num_experts,
636+
device='cuda',
637+
dtype=torch.bfloat16)
638638

639-
output_dequant_actual = output.to(torch.float)
640-
#
641-
# Run the reference implementations
642-
#
643-
output_dequant_reference, _ = run_moe_reference_dsfp8(args)
639+
hidden_states = torch.randn((num_tokens, hidden_size),
640+
device='cuda').to(torch.float8_e4m3fn)
641+
hidden_states_scale = 2 * torch.rand(
642+
(hidden_size // 128, num_tokens), device='cuda').to(torch.float)
643+
644+
gemm1_weights = torch.randn(
645+
(num_experts, 2 * intermediate_size, hidden_size),
646+
device='cuda').to(torch.float8_e4m3fn)
647+
gemm1_scales = 2 * torch.rand(
648+
(num_experts, 2 * intermediate_size // 128, hidden_size // 128),
649+
device='cuda').to(torch.float)
650+
gemm2_weights = torch.randn(
651+
(num_experts, hidden_size, intermediate_size),
652+
device='cuda').to(torch.float8_e4m3fn)
653+
gemm2_scales = 2 * torch.rand(
654+
(num_experts, hidden_size // 128, intermediate_size // 128),
655+
device='cuda').to(torch.float)
644656

645-
#
646-
# Check the results
647-
#
648-
def check_accuracy(a, b, atol, rtol, percent):
649-
if torch.any(torch.isnan(a)):
650-
raise Exception("NaN in a")
651-
if torch.any(torch.isnan(b)):
652-
raise Exception("NaN in b")
653-
assert a.shape == b.shape
654-
left = torch.abs(a - b)
655-
right = atol + rtol * torch.abs(b)
656-
count = torch.sum(left > right)
657-
mismatch_percent = count / a.numel()
658-
if mismatch_percent > 1 - percent:
659-
raise Exception("Mismatch percentage is %f for rtol %f" %
660-
(mismatch_percent, rtol))
657+
permute_info, scores = routing_reference_no_aux(expert_logits,
658+
routing_bias, top_k,
659+
n_groups, top_k_groups,
660+
routed_scaling, padding)
661661

662-
check_accuracy(output_dequant_reference,
663-
output_dequant_actual,
664-
atol=0.1,
665-
rtol=0.85,
666-
percent=0.925)
662+
args = moe_args(num_tokens, num_experts, hidden_size, intermediate_size,
663+
top_k, padding, hidden_states, hidden_states_scale,
664+
None, scores, gemm1_weights, gemm1_scales, None,
665+
gemm2_weights, gemm2_scales, None, permute_info, False)
666+
667+
with autotune(use_autotune):
668+
output = torch.ops.trtllm.fp8_block_scale_moe_runner(
669+
expert_logits, routing_bias, hidden_states, hidden_states_scale,
670+
gemm1_weights, gemm1_scales, gemm2_weights, gemm2_scales,
671+
num_experts, top_k, n_groups, top_k_groups, intermediate_size,
672+
0, num_experts, routed_scaling, tile_tokens_dim,
673+
routing_method_type)
674+
675+
output_dequant_actual = output.to(torch.float)
676+
#
677+
# Run the reference implementations
678+
#
679+
output_dequant_reference, _ = run_moe_reference_dsfp8(args)
680+
681+
#
682+
# Check the results
683+
#
684+
def check_accuracy(a, b, atol, rtol, percent):
685+
if torch.any(torch.isnan(a)):
686+
raise Exception("NaN in a")
687+
if torch.any(torch.isnan(b)):
688+
raise Exception("NaN in b")
689+
assert a.shape == b.shape
690+
left = torch.abs(a - b)
691+
right = atol + rtol * torch.abs(b)
692+
count = torch.sum(left > right)
693+
mismatch_percent = count / a.numel()
694+
if mismatch_percent > 1 - percent:
695+
raise Exception("Mismatch percentage is %f for rtol %f" %
696+
(mismatch_percent, rtol))
697+
698+
check_accuracy(output_dequant_reference,
699+
output_dequant_actual,
700+
atol=0.1,
701+
rtol=0.85,
702+
percent=0.925)
667703

668704

669705
@pytest.mark.skipif(

0 commit comments

Comments
 (0)