diff --git a/dlinfer/ops/llm.py b/dlinfer/ops/llm.py index 7daddba..139065e 100644 --- a/dlinfer/ops/llm.py +++ b/dlinfer/ops/llm.py @@ -534,12 +534,12 @@ def weight_quant_matmul( @register_custom_op("dlinfer::fused_moe", ["hidden_states"]) def fused_moe( hidden_states: Tensor, - top_k: int, - topk_ids: Tensor, - topk_weights: Tensor, gate_up_weights: Tensor, down_weights: Tensor, - renormalize: bool = False, + topk_weights: Tensor, + topk_ids: Tensor, + topk: int, + renormalize: bool, ) -> Tensor: """ Implement the Fused Mixture of Experts (MoE) model. @@ -558,11 +558,11 @@ def fused_moe( """ return vendor_ops_registry["fused_moe"]( hidden_states, - top_k, - topk_ids, - topk_weights, gate_up_weights, down_weights, + topk_weights, + topk_ids, + topk, renormalize, ) diff --git a/dlinfer/vendor/ascend/torch_npu_ops.py b/dlinfer/vendor/ascend/torch_npu_ops.py index 05444df..9a95250 100644 --- a/dlinfer/vendor/ascend/torch_npu_ops.py +++ b/dlinfer/vendor/ascend/torch_npu_ops.py @@ -406,11 +406,12 @@ def weight_quant_matmul( @register_ops(vendor_ops_registry) def fused_moe( hidden_states: Tensor, - top_k: int, - topk_ids: Tensor, - topk_weights: Tensor, gate_up_weights: Tensor, down_weights: Tensor, + topk_weights: Tensor, + topk_ids: Tensor, + top_k: int, + renormalize: bool, ) -> Tensor: seq_length = hidden_states.size(0) moe_output = torch.zeros_like(hidden_states) diff --git a/dlinfer/vendor/maca/maca_ops.py b/dlinfer/vendor/maca/maca_ops.py index 73df01c..3eb9fb9 100644 --- a/dlinfer/vendor/maca/maca_ops.py +++ b/dlinfer/vendor/maca/maca_ops.py @@ -348,14 +348,14 @@ def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor: @register_ops(vendor_ops_registry) def fused_moe( - hidden_states: torch.Tensor, + hidden_states: Tensor, + gate_up_weights: Tensor, + down_weights: Tensor, + topk_weights: Tensor, + topk_ids: Tensor, top_k: int, - topk_ids: torch.LongTensor, - topk_weights: torch.Tensor, - gate_up_weights: torch.Tensor, - down_weights: torch.Tensor, - renormalize: bool = False, -): + renormalize: bool, +) -> Tensor: N = hidden_states.size(0) topk_weights = topk_weights.reshape(N, top_k) topk_ids = topk_ids.reshape(N, top_k)