From 87ef82e8f076687940e042b9f5184590e3821a5e Mon Sep 17 00:00:00 2001 From: FlyHorse24 Date: Fri, 9 Aug 2024 19:14:22 +0800 Subject: [PATCH] packAndisp --- .../meta_llama/Meta_Llama_3_1_8B/config.py | 4 +-- .../Meta_Llama_3_1_8B/modeling_llama.py | 25 +++++++++---------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/examples/meta_llama/Meta_Llama_3_1_8B/config.py b/examples/meta_llama/Meta_Llama_3_1_8B/config.py index b473066..9a7ee37 100644 --- a/examples/meta_llama/Meta_Llama_3_1_8B/config.py +++ b/examples/meta_llama/Meta_Llama_3_1_8B/config.py @@ -185,9 +185,9 @@ """ parallel = dict( zero1=dict(size=-1), - tensor=dict(size=4, mode="isp"), + tensor=dict(size=2, mode="isp"), pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=4, overlap=False, memory_pool=True), + weight=dict(size=2, overlap=False, memory_pool=True), ) cudnn_deterministic = False diff --git a/huggingface_model/meta_llama/Meta_Llama_3_1_8B/modeling_llama.py b/huggingface_model/meta_llama/Meta_Llama_3_1_8B/modeling_llama.py index c6b9e04..6d7f57e 100644 --- a/huggingface_model/meta_llama/Meta_Llama_3_1_8B/modeling_llama.py +++ b/huggingface_model/meta_llama/Meta_Llama_3_1_8B/modeling_llama.py @@ -29,7 +29,6 @@ from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter -#from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -50,10 +49,6 @@ from .configuration_llama import LlamaConfig from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -# try: -# from flash_attn import flash_attn_func, flash_attn_varlen_func -# except: -# pass from internlm.model.modules.embedding import Embedding1D from internlm.model.modules.linear import new_linear from internlm.model.ops.attention import hf_q_k_v_with_cu_seqlens, hf_q_k_v_without_cu_seqlens @@ -214,7 +209,7 @@ def forward(self, x, position_ids): # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids.unsqueeze(0)[:, None, :].float() + position_ids_expanded = position_ids.unsqueeze(1)[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" @@ -510,7 +505,6 @@ def forward( assert bsz == 1, "hidden_states should be packed into bsz=1 when use_packed_dataset=True" cu_seqlens = gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] max_seqlen = gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] - #position_ids = position_ids.unsqueeze(0) query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -522,7 +516,7 @@ def forward( query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - + #import pdb;pdb.set_trace() if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -533,7 +527,11 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + cos_temp = cos.transpose(0,1).unsqueeze(0) + sin_temp = sin.transpose(0,1).unsqueeze(0) + query_states = (query_states * cos_temp) + (rotate_half(query_states) * sin_temp) + key_states = (key_states * cos_temp) + (rotate_half(key_states) * sin_temp) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -573,7 +571,7 @@ def forward( query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - + if use_packed_dataset: attn_output = hf_q_k_v_with_cu_seqlens( query_states, @@ -582,12 +580,12 @@ def forward( cumulative_len=cu_seqlens, max_seqlen=max_seqlen, dropout_p=dropout_rate, - ).unsqueeze(0) + ) else: attn_output = hf_q_k_v_without_cu_seqlens( query_states, key_states, value_states, dropout_p=dropout_rate, softmax_scale=None, causal=True, ) - + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -643,7 +641,7 @@ def forward( query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - + import pdb;pdb.set_trace() if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -654,6 +652,7 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: