Skip to content

Commit

Permalink
[feature] support ep for deepseek v3
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 committed Feb 6, 2025
1 parent 17062c8 commit 3f84584
Show file tree
Hide file tree
Showing 7 changed files with 461 additions and 2 deletions.
175 changes: 175 additions & 0 deletions colossalai/shardformer/modeling/deepseek_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import warnings
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.distributed as dist
import torch.functional as F
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.utils import is_flash_attn_2_available, logging

from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import (
DPGradScalerIn,
DPGradScalerOut,
EPGradScalerIn,
EPGradScalerOut,
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization.fp8 import all_reduce_fp8
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
linear_with_async_comm,
split_forward_gather_backward,
)
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.d_tensor.api import shard_rowwise, sharded_tensor_to_existing_param
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group


class EpDeepseekV3MoE(ParallelModule):
"""
A mixed expert module containing shared experts.
"""

def __init__(self, config):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")

def setup_process_groups(
self,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
):
assert moe_dp_group is not None
assert ep_group is not None

self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
self.num_experts = self.config.n_routed_experts
assert self.num_experts % self.ep_size == 0

self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.experts_per_rank = self.num_experts_per_ep
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]

set_tensors_to_none(self.experts, exclude=set(held_experts))

# setup moe_dp group
self.moe_dp_group = moe_dp_group
self.moe_dp_size = dist.get_world_size(moe_dp_group)

for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)

@staticmethod
def from_native_module(
module,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
*args,
**kwargs,
) -> "EpDeepseekV3MoE":
LazyInitContext.materialize(module)
if module.__class__.__name__ == "DeepseekV3MLP":
return module
module.__class__ = EpDeepseekV3MoE
module.setup_process_groups(moe_dp_group, ep_group)
return module

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
identity = hidden_states
orig_shape = hidden_states.shape
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
y = self.moe_forward(hidden_states, topk_idx, topk_weight).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
return y

def moe_forward(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
if self.ep_size > 1:
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert, group=self.ep_group)

output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).tolist()
input_split_sizes = tokens_per_ep_rank.tolist()

gathered_tokens, _ = all_to_all_uneven(sorted_tokens, input_split_sizes, output_splits, self.ep_group)
tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0)
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
s = 0
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
gatherd_idxs[s : s + k] = i % self.experts_per_rank
s += k
gatherd_idxs = gatherd_idxs.argsort()
sorted_tokens = gathered_tokens[gatherd_idxs]
tokens_per_expert = tokens_per_expert_post_gather

# moe-dp related code
activate_experts = tokens_per_expert_post_gather > 0
activate_experts = activate_experts.int()
dist.all_reduce(activate_experts, group=self.moe_dp_group)

# ep related code
sorted_tokens = EPGradScalerIn.apply(sorted_tokens, self.ep_size)

tokens_per_expert = tokens_per_expert.cpu().numpy()

outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
# moe-dp related code
tokens_for_this_expert = DPGradScalerIn.apply(tokens_for_this_expert, self.moe_dp_size, activate_experts[i])
expert_out = expert(tokens_for_this_expert)
# moe-dp related code
expert_out = DPGradScalerOut.apply(expert_out, self.moe_dp_size, activate_experts[i])
outputs.append(expert_out)
start_idx = end_idx

if len(outputs) > 0:
outs = torch.cat(outputs, dim=0)
else:
assert sorted_tokens.numel() == 0, f"sorted_tokens: should be empty, but got {sorted_tokens.shape}"
outs = sorted_tokens

if self.ep_size > 1:
outs = EPGradScalerOut.apply(outs, self.ep_size)
new_x = torch.empty_like(outs)
new_x[gatherd_idxs] = outs
gathered_tokens, _ = all_to_all_uneven(new_x, output_splits, input_split_sizes, self.ep_group)
outs = gathered_tokens

new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
(new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype) * topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)

return final_out
7 changes: 7 additions & 0 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ class PolicyLocation:
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
file_name="deepseek", class_name="DeepseekForCausalLMPolicy"
),
# DeepseekV3
"transformers_modules.modeling_deepseek.DeepseekV3Model": PolicyLocation(
file_name="deepseek_v3", class_name="DeepseekV3ModelPolicy"
),
"transformers_modules.modeling_deepseek.DeepseekV3ForCausalLM": PolicyLocation(
file_name="deepseek_v3", class_name="DeepseekV3ForCausalLMPolicy"
),
# Falcon
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
file_name="falcon", class_name="FalconModelPolicy"
Expand Down
80 changes: 80 additions & 0 deletions colossalai/shardformer/policies/deepseek_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Callable, Dict, List, Union

