diff --git a/paddlenlp/transformers/conversion_utils.py b/paddlenlp/transformers/conversion_utils.py index d5a34ed3cba4..b7a80d617d08 100644 --- a/paddlenlp/transformers/conversion_utils.py +++ b/paddlenlp/transformers/conversion_utils.py @@ -1205,10 +1205,17 @@ def _get_name_mappings(cls, config: PretrainedConfig) -> List[StateDictNameMappi @classmethod def get_tensor_parallel_convert_actions( - cls, config: PretrainedConfig, loaded_state_dict_keys, is_split=True, ignore_error=False, base_model_prefix=None + cls, + config: PretrainedConfig, + loaded_state_dict_keys, + is_split=True, + ignore_error=False, + base_model_prefix=None, ): name_action_mappings = cls._get_tensor_parallel_mappings(config, is_split=is_split) - state_keys_map = cls._resolve_prefix_keys(name_action_mappings.keys(), loaded_state_dict_keys, ignore_error, base_model_prefix=base_model_prefix) + state_keys_map = cls._resolve_prefix_keys( + name_action_mappings.keys(), loaded_state_dict_keys, ignore_error, base_model_prefix=base_model_prefix + ) for k, v in state_keys_map.items(): if k not in name_action_mappings: continue @@ -1319,7 +1326,7 @@ def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False, b if k.startswith("lm_head."): continue # remove real key name `base_model_prefix` + '.' - state_keys_map[k[len(base_model_prefix + '.'):]] = k + state_keys_map[k[len(base_model_prefix + ".") :]] = k return state_keys_map # sorted by length,match from long to short for A.key B.key ... diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 092202e149b7..24f0dd1ca4d3 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -25,9 +25,9 @@ import warnings from functools import partial from typing import List, Optional, Tuple, Union -import paddle.distributed as dist import paddle +import paddle.distributed as dist import paddle.distributed.fleet.meta_parallel as mpu import paddle.nn.functional as F from paddle import Tensor, nn @@ -759,7 +759,11 @@ def __init__(self, config: DeepseekV2Config): config=config, moe_num_experts=config.n_routed_experts, expert_class=DeepseekV2MLP, - expert_kwargs={"config": config, "intermediate_size": config.moe_intermediate_size, "is_moe": not act_tp_shard}, + expert_kwargs={ + "config": config, + "intermediate_size": config.moe_intermediate_size, + "is_moe": not act_tp_shard, + }, gate=gate, capacity=2.0, ) diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index a7b5964f76a0..737dd943bec3 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -2078,7 +2078,9 @@ def _fuse_or_split_keys( before_fuse_keys = list(state_dict.keys()) if pre_tensor_parallel_split: print("xxxxxsdf: ", prefix) - tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys, ignore_error=True, base_model_prefix="deepseek_v3") + tp_actions = cls.get_tensor_parallel_convert_actions( + config, loaded_keys, ignore_error=True, base_model_prefix="deepseek_v3" + ) else: tp_actions = None state_dict, resume_state_dict = cls.convert_fuse_and_split(config, state_dict, tp_actions) @@ -2145,7 +2147,9 @@ def _fuse_or_split_keys( pre_tensor_parallel_split = True assert loaded_keys is not None, "loaded_keys is not None." print("xxxxxsdf: ", prefix) - tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys, ignore_error=True, base_model_prefix="deepseek_v3") + tp_actions = cls.get_tensor_parallel_convert_actions( + config, loaded_keys, ignore_error=True, base_model_prefix="deepseek_v3" + ) # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors filter_dict_keys = set(expected_keys) fuse_actions, _ = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=True) diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index 28fed62d4eaa..8184b34c7454 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -218,6 +218,8 @@ def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor: Returns: paddle.Tensor: cumsum locations """ + # (LiuTing) this func can further code refine. + # Make num_selected_experts the leading axis to ensure that top-1 choices # have priority over top-2 choices, which have priority over top-3 choices, # etc. @@ -245,17 +247,7 @@ def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor: # Shape: [tokens_per_group, num_experts]. token_priority = paddle.max(token_priority, axis=1) - # Token T can only be routed to expert E if its priority is positive and - # less than the expert capacity. One-hot matrix will ignore indices outside - # the range [0, expert_capacity). - # Shape: [tokens_per_group, num_experts, expert_capacity]. - valid_mask = paddle.logical_and(token_priority >= 0, token_priority < capacity) - token_priority = paddle.masked_fill(token_priority, ~valid_mask, 0) - dispatch_mask = F.one_hot(token_priority, capacity).cast(paddle.int32) - valid_mask = valid_mask.unsqueeze(-1).expand(valid_mask.shape + [capacity]) - dispatch_mask = paddle.masked_fill(dispatch_mask, ~valid_mask, 0) - - return dispatch_mask + return token_priority def _topk_greedy(self, scores: paddle.Tensor, k: int) -> Tuple[paddle.Tensor, paddle.Tensor]: """_summary_ @@ -562,7 +554,4 @@ def topkgating( if self.norm_topk_prob: gates_masked = gates_masked / denom_s - combine_weights = paddle.einsum("se,sec->sec", gates_masked, token_priority.cast(paddle.get_default_dtype())) - dispatch_mask = combine_weights.astype(paddle.bool) - - return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + return capacity, gates_masked, token_priority, exp_counts, l_aux, l_zloss diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 32e54245eeed..2006eeba2f73 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -15,13 +15,9 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Tuple - import paddle import paddle.distributed as dist -from paddle import Tensor, nn -from paddle.distributed.communication import stream -from paddle.distributed.communication.group import Group +from paddle import nn from .moe_gate import PretrainedMoEGate @@ -90,50 +86,6 @@ def combining(x, combine_weights, scatter_index): return paddle.matmul(combine_weights, x).squeeze(1) # [seq,1,2] @ [seq,2,dim] -> [seq,1,dim] -class _AllToAll(paddle.autograd.PyLayer): - @staticmethod - def forward( - ctx: Any, - input: Tensor, - group: Group, - ) -> Tensor: # type: ignore - """ - All-to-all communication in the group. - - Args: - ctx (Any): Context object. - input (Tensor): Input tensor. - group (Group): The group object. - - Returns: - Tensor: Output tensor. - """ - - ctx.group = group - # return input - if dist.get_world_size(group) <= 1: - return input - output = paddle.empty_like(input) - stream.alltoall_single(output, input, None, None, group, True, True) - return output - - @staticmethod - def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor]: - """ - Aggregates gradient information from all input tensors into a single tensor. - - Args: - ctx (Any): The context object used to store information that needs to be passed. - *grad_output (Tensor): A list of input tensors whose gradients are to be aggregated. - - Returns: - Tuple[Tensor]: A tuple containing a tensor that holds the gradients of all input tensors. - - """ - # return grad_output - return _AllToAll.apply(*grad_output, ctx.group) - - class MoELayer(nn.Layer): def __init__( self, @@ -212,20 +164,23 @@ def _post_init(self): p.no_sync = not self.is_dummy_moe # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") - def expert_forward(self, dispatched_input): + def expert_forward(self, dispatched_input, exp_token_idx): true_experts = self.experts[ self.moe_rank * self.moe_num_experts_per_device : (self.moe_rank + 1) * self.moe_num_experts_per_device ] expert_outputs = [] - chunks = dispatched_input.unbind(1) - assert len(chunks) == len(true_experts), (len(chunks), len(true_experts)) - for chunk, expert in zip(chunks, true_experts): - chunk = chunk.contiguous() - print("11", chunk.shape) # [ecm] - expert_outputs += [expert(chunk)] - expert_output = paddle.stack(expert_outputs, axis=1) # [ecm] - print("22", expert_output.shape) # [ecm] - return expert_output + + for idx in range(len(dispatched_input)): + # print(dispatched_input[idx]) + # print(exp_token_idx[idx]) + # (LiuTing) can use paddle.stack here. + expert_outputs.append( + true_experts[idx % self.moe_num_experts_per_device](dispatched_input[idx]) + if exp_token_idx[idx] is not None + else dispatched_input[idx] + ) + + return expert_outputs def forward( self, @@ -249,32 +204,52 @@ def forward( # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 reshaped_input = hidden_state.reshape([-1, d_model]) - capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.gate(hidden_state) - # self.l_aux : - # combine_weights : sec - # dispatch_mask : sec + # gates_masked : se + # token_priority : se # self.exp_counts : - dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input) + capacity, gates_masked, token_priority, exp_counts, l_aux, l_zloss = self.gate(hidden_state) + + dispatched_input = [] + # print("token p: ", token_priority) + # print("token pick exp num: ", (token_priority>=0).astype('bfloat16').sum(axis=1)) + # print("less pick exp: ", (token_priority>=0).astype('bfloat16').sum(axis=0).min()) + # print("gates masked: ", gates_masked) + exp_token_idx = [] + for e_i in range(self.moe_num_experts): + expert_tokens_idx = (token_priority[:, e_i] >= 0).nonzero().squeeze(-1) + # (LiuTing) this expert not deal with any token. + if expert_tokens_idx.shape[0] == 0: + exp_token_idx.append(None) + dispatched_input.append(None) + else: + exp_token_idx.append(expert_tokens_idx) + dispatched_input.append(paddle.gather(reshaped_input, exp_token_idx[e_i], axis=0)) if self.expert_parallel_degree > 1: - dispatched_input = _AllToAll.apply(dispatched_input, self.moe_group) - # Re-shape after all-to-all: ecm -> gecm - dispatched_input = dispatched_input.reshape( - [self.expert_parallel_degree, self.moe_num_experts_per_device, -1, d_model] - ) - expert_output = self.expert_forward(dispatched_input) - # Re-shape before drop_tokens: gecm -> ecm - expert_output = expert_output.reshape( - [self.expert_parallel_degree * self.moe_num_experts_per_device, -1, d_model] - ) + output = [] + dist.alltoall(output, dispatched_input, self.moe_group) + dispatched_input = output + expert_output = self.expert_forward(dispatched_input, exp_token_idx) if self.expert_parallel_degree > 1: - expert_output = _AllToAll.apply(expert_output, self.moe_group) - - # combine withe expert weights - combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output) - - a = combined_output.reshape(hidden_state.shape) - + output = [] + dist.alltoall(output, expert_output, self.moe_group) + expert_output = output + + # reformat output + a = paddle.zeros_like(reshaped_input) + + for e_i in range(self.moe_num_experts): + if exp_token_idx[e_i] is not None: + updated = expert_output[e_i] * gates_masked[exp_token_idx[e_i].unsqueeze(-1), e_i] + # print("e_i: ", e_i) + # print("act: ", token_priority[:, e_i]) + # print("expert_output: ", expert_output[e_i]) + # print("exp token idxx: ", exp_token_idx[e_i]) + # print("gate weight: ", gates_masked[exp_token_idx[e_i].unsqueeze(-1), e_i]) + # print("gates value: ", updated) + a[exp_token_idx[e_i]] += updated.astype(a.dtype) + + # print("moe output: ", a, a.mean(), a.max(), a._md5sum()) return a, l_aux, l_zloss