diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 33fcd4520411..ea5c9673b74a 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -1922,9 +1922,10 @@ def backward(ctx, d_q, d_kv): quant_method="1x128", input_transpose=True, ) - FP8LinearFunctionBase.compute_fp8_linear( - (d_q_fp8, d_q_scale), q_down_weight, weight_transpose=False, out=h_grad.view([-1, h_grad.shape[-1]]) + h_grad_0 = FP8LinearFunctionBase.compute_fp8_linear( + (d_q_fp8, d_q_scale), q_down_weight, weight_transpose=False ) + h_grad = h_grad + h_grad_0 def q_down_weight_grad(h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight): FP8LinearFunctionBase.kitchen_gemm(