import torch.nn as nn

from colossalai.shardformer.layer import FusedRMSNorm
from colossalai.shardformer.modeling.deepseek_v3 import EpDeepseekV3MoE
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]


class DeepseekV3Policy(Policy):
def config_sanity_check(self):
assert not self.shard_config.enable_tensor_parallelism, "DeepSeekV3 does not support tensor parallelism"
assert self.shard_config.pipeline_stage_manager is None, "DeepSeekV3 does not support pipeline parallelism"
assert not self.shard_config.enable_sequence_parallelism, "DeepSeekV3 does not support sequence parallelism"

def preprocess(self):
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:

policy = {}

if self.shard_config.ep_group:
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=EpDeepseekV3MoE,
kwargs={
"ep_group": self.shard_config.ep_group,
"moe_dp_group": self.shard_config.moe_dp_group,
},
)
],
policy=policy,
target_key="DeepseekV3DecoderLayer",
)

# optimization configuration
if self.shard_config.enable_fused_normalization:
# TODO: prevent casting to fp32
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key="DeepseekV3DecoderLayer",
)

self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key="DeepseekV3Model",
)

return policy

def postprocess(self):
return self.model


class DeepseekV3ModelPolicy(DeepseekV3Policy):
pass


class DeepseekV3ForCausalLMPolicy(DeepseekV3Policy):
pass
1 change: 1 addition & 0 deletions tests/kit/model_zoo/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .chatglm2 import *
from .command import *
from .deepseek import *
from .deepseek_v3 import *
from .falcon import *
from .gpt import *
from .gptj import *
Expand Down
82 changes: 82 additions & 0 deletions tests/kit/model_zoo/transformers/deepseek_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# modified from tests/kit/model_zoo/transformers/mistral.py
import torch
import transformers
from transformers import AutoConfig

from ..registry import ModelAttribute, model_zoo

# ===============================
# Register single-sentence Mixtral
# ===============================


def data_gen():
# Generated from following code snippet
#
# from transformers import AutoModelForCausalLM, AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
# tokenized_input = tokenizer([input], return_tensors="pt")
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)


def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data["labels"] = data["input_ids"].clone()
return data


# define output transform function
output_transform_fn = lambda x: x

# define loss function
loss_fn = lambda x: x[0].mean()
loss_fn_for_lm = lambda x: x.loss


def init_deepseek():

config = AutoConfig.from_pretrained(
"deepseek-ai/DeepSeek-V3",
hidden_size=128,
intermediate_size=320,
kv_lora_rank=4,
moe_intermediate_size=32,
num_attention_heads=4,
num_experts_per_tok=4,
n_group=4,
num_hidden_layers=3,
num_key_value_heads=4,
first_k_dense_replace=1,
q_lora_rank=8,
torch_dtype="bfloat16",
n_routed_experts=16,
topk_group=2,
v_head_dim=32,
qk_nope_head_dim=32,
qk_rope_head_dim=32,
trust_remote_code=True,
vocab_size=2048,
)

if hasattr(config, "pad_token_id"):
config.pad_token_id = config.eos_token_id
model = transformers.AutoModelForCausalLM.from_config(config, trust_remote_code=True)

return model


model_zoo.register(
name="transformers_deepseek_v3",
model_fn=init_deepseek,
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
2 changes: 0 additions & 2 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def _criterion(outputs, inputs):
for k, v in data.items():
unshard_test_data[k] = data[k].clone()

sharded_model.train()
if booster.plugin.stage_manager is not None:
for k, v in shard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
Expand All @@ -248,7 +247,6 @@ def _criterion(outputs, inputs):
sharded_loss = criterion(sharded_output)
sharded_optimizer.backward(sharded_loss)

org_model.train()
if booster.plugin.stage_manager is not None:
for k, v in unshard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
Expand Down
Loading

0 comments on commit 3f84584

Please sign in to comment.