|
| 1 | +import os |
| 2 | +import random |
| 3 | +import time |
| 4 | +import torch |
| 5 | +import torch_npu |
| 6 | +import torch.distributed as dist |
| 7 | + |
| 8 | +from deep_ep import Buffer |
| 9 | +from functools import partial |
| 10 | +from utils import bench, calc_diff, hash_tensor |
| 11 | + |
| 12 | +def test(num_tokens: int, hidden: int, num_experts: int, num_topk: int, |
| 13 | + rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: Buffer, seed: int = 0): |
| 14 | + torch.manual_seed(seed + rank) |
| 15 | + random.seed(seed + rank) |
| 16 | + |
| 17 | + assert num_experts % num_ranks == 0 |
| 18 | + num_local_experts = num_experts // num_ranks |
| 19 | + |
| 20 | + # NOTES: the integers greater than 256 exceeds the BF16 precision limit |
| 21 | + rank_offset = 128 |
| 22 | + assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)' |
| 23 | + |
| 24 | + x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='npu') * (rank - rank_offset) |
| 25 | + x[:, -128:] = torch.arange(num_tokens, device='npu').to(torch.bfloat16).view(-1, 1) |
| 26 | + scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='npu').abs() + 1 |
| 27 | + topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] |
| 28 | + topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='npu').abs() |
| 29 | + |
| 30 | + # Check dispatch correctness |
| 31 | + do_check = True |
| 32 | + return_recv_hook = False |
| 33 | + hash_value, num_times = 0, 0 |
| 34 | + |
| 35 | + cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='npu') |
| 36 | + packed_recv_x, packed_recv_count, handle, event, hook = \ |
| 37 | + buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, |
| 38 | + use_fp8=False, round_scale=False, use_ue8m0=False, |
| 39 | + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, |
| 40 | + async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) |
| 41 | + simulated_gemm_x = packed_recv_x.clone() |
| 42 | + all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='npu') |
| 43 | + dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) |
| 44 | + |
| 45 | + for i in range(num_local_experts if do_check else 0): |
| 46 | + expert_id = rank * num_local_experts + i |
| 47 | + temp = num_tokens / num_local_experts |
| 48 | + recv_x = packed_recv_x[i : int((i + 1) * temp)] |
| 49 | + recv_count = packed_recv_count[i] |
| 50 | + if i == 0: |
| 51 | + recv_layout_range = handle[1][(i + 1) * num_ranks - 1] |
| 52 | + else: |
| 53 | + recv_layout_range = handle[1][(i + 1) * num_ranks - 1] - handle[1][i * num_ranks - 1] |
| 54 | + |
| 55 | + # Check expert indices |
| 56 | + int_mask = (2 ** 32) - 1 |
| 57 | + num_valid_tokens = recv_count.item() |
| 58 | + assert num_valid_tokens == (recv_layout_range & int_mask).item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.item()' |
| 59 | + assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}' |
| 60 | + |
| 61 | + if num_valid_tokens == 0: |
| 62 | + continue |
| 63 | + # Check received data |
| 64 | + recv_x = recv_x[:num_valid_tokens] |
| 65 | + recv_x_amin = recv_x[:, :-128].amin(dim=-1) |
| 66 | + assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) |
| 67 | + hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) |
| 68 | + |
| 69 | + # Check combine correctness |
| 70 | + out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='npu') |
| 71 | + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, |
| 72 | + async_finish=not return_recv_hook, zero_copy=False, |
| 73 | + return_recv_hook=return_recv_hook, out=out) |
| 74 | + |
| 75 | + if do_check: |
| 76 | + diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) |
| 77 | + assert torch.isnan(combined_x).sum().item() == 0 |
| 78 | + assert diff < 1e-5, f'Error: {diff=}, {zero_copy=}' |
| 79 | + hash_value ^= hash_tensor(combined_x) |
| 80 | + |
| 81 | + # noinspection PyShadowingNames |
| 82 | + def test_func(zero_copy: bool, return_recv_hook: bool): |
| 83 | + recv_x, recv_count, handle, event, hook = \ |
| 84 | + buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, |
| 85 | + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, |
| 86 | + use_fp8=False, async_finish=False, return_recv_hook=return_recv_hook) |
| 87 | + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, |
| 88 | + zero_copy=zero_copy, return_recv_hook=return_recv_hook) |
| 89 | + |
| 90 | + # Calculate bandwidth |
| 91 | + num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 |
| 92 | + num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 |
| 93 | + for i in range(num_tokens): |
| 94 | + num_selections = (topk_idx[i] != -1).sum().item() |
| 95 | + num_dispatch_comm_bytes += num_fp8_bytes * num_selections |
| 96 | + num_combine_comm_bytes += num_bf16_bytes * num_selections |
| 97 | + |
| 98 | + # Dispatch + combine testing |
| 99 | + avg_t, min_t, max_t = bench(partial(test_func, zero_copy=False, return_recv_hook=False)) |
| 100 | + print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' |
| 101 | + f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True) |
| 102 | + |
| 103 | + return hash_value |
| 104 | + |
| 105 | +def test_main(): |
| 106 | + ip = os.getenv('MASTER_ADDR', '127.0.0.1') |
| 107 | + port = int(os.getenv('MASTER_PORT', '17621')) |
| 108 | + world_size = int(os.getenv('WORLD_SIZE', 16)) |
| 109 | + rank = int(os.getenv('RANK', 0)) |
| 110 | + shared_expert_rank_num = int(os.getenv('MOE_SHARED_EXPERT_RANK_NUM', 0)) |
| 111 | + |
| 112 | + dist.init_process_group( |
| 113 | + backend="hccl", |
| 114 | + init_method=f'tcp://{ip}:{port}', |
| 115 | + world_size=world_size, |
| 116 | + rank=rank |
| 117 | + ) |
| 118 | + torch.npu.set_device(rank) |
| 119 | + group = dist.new_group(list(range(world_size))) |
| 120 | + print("===========group", group.size()) |
| 121 | + if shared_expert_rank_num == 0: |
| 122 | + num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288 |
| 123 | + num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, world_size, num_experts) |
| 124 | + buffer = Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, |
| 125 | + num_qps_per_rank=num_experts // world_size) |
| 126 | + |
| 127 | + use_experts = num_experts |
| 128 | + use_ranks = world_size |
| 129 | + else: |
| 130 | + num_tokens, hidden, num_topk, num_experts = 1, 7168, 8, 31 |
| 131 | + num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, world_size, num_experts) |
| 132 | + buffer = Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, |
| 133 | + num_qps_per_rank=num_experts // world_size) |
| 134 | + |
| 135 | + use_experts = num_experts - 1 |
| 136 | + use_ranks = world_size - shared_expert_rank_num |
| 137 | + |
| 138 | + do_pressure_test = False |
| 139 | + for seed in range(int(1e9) if do_pressure_test else 1): |
| 140 | + if rank == 0: |
| 141 | + print(f'Testing with seed {seed} ...', flush=True) |
| 142 | + ref_hash = test(num_tokens, hidden, use_experts, num_topk, rank, use_ranks, group, buffer, seed) |
| 143 | + for i in range(20): |
| 144 | + assert test(num_tokens, hidden, use_experts, num_topk, rank, use_ranks, group, buffer, seed) == ref_hash, f'Error: seed={seed}' |
| 145 | + dist.barrier() |
| 146 | + dist.destroy_process_group() |
| 147 | + |
| 148 | +if __name__ == '__main__': |
| 149 | + test_main() |
0 commit comments