From 9084462061f466d1546f735b82c9409301fe1fba Mon Sep 17 00:00:00 2001 From: FlyHorse24 Date: Wed, 14 Aug 2024 18:01:59 +0800 Subject: [PATCH] pack&isp --- .../meta_llama/Meta_Llama_3_1_8B/train.py | 8 --- .../Meta_Llama_3_1_8B/modeling_llama.py | 61 +++++++++++-------- 2 files changed, 37 insertions(+), 32 deletions(-) diff --git a/examples/meta_llama/Meta_Llama_3_1_8B/train.py b/examples/meta_llama/Meta_Llama_3_1_8B/train.py index b86ee4e..600e6e9 100644 --- a/examples/meta_llama/Meta_Llama_3_1_8B/train.py +++ b/examples/meta_llama/Meta_Llama_3_1_8B/train.py @@ -26,14 +26,6 @@ def main(args): # register huggingface model and config for InternEvo model_initializer.register_module(gpc.config.model_type, LlamaForCausalLM) hf_config_initializer.register_module(gpc.config.model_type, LlamaConfig) - # if gpc.config.model_type == "HF": - # hf_config_builder = hf_config_initializer.get_module(module_name=gpc.config.model_type) - # hf_cfg = hf_config_builder(return_dict=False) - # gpc.config.model.num_layers = hf_cfg.num_hidden_layers - # gpc.config.model.hidden_size = hf_cfg.hidden_size - # gpc.config.model.num_attention_heads = hf_cfg.num_attention_heads - # gpc.config.model.mlp_ratio = hf_cfg.intermediate_size / hf_cfg.hidden_size - # gpc.config.model.vocab_size = hf_cfg.vocab_size # initialize model model = initialize_model() diff --git a/huggingface_model/meta_llama/Meta_Llama_3_1_8B/modeling_llama.py b/huggingface_model/meta_llama/Meta_Llama_3_1_8B/modeling_llama.py index 6d7f57e..7778f58 100644 --- a/huggingface_model/meta_llama/Meta_Llama_3_1_8B/modeling_llama.py +++ b/huggingface_model/meta_llama/Meta_Llama_3_1_8B/modeling_llama.py @@ -23,9 +23,20 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import ( + IS_REPLICA_ZERO_PARALLEL, + IS_WEIGHT_ZERO_PARALLEL, +) +from internlm.model.modules.embedding import Embedding1D +from internlm.model.modules.linear import new_linear +from internlm.model.ops.attention import ( + hf_q_k_v_with_cu_seqlens, + hf_q_k_v_without_cu_seqlens, +) from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter @@ -46,13 +57,8 @@ logging, replace_return_docstrings, ) + from .configuration_llama import LlamaConfig -from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc -from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.linear import new_linear -from internlm.model.ops.attention import hf_q_k_v_with_cu_seqlens, hf_q_k_v_without_cu_seqlens -from internlm.core.context.parallel_context import IS_REPLICA_ZERO_PARALLEL, IS_WEIGHT_ZERO_PARALLEL logger = logging.get_logger(__name__) @@ -105,9 +111,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype) return causal_mask @@ -303,9 +307,7 @@ def forward(self, x): up_proj_slices = self.up_proj.weight.split(slice, dim=0) down_proj_slices = self.down_proj.weight.split(slice, dim=1) - gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 - ) + gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) @@ -362,8 +364,12 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): ) self.q_proj = new_linear("wq", self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = new_linear("wk", self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = new_linear("wv", self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = new_linear( + "wk", self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = new_linear( + "wv", self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) self.o_proj = new_linear("wo", self.hidden_size, self.hidden_size, bias=config.attention_bias) # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) @@ -500,7 +506,7 @@ def forward( bsz, q_len, _ = hidden_states.size() use_packed_dataset = gpc.config.data.get("use_packed_dataset", False) - + if use_packed_dataset: assert bsz == 1, "hidden_states should be packed into bsz=1 when use_packed_dataset=True" cu_seqlens = gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] @@ -516,7 +522,7 @@ def forward( query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - #import pdb;pdb.set_trace() + # import pdb;pdb.set_trace() if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -528,8 +534,8 @@ def forward( else: cos, sin = position_embeddings - cos_temp = cos.transpose(0,1).unsqueeze(0) - sin_temp = sin.transpose(0,1).unsqueeze(0) + cos_temp = cos.transpose(0, 1).unsqueeze(0) + sin_temp = sin.transpose(0, 1).unsqueeze(0) query_states = (query_states * cos_temp) + (rotate_half(query_states) * sin_temp) key_states = (key_states * cos_temp) + (rotate_half(key_states) * sin_temp) @@ -571,7 +577,7 @@ def forward( query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - + if use_packed_dataset: attn_output = hf_q_k_v_with_cu_seqlens( query_states, @@ -583,9 +589,14 @@ def forward( ) else: attn_output = hf_q_k_v_without_cu_seqlens( - query_states, key_states, value_states, dropout_p=dropout_rate, softmax_scale=None, causal=True, + query_states, + key_states, + value_states, + dropout_p=dropout_rate, + softmax_scale=None, + causal=True, ) - + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -641,7 +652,9 @@ def forward( query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - import pdb;pdb.set_trace() + import pdb + + pdb.set_trace() if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -652,7 +665,7 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings - + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: