From d51155c874807531f1fffba8d7b7fbf641e1fceb Mon Sep 17 00:00:00 2001 From: zhoushenglong Date: Wed, 18 Dec 2024 02:26:54 +0000 Subject: [PATCH 1/2] fix moe op for maca. --- dlinfer/ops/llm.py | 9 ++++++++- dlinfer/vendor/maca/maca_ops.py | 25 ++++++++----------------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/dlinfer/ops/llm.py b/dlinfer/ops/llm.py index 6434acbe..7daddbae 100644 --- a/dlinfer/ops/llm.py +++ b/dlinfer/ops/llm.py @@ -539,6 +539,7 @@ def fused_moe( topk_weights: Tensor, gate_up_weights: Tensor, down_weights: Tensor, + renormalize: bool = False, ) -> Tensor: """ Implement the Fused Mixture of Experts (MoE) model. @@ -556,7 +557,13 @@ def fused_moe( """ return vendor_ops_registry["fused_moe"]( - hidden_states, top_k, topk_ids, topk_weights, gate_up_weights, down_weights + hidden_states, + top_k, + topk_ids, + topk_weights, + gate_up_weights, + down_weights, + renormalize, ) diff --git a/dlinfer/vendor/maca/maca_ops.py b/dlinfer/vendor/maca/maca_ops.py index 2838babe..73df01c7 100644 --- a/dlinfer/vendor/maca/maca_ops.py +++ b/dlinfer/vendor/maca/maca_ops.py @@ -354,25 +354,16 @@ def fused_moe( topk_weights: torch.Tensor, gate_up_weights: torch.Tensor, down_weights: torch.Tensor, + renormalize: bool = False, ): - N, D = hidden_states.shape - hidden_states = hidden_states.view(N, -1, D).repeat(1, top_k, 1).reshape(-1, D) - out = torch.zeros( - N * top_k, - down_weights.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device, + N = hidden_states.size(0) + topk_weights = topk_weights.reshape(N, top_k) + topk_ids = topk_ids.reshape(N, top_k) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return vllm.model_executor.layers.fused_moe.fused_experts( + hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids ) - for i in range(gate_up_weights.shape[0]): - mask = topk_ids == i - if mask.sum(): - out[mask] = silu_and_mul( - hidden_states[mask] @ gate_up_weights[i].transpose(0, 1) - ) @ down_weights[i].transpose(0, 1) - return ( - out.view(N, -1, down_weights.shape[1]) - * topk_weights.view(N, -1, 1).to(out.dtype) - ).sum(dim=1) @register_ops(vendor_ops_registry) From 7ec27912ffba1bd9c31f500534922074e3d3c0ff Mon Sep 17 00:00:00 2001 From: zhoushenglong Date: Wed, 18 Dec 2024 08:23:55 +0000 Subject: [PATCH 2/2] refine code. --- dlinfer/ops/llm.py | 16 ++++++++-------- dlinfer/vendor/ascend/torch_npu_ops.py | 7 ++++--- dlinfer/vendor/maca/maca_ops.py | 14 +++++++------- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/dlinfer/ops/llm.py b/dlinfer/ops/llm.py index 7daddbae..59f207f2 100644 --- a/dlinfer/ops/llm.py +++ b/dlinfer/ops/llm.py @@ -534,12 +534,12 @@ def weight_quant_matmul( @register_custom_op("dlinfer::fused_moe", ["hidden_states"]) def fused_moe( hidden_states: Tensor, - top_k: int, - topk_ids: Tensor, - topk_weights: Tensor, gate_up_weights: Tensor, down_weights: Tensor, - renormalize: bool = False, + topk_weights: Tensor, + topk_ids: Tensor, + topk: int, + renormalize: bool, ) -> Tensor: """ Implement the Fused Mixture of Experts (MoE) model. @@ -558,12 +558,12 @@ def fused_moe( """ return vendor_ops_registry["fused_moe"]( hidden_states, - top_k, - topk_ids, - topk_weights, gate_up_weights, down_weights, - renormalize, + topk_weights, + topk_ids, + topk, + renormalize ) diff --git a/dlinfer/vendor/ascend/torch_npu_ops.py b/dlinfer/vendor/ascend/torch_npu_ops.py index 05444df6..9a952505 100644 --- a/dlinfer/vendor/ascend/torch_npu_ops.py +++ b/dlinfer/vendor/ascend/torch_npu_ops.py @@ -406,11 +406,12 @@ def weight_quant_matmul( @register_ops(vendor_ops_registry) def fused_moe( hidden_states: Tensor, - top_k: int, - topk_ids: Tensor, - topk_weights: Tensor, gate_up_weights: Tensor, down_weights: Tensor, + topk_weights: Tensor, + topk_ids: Tensor, + top_k: int, + renormalize: bool, ) -> Tensor: seq_length = hidden_states.size(0) moe_output = torch.zeros_like(hidden_states) diff --git a/dlinfer/vendor/maca/maca_ops.py b/dlinfer/vendor/maca/maca_ops.py index 73df01c7..3eb9fb95 100644 --- a/dlinfer/vendor/maca/maca_ops.py +++ b/dlinfer/vendor/maca/maca_ops.py @@ -348,14 +348,14 @@ def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor: @register_ops(vendor_ops_registry) def fused_moe( - hidden_states: torch.Tensor, + hidden_states: Tensor, + gate_up_weights: Tensor, + down_weights: Tensor, + topk_weights: Tensor, + topk_ids: Tensor, top_k: int, - topk_ids: torch.LongTensor, - topk_weights: torch.Tensor, - gate_up_weights: torch.Tensor, - down_weights: torch.Tensor, - renormalize: bool = False, -): + renormalize: bool, +) -> Tensor: N = hidden_states.size(0) topk_weights = topk_weights.reshape(N, top_k) topk_ids = topk_ids.reshape(N, top_k)