Skip to content

Commit

Permalink
fix moe op for maca.
Browse files Browse the repository at this point in the history
  • Loading branch information
Reinerzhou committed Dec 18, 2024
1 parent 119130c commit d51155c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 18 deletions.
9 changes: 8 additions & 1 deletion dlinfer/ops/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def fused_moe(
topk_weights: Tensor,
gate_up_weights: Tensor,
down_weights: Tensor,
renormalize: bool = False,
) -> Tensor:
"""
Implement the Fused Mixture of Experts (MoE) model.
Expand All @@ -556,7 +557,13 @@ def fused_moe(
"""
return vendor_ops_registry["fused_moe"](
hidden_states, top_k, topk_ids, topk_weights, gate_up_weights, down_weights
hidden_states,
top_k,
topk_ids,
topk_weights,
gate_up_weights,
down_weights,
renormalize,
)


Expand Down
25 changes: 8 additions & 17 deletions dlinfer/vendor/maca/maca_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,25 +354,16 @@ def fused_moe(
topk_weights: torch.Tensor,
gate_up_weights: torch.Tensor,
down_weights: torch.Tensor,
renormalize: bool = False,
):
N, D = hidden_states.shape
hidden_states = hidden_states.view(N, -1, D).repeat(1, top_k, 1).reshape(-1, D)
out = torch.zeros(
N * top_k,
down_weights.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
N = hidden_states.size(0)
topk_weights = topk_weights.reshape(N, top_k)
topk_ids = topk_ids.reshape(N, top_k)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return vllm.model_executor.layers.fused_moe.fused_experts(
hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids
)
for i in range(gate_up_weights.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = silu_and_mul(
hidden_states[mask] @ gate_up_weights[i].transpose(0, 1)
) @ down_weights[i].transpose(0, 1)
return (
out.view(N, -1, down_weights.shape[1])
* topk_weights.view(N, -1, 1).to(out.dtype)
).sum(dim=1)


@register_ops(vendor_ops_registry)
Expand Down

0 comments on commit d51155c

Please sign in to comment.