Skip to content

Commit 176b8ab

Browse files
author
luanyundu
committed
Add test for low-latency-mode
1 parent 6446516 commit 176b8ab

File tree

5 files changed

+161
-87
lines changed

5 files changed

+161
-87
lines changed

csrc/deepep/deep_ep.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ std::tuple<at::Tensor, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Te
165165
auto device = new_x.device();
166166
auto packed_recv_x = at::empty({num_max_tokens, hidden}, new_x.options().dtype(use_fp8 ? at::kChar : at::kBFloat16));
167167
auto packed_recv_x_scales = at::empty({num_max_tokens}, at::dtype(at::kFloat).device(device));
168-
auto assist_info_for_combine = at::empty({num_tokens * num_topk}, at::dtype(at::kInt).device(device));
168+
auto assist_info_for_combine = at::empty({max_size}, at::dtype(at::kInt).device(device));
169169
auto packed_recv_count = at::empty({num_local_experts * num_ranks}, at::dtype(at::kInt).device(device));
170170
auto tp_recv_count = at::empty({1}, at::dtype(at::kInt).device(device));
171171
auto expertTokenNumsOut = at::empty({num_local_experts}, at::dtype(at::kLong).device(device));
@@ -179,7 +179,6 @@ std::tuple<at::Tensor, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Te
179179
int64_t expert_shard_type = 0;
180180
int64_t expert_token_nums_type = 1;
181181
int64_t global_bs = num_max_dispatch_tokens_per_rank * num_ranks;
182-
int64_t shared_expert_rank_num = 0;
183182
std::string comm_log = "0";
184183
char *comm_log_ptr = const_cast<char *>(comm_log.c_str());
185184

tests/python/deepep/test_low_latency.py

Lines changed: 0 additions & 84 deletions
This file was deleted.
File renamed without changes.

