Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 1 addition & 29 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,19 +798,6 @@ def forward(self, x):
class FusedNormGateFunc(paddle.autograd.PyLayer):
"""recompute of postnorm and gate"""

_current_norm_output = None
_current_invar = None

@classmethod
def set_temporary_vars(cls, norm_output, invar):
FusedNormGateFunc._current_norm_output = norm_output
FusedNormGateFunc._current_invar = invar

@classmethod
def clear_temporary_vars(cls):
FusedNormGateFunc._current_norm_output = None
FusedNormGateFunc._current_invar = None

@staticmethod
def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps):
ctx.dtype = paddle.float32
Expand All @@ -825,10 +812,7 @@ def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps):
def backward(ctx, d_gate_logits, d_norm_output):
x, rms_norm_weight, moe_gate_weight, eps = ctx.saved_tensor()
# recompute rmsnorm
norm_output = FusedNormGateFunc._current_norm_output
invar = FusedNormGateFunc._current_invar
if norm_output is None or invar is None:
norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
d_norm_output_linear, d_moe_gate_weight = paddle._C_ops.matmul_grad(
cast_if_needed(norm_output, ctx.dtype),
cast_if_needed(moe_gate_weight, ctx.dtype),
Expand All @@ -846,18 +830,6 @@ def backward(ctx, d_gate_logits, d_norm_output):
return dx, d_rms_norm_weight, d_moe_gate_weight


class TemporaryVarContext:
def __init__(self, norm_output, invar):
self.norm_output = norm_output
self.invar = invar

def __enter__(self):
FusedNormGateFunc.set_temporary_vars(self.norm_output, self.invar)

def __exit__(self, exc_type, exc_val, exc_tb):
FusedNormGateFunc.clear_temporary_vars()


def balance_expert_assignment(n, m, k):
assert k * n % m == 0
matrix = paddle.zeros((n, m), dtype=paddle.int32)
Expand Down
64 changes: 10 additions & 54 deletions paddlenlp/transformers/deepseek_v2/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
DeepseekV2PretrainedModel,
DeepseekV2PretrainingCriterion,
DeepseekV2RMSNorm,
TemporaryVarContext,
set_global_step,
)

Expand Down Expand Up @@ -274,19 +273,9 @@ def backward(self, output_grad):
hidden_states_grad = do3
inputs_embeds_mtp_grad = None

if self.using_post_norm_recompute:
dx, norm_out, invar = FP8LinearFunctionBase.fp8_mlp_bwd_norm_rc(
hidden_states_grad,
self.x,
self.shared_experts.norm_weight,
self.shared_experts.norm_eps,
self.shared_experts.w1,
self.shared_experts.w2,
)
else:
dx = FP8LinearFunctionBase.fp8_mlp_bwd(
hidden_states_grad, self.x, self.shared_experts.w1, self.shared_experts.w2, True
)
dx = FP8LinearFunctionBase.fp8_mlp_bwd(
hidden_states_grad, self.x, self.shared_experts.w1, self.shared_experts.w2, True
)

self.x = None

Expand All @@ -296,17 +285,9 @@ def backward(self, output_grad):

if self.using_post_norm_recompute:
if self.send_mtp_embed:
return (
inputs_embeds_mtp_grad,
dx,
residual_grad,
l_aux_grad,
final_hidden_states_grad,
norm_out,
invar,
)
return (inputs_embeds_mtp_grad, dx, residual_grad, l_aux_grad, final_hidden_states_grad)
else:
return (dx, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar)
return (dx, residual_grad, l_aux_grad, final_hidden_states_grad)
else:
if self.send_mtp_embed:
return (inputs_embeds_mtp_grad, dx, residual_grad, l_aux_grad, final_hidden_states_grad)
Expand Down Expand Up @@ -779,11 +760,9 @@ def post_process_backward(self, output_grad, event_to_wait=None):
residual_grad,
l_aux_grad,
final_hidden_states_grad,
norm_out,
invar,
) = grad
else:
hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar = grad
hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad
else:
if self.send_mtp_embed:
inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad
Expand All @@ -796,7 +775,6 @@ def post_process_backward(self, output_grad, event_to_wait=None):

