Skip to content

Commit

Permalink
refactor fused_moe on ascend platform (#2613)
Browse files Browse the repository at this point in the history
  • Loading branch information
yao-fengchen authored Oct 21, 2024
1 parent c918669 commit 77be205
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 25 deletions.
28 changes: 3 additions & 25 deletions lmdeploy/pytorch/backends/dlinfer/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from lmdeploy.pytorch.kernels.dlinfer import moe_gating_topk_softmax
from lmdeploy.pytorch.kernels.dlinfer import fused_moe, moe_gating_topk_softmax

from ..moe import (FusedMoEBuilder, FusedMoEImpl, SoftmaxTopKBuilder,
SoftmaxTopKImpl)
Expand Down Expand Up @@ -42,30 +42,8 @@ def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.LongTensor, gate_up_weights: torch.Tensor,
down_weights: torch.Tensor):
"""forward."""
seq_length = hidden_states.size(0)
moe_output = torch.zeros_like(hidden_states)

for i in range(seq_length):
current_hidden_state = hidden_states[i]

# faster than remove the for loop
for j in range(self.top_k):
expert_id = topk_ids[i][j]
weight = topk_weights[i][j]

up_weight = gate_up_weights[expert_id]
up_proj = torch.matmul(up_weight, current_hidden_state)

gate_cache, up_cache = up_proj.chunk(2, -1)
gate_cache = torch.nn.functional.silu(gate_cache,
inplace=True) * up_cache

down_weight = down_weights[expert_id]
down_proj = torch.matmul(down_weight, gate_cache)

moe_output[i] += weight * down_proj

return moe_output
return fused_moe(hidden_states, self.top_k, topk_ids, topk_weights,
gate_up_weights, down_weights)


class DlinferFusedMoEBuilder(FusedMoEBuilder):
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ..default import multinomial_sampling
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .fill_kv_cache import fill_kv_cache
from .fused_moe import fused_moe
from .moe_gating_topk_softmax import moe_gating_topk_softmax
from .pagedattention import paged_attention_fwd
from .rms_norm import rms_norm
Expand All @@ -10,6 +11,7 @@
'rms_norm',
'apply_rotary_pos_emb',
'fill_kv_cache',
'fused_moe',
'paged_attention_fwd',
'moe_gating_topk_softmax',
'multinomial_sampling',
Expand Down
16 changes: 16 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/fused_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
from torch import Tensor


def fused_moe(
hidden_states: Tensor,
top_k: int,
topk_ids: Tensor,
topk_weights: Tensor,
gate_up_weights: Tensor,
down_weights: Tensor,
):
"""ascend fused moe."""
return ext_ops.fused_moe(hidden_states, top_k, topk_ids, topk_weights,
gate_up_weights, down_weights)

0 comments on commit 77be205

Please sign in to comment.