Skip to content

Commit

Permalink
Optimize code of expert dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
wtmlon committed Feb 28, 2025
1 parent 8671e0a commit c4960e4
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 103 deletions.
13 changes: 10 additions & 3 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,10 +1205,17 @@ def _get_name_mappings(cls, config: PretrainedConfig) -> List[StateDictNameMappi

@classmethod
def get_tensor_parallel_convert_actions(
cls, config: PretrainedConfig, loaded_state_dict_keys, is_split=True, ignore_error=False, base_model_prefix=None
cls,
config: PretrainedConfig,
loaded_state_dict_keys,
is_split=True,
ignore_error=False,
base_model_prefix=None,
):
name_action_mappings = cls._get_tensor_parallel_mappings(config, is_split=is_split)
state_keys_map = cls._resolve_prefix_keys(name_action_mappings.keys(), loaded_state_dict_keys, ignore_error, base_model_prefix=base_model_prefix)
state_keys_map = cls._resolve_prefix_keys(
name_action_mappings.keys(), loaded_state_dict_keys, ignore_error, base_model_prefix=base_model_prefix
)
for k, v in state_keys_map.items():
if k not in name_action_mappings:
continue
Expand Down Expand Up @@ -1319,7 +1326,7 @@ def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False, b
if k.startswith("lm_head."):
continue
# remove real key name `base_model_prefix` + '.'
state_keys_map[k[len(base_model_prefix + '.'):]] = k
state_keys_map[k[len(base_model_prefix + ".") :]] = k
return state_keys_map

# sorted by length,match from long to short for A.key B.key ...
Expand Down
8 changes: 6 additions & 2 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import warnings
from functools import partial
from typing import List, Optional, Tuple, Union
import paddle.distributed as dist

import paddle
import paddle.distributed as dist
import paddle.distributed.fleet.meta_parallel as mpu
import paddle.nn.functional as F
from paddle import Tensor, nn
Expand Down Expand Up @@ -759,7 +759,11 @@ def __init__(self, config: DeepseekV2Config):
config=config,
moe_num_experts=config.n_routed_experts,
expert_class=DeepseekV2MLP,
expert_kwargs={"config": config, "intermediate_size": config.moe_intermediate_size, "is_moe": not act_tp_shard},
expert_kwargs={
"config": config,
"intermediate_size": config.moe_intermediate_size,
"is_moe": not act_tp_shard,
},
gate=gate,
capacity=2.0,
)
Expand Down
8 changes: 6 additions & 2 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,7 +2078,9 @@ def _fuse_or_split_keys(
before_fuse_keys = list(state_dict.keys())
if pre_tensor_parallel_split:
print("xxxxxsdf: ", prefix)
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys, ignore_error=True, base_model_prefix="deepseek_v3")
tp_actions = cls.get_tensor_parallel_convert_actions(
config, loaded_keys, ignore_error=True, base_model_prefix="deepseek_v3"
)
else:
tp_actions = None
state_dict, resume_state_dict = cls.convert_fuse_and_split(config, state_dict, tp_actions)
Expand Down Expand Up @@ -2145,7 +2147,9 @@ def _fuse_or_split_keys(
pre_tensor_parallel_split = True
assert loaded_keys is not None, "loaded_keys is not None."
print("xxxxxsdf: ", prefix)
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys, ignore_error=True, base_model_prefix="deepseek_v3")
tp_actions = cls.get_tensor_parallel_convert_actions(
config, loaded_keys, ignore_error=True, base_model_prefix="deepseek_v3"
)
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
filter_dict_keys = set(expected_keys)
fuse_actions, _ = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=True)
Expand Down
19 changes: 4 additions & 15 deletions paddlenlp/transformers/moe_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor:
Returns:
paddle.Tensor: cumsum locations
"""
# (LiuTing) this func can further code refine.

# Make num_selected_experts the leading axis to ensure that top-1 choices
# have priority over top-2 choices, which have priority over top-3 choices,
# etc.
Expand Down Expand Up @@ -245,17 +247,7 @@ def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor:
# Shape: [tokens_per_group, num_experts].
token_priority = paddle.max(token_priority, axis=1)

# Token T can only be routed to expert E if its priority is positive and
# less than the expert capacity. One-hot matrix will ignore indices outside
# the range [0, expert_capacity).
# Shape: [tokens_per_group, num_experts, expert_capacity].
valid_mask = paddle.logical_and(token_priority >= 0, token_priority < capacity)
token_priority = paddle.masked_fill(token_priority, ~valid_mask, 0)
dispatch_mask = F.one_hot(token_priority, capacity).cast(paddle.int32)
valid_mask = valid_mask.unsqueeze(-1).expand(valid_mask.shape + [capacity])
dispatch_mask = paddle.masked_fill(dispatch_mask, ~valid_mask, 0)

return dispatch_mask
return token_priority

def _topk_greedy(self, scores: paddle.Tensor, k: int) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""_summary_
Expand Down Expand Up @@ -562,7 +554,4 @@ def topkgating(
if self.norm_topk_prob:
gates_masked = gates_masked / denom_s

combine_weights = paddle.einsum("se,sec->sec", gates_masked, token_priority.cast(paddle.get_default_dtype()))
dispatch_mask = combine_weights.astype(paddle.bool)

return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss
return capacity, gates_masked, token_priority, exp_counts, l_aux, l_zloss
137 changes: 56 additions & 81 deletions paddlenlp/transformers/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, Tuple

import paddle
import paddle.distributed as dist
from paddle import Tensor, nn
from paddle.distributed.communication import stream
from paddle.distributed.communication.group import Group
from paddle import nn

from .moe_gate import PretrainedMoEGate

Expand Down Expand Up @@ -90,50 +86,6 @@ def combining(x, combine_weights, scatter_index):
return paddle.matmul(combine_weights, x).squeeze(1) # [seq,1,2] @ [seq,2,dim] -> [seq,1,dim]


class _AllToAll(paddle.autograd.PyLayer):
@staticmethod
def forward(
ctx: Any,
input: Tensor,
group: Group,
) -> Tensor: # type: ignore
"""
All-to-all communication in the group.
Args:
ctx (Any): Context object.
input (Tensor): Input tensor.
group (Group): The group object.
Returns:
Tensor: Output tensor.
"""

ctx.group = group
# return input
if dist.get_world_size(group) <= 1:
return input
output = paddle.empty_like(input)
stream.alltoall_single(output, input, None, None, group, True, True)
return output

@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor]:
"""
Aggregates gradient information from all input tensors into a single tensor.
Args:
ctx (Any): The context object used to store information that needs to be passed.
*grad_output (Tensor): A list of input tensors whose gradients are to be aggregated.
Returns:
Tuple[Tensor]: A tuple containing a tensor that holds the gradients of all input tensors.
"""
# return grad_output
return _AllToAll.apply(*grad_output, ctx.group)


class MoELayer(nn.Layer):
def __init__(
self,
Expand Down Expand Up @@ -212,20 +164,23 @@ def _post_init(self):
p.no_sync = not self.is_dummy_moe
# logger.info(f"expert param={p.name}, no-sync={p.no_sync}")

def expert_forward(self, dispatched_input):
def expert_forward(self, dispatched_input, exp_token_idx):
true_experts = self.experts[
self.moe_rank * self.moe_num_experts_per_device : (self.moe_rank + 1) * self.moe_num_experts_per_device
]
expert_outputs = []
chunks = dispatched_input.unbind(1)
assert len(chunks) == len(true_experts), (len(chunks), len(true_experts))
for chunk, expert in zip(chunks, true_experts):
chunk = chunk.contiguous()
print("11", chunk.shape) # [ecm]
expert_outputs += [expert(chunk)]
expert_output = paddle.stack(expert_outputs, axis=1) # [ecm]
print("22", expert_output.shape) # [ecm]
return expert_output

for idx in range(len(dispatched_input)):
# print(dispatched_input[idx])
# print(exp_token_idx[idx])
# (LiuTing) can use paddle.stack here.
expert_outputs.append(
true_experts[idx % self.moe_num_experts_per_device](dispatched_input[idx])
if exp_token_idx[idx] is not None
else dispatched_input[idx]
)

return expert_outputs

def forward(
self,
Expand All @@ -249,32 +204,52 @@ def forward(
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
reshaped_input = hidden_state.reshape([-1, d_model])

capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.gate(hidden_state)

# self.l_aux :
# combine_weights : sec
# dispatch_mask : sec
# gates_masked : se
# token_priority : se
# self.exp_counts :
dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input)
capacity, gates_masked, token_priority, exp_counts, l_aux, l_zloss = self.gate(hidden_state)

dispatched_input = []
# print("token p: ", token_priority)
# print("token pick exp num: ", (token_priority>=0).astype('bfloat16').sum(axis=1))
# print("less pick exp: ", (token_priority>=0).astype('bfloat16').sum(axis=0).min())
# print("gates masked: ", gates_masked)
exp_token_idx = []
for e_i in range(self.moe_num_experts):
expert_tokens_idx = (token_priority[:, e_i] >= 0).nonzero().squeeze(-1)
# (LiuTing) this expert not deal with any token.
if expert_tokens_idx.shape[0] == 0:
exp_token_idx.append(None)
dispatched_input.append(None)
else:
exp_token_idx.append(expert_tokens_idx)
dispatched_input.append(paddle.gather(reshaped_input, exp_token_idx[e_i], axis=0))

if self.expert_parallel_degree > 1:
dispatched_input = _AllToAll.apply(dispatched_input, self.moe_group)
# Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(
[self.expert_parallel_degree, self.moe_num_experts_per_device, -1, d_model]
)
expert_output = self.expert_forward(dispatched_input)
# Re-shape before drop_tokens: gecm -> ecm
expert_output = expert_output.reshape(
[self.expert_parallel_degree * self.moe_num_experts_per_device, -1, d_model]
)
output = []
dist.alltoall(output, dispatched_input, self.moe_group)
dispatched_input = output

expert_output = self.expert_forward(dispatched_input, exp_token_idx)
if self.expert_parallel_degree > 1:
expert_output = _AllToAll.apply(expert_output, self.moe_group)

# combine withe expert weights
combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output)

a = combined_output.reshape(hidden_state.shape)

output = []
dist.alltoall(output, expert_output, self.moe_group)
expert_output = output

# reformat output
a = paddle.zeros_like(reshaped_input)

for e_i in range(self.moe_num_experts):
if exp_token_idx[e_i] is not None:
updated = expert_output[e_i] * gates_masked[exp_token_idx[e_i].unsqueeze(-1), e_i]
# print("e_i: ", e_i)
# print("act: ", token_priority[:, e_i])
# print("expert_output: ", expert_output[e_i])
# print("exp token idxx: ", exp_token_idx[e_i])
# print("gate weight: ", gates_masked[exp_token_idx[e_i].unsqueeze(-1), e_i])
# print("gates value: ", updated)
a[exp_token_idx[e_i]] += updated.astype(a.dtype)

# print("moe output: ", a, a.mean(), a.max(), a._md5sum())
return a, l_aux, l_zloss

0 comments on commit c4960e4

Please sign in to comment.