Skip to content

Commit

Permalink
support deepseek-v3/loraGA/ on xpu
Browse files Browse the repository at this point in the history
  • Loading branch information
wtmlon committed Feb 24, 2025
1 parent 0c65ce7 commit 5d13984
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 23 deletions.
4 changes: 3 additions & 1 deletion paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,9 @@ def process_split_and_assign(name, concat_tensor, init_dict, state_dict):
base_name = name.replace("lora_A", "weight")
if not self.reinit_base_model:
# Reinit base model
offset = init_loraA.cuda() @ init_loraB.cuda()
offset = init_loraA._copy_to(
paddle.framework._current_expected_place(), False
) @ init_loraB._copy_to(paddle.framework._current_expected_place(), False)
ori_weight = model_state_dict[base_name]
model_state_dict[base_name].set_value(ori_weight - self.lora_config.scaling * offset)
del model_state_dict
Expand Down
11 changes: 8 additions & 3 deletions paddlenlp/peft/lora/loraga_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_module_gradient(

rank_suffix = "_" + str(local_rank)
local_grad_name = ".".join(grad_name.split(".")[1:]) + ".weight" + rank_suffix
gradient = gradient_dict.pop(local_grad_name).cuda()
gradient = gradient_dict.pop(local_grad_name)._copy_to(paddle.framework._current_expected_place(), False)

is_fleet_init = True
try:
Expand All @@ -187,7 +187,7 @@ def get_module_gradient(
dist.all_gather(output_tensors, gradient, group=model_parallel_group)

output_tensors = [t if len(t.shape) > 0 else t.reshape_([-1]) for t in output_tensors]
gradient = merge_func(output_tensors).cuda()
gradient = merge_func(output_tensors)._copy_to(paddle.framework._current_expected_place(), False)

# sharding
if sharding_degree > 1:
Expand Down Expand Up @@ -381,7 +381,12 @@ def record_gradient_hook(*_):
gradient_dict[local_grad_name] = grad.clone() / self.loraga_init_iters
else:
if self.gradient_offload:
new_grad = gradient_dict[local_grad_name].cuda() + grad / self.loraga_init_iters
new_grad = (
gradient_dict[local_grad_name]._copy_to(
paddle.framework._current_expected_place(), False
)
+ grad / self.loraga_init_iters
)
gradient_dict[local_grad_name] = new_grad.cpu()
else:
gradient_dict[local_grad_name] += grad / self.loraga_init_iters
Expand Down
70 changes: 61 additions & 9 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from paddle import Tensor, nn
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.utils import recompute
from paddle.distributed.fleet.recompute.recompute import recompute
from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

try:
Expand Down Expand Up @@ -691,7 +691,9 @@ def forward(self, hidden_states):
_, h_dim = hidden_states.shape

# compute gating score
print("linear input: ", hidden_states._md5sum())
logits = F.linear(hidden_states, self.weight, None)
print("linear output: ", logits._md5sum())

with paddle.amp.auto_cast(False):
scores = self.gate_score_func(logits=logits)
Expand Down Expand Up @@ -831,16 +833,16 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False):
else:
self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias)
self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank, use_sequence_parallel=False)
self.q_b_proj = ColumnParallelLinear(config.q_lora_rank, self.num_heads * self.q_head_dim, has_bias=False, gather_output=False)
self.q_b_proj = ColumnParallelLinear(config.q_lora_rank, self.num_heads * self.q_head_dim, has_bias=False, gather_output=True)

self.kv_a_proj_with_mqa = nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias)
self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank, use_sequence_parallel=False)
self.kv_b_proj = ColumnParallelLinear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), has_bias=False, gather_output=False)
self.kv_b_proj = ColumnParallelLinear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), has_bias=False, gather_output=True)

self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, has_bias=config.attention_bias, input_is_parallel=True)
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, has_bias=config.attention_bias, input_is_parallel=False)

assert self.num_heads % config.tensor_parallel_degree == 0, f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
self.num_heads = self.num_heads // config.tensor_parallel_degree
# self.num_heads = self.num_heads // config.tensor_parallel_degree

