From c01a387ad403437198084133104e32c93a9fd81d Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 16 Sep 2025 17:38:28 +0800 Subject: [PATCH] update using_post_norm_recompute --- .../transformers/deepseek_v2/modeling.py | 30 +-------- .../transformers/deepseek_v2/modeling_pp.py | 64 +++---------------- paddlenlp/transformers/fp8_utils.py | 13 ---- 3 files changed, 11 insertions(+), 96 deletions(-) diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 33fcd4520411..3c4a60bd1fa3 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -798,19 +798,6 @@ def forward(self, x): class FusedNormGateFunc(paddle.autograd.PyLayer): """recompute of postnorm and gate""" - _current_norm_output = None - _current_invar = None - - @classmethod - def set_temporary_vars(cls, norm_output, invar): - FusedNormGateFunc._current_norm_output = norm_output - FusedNormGateFunc._current_invar = invar - - @classmethod - def clear_temporary_vars(cls): - FusedNormGateFunc._current_norm_output = None - FusedNormGateFunc._current_invar = None - @staticmethod def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps): ctx.dtype = paddle.float32 @@ -825,10 +812,7 @@ def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps): def backward(ctx, d_gate_logits, d_norm_output): x, rms_norm_weight, moe_gate_weight, eps = ctx.saved_tensor() # recompute rmsnorm - norm_output = FusedNormGateFunc._current_norm_output - invar = FusedNormGateFunc._current_invar - if norm_output is None or invar is None: - norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) d_norm_output_linear, d_moe_gate_weight = paddle._C_ops.matmul_grad( cast_if_needed(norm_output, ctx.dtype), cast_if_needed(moe_gate_weight, ctx.dtype), @@ -846,18 +830,6 @@ def backward(ctx, d_gate_logits, d_norm_output): return dx, d_rms_norm_weight, d_moe_gate_weight -class TemporaryVarContext: - def __init__(self, norm_output, invar): - self.norm_output = norm_output - self.invar = invar - - def __enter__(self): - FusedNormGateFunc.set_temporary_vars(self.norm_output, self.invar) - - def __exit__(self, exc_type, exc_val, exc_tb): - FusedNormGateFunc.clear_temporary_vars() - - def balance_expert_assignment(n, m, k): assert k * n % m == 0 matrix = paddle.zeros((n, m), dtype=paddle.int32) diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index ab1db3621cf4..dae012efb95d 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -50,7 +50,6 @@ DeepseekV2PretrainedModel, DeepseekV2PretrainingCriterion, DeepseekV2RMSNorm, - TemporaryVarContext, set_global_step, ) @@ -274,19 +273,9 @@ def backward(self, output_grad): hidden_states_grad = do3 inputs_embeds_mtp_grad = None - if self.using_post_norm_recompute: - dx, norm_out, invar = FP8LinearFunctionBase.fp8_mlp_bwd_norm_rc( - hidden_states_grad, - self.x, - self.shared_experts.norm_weight, - self.shared_experts.norm_eps, - self.shared_experts.w1, - self.shared_experts.w2, - ) - else: - dx = FP8LinearFunctionBase.fp8_mlp_bwd( - hidden_states_grad, self.x, self.shared_experts.w1, self.shared_experts.w2, True - ) + dx = FP8LinearFunctionBase.fp8_mlp_bwd( + hidden_states_grad, self.x, self.shared_experts.w1, self.shared_experts.w2, True + ) self.x = None @@ -296,17 +285,9 @@ def backward(self, output_grad): if self.using_post_norm_recompute: if self.send_mtp_embed: - return ( - inputs_embeds_mtp_grad, - dx, - residual_grad, - l_aux_grad, - final_hidden_states_grad, - norm_out, - invar, - ) + return (inputs_embeds_mtp_grad, dx, residual_grad, l_aux_grad, final_hidden_states_grad) else: - return (dx, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar) + return (dx, residual_grad, l_aux_grad, final_hidden_states_grad) else: if self.send_mtp_embed: return (inputs_embeds_mtp_grad, dx, residual_grad, l_aux_grad, final_hidden_states_grad) @@ -779,11 +760,9 @@ def post_process_backward(self, output_grad, event_to_wait=None): residual_grad, l_aux_grad, final_hidden_states_grad, - norm_out, - invar, ) = grad else: - hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar = grad + hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad else: if self.send_mtp_embed: inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad @@ -796,7 +775,6 @@ def post_process_backward(self, output_grad, event_to_wait=None): ret = (hidden_states_grad, residual_grad, l_aux_grad, output_combine_grad, quant_event) ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret - ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret return ret def combine_backward(self, output_grad, previous_event=None, async_finish=False, allocate_on_comm_stream=False): @@ -809,19 +787,9 @@ def combine_backward(self, output_grad, previous_event=None, async_finish=False, l_aux_grad, output_combine_grad, quant_event, - norm_out, - invar, ) = output_grad else: - ( - hidden_states_grad, - residual_grad, - l_aux_grad, - output_combine_grad, - quant_event, - norm_out, - invar, - ) = output_grad + (hidden_states_grad, residual_grad, l_aux_grad, output_combine_grad, quant_event) = output_grad else: if self.send_mtp_embed: ( @@ -854,7 +822,6 @@ def combine_backward(self, output_grad, previous_event=None, async_finish=False, ret = (hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad) ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret - ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret return ret def mlp_backward(self, output_grad): @@ -866,11 +833,9 @@ def mlp_backward(self, output_grad): residual_grad, l_aux_grad, hidden_states_out_grad, - norm_out, - invar, ) = output_grad else: - hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad, norm_out, invar = output_grad + hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad = output_grad else: if self.send_mtp_embed: ( @@ -886,7 +851,6 @@ def mlp_backward(self, output_grad): ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad) ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret - ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret return ret def dispatch_backward(self, output_grad, async_finish=False, previous_event=None, allocate_on_comm_stream=False): @@ -899,8 +863,6 @@ def dispatch_backward(self, output_grad, async_finish=False, previous_event=None l_aux_grad, hs_dispatched_grad, dispatched_probs_grad, - norm_out, - invar, ) = output_grad else: ( @@ -909,8 +871,6 @@ def dispatch_backward(self, output_grad, async_finish=False, previous_event=None l_aux_grad, hs_dispatched_grad, dispatched_probs_grad, - norm_out, - invar, ) = output_grad else: if self.send_mtp_embed: @@ -935,7 +895,6 @@ def dispatch_backward(self, output_grad, async_finish=False, previous_event=None ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad) ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret - ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret return ret def attn_backward(self, output_grad): @@ -948,14 +907,12 @@ def attn_backward(self, output_grad): l_aux_grad, hs_grad, token_probs_grad, - norm_out, - invar, ) = output_grad inputs_embeds_mtp_grad_shape = hidden_states_grad.shape inputs_embeds_mtp_grad_shape[-1] = -1 inputs_embeds_mtp_grad = inputs_embeds_mtp_grad.view(inputs_embeds_mtp_grad_shape) else: - hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad, norm_out, invar = output_grad + hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad = output_grad else: if self.send_mtp_embed: ( @@ -986,8 +943,7 @@ def attn_backward(self, output_grad): output_grad = (inputs_embeds_mtp_grad, *output_grad) if self.send_mtp_embed else output_grad if self.using_post_norm_recompute: - with TemporaryVarContext(norm_out, invar): - output_grad = self.attn_and_gate_node.backward(output_grad) + output_grad = self.attn_and_gate_node.backward(output_grad) else: output_grad = self.attn_and_gate_node.backward(output_grad) return output_grad diff --git a/paddlenlp/transformers/fp8_utils.py b/paddlenlp/transformers/fp8_utils.py index 506bcca3b75a..e4f35d31f626 100644 --- a/paddlenlp/transformers/fp8_utils.py +++ b/paddlenlp/transformers/fp8_utils.py @@ -508,19 +508,6 @@ def fp8_mlp_bwd(do3, x, w1, w2, apply_backward_hook=False): return dx, dw1, dw2 - @staticmethod - def fp8_mlp_bwd_norm_rc(do3, x, norm_w, norm_eps, w1, w2): - # ===== recompute norm_output ===== - norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps) - - # ===== compute fp8_mlp_fwd ===== - d_norm_output = FP8LinearFunctionBase.fp8_mlp_bwd(do3, norm_output, w1, w2, True) - - if hasattr(norm_w, "_apply_backward_hook"): - norm_w._apply_backward_hook() - - return d_norm_output, norm_output, invar - class FP8LinearFunction(paddle.autograd.PyLayer): @staticmethod