ret = (hidden_states_grad, residual_grad, l_aux_grad, output_combine_grad, quant_event)
ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret
ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret
return ret

def combine_backward(self, output_grad, previous_event=None, async_finish=False, allocate_on_comm_stream=False):
Expand All @@ -809,19 +787,9 @@ def combine_backward(self, output_grad, previous_event=None, async_finish=False,
l_aux_grad,
output_combine_grad,
quant_event,
norm_out,
invar,
) = output_grad
else:
(
hidden_states_grad,
residual_grad,
l_aux_grad,
output_combine_grad,
quant_event,
norm_out,
invar,
) = output_grad
(hidden_states_grad, residual_grad, l_aux_grad, output_combine_grad, quant_event) = output_grad
else:
if self.send_mtp_embed:
(
Expand Down Expand Up @@ -854,7 +822,6 @@ def combine_backward(self, output_grad, previous_event=None, async_finish=False,

ret = (hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad)
ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret
ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret
return ret

def mlp_backward(self, output_grad):
Expand All @@ -866,11 +833,9 @@ def mlp_backward(self, output_grad):
residual_grad,
l_aux_grad,
hidden_states_out_grad,
norm_out,
invar,
) = output_grad
else:
hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad, norm_out, invar = output_grad
hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad = output_grad
else:
if self.send_mtp_embed:
(
Expand All @@ -886,7 +851,6 @@ def mlp_backward(self, output_grad):

ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad)
ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret
ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret
return ret

def dispatch_backward(self, output_grad, async_finish=False, previous_event=None, allocate_on_comm_stream=False):
Expand All @@ -899,8 +863,6 @@ def dispatch_backward(self, output_grad, async_finish=False, previous_event=None
l_aux_grad,
hs_dispatched_grad,
dispatched_probs_grad,
norm_out,
invar,
) = output_grad
else:
(
Expand All @@ -909,8 +871,6 @@ def dispatch_backward(self, output_grad, async_finish=False, previous_event=None
l_aux_grad,
hs_dispatched_grad,
dispatched_probs_grad,
norm_out,
invar,
) = output_grad
else:
if self.send_mtp_embed:
Expand All @@ -935,7 +895,6 @@ def dispatch_backward(self, output_grad, async_finish=False, previous_event=None

ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad)
ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret
ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret
return ret

def attn_backward(self, output_grad):
Expand All @@ -948,14 +907,12 @@ def attn_backward(self, output_grad):
l_aux_grad,
hs_grad,
token_probs_grad,
norm_out,
invar,
) = output_grad
inputs_embeds_mtp_grad_shape = hidden_states_grad.shape
inputs_embeds_mtp_grad_shape[-1] = -1
inputs_embeds_mtp_grad = inputs_embeds_mtp_grad.view(inputs_embeds_mtp_grad_shape)
else:
hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad, norm_out, invar = output_grad
hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad = output_grad
else:
if self.send_mtp_embed:
(
Expand Down Expand Up @@ -986,8 +943,7 @@ def attn_backward(self, output_grad):
output_grad = (inputs_embeds_mtp_grad, *output_grad) if self.send_mtp_embed else output_grad

if self.using_post_norm_recompute:
with TemporaryVarContext(norm_out, invar):
output_grad = self.attn_and_gate_node.backward(output_grad)
output_grad = self.attn_and_gate_node.backward(output_grad)
else:
output_grad = self.attn_and_gate_node.backward(output_grad)
return output_grad
Expand Down
13 changes: 0 additions & 13 deletions paddlenlp/transformers/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,19 +508,6 @@ def fp8_mlp_bwd(do3, x, w1, w2, apply_backward_hook=False):

return dx, dw1, dw2

@staticmethod
def fp8_mlp_bwd_norm_rc(do3, x, norm_w, norm_eps, w1, w2):
# ===== recompute norm_output =====
norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps)

# ===== compute fp8_mlp_fwd =====
d_norm_output = FP8LinearFunctionBase.fp8_mlp_bwd(do3, norm_output, w1, w2, True)

if hasattr(norm_w, "_apply_backward_hook"):
norm_w._apply_backward_hook()

return d_norm_output, norm_output, invar


class FP8LinearFunction(paddle.autograd.PyLayer):
@staticmethod
Expand Down
Loading