Skip to content

Commit

Permalink
fix ds3 eval
Browse files Browse the repository at this point in the history
  • Loading branch information
wtmlon committed Feb 25, 2025
1 parent d0e834e commit caebc83
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 18 deletions.
6 changes: 1 addition & 5 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False):
RowParallelLinear = linear_utils.RowParallelLinear

if self.q_lora_rank is None:
self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.q_head_dim, has_bias=False, gather_output=False)
self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.q_head_dim, has_bias=False, gather_output=True)
else:
self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias)
self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank, use_sequence_parallel=False)
Expand All @@ -840,10 +840,6 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False):
self.kv_b_proj = ColumnParallelLinear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), has_bias=False, gather_output=True)

self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, has_bias=config.attention_bias, input_is_parallel=False)

assert self.num_heads % config.tensor_parallel_degree == 0, f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
# self.num_heads = self.num_heads // config.tensor_parallel_degree

else:
# for without tensor parallel
if self.q_lora_rank is None:
Expand Down
19 changes: 6 additions & 13 deletions paddlenlp/transformers/moe_gate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# = paddle.topk(tmp_scores, k=k, axis=-1, sorted=Trui Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -231,11 +231,11 @@ def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor:
# less than the expert capacity. One-hot matrix will ignore indices outside
# the range [0, expert_capacity).
# Shape: [tokens_per_group, num_experts, expert_capacity].
valid_mask = paddle.logical_and(token_priority >= 0, token_priority < capacity).astype(paddle.bool)
valid_mask = paddle.logical_and(token_priority >= 0, token_priority < capacity)
token_priority = paddle.masked_fill(token_priority, ~valid_mask, 0)
dispatch_mask = F.one_hot(token_priority, capacity).cast(paddle.bool)
dispatch_mask = F.one_hot(token_priority, capacity).cast(paddle.int32)
valid_mask = valid_mask.unsqueeze(-1).expand(valid_mask.shape + [capacity])
dispatch_mask = dispatch_mask * (~valid_mask)
dispatch_mask = paddle.masked_fill(dispatch_mask, ~valid_mask, 0)

return dispatch_mask

Expand Down Expand Up @@ -530,20 +530,13 @@ def topkgating(
token_priority = self._priority(top_idx, capacity)

# normalize gates
gates_masked = gates * mask
if self.training:
gates_masked = gates * mask
gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True)
denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps)
if self.norm_topk_prob:
gates_masked = gates_masked / denom_s
combine_weights = paddle.einsum(
"se,sec->sec", gates_masked, token_priority.cast(paddle.get_default_dtype())
)
else:
topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1)
combine_weights = paddle.einsum(
"se,sec->sec", topk_masked_gates, token_priority.cast(paddle.get_default_dtype())
)
combine_weights = paddle.einsum("se,sec->sec", gates_masked, token_priority.cast(paddle.get_default_dtype()))

combine_weights = paddle.einsum("se,sec->sec", gates_masked, token_priority.cast(paddle.get_default_dtype()))
dispatch_mask = combine_weights.astype(paddle.bool)
Expand Down

0 comments on commit caebc83

Please sign in to comment.