-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feature] support ep for deepseek v3
- Loading branch information
Showing
7 changed files
with
461 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import warnings | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import torch | ||
import torch.distributed as dist | ||
import torch.functional as F | ||
from torch.distributed import ProcessGroup | ||
from torch.nn import CrossEntropyLoss | ||
from transformers.cache_utils import Cache, DynamicCache | ||
from transformers.modeling_attn_mask_utils import ( | ||
_prepare_4d_causal_attention_mask, | ||
_prepare_4d_causal_attention_mask_for_sdpa, | ||
) | ||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | ||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb | ||
from transformers.utils import is_flash_attn_2_available, logging | ||
|
||
from colossalai.lazy import LazyInitContext | ||
from colossalai.moe._operation import ( | ||
DPGradScalerIn, | ||
DPGradScalerOut, | ||
EPGradScalerIn, | ||
EPGradScalerOut, | ||
all_to_all_uneven, | ||
) | ||
from colossalai.pipeline.stage_manager import PipelineStageManager | ||
from colossalai.quantization.fp8 import all_reduce_fp8 | ||
from colossalai.shardformer.layer._operation import ( | ||
all_to_all_comm, | ||
gather_forward_split_backward, | ||
linear_with_async_comm, | ||
split_forward_gather_backward, | ||
) | ||
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule | ||
from colossalai.shardformer.shard import ShardConfig | ||
from colossalai.shardformer.shard.utils import set_tensors_to_none | ||
from colossalai.tensor.d_tensor.api import shard_rowwise, sharded_tensor_to_existing_param | ||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group | ||
|
||
|
||
class EpDeepseekV3MoE(ParallelModule): | ||
""" | ||
A mixed expert module containing shared experts. | ||
""" | ||
|
||
def __init__(self, config): | ||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") | ||
|
||
def setup_process_groups( | ||
self, | ||
moe_dp_group: ProcessGroup, | ||
ep_group: ProcessGroup, | ||
): | ||
assert moe_dp_group is not None | ||
assert ep_group is not None | ||
|
||
self.ep_size = dist.get_world_size(ep_group) | ||
self.ep_rank = dist.get_rank(ep_group) | ||
self.num_experts = self.config.n_routed_experts | ||
assert self.num_experts % self.ep_size == 0 | ||
|
||
self.ep_group = ep_group | ||
self.num_experts_per_ep = self.num_experts // self.ep_size | ||
self.experts_per_rank = self.num_experts_per_ep | ||
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep | ||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] | ||
|
||
set_tensors_to_none(self.experts, exclude=set(held_experts)) | ||
|
||
# setup moe_dp group | ||
self.moe_dp_group = moe_dp_group | ||
self.moe_dp_size = dist.get_world_size(moe_dp_group) | ||
|
||
for p in self.experts.parameters(): | ||
set_moe_tensor_ep_group(p, ep_group) | ||
|
||
@staticmethod | ||
def from_native_module( | ||
module, | ||
moe_dp_group: ProcessGroup, | ||
ep_group: ProcessGroup, | ||
*args, | ||
**kwargs, | ||
) -> "EpDeepseekV3MoE": | ||
LazyInitContext.materialize(module) | ||
if module.__class__.__name__ == "DeepseekV3MLP": | ||
return module | ||
module.__class__ = EpDeepseekV3MoE | ||
module.setup_process_groups(moe_dp_group, ep_group) | ||
return module | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
identity = hidden_states | ||
orig_shape = hidden_states.shape | ||
topk_idx, topk_weight = self.gate(hidden_states) | ||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||
y = self.moe_forward(hidden_states, topk_idx, topk_weight).view(*orig_shape) | ||
if self.config.n_shared_experts is not None: | ||
y = y + self.shared_experts(identity) | ||
return y | ||
|
||
def moe_forward(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: | ||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) | ||
cnts.scatter_(1, topk_ids, 1) | ||
tokens_per_expert = cnts.sum(dim=0) | ||
idxs = topk_ids.view(-1).argsort() | ||
sorted_tokens = x[idxs // topk_ids.shape[1]] | ||
if self.ep_size > 1: | ||
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) | ||
tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0]) | ||
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert, group=self.ep_group) | ||
|
||
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).tolist() | ||
input_split_sizes = tokens_per_ep_rank.tolist() | ||
|
||
gathered_tokens, _ = all_to_all_uneven(sorted_tokens, input_split_sizes, output_splits, self.ep_group) | ||
tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0) | ||
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) | ||
s = 0 | ||
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): | ||
gatherd_idxs[s : s + k] = i % self.experts_per_rank | ||
s += k | ||
gatherd_idxs = gatherd_idxs.argsort() | ||
sorted_tokens = gathered_tokens[gatherd_idxs] | ||
tokens_per_expert = tokens_per_expert_post_gather | ||
|
||
# moe-dp related code | ||
activate_experts = tokens_per_expert_post_gather > 0 | ||
activate_experts = activate_experts.int() | ||
dist.all_reduce(activate_experts, group=self.moe_dp_group) | ||
|
||
# ep related code | ||
sorted_tokens = EPGradScalerIn.apply(sorted_tokens, self.ep_size) | ||
|
||
tokens_per_expert = tokens_per_expert.cpu().numpy() | ||
|
||
outputs = [] | ||
start_idx = 0 | ||
for i, num_tokens in enumerate(tokens_per_expert): | ||
end_idx = start_idx + num_tokens | ||
if num_tokens == 0: | ||
continue | ||
expert = self.experts[i + self.ep_rank * self.experts_per_rank] | ||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx] | ||
# moe-dp related code | ||
tokens_for_this_expert = DPGradScalerIn.apply(tokens_for_this_expert, self.moe_dp_size, activate_experts[i]) | ||
expert_out = expert(tokens_for_this_expert) | ||
# moe-dp related code | ||
expert_out = DPGradScalerOut.apply(expert_out, self.moe_dp_size, activate_experts[i]) | ||
outputs.append(expert_out) | ||
start_idx = end_idx | ||
|
||
if len(outputs) > 0: | ||
outs = torch.cat(outputs, dim=0) | ||
else: | ||
assert sorted_tokens.numel() == 0, f"sorted_tokens: should be empty, but got {sorted_tokens.shape}" | ||
outs = sorted_tokens | ||
|
||
if self.ep_size > 1: | ||
outs = EPGradScalerOut.apply(outs, self.ep_size) | ||
new_x = torch.empty_like(outs) | ||
new_x[gatherd_idxs] = outs | ||
gathered_tokens, _ = all_to_all_uneven(new_x, output_splits, input_split_sizes, self.ep_group) | ||
outs = gathered_tokens | ||
|
||
new_x = torch.empty_like(outs) | ||
new_x[idxs] = outs | ||
final_out = ( | ||
(new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype) * topk_weight.unsqueeze(dim=-1)) | ||
.sum(dim=1) | ||
.type(new_x.dtype) | ||
) | ||
|
||
return final_out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from typing import Callable, Dict, List, Union | ||
|
||
import torch.nn as nn | ||
|
||
from colossalai.shardformer.layer import FusedRMSNorm | ||
from colossalai.shardformer.modeling.deepseek_v3 import EpDeepseekV3MoE | ||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription | ||
|
||
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"] | ||
|
||
|
||
class DeepseekV3Policy(Policy): | ||
def config_sanity_check(self): | ||
assert not self.shard_config.enable_tensor_parallelism, "DeepSeekV3 does not support tensor parallelism" | ||
assert self.shard_config.pipeline_stage_manager is None, "DeepSeekV3 does not support pipeline parallelism" | ||
assert not self.shard_config.enable_sequence_parallelism, "DeepSeekV3 does not support sequence parallelism" | ||
|
||
def preprocess(self): | ||
return self.model | ||
|
||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: | ||
|
||
policy = {} | ||
|
||
if self.shard_config.ep_group: | ||
# expert parallel | ||
self.append_or_create_submodule_replacement( | ||
description=[ | ||
SubModuleReplacementDescription( | ||
suffix="mlp", | ||
target_module=EpDeepseekV3MoE, | ||
kwargs={ | ||
"ep_group": self.shard_config.ep_group, | ||
"moe_dp_group": self.shard_config.moe_dp_group, | ||
}, | ||
) | ||
], | ||
policy=policy, | ||
target_key="DeepseekV3DecoderLayer", | ||
) | ||
|
||
# optimization configuration | ||
if self.shard_config.enable_fused_normalization: | ||
# TODO: prevent casting to fp32 | ||
self.append_or_create_submodule_replacement( | ||
description=[ | ||
SubModuleReplacementDescription( | ||
suffix="input_layernorm", | ||
target_module=FusedRMSNorm, | ||
), | ||
SubModuleReplacementDescription( | ||
suffix="post_attention_layernorm", | ||
target_module=FusedRMSNorm, | ||
), | ||
], | ||
policy=policy, | ||
target_key="DeepseekV3DecoderLayer", | ||
) | ||
|
||
self.append_or_create_submodule_replacement( | ||
description=SubModuleReplacementDescription( | ||
suffix="norm", | ||
target_module=FusedRMSNorm, | ||
), | ||
policy=policy, | ||
target_key="DeepseekV3Model", | ||
) | ||
|
||
return policy | ||
|
||
def postprocess(self): | ||
return self.model | ||
|
||
|
||
class DeepseekV3ModelPolicy(DeepseekV3Policy): | ||
pass | ||
|
||
|
||
class DeepseekV3ForCausalLMPolicy(DeepseekV3Policy): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# modified from tests/kit/model_zoo/transformers/mistral.py | ||
import torch | ||
import transformers | ||
from transformers import AutoConfig | ||
|
||
from ..registry import ModelAttribute, model_zoo | ||
|
||
# =============================== | ||
# Register single-sentence Mixtral | ||
# =============================== | ||
|
||
|
||
def data_gen(): | ||
# Generated from following code snippet | ||
# | ||
# from transformers import AutoModelForCausalLM, AutoTokenizer | ||
# tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1") | ||
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement) | ||
# tokenized_input = tokenizer([input], return_tensors="pt") | ||
# input_ids = tokenized_input['input_ids'] | ||
# attention_mask = tokenized_input['attention_mask'] | ||
input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64) | ||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) | ||
return dict(input_ids=input_ids, attention_mask=attention_mask) | ||
|
||
|
||
def data_gen_for_lm(): | ||
# LM data gen | ||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` | ||
data = data_gen() | ||
data["labels"] = data["input_ids"].clone() | ||
return data | ||
|
||
|
||
# define output transform function | ||
output_transform_fn = lambda x: x | ||
|
||
# define loss function | ||
loss_fn = lambda x: x[0].mean() | ||
loss_fn_for_lm = lambda x: x.loss | ||
|
||
|
||
def init_deepseek(): | ||
|
||
config = AutoConfig.from_pretrained( | ||
"deepseek-ai/DeepSeek-V3", | ||
hidden_size=128, | ||
intermediate_size=320, | ||
kv_lora_rank=4, | ||
moe_intermediate_size=32, | ||
num_attention_heads=4, | ||
num_experts_per_tok=4, | ||
n_group=4, | ||
num_hidden_layers=3, | ||
num_key_value_heads=4, | ||
first_k_dense_replace=1, | ||
q_lora_rank=8, | ||
torch_dtype="bfloat16", | ||
n_routed_experts=16, | ||
topk_group=2, | ||
v_head_dim=32, | ||
qk_nope_head_dim=32, | ||
qk_rope_head_dim=32, | ||
trust_remote_code=True, | ||
vocab_size=2048, | ||
) | ||
|
||
if hasattr(config, "pad_token_id"): | ||
config.pad_token_id = config.eos_token_id | ||
model = transformers.AutoModelForCausalLM.from_config(config, trust_remote_code=True) | ||
|
||
return model | ||
|
||
|
||
model_zoo.register( | ||
name="transformers_deepseek_v3", | ||
model_fn=init_deepseek, | ||
data_gen_fn=data_gen_for_lm, | ||
output_transform_fn=output_transform_fn, | ||
loss_fn=loss_fn_for_lm, | ||
model_attribute=ModelAttribute(has_control_flow=True), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.