else:
# for without tensor parallel
Expand Down Expand Up @@ -941,12 +943,20 @@ def forward(
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
print("qa input: ", hidden_states._md5sum())
print("qa weight: ", self.q_a_proj.weight._md5sum())
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
print("qb weight shape: ", self.q_b_proj.weight.shape)
print("qb weight reshape: ", [bsz, q_len, self.num_heads, self.q_head_dim])
print("q output shape: ", q.shape)
print(self.q_a_proj, self.q_b_proj)
q = q.reshape([bsz, q_len, self.num_heads, self.q_head_dim])
q_nope, q_pe = paddle.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)

# DeepSeekV2 kv_lora_rank+qk_rope_head_dim=512+64
print("kva weight: ", self.kv_a_proj_with_mqa.weight._md5sum())
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
print(self.kv_a_proj_with_mqa, self.kv_b_proj)
compressed_kv, k_pe = paddle.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
k_pe = k_pe.reshape([bsz, q_len, 1, self.qk_rope_head_dim])

Expand Down Expand Up @@ -1022,11 +1032,23 @@ def forward(
# if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim]
# else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism.
attn_output = self.o_proj(attn_output)
print("o ouput: ", attn_output._md5sum())

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
outputs = (attn_output,)

if output_attentions:
outputs += (attn_weights,)

if use_cache:
outputs += (past_key_value,)

if type(outputs) is tuple and len(outputs) == 1:
outputs = outputs[0]

return outputs


class DeepseekV2DecoderLayer(nn.Layer):
Expand Down Expand Up @@ -1095,7 +1117,7 @@ def forward(
and has_gradient
and self.recompute_granularity == "full_attn"
):
hidden_states, self_attn_weights, present_key_value = recompute(
outputs = recompute(
self.self_attn,
hidden_states=hidden_states,
position_ids=position_ids,
Expand All @@ -1107,7 +1129,7 @@ def forward(
**kwargs,
)
else:
hidden_states, self_attn_weights, present_key_value = self.self_attn(
outputs = self.self_attn(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
Expand All @@ -1117,6 +1139,18 @@ def forward(
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
**kwargs,
)

if type(outputs) is tuple:
hidden_states = outputs[0]
else:
hidden_states = outputs

if output_attentions:
self_attn_weights = outputs[1]

if use_cache:
present_key_value = outputs[2 if output_attentions else 1]

hidden_states = residual + hidden_states

# Fully Connected
Expand Down Expand Up @@ -1676,6 +1710,15 @@ def __init__(self, config: DeepseekV2Config):
)
# Must set distributed attr for Tensor Parallel !
self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False
if get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import ( # noqa: F401
parallel_matmul as xpu_parallel_matmul,
)

self.xpu_parallel_matmul = xpu_parallel_matmul()
except ImportError:
self.xpu_parallel_matmul = None

def forward(self, hidden_states, tensor_parallel_output=None):
if self.config.sequence_parallel:
Expand All @@ -1686,7 +1729,16 @@ def forward(self, hidden_states, tensor_parallel_output=None):
if tensor_parallel_output is None:
tensor_parallel_output = self.config.tensor_parallel_output

logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None:
logits = self.xpu_parallel_matmul(
hidden_states,
self.weight,
transpose_y=False,
tensor_parallel_output=tensor_parallel_output,
training=self.training,
)
else:
logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
return logits


Expand Down
33 changes: 24 additions & 9 deletions paddlenlp/transformers/moe_gate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
# = paddle.topk(tmp_scores, k=k, axis=-1, sorted=Trui Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -231,11 +231,21 @@ def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor:
# less than the expert capacity. One-hot matrix will ignore indices outside
# the range [0, expert_capacity).
# Shape: [tokens_per_group, num_experts, expert_capacity].
valid_mask = paddle.logical_and(token_priority >= 0, token_priority < capacity)
valid_mask = paddle.logical_and(token_priority >= 0, token_priority < capacity).astype(paddle.bool)
token_priority = paddle.masked_fill(token_priority, ~valid_mask, 0)
dispatch_mask = F.one_hot(token_priority, capacity).cast(paddle.int32)
dispatch_mask = F.one_hot(token_priority, capacity).cast(paddle.bool)
valid_mask = valid_mask.unsqueeze(-1).expand(valid_mask.shape + [capacity])
dispatch_mask = paddle.masked_fill(dispatch_mask, ~valid_mask, 0)
# p = paddle.topk(tmp_scores, k=k, axis=-1, sorted=Truirint('1', valid_mask)
# print('2', ~valid_mask)
# print('3', dispatch_mask)
# dispatch_mask = paddle.masked_fill(dispatch_mask, ~valid_mask, 0)
dispatch_mask = dispatch_mask * (~valid_mask)

# valid_mask = paddle.logical_and(token_priority >= 0, token_priority < capacity)
# token_priority = paddle.masked_fill(token_priority, ~valid_mask, 0)
# dispatch_mask = F.one_hot(token_priority, capacity).cast(paddle.int32)
# valid_mask = valid_mask.unsqueeze(-1).expand(valid_mask.shape + [capacity])
# dispatch_mask = paddle.masked_fill(dispatch_mask, ~valid_mask, 0)

return dispatch_mask

Expand Down Expand Up @@ -276,13 +286,13 @@ def _topk_group_limited_greedy(
assert n_experts % n_group == 0, "n_experts must be divisible by n_groups"

group_scores = scores.reshape([0, n_group, -1]).max(axis=-1) # [n, n_group]
group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=False)[1] # [n, top_k_group]
group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group]
group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0), axis=-1) # fmt:skip
score_mask = (
group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1])
) # [n, e]
tmp_scores = scores * score_mask # [n, e]
topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=False)
topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=True)

