|
15 | 15 |
|
16 | 16 | import os |
17 | 17 | import sys |
| 18 | +from typing import Tuple |
18 | 19 |
|
19 | 20 | import pytest |
20 | 21 | import torch |
@@ -570,100 +571,135 @@ def quant_dequant_per_tensor_fp8(a): |
570 | 571 | reason="The kernel only supports Blackwell. Current SM is %d." % |
571 | 572 | getSMVersion(), |
572 | 573 | ) |
573 | | -@pytest.mark.parametrize("num_tokens", [16, 64, 1024, 4096]) |
574 | | -@pytest.mark.parametrize("expert_info", [(32, 8, 4, 8), (32, 1, 1, 5), |
575 | | - (72, 1, 1, 6), (256, 8, 4, 8)]) |
576 | | -@pytest.mark.parametrize("hidden_size", [512]) |
577 | | -@pytest.mark.parametrize("intermediate_size", [512]) |
578 | | -@pytest.mark.parametrize("use_autotune", [True, False], |
579 | | - ids=["autotune", "no_autotune"]) |
580 | | -def test_moe_fp8(num_tokens, expert_info, hidden_size, intermediate_size, |
581 | | - use_autotune): |
582 | | - torch.random.manual_seed(0) |
583 | | - |
584 | | - # |
585 | | - # Data Generation |
586 | | - # |
587 | | - num_experts, n_groups, top_k_groups, top_k = expert_info |
588 | | - padding = 8 |
589 | | - routed_scaling = 2.5 |
590 | | - routing_method_type = RoutingMethodType.DeepSeekV3 |
591 | | - tile_tokens_dim = 8 if num_tokens < 1024 else 32 |
592 | | - |
593 | | - assert top_k <= num_experts |
594 | | - assert top_k <= 8 |
595 | | - assert top_k_groups <= 4 |
596 | | - assert num_experts > n_groups |
597 | | - assert num_experts % n_groups == 0 |
598 | | - assert num_experts % 4 == 0 |
599 | | - assert top_k < (top_k_groups * num_experts / n_groups) |
600 | | - |
601 | | - expert_logits = torch.randn((num_tokens, num_experts), |
602 | | - device='cuda').to(torch.float) |
603 | | - routing_bias = torch.randn(num_experts, device='cuda', dtype=torch.bfloat16) |
604 | | - |
605 | | - hidden_states = torch.randn((num_tokens, hidden_size), |
606 | | - device='cuda').to(torch.float8_e4m3fn) |
607 | | - hidden_states_scale = 2 * torch.rand( |
608 | | - (hidden_size // 128, num_tokens), device='cuda').to(torch.float) |
609 | | - |
610 | | - gemm1_weights = torch.randn( |
611 | | - (num_experts, 2 * intermediate_size, hidden_size), |
612 | | - device='cuda').to(torch.float8_e4m3fn) |
613 | | - gemm1_scales = 2 * torch.rand( |
614 | | - (num_experts, 2 * intermediate_size // 128, hidden_size // 128), |
615 | | - device='cuda').to(torch.float) |
616 | | - gemm2_weights = torch.randn((num_experts, hidden_size, intermediate_size), |
617 | | - device='cuda').to(torch.float8_e4m3fn) |
618 | | - gemm2_scales = 2 * torch.rand( |
619 | | - (num_experts, hidden_size // 128, intermediate_size // 128), |
620 | | - device='cuda').to(torch.float) |
621 | | - |
622 | | - permute_info, scores = routing_reference_no_aux(expert_logits, routing_bias, |
623 | | - top_k, n_groups, |
624 | | - top_k_groups, |
625 | | - routed_scaling, padding) |
626 | | - |
627 | | - args = moe_args(num_tokens, num_experts, hidden_size, intermediate_size, |
628 | | - top_k, padding, hidden_states, hidden_states_scale, None, |
629 | | - scores, gemm1_weights, gemm1_scales, None, gemm2_weights, |
630 | | - gemm2_scales, None, permute_info, False) |
| 574 | +class TestMoeFP8: |
| 575 | + """ |
| 576 | + Test the FP8 MoE. As autotune also covers the actual MoE, we can run the test |
| 577 | + with autotune by default. We add a separate test for no autotune to ensure that |
| 578 | + the default tactic selection works. This reduces unnecessary test runs for CI |
| 579 | + """ |
| 580 | + |
| 581 | + @pytest.mark.parametrize("num_tokens", [16, 64, 1024, 4096]) |
| 582 | + @pytest.mark.parametrize("expert_info", [(32, 8, 4, 8), (32, 1, 1, 5), |
| 583 | + (72, 1, 1, 6), (256, 8, 4, 8)]) |
| 584 | + @pytest.mark.parametrize("hidden_size", [512]) |
| 585 | + @pytest.mark.parametrize("intermediate_size", [512]) |
| 586 | + def test_autotune(self, num_tokens: int, expert_info: Tuple[int, int, int, |
| 587 | + int], |
| 588 | + hidden_size: int, intermediate_size: int): |
| 589 | + |
| 590 | + self.run_moe_fp8_test(num_tokens, |
| 591 | + expert_info, |
| 592 | + hidden_size, |
| 593 | + intermediate_size, |
| 594 | + use_autotune=True) |
| 595 | + |
| 596 | + @pytest.mark.parametrize("num_tokens", [16]) |
| 597 | + @pytest.mark.parametrize("expert_info", [(32, 8, 4, 8)]) |
| 598 | + @pytest.mark.parametrize("hidden_size", [512]) |
| 599 | + @pytest.mark.parametrize("intermediate_size", [512]) |
| 600 | + def test_no_autotune(self, num_tokens: int, expert_info: Tuple[int, int, |
| 601 | + int, int], |
| 602 | + hidden_size: int, intermediate_size: int): |
| 603 | + |
| 604 | + self.run_moe_fp8_test(num_tokens, |
| 605 | + expert_info, |
| 606 | + hidden_size, |
| 607 | + intermediate_size, |
| 608 | + use_autotune=False) |
| 609 | + |
| 610 | + def run_moe_fp8_test(self, num_tokens: int, expert_info: Tuple[int, int, |
| 611 | + int, int], |
| 612 | + hidden_size: int, intermediate_size: int, |
| 613 | + use_autotune: bool): |
| 614 | + torch.random.manual_seed(0) |
| 615 | + |
| 616 | + # |
| 617 | + # Data Generation |
| 618 | + # |
| 619 | + num_experts, n_groups, top_k_groups, top_k = expert_info |
| 620 | + padding = 8 |
| 621 | + routed_scaling = 2.5 |
| 622 | + routing_method_type = RoutingMethodType.DeepSeekV3 |
| 623 | + tile_tokens_dim = 8 if num_tokens < 1024 else 32 |
| 624 | + |
| 625 | + assert top_k <= num_experts |
| 626 | + assert top_k <= 8 |
| 627 | + assert top_k_groups <= 4 |
| 628 | + assert num_experts > n_groups |
| 629 | + assert num_experts % n_groups == 0 |
| 630 | + assert num_experts % 4 == 0 |
| 631 | + assert top_k < (top_k_groups * num_experts / n_groups) |
631 | 632 |
|
632 | | - with autotune(use_autotune): |
633 | | - output = torch.ops.trtllm.fp8_block_scale_moe_runner( |
634 | | - expert_logits, routing_bias, hidden_states, hidden_states_scale, |
635 | | - gemm1_weights, gemm1_scales, gemm2_weights, gemm2_scales, |
636 | | - num_experts, top_k, n_groups, top_k_groups, intermediate_size, 0, |
637 | | - num_experts, routed_scaling, tile_tokens_dim, routing_method_type) |
| 633 | + expert_logits = torch.randn((num_tokens, num_experts), |
| 634 | + device='cuda').to(torch.float) |
| 635 | + routing_bias = torch.randn(num_experts, |
| 636 | + device='cuda', |
| 637 | + dtype=torch.bfloat16) |
638 | 638 |
|
639 | | - output_dequant_actual = output.to(torch.float) |
640 | | - # |
641 | | - # Run the reference implementations |
642 | | - # |
643 | | - output_dequant_reference, _ = run_moe_reference_dsfp8(args) |
| 639 | + hidden_states = torch.randn((num_tokens, hidden_size), |
| 640 | + device='cuda').to(torch.float8_e4m3fn) |
| 641 | + hidden_states_scale = 2 * torch.rand( |
| 642 | + (hidden_size // 128, num_tokens), device='cuda').to(torch.float) |
| 643 | + |
| 644 | + gemm1_weights = torch.randn( |
| 645 | + (num_experts, 2 * intermediate_size, hidden_size), |
| 646 | + device='cuda').to(torch.float8_e4m3fn) |
| 647 | + gemm1_scales = 2 * torch.rand( |
| 648 | + (num_experts, 2 * intermediate_size // 128, hidden_size // 128), |
| 649 | + device='cuda').to(torch.float) |
| 650 | + gemm2_weights = torch.randn( |
| 651 | + (num_experts, hidden_size, intermediate_size), |
| 652 | + device='cuda').to(torch.float8_e4m3fn) |
| 653 | + gemm2_scales = 2 * torch.rand( |
| 654 | + (num_experts, hidden_size // 128, intermediate_size // 128), |
| 655 | + device='cuda').to(torch.float) |
644 | 656 |
|
645 | | - # |
646 | | - # Check the results |
647 | | - # |
648 | | - def check_accuracy(a, b, atol, rtol, percent): |
649 | | - if torch.any(torch.isnan(a)): |
650 | | - raise Exception("NaN in a") |
651 | | - if torch.any(torch.isnan(b)): |
652 | | - raise Exception("NaN in b") |
653 | | - assert a.shape == b.shape |
654 | | - left = torch.abs(a - b) |
655 | | - right = atol + rtol * torch.abs(b) |
656 | | - count = torch.sum(left > right) |
657 | | - mismatch_percent = count / a.numel() |
658 | | - if mismatch_percent > 1 - percent: |
659 | | - raise Exception("Mismatch percentage is %f for rtol %f" % |
660 | | - (mismatch_percent, rtol)) |
| 657 | + permute_info, scores = routing_reference_no_aux(expert_logits, |
| 658 | + routing_bias, top_k, |
| 659 | + n_groups, top_k_groups, |
| 660 | + routed_scaling, padding) |
661 | 661 |
|
662 | | - check_accuracy(output_dequant_reference, |
663 | | - output_dequant_actual, |
664 | | - atol=0.1, |
665 | | - rtol=0.85, |
666 | | - percent=0.925) |
| 662 | + args = moe_args(num_tokens, num_experts, hidden_size, intermediate_size, |
| 663 | + top_k, padding, hidden_states, hidden_states_scale, |
| 664 | + None, scores, gemm1_weights, gemm1_scales, None, |
| 665 | + gemm2_weights, gemm2_scales, None, permute_info, False) |
| 666 | + |
| 667 | + with autotune(use_autotune): |
| 668 | + output = torch.ops.trtllm.fp8_block_scale_moe_runner( |
| 669 | + expert_logits, routing_bias, hidden_states, hidden_states_scale, |
| 670 | + gemm1_weights, gemm1_scales, gemm2_weights, gemm2_scales, |
| 671 | + num_experts, top_k, n_groups, top_k_groups, intermediate_size, |
| 672 | + 0, num_experts, routed_scaling, tile_tokens_dim, |
| 673 | + routing_method_type) |
| 674 | + |
| 675 | + output_dequant_actual = output.to(torch.float) |
| 676 | + # |
| 677 | + # Run the reference implementations |
| 678 | + # |
| 679 | + output_dequant_reference, _ = run_moe_reference_dsfp8(args) |
| 680 | + |
| 681 | + # |
| 682 | + # Check the results |
| 683 | + # |
| 684 | + def check_accuracy(a, b, atol, rtol, percent): |
| 685 | + if torch.any(torch.isnan(a)): |
| 686 | + raise Exception("NaN in a") |
| 687 | + if torch.any(torch.isnan(b)): |
| 688 | + raise Exception("NaN in b") |
| 689 | + assert a.shape == b.shape |
| 690 | + left = torch.abs(a - b) |
| 691 | + right = atol + rtol * torch.abs(b) |
| 692 | + count = torch.sum(left > right) |
| 693 | + mismatch_percent = count / a.numel() |
| 694 | + if mismatch_percent > 1 - percent: |
| 695 | + raise Exception("Mismatch percentage is %f for rtol %f" % |
| 696 | + (mismatch_percent, rtol)) |
| 697 | + |
| 698 | + check_accuracy(output_dequant_reference, |
| 699 | + output_dequant_actual, |
| 700 | + atol=0.1, |
| 701 | + rtol=0.85, |
| 702 | + percent=0.925) |
667 | 703 |
|
668 | 704 |
|
669 | 705 | @pytest.mark.skipif( |
|
0 commit comments