tests/python/test_low_latency.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import os
2+
import random
3+
import time
4+
import torch
5+
import torch_npu
6+
import torch.distributed as dist
7+
8+
from deep_ep import Buffer
9+
from functools import partial
10+
from utils import bench, calc_diff, hash_tensor
11+
12+
def test(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
13+
rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: Buffer, seed: int = 0):
14+
torch.manual_seed(seed + rank)
15+
random.seed(seed + rank)
16+
17+
assert num_experts % num_ranks == 0
18+
num_local_experts = num_experts // num_ranks
19+
20+
# NOTES: the integers greater than 256 exceeds the BF16 precision limit
21+
rank_offset = 128
22+
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
23+
24+
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='npu') * (rank - rank_offset)
25+
x[:, -128:] = torch.arange(num_tokens, device='npu').to(torch.bfloat16).view(-1, 1)
26+
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='npu').abs() + 1
27+
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
28+
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='npu').abs()
29+
30+
# Check dispatch correctness
31+
do_check = True
32+
return_recv_hook = False
33+
hash_value, num_times = 0, 0
34+
35+
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='npu')
36+
packed_recv_x, packed_recv_count, handle, event, hook = \
37+
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
38+
use_fp8=False, round_scale=False, use_ue8m0=False,
39+
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
40+
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
41+
simulated_gemm_x = packed_recv_x.clone()
42+
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='npu')
43+
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
44+
45+
for i in range(num_local_experts if do_check else 0):
46+
expert_id = rank * num_local_experts + i
47+
temp = num_tokens / num_local_experts
48+
recv_x = packed_recv_x[i : int((i + 1) * temp)]
49+
recv_count = packed_recv_count[i]
50+
if i == 0:
51+
recv_layout_range = handle[1][(i + 1) * num_ranks - 1]
52+
else:
53+
recv_layout_range = handle[1][(i + 1) * num_ranks - 1] - handle[1][i * num_ranks - 1]
54+
55+
# Check expert indices
56+
int_mask = (2 ** 32) - 1
57+
num_valid_tokens = recv_count.item()
58+
assert num_valid_tokens == (recv_layout_range & int_mask).item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.item()'
59+
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
60+
61+
if num_valid_tokens == 0:
62+
continue
63+
# Check received data
64+
recv_x = recv_x[:num_valid_tokens]
65+
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
66+
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
67+
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
68+
69+
# Check combine correctness
70+
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='npu')
71+
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
72+
async_finish=not return_recv_hook, zero_copy=False,
73+
return_recv_hook=return_recv_hook, out=out)
74+
75+
if do_check:
76+
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
77+
assert torch.isnan(combined_x).sum().item() == 0
78+
assert diff < 1e-5, f'Error: {diff=}, {zero_copy=}'
79+
hash_value ^= hash_tensor(combined_x)
80+
81+
# noinspection PyShadowingNames
82+
def test_func(zero_copy: bool, return_recv_hook: bool):
83+
recv_x, recv_count, handle, event, hook = \
84+
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
85+
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
86+
use_fp8=False, async_finish=False, return_recv_hook=return_recv_hook)
87+
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
88+
zero_copy=zero_copy, return_recv_hook=return_recv_hook)
89+
90+
# Calculate bandwidth
91+
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
92+
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
93+
for i in range(num_tokens):
94+
num_selections = (topk_idx[i] != -1).sum().item()
95+
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
96+
num_combine_comm_bytes += num_bf16_bytes * num_selections
97+
98+
# Dispatch + combine testing
99+
avg_t, min_t, max_t = bench(partial(test_func, zero_copy=False, return_recv_hook=False))
100+
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
101+
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)
102+
103+
return hash_value
104+
105+
def test_main():
106+
ip = os.getenv('MASTER_ADDR', '127.0.0.1')
107+
port = int(os.getenv('MASTER_PORT', '17621'))
108+
world_size = int(os.getenv('WORLD_SIZE', 16))
109+
rank = int(os.getenv('RANK', 0))
110+
shared_expert_rank_num = int(os.getenv('MOE_SHARED_EXPERT_RANK_NUM', 0))
111+
112+
dist.init_process_group(
113+
backend="hccl",
114+
init_method=f'tcp://{ip}:{port}',
115+
world_size=world_size,
116+
rank=rank
117+
)
118+
torch.npu.set_device(rank)
119+
group = dist.new_group(list(range(world_size)))
120+
print("===========group", group.size())
121+
if shared_expert_rank_num == 0:
122+
num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288
123+
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, world_size, num_experts)
124+
buffer = Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
125+
num_qps_per_rank=num_experts // world_size)
126+
127+
use_experts = num_experts
128+
use_ranks = world_size
129+
else:
130+
num_tokens, hidden, num_topk, num_experts = 1, 7168, 8, 31
131+
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, world_size, num_experts)
132+
buffer = Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
133+
num_qps_per_rank=num_experts // world_size)
134+
135+
use_experts = num_experts - 1
136+
use_ranks = world_size - shared_expert_rank_num
137+
138+
do_pressure_test = False
139+
for seed in range(int(1e9) if do_pressure_test else 1):
140+
if rank == 0:
141+
print(f'Testing with seed {seed} ...', flush=True)
142+
ref_hash = test(num_tokens, hidden, use_experts, num_topk, rank, use_ranks, group, buffer, seed)
143+
for i in range(20):
144+
assert test(num_tokens, hidden, use_experts, num_topk, rank, use_ranks, group, buffer, seed) == ref_hash, f'Error: seed={seed}'
145+
dist.barrier()
146+
dist.destroy_process_group()
147+
148+
if __name__ == '__main__':
149+
test_main()

tests/python/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ def init_dist(local_rank: int, num_local_ranks: int):
3737
return dist.get_rank(), dist.get_world_size(), group
3838

3939

40+
def calc_diff(x: torch.Tensor, y: torch.Tensor):
41+
x, y = x.double() + 1, y.double() + 1
42+
denominator = (x * x + y * y).sum()
43+
sim = 2 * (x * y).sum() / denominator
44+
return (1 - sim).item()
45+
46+
4047
def inplace_unique(x: torch.Tensor, num_slots: int):
4148
assert x.dim() == 2
4249
mask = x < 0
@@ -85,4 +92,7 @@ def bench(fn, num_warmups: int = 50, num_tests: int = 50, post_fn=None):
8592
times.append(elapsed_time)
8693

8794
times = np.array(times[1:]) # Remove the first timing
88-
return np.average(times), np.min(times), np.max(times)
95+
return np.average(times), np.min(times), np.max(times)
96+
97+
def hash_tensor(t: torch.Tensor):
98+
return t.view(torch.int8).sum().item()

0 commit comments

Comments
 (0)