return topk_weight, topk_idx

Expand Down Expand Up @@ -312,13 +322,13 @@ def _topk_noaux_tc(
group_scores = (
scores_for_choice.reshape([bsz_seq_len, self.n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1)
) # fmt:skip [n, n_group]
group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=False)[1] # [n, top_k_group]
group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group]
group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0), axis=-1) # fmt:skip
score_mask = (
group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1])
) # [n, e]
tmp_scores = scores_for_choice * score_mask # [n, e]
topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=False)
topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=True)
topk_weight = scores.take_along_axis(topk_idx, axis=1) if not self.training else topk_weight

return topk_weight, topk_idx
Expand Down Expand Up @@ -545,6 +555,11 @@ def topkgating(
"se,sec->sec", topk_masked_gates, token_priority.cast(paddle.get_default_dtype())
)

dispatch_mask = combine_weights.cast(paddle.bool)
# print(gates_masked)
# gates_masked = gates_masked.astype("bool").unsqueeze(-1).expand(gates_masked.shape + token_priority.shape[-1:])
# print(gates_masked)
combine_weights = paddle.einsum("se,sec->sec", gates_masked, token_priority.cast(paddle.get_default_dtype()))
# combine_weights = gates_masked * token_priority
dispatch_mask = combine_weights.astype(paddle.bool)

return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss
6 changes: 5 additions & 1 deletion paddlenlp/transformers/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,10 @@ def expert_forward(self, dispatched_input):
assert len(chunks) == len(true_experts), (len(chunks), len(true_experts))
for chunk, expert in zip(chunks, true_experts):
chunk = chunk.contiguous()
print("11", chunk.shape) # [ecm]
expert_outputs += [expert(chunk)]
expert_output = paddle.stack(expert_outputs, axis=1) # [ecm]
print("22", expert_output.shape) # [ecm]
return expert_output

def forward(
Expand Down Expand Up @@ -248,11 +250,13 @@ def forward(
# dispatch_mask : sec
# self.exp_counts :
dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input)
# dispatched_input = paddle.masked_fill_(reshaped_input, dispatch_mask)

if self.expert_parallel_degree > 1:
print(dispatched_input, self.moe_group)
dispatched_input = _AllToAll.apply(dispatched_input, self.moe_group)

# Re-shape after all-to-all: ecm -> gecm
print(dispatched_input.shape)
dispatched_input = dispatched_input.reshape(
[self.expert_parallel_degree, self.moe_num_experts_per_device, -1, d_model]
)
Expand Down

0 comments on commit 5d13984

Please sign in to comment.