diff --git a/csrc/deepep/deep_ep.cpp b/csrc/deepep/deep_ep.cpp index b37a618e..6565e7f1 100644 --- a/csrc/deepep/deep_ep.cpp +++ b/csrc/deepep/deep_ep.cpp @@ -127,15 +127,30 @@ std::tuple, at::Tensor, at::Tensor, at::Te EP_HOST_ASSERT(low_latency_mode); at::Tensor new_x = x; this->new_topk_idx = topk_idx; - if (topk_idx.size(0) == 0) { + if (topk_idx.size(0) <= 2) { this->is_padding = true; - this->ori_x = x.clone(); - new_x = torch::ones({1, 7168}, x.options()); - this->new_topk_idx = torch::arange(0, 8, topk_idx.options()).reshape({1, 8}); + this->padding_cnt = 3 - topk_idx.size(0); + std::vector x_blocks; + std::vector topk_blocks; + if (topk_idx.size(0) != 0) { + x_blocks.emplace_back(x); + topk_blocks.emplace_back(topk_idx); + } else { + this->ori_x = x.clone(); + } + for (int i = 0; i < this->padding_cnt; i++) { + at::Tensor tmp_x = torch::ones({1, 7168}, x.options()); + at::Tensor tmp_topk = torch::arange(0, 8, topk_idx.options()).reshape({1, 8}); + x_blocks.emplace_back(tmp_x); + topk_blocks.emplace_back(tmp_topk); + } + new_x = torch::cat(x_blocks, 0); + this->new_topk_idx = torch::cat(topk_blocks, 0); } auto num_tokens = static_cast(new_x.size(0)), hidden = static_cast(new_x.size(1)); auto num_scales = hidden / 128, num_topk = static_cast(new_topk_idx.size(1)); + auto num_local_experts = num_experts / (num_ranks - shared_expert_rank_num); auto num_max_tokens = 0; if (rank < shared_expert_rank_num) { @@ -144,12 +159,13 @@ std::tuple, at::Tensor, at::Tensor, at::Te } else { // moe expert num_max_tokens = num_max_dispatch_tokens_per_rank * num_ranks * num_local_experts; } + auto max_size = num_tokens * num_topk > num_max_tokens * 128 ? num_tokens * num_topk : num_max_tokens * 128; // Allocate packed tensors auto device = new_x.device(); auto packed_recv_x = at::empty({num_max_tokens, hidden}, new_x.options().dtype(use_fp8 ? at::kChar : at::kBFloat16)); auto packed_recv_x_scales = at::empty({num_max_tokens}, at::dtype(at::kFloat).device(device)); - auto expandIdx = at::empty({num_tokens * num_topk}, at::dtype(at::kInt).device(device)); + auto assist_info_for_combine = at::empty({max_size}, at::dtype(at::kInt).device(device)); auto packed_recv_count = at::empty({num_local_experts * num_ranks}, at::dtype(at::kInt).device(device)); auto tp_recv_count = at::empty({1}, at::dtype(at::kInt).device(device)); auto expertTokenNumsOut = at::empty({num_local_experts}, at::dtype(at::kLong).device(device)); @@ -163,6 +179,9 @@ std::tuple, at::Tensor, at::Tensor, at::Te int64_t expert_shard_type = 0; int64_t expert_token_nums_type = 1; int64_t global_bs = num_max_dispatch_tokens_per_rank * num_ranks; + std::string comm_log = "0"; + char *comm_log_ptr = const_cast(comm_log.c_str()); + // get ep & tp name char hcom_ep_name[128]; @@ -172,7 +191,7 @@ std::tuple, at::Tensor, at::Tensor, at::Te HCCL_CHECK(HcclGetCommName(ep_comm, hcom_ep_name)); } - EXEC_NPU_CMD(aclnnMoeDistributeDispatch, + EXEC_NPU_CMD(aclnnMoeDistributeDispatchV2, new_x, new_topk_idx, scales, // smooth scales, @@ -191,9 +210,10 @@ std::tuple, at::Tensor, at::Tensor, at::Te quant_mode, global_bs, // global_bs expert_token_nums_type, // expert_token_nums_type + comm_log_ptr, packed_recv_x, packed_recv_x_scales, // dynamicScalesOut - expandIdx, + assist_info_for_combine, expertTokenNumsOut, packed_recv_count, tp_recv_count, @@ -203,7 +223,7 @@ std::tuple, at::Tensor, at::Tensor, at::Te std::optional event; // Return values - return {packed_recv_x, packed_recv_x_scales, packed_recv_count, expandIdx, expertTokenNumsOut, event, std::function([]{})}; + return {packed_recv_x, packed_recv_x_scales, packed_recv_count, assist_info_for_combine, expertTokenNumsOut, event, std::function([]{})}; } int Buffer::get_rdma_rank() const { @@ -219,8 +239,16 @@ std::tuple, std::optionalis_padding) { + std::vector scales_blocks; + if (this->padding_cnt != 3) { + scales_blocks.emplace_back(topk_weights); + } + for (int i = 0; i < this->padding_cnt; i++) { + at::Tensor tmp_scales = torch::zeros({1, 8}, topk_weights.options()); + scales_blocks.emplace_back(tmp_scales); + } new_idx = this->new_topk_idx; - this->new_scales = torch::zeros({1, 8}, topk_weights.options()); + this->new_scales = torch::cat(scales_blocks, 0); new_scales = this->new_scales; } // Tensor checks @@ -238,7 +266,7 @@ std::tuple, std::optional, std::optional(new_scales.size(0)); auto hidden = static_cast(x.size(1)); + at::Tensor shared_expert_x{nullptr}; at::Tensor combined_x = at::empty({num_combined_tokens, hidden}, x.options()); std::optional event; + std::string comm_log = "0"; + char *comm_log_ptr = const_cast(comm_log.c_str()); - EXEC_NPU_CMD(aclnnMoeDistributeCombine, + EXEC_NPU_CMD(aclnnMoeDistributeCombineV2, expand_x, expert_ids, - expand_idx, + assist_info_for_combine, ep_send_counts, expert_scales, tp_send_counts, @@ -269,6 +300,7 @@ std::tuple, std::optional, std::optionalis_padding) { - combined_x = this->ori_x; + if (this->padding_cnt == 3) { + combined_x = this->ori_x; + } else { + combined_x = combined_x.slice(0, 0, 3 - this->padding_cnt); + } } return {combined_x, event, std::function([]{})}; } diff --git a/csrc/deepep/deep_ep.hpp b/csrc/deepep/deep_ep.hpp index 51922f06..6c55264b 100644 --- a/csrc/deepep/deep_ep.hpp +++ b/csrc/deepep/deep_ep.hpp @@ -23,6 +23,7 @@ struct Buffer { bool low_latency_mode = false; bool is_padding = false; + int padding_cnt = 0; at::Tensor ori_x; at::Tensor new_topk_idx; at::Tensor new_scales; diff --git a/csrc/utils/whatsup/666 b/csrc/utils/whatsup/666 new file mode 100644 index 00000000..7cc86ad1 --- /dev/null +++ b/csrc/utils/whatsup/666 @@ -0,0 +1 @@ +666 diff --git a/tests/python/deepep/test_low_latency.py b/tests/python/deepep/test_low_latency.py deleted file mode 100644 index efa61942..00000000 --- a/tests/python/deepep/test_low_latency.py +++ /dev/null @@ -1,84 +0,0 @@ -import os -import random -import time -import torch -import torch_npu -import torch.distributed as dist - -from deep_ep import Buffer - -def test(num_tokens: int, hidden: int, num_experts: int, num_topk: int, - rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: Buffer, seed: int = 0): - torch.manual_seed(seed + rank) - random.seed(seed + rank) - - assert num_experts % num_ranks == 0 - num_local_experts = num_experts // num_ranks - - # NOTES: the integers greater than 256 exceeds the BF16 precision limit - rank_offset = 128 - assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)' - - x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='npu') * (rank - rank_offset) - x[:, -128:] = torch.arange(num_tokens, device='npu').to(torch.bfloat16).view(-1, 1) - scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='npu').abs() + 1 - topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] - topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='npu').abs() - - # Check dispatch correctness - do_check = True - return_recv_hook = False - hash_value, num_times = 0, 0 - - cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='npu') - packed_recv_x, packed_recv_count, handle, event, hook = \ - buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, - use_fp8=False, round_scale=False, use_ue8m0=False, - cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, - async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) - simulated_gemm_x = packed_recv_x.clone() - - # Check combine correctness - out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='npu') - combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, - async_finish=not return_recv_hook, zero_copy=False, - return_recv_hook=return_recv_hook, out=out) - - return hash_value - -def test_main(): - ip = os.getenv('MASTER_ADDR', '127.0.0.1') - port = int(os.getenv('MASTER_PORT', '17621')) - world_size = int(os.getenv('WORLD_SIZE', 16)) - rank = int(os.getenv('RANK', 0)) - shared_expert_rank_num = int(os.getenv('MOE_SHARED_EXPERT_RANK_NUM', 0)) - - dist.init_process_group( - backend="hccl", - init_method=f'tcp://{ip}:{port}', - world_size=world_size, - rank=rank - ) - torch.npu.set_device(rank) - group = dist.new_group(list(range(world_size))) - print("===========group", group.size()) - if shared_expert_rank_num == 0: - num_tokens, hidden, num_topk, num_experts = 1, 7168, 8, 16 - num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, world_size, num_experts) - buffer = Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, - num_qps_per_rank=num_experts // world_size) - - test(num_tokens, hidden, num_experts, num_topk, rank, world_size, group, buffer, seed=1) - else: - num_tokens, hidden, num_topk, num_experts = 1, 7168, 8, 31 - num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, world_size, num_experts) - buffer = Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, - num_qps_per_rank=num_experts // world_size) - - test(num_tokens, hidden, num_experts - 1, num_topk, rank, world_size - shared_expert_rank_num, - group, buffer, seed=1) - dist.barrier() - dist.destroy_process_group() - -if __name__ == '__main__': - test_main() diff --git a/tests/python/deepep/run_test.sh b/tests/python/run_test.sh similarity index 100% rename from tests/python/deepep/run_test.sh rename to tests/python/run_test.sh diff --git a/tests/python/test_low_latency.py b/tests/python/test_low_latency.py new file mode 100644 index 00000000..ad52d91f --- /dev/null +++ b/tests/python/test_low_latency.py @@ -0,0 +1,149 @@ +import os +import random +import time +import torch +import torch_npu +import torch.distributed as dist + +from deep_ep import Buffer +from functools import partial +from utils import bench, calc_diff, hash_tensor + +def test(num_tokens: int, hidden: int, num_experts: int, num_topk: int, + rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: Buffer, seed: int = 0): + torch.manual_seed(seed + rank) + random.seed(seed + rank) + + assert num_experts % num_ranks == 0 + num_local_experts = num_experts // num_ranks + + # NOTES: the integers greater than 256 exceeds the BF16 precision limit + rank_offset = 128 + assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)' + + x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='npu') * (rank - rank_offset) + x[:, -128:] = torch.arange(num_tokens, device='npu').to(torch.bfloat16).view(-1, 1) + scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='npu').abs() + 1 + topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] + topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='npu').abs() + + # Check dispatch correctness + do_check = True + return_recv_hook = False + hash_value, num_times = 0, 0 + + cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='npu') + packed_recv_x, packed_recv_count, handle, event, hook = \ + buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, + use_fp8=False, round_scale=False, use_ue8m0=False, + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, + async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) + simulated_gemm_x = packed_recv_x.clone() + all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='npu') + dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) + + for i in range(num_local_experts if do_check else 0): + expert_id = rank * num_local_experts + i + temp = num_tokens / num_local_experts + recv_x = packed_recv_x[i : int((i + 1) * temp)] + recv_count = packed_recv_count[i] + if i == 0: + recv_layout_range = handle[1][(i + 1) * num_ranks - 1] + else: + recv_layout_range = handle[1][(i + 1) * num_ranks - 1] - handle[1][i * num_ranks - 1] + + # Check expert indices + int_mask = (2 ** 32) - 1 + num_valid_tokens = recv_count.item() + assert num_valid_tokens == (recv_layout_range & int_mask).item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.item()' + assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}' + + if num_valid_tokens == 0: + continue + # Check received data + recv_x = recv_x[:num_valid_tokens] + recv_x_amin = recv_x[:, :-128].amin(dim=-1) + assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) + hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) + + # Check combine correctness + out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='npu') + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, + async_finish=not return_recv_hook, zero_copy=False, + return_recv_hook=return_recv_hook, out=out) + + if do_check: + diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) + assert torch.isnan(combined_x).sum().item() == 0 + assert diff < 1e-5, f'Error: {diff=}, {zero_copy=}' + hash_value ^= hash_tensor(combined_x) + + # noinspection PyShadowingNames + def test_func(zero_copy: bool, return_recv_hook: bool): + recv_x, recv_count, handle, event, hook = \ + buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, + use_fp8=False, async_finish=False, return_recv_hook=return_recv_hook) + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, + zero_copy=zero_copy, return_recv_hook=return_recv_hook) + + # Calculate bandwidth + num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 + num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 + for i in range(num_tokens): + num_selections = (topk_idx[i] != -1).sum().item() + num_dispatch_comm_bytes += num_fp8_bytes * num_selections + num_combine_comm_bytes += num_bf16_bytes * num_selections + + # Dispatch + combine testing + avg_t, min_t, max_t = bench(partial(test_func, zero_copy=False, return_recv_hook=False)) + print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' + f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True) + + return hash_value + +def test_main(): + ip = os.getenv('MASTER_ADDR', '127.0.0.1') + port = int(os.getenv('MASTER_PORT', '17621')) + world_size = int(os.getenv('WORLD_SIZE', 16)) + rank = int(os.getenv('RANK', 0)) + shared_expert_rank_num = int(os.getenv('MOE_SHARED_EXPERT_RANK_NUM', 0)) + + dist.init_process_group( + backend="hccl", + init_method=f'tcp://{ip}:{port}', + world_size=world_size, + rank=rank + ) + torch.npu.set_device(rank) + group = dist.new_group(list(range(world_size))) + print("===========group", group.size()) + if shared_expert_rank_num == 0: + num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288 + num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, world_size, num_experts) + buffer = Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, + num_qps_per_rank=num_experts // world_size) + + use_experts = num_experts + use_ranks = world_size + else: + num_tokens, hidden, num_topk, num_experts = 1, 7168, 8, 31 + num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, world_size, num_experts) + buffer = Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, + num_qps_per_rank=num_experts // world_size) + + use_experts = num_experts - 1 + use_ranks = world_size - shared_expert_rank_num + + do_pressure_test = False + for seed in range(int(1e9) if do_pressure_test else 1): + if rank == 0: + print(f'Testing with seed {seed} ...', flush=True) + ref_hash = test(num_tokens, hidden, use_experts, num_topk, rank, use_ranks, group, buffer, seed) + for i in range(20): + assert test(num_tokens, hidden, use_experts, num_topk, rank, use_ranks, group, buffer, seed) == ref_hash, f'Error: seed={seed}' + dist.barrier() + dist.destroy_process_group() + +if __name__ == '__main__': + test_main() diff --git a/tests/python/utils.py b/tests/python/utils.py index 426fa4d5..74cb304f 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -37,6 +37,13 @@ def init_dist(local_rank: int, num_local_ranks: int): return dist.get_rank(), dist.get_world_size(), group +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double() + 1, y.double() + 1 + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return (1 - sim).item() + + def inplace_unique(x: torch.Tensor, num_slots: int): assert x.dim() == 2 mask = x < 0 @@ -85,4 +92,7 @@ def bench(fn, num_warmups: int = 50, num_tests: int = 50, post_fn=None): times.append(elapsed_time) times = np.array(times[1:]) # Remove the first timing - return np.average(times), np.min(times), np.max(times) \ No newline at end of file + return np.average(times), np.min(times), np.max(times) + +def hash_tensor(t: torch.Tensor): + return t.view(torch.int8).sum().item() \ No newline at end of file