From d51155c874807531f1fffba8d7b7fbf641e1fceb Mon Sep 17 00:00:00 2001 From: zhoushenglong Date: Wed, 18 Dec 2024 02:26:54 +0000 Subject: [PATCH] fix moe op for maca. --- dlinfer/ops/llm.py | 9 ++++++++- dlinfer/vendor/maca/maca_ops.py | 25 ++++++++----------------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/dlinfer/ops/llm.py b/dlinfer/ops/llm.py index 6434acbe..7daddbae 100644 --- a/dlinfer/ops/llm.py +++ b/dlinfer/ops/llm.py @@ -539,6 +539,7 @@ def fused_moe( topk_weights: Tensor, gate_up_weights: Tensor, down_weights: Tensor, + renormalize: bool = False, ) -> Tensor: """ Implement the Fused Mixture of Experts (MoE) model. @@ -556,7 +557,13 @@ def fused_moe( """ return vendor_ops_registry["fused_moe"]( - hidden_states, top_k, topk_ids, topk_weights, gate_up_weights, down_weights + hidden_states, + top_k, + topk_ids, + topk_weights, + gate_up_weights, + down_weights, + renormalize, ) diff --git a/dlinfer/vendor/maca/maca_ops.py b/dlinfer/vendor/maca/maca_ops.py index 2838babe..73df01c7 100644 --- a/dlinfer/vendor/maca/maca_ops.py +++ b/dlinfer/vendor/maca/maca_ops.py @@ -354,25 +354,16 @@ def fused_moe( topk_weights: torch.Tensor, gate_up_weights: torch.Tensor, down_weights: torch.Tensor, + renormalize: bool = False, ): - N, D = hidden_states.shape - hidden_states = hidden_states.view(N, -1, D).repeat(1, top_k, 1).reshape(-1, D) - out = torch.zeros( - N * top_k, - down_weights.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device, + N = hidden_states.size(0) + topk_weights = topk_weights.reshape(N, top_k) + topk_ids = topk_ids.reshape(N, top_k) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return vllm.model_executor.layers.fused_moe.fused_experts( + hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids ) - for i in range(gate_up_weights.shape[0]): - mask = topk_ids == i - if mask.sum(): - out[mask] = silu_and_mul( - hidden_states[mask] @ gate_up_weights[i].transpose(0, 1) - ) @ down_weights[i].transpose(0, 1) - return ( - out.view(N, -1, down_weights.shape[1]) - * topk_weights.view(N, -1, 1).to(out.dtype) - ).sum(dim=1) @register_ops(vendor_ops_registry)