Skip to content

Commit

Permalink
refine code.
Browse files Browse the repository at this point in the history
  • Loading branch information
Reinerzhou committed Dec 18, 2024
1 parent d51155c commit 982a5a5
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
14 changes: 7 additions & 7 deletions dlinfer/ops/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,12 +534,12 @@ def weight_quant_matmul(
@register_custom_op("dlinfer::fused_moe", ["hidden_states"])
def fused_moe(
hidden_states: Tensor,
top_k: int,
topk_ids: Tensor,
topk_weights: Tensor,
gate_up_weights: Tensor,
down_weights: Tensor,
renormalize: bool = False,
topk_weights: Tensor,
topk_ids: Tensor,
topk: int,
renormalize: bool,
) -> Tensor:
"""
Implement the Fused Mixture of Experts (MoE) model.
Expand All @@ -558,11 +558,11 @@ def fused_moe(
"""
return vendor_ops_registry["fused_moe"](
hidden_states,
top_k,
topk_ids,
topk_weights,
gate_up_weights,
down_weights,
topk_weights,
topk_ids,
topk,
renormalize,
)

Expand Down
7 changes: 4 additions & 3 deletions dlinfer/vendor/ascend/torch_npu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,12 @@ def weight_quant_matmul(
@register_ops(vendor_ops_registry)
def fused_moe(
hidden_states: Tensor,
top_k: int,
topk_ids: Tensor,
topk_weights: Tensor,
gate_up_weights: Tensor,
down_weights: Tensor,
topk_weights: Tensor,
topk_ids: Tensor,
top_k: int,
renormalize: bool,
) -> Tensor:
seq_length = hidden_states.size(0)
moe_output = torch.zeros_like(hidden_states)
Expand Down
14 changes: 7 additions & 7 deletions dlinfer/vendor/maca/maca_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,14 +348,14 @@ def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor:

@register_ops(vendor_ops_registry)
def fused_moe(
hidden_states: torch.Tensor,
hidden_states: Tensor,
gate_up_weights: Tensor,
down_weights: Tensor,
topk_weights: Tensor,
topk_ids: Tensor,
top_k: int,
topk_ids: torch.LongTensor,
topk_weights: torch.Tensor,
gate_up_weights: torch.Tensor,
down_weights: torch.Tensor,
renormalize: bool = False,
):
renormalize: bool,
) -> Tensor:
N = hidden_states.size(0)
topk_weights = topk_weights.reshape(N, top_k)
topk_ids = topk_ids.reshape(N, top_k)
Expand Down

0 comments on commit 982a5a5

Please sign in to comment.