Skip to content

Commit 1b10b57

Browse files
authored
Fix FusedRMSLinear backward compute (#11095)
1 parent da47d99 commit 1b10b57

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1996,9 +1996,10 @@ def backward(ctx, d_q, d_kv):
19961996
quant_method="1x128",
19971997
input_transpose=True,
19981998
)
1999-
FP8LinearFunctionBase.compute_fp8_linear(
2000-
(d_q_fp8, d_q_scale), q_down_weight, weight_transpose=False, out=h_grad.view([-1, h_grad.shape[-1]])
1999+
h_grad_0 = FP8LinearFunctionBase.compute_fp8_linear(
2000+
(d_q_fp8, d_q_scale), q_down_weight, weight_transpose=False
20012001
)
2002+
h_grad = h_grad + h_grad_0
20022003

20032004
def q_down_weight_grad(h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight):
20042005
FP8LinearFunctionBase.kitchen_gemm(

0 commit comments

Comments
 (0)