diff --git a/dlinfer/ops/llm.py b/dlinfer/ops/llm.py index 6434acb..59f207f 100644 --- a/dlinfer/ops/llm.py +++ b/dlinfer/ops/llm.py @@ -534,11 +534,12 @@ def weight_quant_matmul( @register_custom_op("dlinfer::fused_moe", ["hidden_states"]) def fused_moe( hidden_states: Tensor, - top_k: int, - topk_ids: Tensor, - topk_weights: Tensor, gate_up_weights: Tensor, down_weights: Tensor, + topk_weights: Tensor, + topk_ids: Tensor, + topk: int, + renormalize: bool, ) -> Tensor: """ Implement the Fused Mixture of Experts (MoE) model. @@ -556,7 +557,13 @@ def fused_moe( """ return vendor_ops_registry["fused_moe"]( - hidden_states, top_k, topk_ids, topk_weights, gate_up_weights, down_weights + hidden_states, + gate_up_weights, + down_weights, + topk_weights, + topk_ids, + topk, + renormalize ) diff --git a/dlinfer/vendor/ascend/torch_npu_ops.py b/dlinfer/vendor/ascend/torch_npu_ops.py index 05444df..9a95250 100644 --- a/dlinfer/vendor/ascend/torch_npu_ops.py +++ b/dlinfer/vendor/ascend/torch_npu_ops.py @@ -406,11 +406,12 @@ def weight_quant_matmul( @register_ops(vendor_ops_registry) def fused_moe( hidden_states: Tensor, - top_k: int, - topk_ids: Tensor, - topk_weights: Tensor, gate_up_weights: Tensor, down_weights: Tensor, + topk_weights: Tensor, + topk_ids: Tensor, + top_k: int, + renormalize: bool, ) -> Tensor: seq_length = hidden_states.size(0) moe_output = torch.zeros_like(hidden_states) diff --git a/dlinfer/vendor/maca/maca_ops.py b/dlinfer/vendor/maca/maca_ops.py index 2838bab..3eb9fb9 100644 --- a/dlinfer/vendor/maca/maca_ops.py +++ b/dlinfer/vendor/maca/maca_ops.py @@ -348,31 +348,22 @@ def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor: @register_ops(vendor_ops_registry) def fused_moe( - hidden_states: torch.Tensor, + hidden_states: Tensor, + gate_up_weights: Tensor, + down_weights: Tensor, + topk_weights: Tensor, + topk_ids: Tensor, top_k: int, - topk_ids: torch.LongTensor, - topk_weights: torch.Tensor, - gate_up_weights: torch.Tensor, - down_weights: torch.Tensor, -): - N, D = hidden_states.shape - hidden_states = hidden_states.view(N, -1, D).repeat(1, top_k, 1).reshape(-1, D) - out = torch.zeros( - N * top_k, - down_weights.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device, + renormalize: bool, +) -> Tensor: + N = hidden_states.size(0) + topk_weights = topk_weights.reshape(N, top_k) + topk_ids = topk_ids.reshape(N, top_k) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return vllm.model_executor.layers.fused_moe.fused_experts( + hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids ) - for i in range(gate_up_weights.shape[0]): - mask = topk_ids == i - if mask.sum(): - out[mask] = silu_and_mul( - hidden_states[mask] @ gate_up_weights[i].transpose(0, 1) - ) @ down_weights[i].transpose(0, 1) - return ( - out.view(N, -1, down_weights.shape[1]) - * topk_weights.view(N, -1, 1).to(out.dtype) - ).sum(dim=1) @register_ops(vendor_ops_registry)