diff --git a/paddlenlp/transformers/deepseek_v2/configuration.py b/paddlenlp/transformers/deepseek_v2/configuration.py index d21afc20780f..5dba2990df38 100644 --- a/paddlenlp/transformers/deepseek_v2/configuration.py +++ b/paddlenlp/transformers/deepseek_v2/configuration.py @@ -179,6 +179,7 @@ def __init__( attention_dropout=0.0, speculate_model_type=False, using_flex_token=False, + decoderlayer_act_offload_settings={}, **kwargs, ): self.vocab_size = vocab_size @@ -227,7 +228,7 @@ def __init__( self.speculate_model_type = speculate_model_type self.use_fp8 = False self.using_flex_token = using_flex_token - + self.decoderlayer_act_offload_settings = decoderlayer_act_offload_settings super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index fde4fd576200..9613cdaab919 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -1660,6 +1660,33 @@ def recompute_training_full( use_cache: bool, attn_mask_startend_row_indices: Optional[Tensor] = None, ): + decoderlayer_act_offload_settings = self.config.get( + "decoderlayer_act_offload_settings", {"type": "", "value": ""} + ) + + setting_type = decoderlayer_act_offload_settings["type"] + offload_value = decoderlayer_act_offload_settings["value"] + + def get_offload_kwargs(layer_idx, setting_type, offload_value): + offload_kwargs = {} + if "mod" == setting_type: + assert isinstance(offload_value, (list, tuple)) + v1, v2 = offload_value + offload_kwargs["offload_indices"] = [0] if layer_idx % v1 == v2 else [] + else: + raise ValueError( + f"decoderlayer_act_offload_settings only support type == 'mod' ,but get type {setting_type}" + ) + return offload_kwargs + + layer_idx = layer_module.layer_idx + # NOTE: the first layer inputs will be used in mtp, so do not offload it + if layer_idx == 0: + offload_kwargs = {} + else: + offload_kwargs = get_offload_kwargs(layer_idx, setting_type, offload_value) + print("recompute offload ", offload_kwargs) + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) @@ -1676,6 +1703,7 @@ def custom_forward(*inputs): use_cache, attn_mask_startend_row_indices, use_reentrant=self.config.recompute_use_reentrant, + **offload_kwargs, ) return hidden_states diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 4361023f6bbc..284f5c3e18d1 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -93,6 +93,27 @@ def get_attr(layer, name): return get_attr(layer._layer, name) +def get_offload_kwargs(layer_idx, decoderlayer_act_offload_settings): + setting_type = decoderlayer_act_offload_settings["type"] + offload_value = decoderlayer_act_offload_settings["value"] + + # NOTE: the first layer inputs will be used in mtp, so do not offload it + if layer_idx == 0: + offload_kwargs = {} + else: + offload_kwargs = {} + if "mod" == setting_type: + assert isinstance(offload_value, (list, tuple)) + v1, v2 = offload_value + offload_kwargs["offload_indices"] = [0] if layer_idx % v1 == v2 else [] + elif setting_type is not None and setting_type != "": + raise ValueError( + f"decoderlayer_act_offload_settings only support type == 'mod' ,but get type {setting_type}" + ) + print("offload_kwargs ", offload_kwargs) + return offload_kwargs + + class DeepseekV2EmbeddingPipe(nn.Layer): def __init__(self, config: DeepseekV2Config): super(DeepseekV2EmbeddingPipe, self).__init__() @@ -132,12 +153,14 @@ def forward(self, args): attention_mask = attention_mask[ :, :, : -self.config.num_nextn_predict_layers, : -self.config.num_nextn_predict_layers ] - + # attn_mask_startend_row_indices: [b, num_head, seq_len] or [b, num_head, seq_len, C], C is 2 or 4 if attn_mask_startend_row_indices is not None: if attn_mask_startend_row_indices.ndim == 3: attn_mask_startend_row_indices = attn_mask_startend_row_indices[ - :, :, : -self.config.num_nextn_predict_layers, + :, + :, + : -self.config.num_nextn_predict_layers, ] elif attn_mask_startend_row_indices.ndim == 4: attn_mask_startend_row_indices = attn_mask_startend_row_indices[ @@ -222,6 +245,10 @@ def forward(self, args): attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + decoderlayer_act_offload_settings = self.config.get( + "decoderlayer_act_offload_settings", {"type": "", "value": ""} + ) + offload_kwargs = get_offload_kwargs(self.layer_idx, decoderlayer_act_offload_settings) if attention_mask is not None or attn_mask_startend_row_indices is not None: hidden_states = recompute( super().forward, @@ -230,6 +257,7 @@ def forward(self, args): attention_mask=attention_mask, attn_mask_startend_row_indices=attn_mask_startend_row_indices, use_reentrant=False, + **offload_kwargs, ) else: # for pretrain @@ -239,6 +267,7 @@ def forward(self, args): position_ids=position_ids, attn_mask_startend_row_indices=attn_mask_startend_row_indices, use_reentrant=self.config.recompute_use_reentrant, + **offload_kwargs, ) else: hidden_states = super().forward( @@ -279,6 +308,10 @@ def forward(self, args): for depth in range(self.config.num_nextn_predict_layers): inputs_embeds_cur_depth = inputs_embeds_cur_depth_list[depth] if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + decoderlayer_act_offload_settings = self.config.get( + "decoderlayer_act_offload_settings", {"type": "", "value": ""} + ) + offload_kwargs = get_offload_kwargs(self.layer_idx, decoderlayer_act_offload_settings) if attention_mask is not None or attn_mask_startend_row_indices is not None: hidden_states = recompute( super().forward, @@ -288,6 +321,7 @@ def forward(self, args): attention_mask=attention_mask, attn_mask_startend_row_indices=attn_mask_startend_row_indices, use_reentrant=False, + **offload_kwargs, ) else: # for pretrain @@ -298,6 +332,7 @@ def forward(self, args): position_ids=position_ids, attn_mask_startend_row_indices=attn_mask_startend_row_indices, use_reentrant=self.config.recompute_use_reentrant, + **offload_kwargs, ) else: hidden_states = super().forward(