From 1852f7cc4be7f3d4c80a341b59c01fe266229bad Mon Sep 17 00:00:00 2001 From: Liang Shuhao Date: Wed, 17 Sep 2025 06:40:48 +0000 Subject: [PATCH] Fix FusedRMSLinear backward compute --- paddlenlp/transformers/deepseek_v2/modeling.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 33fcd4520411..ea5c9673b74a 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -1922,9 +1922,10 @@ def backward(ctx, d_q, d_kv): quant_method="1x128", input_transpose=True, ) - FP8LinearFunctionBase.compute_fp8_linear( - (d_q_fp8, d_q_scale), q_down_weight, weight_transpose=False, out=h_grad.view([-1, h_grad.shape[-1]]) + h_grad_0 = FP8LinearFunctionBase.compute_fp8_linear( + (d_q_fp8, d_q_scale), q_down_weight, weight_transpose=False ) + h_grad = h_grad + h_grad_0 def q_down_weight_grad(h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight): FP8LinearFunctionBase.kitchen_gemm(